diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000000..cce74a2d901d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,26 @@ +# The default behavior, which overrides 'core.autocrlf', is to use Git's +# built-in heuristics to determine whether a particular file is text or binary. +# Text files are automatically normalized to the user's platforms. +* text=auto + +# Explicitly declare text files that should always be normalized and converted +# to native line endings. +.gitattributes text +.gitignore text +LICENSE text +*.avsc text +*.html text +*.java text +*.md text +*.properties text +*.proto text +*.py text +*.sh text +*.xml text +*.yml text + +# Declare files that will always have CRLF line endings on checkout. +# *.sln text eol=crlf + +# Explicitly denote all files that are truly binary and should not be modified. +# *.jpg binary diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000000..0ba351f6d641 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +target/ + +# Ignore IntelliJ files. +.idea/ +*.iml +*.ipr +*.iws + +# Ignore Eclipse files. +.classpath +.project +.settings/ + +# The build process generates the dependency-reduced POM, but it shouldn't be +# committed. +dependency-reduced-pom.xml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000000..52e1d3a5cbd2 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,35 @@ +language: java + +sudo: false + +notifications: + email: + recipients: + - dataflow-sdk-build-notifications+travis@google.com + on_success: change + on_failure: always + +matrix: + include: + # On OSX, run with default JDK only. + - os: osx + env: MAVEN_OVERRIDE="" + # On Linux, run with specific JDKs only. + - os: linux + env: CUSTOM_JDK="oraclejdk8" MAVEN_OVERRIDE="-DforkCount=0" + - os: linux + env: CUSTOM_JDK="oraclejdk7" MAVEN_OVERRIDE="-DforkCount=0" + - os: linux + env: CUSTOM_JDK="openjdk7" MAVEN_OVERRIDE="-DforkCount=0" + +before_install: + - if [ "$TRAVIS_OS_NAME" == "osx" ]; then export JAVA_HOME=$(/usr/libexec/java_home); fi + - if [ "$TRAVIS_OS_NAME" == "linux" ]; then jdk_switcher use "$CUSTOM_JDK"; fi + +install: + - travis_retry mvn install clean -U -DskipTests=true + +script: + - travis_retry mvn versions:set -DnewVersion=manual_build + - travis_retry mvn $MAVEN_OVERRIDE install -U + - travis_retry travis/test_wordcount.sh diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index f3533818d7aa..db4a13fc8925 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,71 @@ # Apache Beam -[Apache Beam](http://beam.incubator.apache.org) provides a simple, powerful -programming model for building both batch and streaming parallel data processing -pipelines. It also covers the data integration processus. +[Apache Beam](http://beam.incubator.apache.org) is a unified model for defining both batch and streaming data-parallel processing pipelines, as well as a set of language-specific SDKs for constructing pipelines and Runners for executing them on distributed processing backends like [Apache Spark](http://spark.apache.org/), [Apache Flink](http://flink.apache.org), and [Google Cloud Dataflow](http://cloud.google.com/dataflow). -[General usage](http://beam.incubator.apache.org/documentation/getting-started) is -a good starting point for Apache Beam. -You can take a look on the [Beam Examples](http://git-wip-us.apache.org/repos/asf/incubator-beam/examples). +## Status + +_**The Apache Beam project is in the process of bootstrapping. This includes the creation of project resources, the refactoring of the initial code submissions, and the formulation of project documentation, planning, and design documents. Please expect a significant amount of churn and breaking changes in the near future.**_ + +[Build Status](http://builds.apache.org/job/beam-master) -## Status [Build Status](http://builds.apache.org/job/beam-master) ## Overview -The key concepts in this programming model are: +Beam provides a general approach to expressing [embarrassingly parallel](https://en.wikipedia.org/wiki/Embarrassingly_parallel) data processing pipelines and supports three categories of users, each of which have relatively disparate backgrounds and needs. + +1. _End Users_: Writing pipelines with an existing SDK, running it on an existing runner. These users want to focus on writing their application logic and have everything else just work. +2. _SDK Writers_: Developing a Beam SDK targeted at a specific user community (Java, Python, Scala, Go, R, graphical, etc). These users are language geeks, and would prefer to be shielded from all the details of various runners and their implementations. +3. _Runner Writers_: Have an execution environment for distributed processing and would like to support programs written against the Beam Model. Would prefer to be shielded from details of multiple SDKs. + + +### The Beam Model + +The model behind Beam evolved from a number of internal Google data processing projects, including [MapReduce](http://research.google.com/archive/mapreduce.html), [FlumeJava](http://research.google.com/pubs/pub35650.html), and [Millwheel](http://research.google.com/pubs/pub41378.html). This model was originally known as the “[Dataflow Model](http://www.vldb.org/pvldb/vol8/p1792-Akidau.pdf)”. + +To learn more about the Beam Model (though still under the original name of Dataflow), see the World Beyond Batch: [Streaming 101](https://wiki.apache.org/incubator/BeamProposal) and [Streaming 102](https://www.oreilly.com/ideas/the-world-beyond-batch-streaming-101) posts on O’Reilly’s Radar site, and the [VLDB 2015 paper](http://www.vldb.org/pvldb/vol8/p1792-Akidau.pdf). + +The key concepts in the Beam programming model are: * `PCollection`: represents a collection of data, which could be bounded or unbounded in size. * `PTransform`: represents a computation that transforms input PCollections into output PCollections. * `Pipeline`: manages a directed acyclic graph of PTransforms and PCollections that is ready for execution. * `PipelineRunner`: specifies where and how the pipeline should execute. -We provide the following PipelineRunners: - 1. The `DirectPipelineRunner` runs the pipeline on your local machine. - 2. The `BlockingDataflowPipelineRunner` submits the pipeline to the Dataflow Service via the `DataflowPipelineRunner` -and then prints messages about the job status until the execution is complete. - 3. The `SparkPipelineRunner` runs the pipeline on an Apache Spark cluster. - 4. The `FlinkPipelineRunner` runs the pipeline on an Apache Flink cluster. +### SDKs -## Getting Started +Beam supports multiple language specific SDKs for writing pipelines against the Beam Model. -The following command will build both the `sdk` and `example` modules and -install them in your local Maven repository: +Currently, this repository contains the Beam Java SDK, which is in the process of evolving from the [Dataflow Java SDK](https://github.com/GoogleCloudPlatform/DataflowJavaSDK). The [Dataflow Python SDK](https://github.com/GoogleCloudPlatform/DataflowPythonSDK) will also become part of Beam in the near future. - mvn clean install +Have ideas for new SDKs or DSLs? See the [Jira](https://issues.apache.org/jira/browse/BEAM/component/12328909/). -You can speed up the build and install process by using the following options: - 1. To skip execution of the unit tests, run: +### Runners - mvn install -DskipTests +Beam supports executing programs on multiple distributed processing backends. After the Beam project's initial bootstrapping completes, it will include: + 1. The `DirectPipelineRunner` runs the pipeline on your local machine. + 2. The `DataflowPipelineRunner` submits the pipeline to the [Google Cloud Dataflow](http://cloud.google.com/dataflow/). + 3. The `SparkPipelineRunner` runs the pipeline on an Apache Spark cluster. See the code that will be donated at [cloudera/spark-dataflow](https://github.com/cloudera/spark-dataflow). + 4. The `FlinkPipelineRunner` runs the pipeline on an Apache Flink cluster. See the code that will be donated at [dataArtisans/flink-dataflow](https://github.com/dataArtisans/flink-dataflow). + +Have ideas for new Runners? See the [Jira](https://issues.apache.org/jira/browse/BEAM/component/12328916/). - 2. While iterating on a specific module, use the following command to compile - and reinstall it. For example, to reinstall the `examples` module, run: - mvn install -pl examples +## Getting Started - Be careful, however, as this command will use the most recently installed SDK - from the local repository (or Maven Central) even if you have changed it - locally. +_Coming soon!_ -After building and installing, you can execute the `WordCount` and other -example pipelines by following the instructions in this [README](https://git-wip-us.apache.org/repos/asf/incubator-beam/examples/README.md). ## Contact Us -You can subscribe on the mailing lists to discuss and get involved in Apache Beam: +To get involved in Apache Beam: -* [Subscribe](mailto:user-subscribe@beam.incubator.apache.org) on the [user@beam.incubator.apache.org](mailto:user@beam.incubator.apache.org) -* [Subscribe](mailto:dev-subscribe@beam.incubator.apache.org) on the [dev@beam.incubator.apache.org](mailto:dev@beam.incubator.apache.org) +* [Subscribe](mailto:user-subscribe@beam.incubator.apache.org) or [mail](mailto:user@beam.incubator.apache.org) the [user@beam.incubator.apache.org](http://mail-archives.apache.org/mod_mbox/incubator-beam-user/) list. +* [Subscribe](mailto:dev-subscribe@beam.incubator.apache.org) or [mail](mailto:dev@beam.incubator.apache.org) the [dev@beam.incubator.apache.org](http://mail-archives.apache.org/mod_mbox/incubator-beam-dev/) list. +* Report issues on [Jira](https://issues.apache.org/jira/browse/BEAM). -You can report issue on [Jira](https://issues.apache.org/jira/browse/BEAM). ## More Information diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 000000000000..f38dd74b3475 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,413 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/contrib/README.md b/contrib/README.md new file mode 100644 index 000000000000..b99cf46f58bf --- /dev/null +++ b/contrib/README.md @@ -0,0 +1,53 @@ +# Community contributions + +This directory hosts a wide variety of community contributions that may be +useful to other users of +[Google Cloud Dataflow](https://cloud.google.com/dataflow/), +but may not be appropriate or ready yet for inclusion into the +[mainline SDK](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/) or a +separate Google-maintained artifact. + +## Organization + +Each subdirectory represents a logically separate and independent module. +Preferably, the code is hosted directly in this repository. When appropriate, we +are also open to linking external repositories via +[`submodule`](http://git-scm.com/docs/git-submodule/) functionality within Git. + +While we are happy to host individual modules to provide additional value to all +Cloud Dataflow users, the modules are _maintained solely by their respective +authors_. We will make sure that modules are related to Cloud Dataflow, that +they are distributed under the same license as the mainline SDK, and provide +some guidance to the authors to make the quality as high as possible. + +We __cannot__, however, provide _any_ guarantees about correctness, +compatibility, performance, support, test coverage, maintenance or future +availability of individual modules hosted here. + +## Process + +In general, we recommend to get in touch with us through the issue tracker +first. That way we can help out and possibly guide you. Coordinating up front +makes it much easier to avoid frustration later on. + +We welcome pull requests with a new module from everyone. Every module must be +related to Cloud Dataflow and must have an informative README.md file. We will +provide general guidance, but usually won't be reviewing the module in detail. +We reserve the right to refuse acceptance to any module, or remove it at any +time in the future. + +We also welcome improvements to an existing module from everyone. We'll often +wait for comments from the primary author of the module before merging a pull +request from a non-primary author. + +As the module matures, we may choose to pull it directly into the mainline SDK +or promote it to a Google-managed artifact. + +## Licensing + +We require all contributors to sign the Contributor License Agreement, exactly +as we require for any contributions to the mainline SDK. More information is +available in our [CONTRIBUTING.md](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/CONTRIBUTING.md) +file. + +_Thank you for your contribution to the Cloud Dataflow community!_ diff --git a/contrib/hadoop/AUTHORS.md b/contrib/hadoop/AUTHORS.md new file mode 100644 index 000000000000..6effdb917d19 --- /dev/null +++ b/contrib/hadoop/AUTHORS.md @@ -0,0 +1,7 @@ +# Authors of 'hadoop' module + +The following is the official list of authors for copyright purposes of this community-contributed module. + + Cloudera + Tom White, tom [at] cloudera [dot] com + Google Inc. \ No newline at end of file diff --git a/contrib/hadoop/README.md b/contrib/hadoop/README.md new file mode 100644 index 000000000000..49bbf980e80a --- /dev/null +++ b/contrib/hadoop/README.md @@ -0,0 +1,24 @@ +# Hadoop module + +This library provides Dataflow sources and sinks to make it possible to read and +write Apache Hadoop file formats from Dataflow pipelines. + +Currently, only the read path is implemented. A `HadoopFileSource` allows any +Hadoop `FileInputFormat` to be read as a `PCollection`. + +A `HadoopFileSource` can be read from using the +`com.google.cloud.dataflow.sdk.io.Read` transform. For example: + +```java +HadoopFileSource source = HadoopFileSource.from(path, MyInputFormat.class, + MyKey.class, MyValue.class); +PCollection> records = Read.from(mySource); +``` + +Alternatively, the `readFrom` method is a convenience method that returns a read +transform. For example: + +```java +PCollection> records = HadoopFileSource.readFrom(path, + MyInputFormat.class, MyKey.class, MyValue.class); +``` diff --git a/contrib/hadoop/pom.xml b/contrib/hadoop/pom.xml new file mode 100644 index 000000000000..8e5a207d1215 --- /dev/null +++ b/contrib/hadoop/pom.xml @@ -0,0 +1,169 @@ + + + + 4.0.0 + + com.google.cloud.dataflow + google-cloud-dataflow-java-contrib-hadoop + Google Cloud Dataflow Hadoop Library + Library to read and write Hadoop file formats from Dataflow. + 0.0.1-SNAPSHOT + jar + + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + UTF-8 + [1.2.0,2.0.0) + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.2 + + 1.7 + 1.7 + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.12 + + + com.puppycrawl.tools + checkstyle + 6.6 + + + + ../../checkstyle.xml + true + true + true + + + + + check + + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.4 + + + attach-sources + compile + + jar + + + + attach-test-sources + test-compile + + test-jar + + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + Google Cloud Dataflow Hadoop Contrib + Google Cloud Dataflow Hadoop Contrib + + com.google.cloud.dataflow.contrib.hadoop + false + ]]> + + + + https://cloud.google.com/dataflow/java-sdk/JavaDoc/ + ${basedir}/../../javadoc/dataflow-sdk-docs + + + http://docs.guava-libraries.googlecode.com/git-history/release18/javadoc/ + ${basedir}/../../javadoc/guava-docs + + + + + + + jar + + package + + + + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + ${google-cloud-dataflow-version} + + + + + org.apache.hadoop + hadoop-client + 2.7.0 + provided + + + + + org.hamcrest + hamcrest-all + 1.3 + test + + + + junit + junit + 4.11 + test + + + diff --git a/contrib/hadoop/src/main/java/com/google/cloud/dataflow/contrib/hadoop/HadoopFileSource.java b/contrib/hadoop/src/main/java/com/google/cloud/dataflow/contrib/hadoop/HadoopFileSource.java new file mode 100644 index 000000000000..f24c3b7bd823 --- /dev/null +++ b/contrib/hadoop/src/main/java/com/google/cloud/dataflow/contrib/hadoop/HadoopFileSource.java @@ -0,0 +1,485 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow Hadoop Library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.hadoop; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.TaskAttemptID; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; +import java.util.ListIterator; +import java.util.NoSuchElementException; +import javax.annotation.Nullable; + +/** + * A {@code BoundedSource} for reading files resident in a Hadoop filesystem using a + * Hadoop file-based input format. + * + *

To read a {@link com.google.cloud.dataflow.sdk.values.PCollection} of + * {@link com.google.cloud.dataflow.sdk.values.KV} key-value pairs from one or more + * Hadoop files, use {@link HadoopFileSource#from} to specify the path(s) of the files to + * read, the Hadoop {@link org.apache.hadoop.mapreduce.lib.input.FileInputFormat}, the + * key class and the value class. + * + *

A {@code HadoopFileSource} can be read from using the + * {@link com.google.cloud.dataflow.sdk.io.Read} transform. For example: + * + *

+ * {@code
+ * HadoopFileSource source = HadoopFileSource.from(path, MyInputFormat.class,
+ *   MyKey.class, MyValue.class);
+ * PCollection> records = Read.from(mySource);
+ * }
+ * 
+ * + *

The {@link HadoopFileSource#readFrom} method is a convenience method + * that returns a read transform. For example: + * + *

+ * {@code
+ * PCollection> records = HadoopFileSource.readFrom(path,
+ *   MyInputFormat.class, MyKey.class, MyValue.class);
+ * }
+ * 
+ * + * Implementation note: Since Hadoop's {@link org.apache.hadoop.mapreduce.lib.input.FileInputFormat} + * determines the input splits, this class extends {@link BoundedSource} rather than + * {@link com.google.cloud.dataflow.sdk.io.OffsetBasedSource}, since the latter + * dictates input splits. + + * @param The type of keys to be read from the source. + * @param The type of values to be read from the source. + */ +public class HadoopFileSource extends BoundedSource> { + private static final long serialVersionUID = 0L; + + private final String filepattern; + private final Class> formatClass; + private final Class keyClass; + private final Class valueClass; + private final SerializableSplit serializableSplit; + + /** + * Creates a {@code Read} transform that will read from an {@code HadoopFileSource} + * with the given file name or pattern ("glob") using the given Hadoop + * {@link org.apache.hadoop.mapreduce.lib.input.FileInputFormat}, + * with key-value types specified by the given key class and value class. + */ + public static > Read.Bounded> readFrom( + String filepattern, Class formatClass, Class keyClass, Class valueClass) { + return Read.from(from(filepattern, formatClass, keyClass, valueClass)); + } + + /** + * Creates a {@code HadoopFileSource} that reads from the given file name or pattern ("glob") + * using the given Hadoop {@link org.apache.hadoop.mapreduce.lib.input.FileInputFormat}, + * with key-value types specified by the given key class and value class. + */ + public static > HadoopFileSource from( + String filepattern, Class formatClass, Class keyClass, Class valueClass) { + @SuppressWarnings("unchecked") + HadoopFileSource source = (HadoopFileSource) + new HadoopFileSource(filepattern, formatClass, keyClass, valueClass); + return source; + } + + /** + * Create a {@code HadoopFileSource} based on a file or a file pattern specification. + */ + private HadoopFileSource(String filepattern, + Class> formatClass, Class keyClass, + Class valueClass) { + this(filepattern, formatClass, keyClass, valueClass, null); + } + + /** + * Create a {@code HadoopFileSource} based on a single Hadoop input split, which won't be + * split up further. + */ + private HadoopFileSource(String filepattern, + Class> formatClass, Class keyClass, + Class valueClass, SerializableSplit serializableSplit) { + this.filepattern = filepattern; + this.formatClass = formatClass; + this.keyClass = keyClass; + this.valueClass = valueClass; + this.serializableSplit = serializableSplit; + } + + public String getFilepattern() { + return filepattern; + } + + public Class> getFormatClass() { + return formatClass; + } + + public Class getKeyClass() { + return keyClass; + } + + public Class getValueClass() { + return valueClass; + } + + @Override + public void validate() { + Preconditions.checkNotNull(filepattern, + "need to set the filepattern of a HadoopFileSource"); + Preconditions.checkNotNull(formatClass, + "need to set the format class of a HadoopFileSource"); + Preconditions.checkNotNull(keyClass, + "need to set the key class of a HadoopFileSource"); + Preconditions.checkNotNull(valueClass, + "need to set the value class of a HadoopFileSource"); + } + + @Override + public List>> splitIntoBundles(long desiredBundleSizeBytes, + PipelineOptions options) throws Exception { + if (serializableSplit == null) { + return Lists.transform(computeSplits(desiredBundleSizeBytes), + new Function>>() { + @Nullable @Override + public BoundedSource> apply(@Nullable InputSplit inputSplit) { + return new HadoopFileSource(filepattern, formatClass, keyClass, + valueClass, new SerializableSplit(inputSplit)); + } + }); + } else { + return ImmutableList.of(this); + } + } + + private FileInputFormat createFormat(Job job) throws IOException, IllegalAccessException, + InstantiationException { + Path path = new Path(filepattern); + FileInputFormat.addInputPath(job, path); + return formatClass.newInstance(); + } + + private List computeSplits(long desiredBundleSizeBytes) throws IOException, + IllegalAccessException, InstantiationException { + Job job = Job.getInstance(); + FileInputFormat.setMinInputSplitSize(job, desiredBundleSizeBytes); + FileInputFormat.setMaxInputSplitSize(job, desiredBundleSizeBytes); + return createFormat(job).getSplits(job); + } + + @Override + public BoundedReader> createReader(PipelineOptions options) throws IOException { + this.validate(); + + if (serializableSplit == null) { + return new HadoopFileReader<>(this, filepattern, formatClass); + } else { + return new HadoopFileReader<>(this, filepattern, formatClass, + serializableSplit.getSplit()); + } + } + + @Override + public Coder> getDefaultOutputCoder() { + return KvCoder.of(getDefaultCoder(keyClass), getDefaultCoder(valueClass)); + } + + @SuppressWarnings("unchecked") + private Coder getDefaultCoder(Class c) { + if (Writable.class.isAssignableFrom(c)) { + Class writableClass = (Class) c; + return (Coder) WritableCoder.of(writableClass); + } else if (Void.class.equals(c)) { + return (Coder) VoidCoder.of(); + } + // TODO: how to use registered coders here? + throw new IllegalStateException("Cannot find coder for " + c); + } + + // BoundedSource + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) { + long size = 0; + try { + Job job = Job.getInstance(); // new instance + for (FileStatus st : listStatus(createFormat(job), job)) { + size += st.getLen(); + } + } catch (IOException | NoSuchMethodException | InvocationTargetException + | IllegalAccessException | InstantiationException e) { + // ignore, and return 0 + } + return size; + } + + private List listStatus(FileInputFormat format, + JobContext jobContext) throws NoSuchMethodException, InvocationTargetException, + IllegalAccessException { + // FileInputFormat#listStatus is protected, so call using reflection + Method listStatus = FileInputFormat.class.getDeclaredMethod("listStatus", JobContext.class); + listStatus.setAccessible(true); + @SuppressWarnings("unchecked") + List stat = (List) listStatus.invoke(format, jobContext); + return stat; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + static class HadoopFileReader extends BoundedSource.BoundedReader> { + + private final BoundedSource> source; + private final String filepattern; + private final Class formatClass; + + private FileInputFormat format; + private TaskAttemptContext attemptContext; + private List splits; + private ListIterator splitsIterator; + private Configuration conf; + private RecordReader currentReader; + private KV currentPair; + + /** + * Create a {@code HadoopFileReader} based on a file or a file pattern specification. + */ + public HadoopFileReader(BoundedSource> source, String filepattern, + Class> formatClass) { + this(source, filepattern, formatClass, null); + } + + /** + * Create a {@code HadoopFileReader} based on a single Hadoop input split. + */ + public HadoopFileReader(BoundedSource> source, String filepattern, + Class> formatClass, InputSplit split) { + this.source = source; + this.filepattern = filepattern; + this.formatClass = formatClass; + if (split != null) { + this.splits = ImmutableList.of(split); + this.splitsIterator = splits.listIterator(); + } + } + + @Override + public boolean start() throws IOException { + Job job = Job.getInstance(); // new instance + Path path = new Path(filepattern); + FileInputFormat.addInputPath(job, path); + + try { + @SuppressWarnings("unchecked") + FileInputFormat f = (FileInputFormat) formatClass.newInstance(); + this.format = f; + } catch (InstantiationException | IllegalAccessException e) { + throw new IOException("Cannot instantiate file input format " + formatClass, e); + } + this.attemptContext = new TaskAttemptContextImpl(job.getConfiguration(), + new TaskAttemptID()); + + if (splitsIterator == null) { + this.splits = format.getSplits(job); + this.splitsIterator = splits.listIterator(); + } + this.conf = job.getConfiguration(); + return advance(); + } + + @Override + public boolean advance() throws IOException { + try { + if (currentReader != null && currentReader.nextKeyValue()) { + currentPair = nextPair(); + return true; + } else { + while (splitsIterator.hasNext()) { + // advance the reader and see if it has records + InputSplit nextSplit = splitsIterator.next(); + @SuppressWarnings("unchecked") + RecordReader reader = + (RecordReader) format.createRecordReader(nextSplit, attemptContext); + if (currentReader != null) { + currentReader.close(); + } + currentReader = reader; + currentReader.initialize(nextSplit, attemptContext); + if (currentReader.nextKeyValue()) { + currentPair = nextPair(); + return true; + } + currentReader.close(); + currentReader = null; + } + // either no next split or all readers were empty + currentPair = null; + return false; + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException(e); + } + } + + @SuppressWarnings("unchecked") + private KV nextPair() throws IOException, InterruptedException { + K key = currentReader.getCurrentKey(); + V value = currentReader.getCurrentValue(); + // clone Writable objects since they are reused between calls to RecordReader#nextKeyValue + if (key instanceof Writable) { + key = (K) WritableUtils.clone((Writable) key, conf); + } + if (value instanceof Writable) { + value = (V) WritableUtils.clone((Writable) value, conf); + } + return KV.of(key, value); + } + + @Override + public KV getCurrent() throws NoSuchElementException { + if (currentPair == null) { + throw new NoSuchElementException(); + } + return currentPair; + } + + @Override + public void close() throws IOException { + if (currentReader != null) { + currentReader.close(); + currentReader = null; + } + currentPair = null; + } + + @Override + public BoundedSource> getCurrentSource() { + return source; + } + + // BoundedReader + + @Override + public Double getFractionConsumed() { + if (currentReader == null) { + return 0.0; + } + if (splits.isEmpty()) { + return 1.0; + } + int index = splitsIterator.previousIndex(); + int numReaders = splits.size(); + if (index == numReaders) { + return 1.0; + } + double before = 1.0 * index / numReaders; + double after = 1.0 * (index + 1) / numReaders; + Double fractionOfCurrentReader = getProgress(); + if (fractionOfCurrentReader == null) { + return before; + } + return before + fractionOfCurrentReader * (after - before); + } + + private Double getProgress() { + try { + return (double) currentReader.getProgress(); + } catch (IOException | InterruptedException e) { + return null; + } + } + + @Override + public BoundedSource> splitAtFraction(double fraction) { + // Not yet supported. To implement this, the sizes of the splits should be used to + // calculate the remaining splits that constitute the given fraction, then a + // new source backed by those splits should be returned. + return null; + } + } + + /** + * A wrapper to allow Hadoop {@link org.apache.hadoop.mapreduce.InputSplit}s to be + * serialized using Java's standard serialization mechanisms. Note that the InputSplit + * has to be Writable (which most are). + */ + public static class SerializableSplit implements Externalizable { + private static final long serialVersionUID = 0L; + + private InputSplit split; + + public SerializableSplit() { + } + + public SerializableSplit(InputSplit split) { + Preconditions.checkArgument(split instanceof Writable, "Split is not writable: " + + split); + this.split = split; + } + + public InputSplit getSplit() { + return split; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeUTF(split.getClass().getCanonicalName()); + ((Writable) split).write(out); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + String className = in.readUTF(); + try { + split = (InputSplit) Class.forName(className).newInstance(); + ((Writable) split).readFields(in); + } catch (InstantiationException | IllegalAccessException e) { + throw new IOException(e); + } + } + } + + +} diff --git a/contrib/hadoop/src/main/java/com/google/cloud/dataflow/contrib/hadoop/WritableCoder.java b/contrib/hadoop/src/main/java/com/google/cloud/dataflow/contrib/hadoop/WritableCoder.java new file mode 100644 index 000000000000..5dba58d39c19 --- /dev/null +++ b/contrib/hadoop/src/main/java/com/google/cloud/dataflow/contrib/hadoop/WritableCoder.java @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow Hadoop Library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.hadoop; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.hadoop.io.Writable; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * A {@code WritableCoder} is a {@link com.google.cloud.dataflow.sdk.coders.Coder} for a + * Java class that implements {@link org.apache.hadoop.io.Writable}. + * + *

To use, specify the coder type on a PCollection: + *

+ * {@code
+ *   PCollection records =
+ *       foo.apply(...).setCoder(WritableCoder.of(MyRecord.class));
+ * }
+ * 
+ * + * @param the type of elements handled by this coder + */ +public class WritableCoder extends StandardCoder { + private static final long serialVersionUID = 0L; + + /** + * Returns a {@code WritableCoder} instance for the provided element class. + * @param the element type + */ + public static WritableCoder of(Class clazz) { + return new WritableCoder<>(clazz); + } + + @JsonCreator + @SuppressWarnings("unchecked") + public static WritableCoder of(@JsonProperty("type") String classType) + throws ClassNotFoundException { + Class clazz = Class.forName(classType); + if (!Writable.class.isAssignableFrom(clazz)) { + throw new ClassNotFoundException( + "Class " + classType + " does not implement Writable"); + } + return of((Class) clazz); + } + + private final Class type; + + public WritableCoder(Class type) { + this.type = type; + } + + @Override + public void encode(T value, OutputStream outStream, Context context) throws IOException { + value.write(new DataOutputStream(outStream)); + } + + @Override + public T decode(InputStream inStream, Context context) throws IOException { + try { + T t = type.newInstance(); + t.readFields(new DataInputStream(inStream)); + return t; + } catch (InstantiationException | IllegalAccessException e) { + throw new CoderException("unable to deserialize record", e); + } + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + result.put("type", type.getName()); + return result; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "Hadoop Writable may be non-deterministic."); + } + +} diff --git a/contrib/hadoop/src/test/java/com/google/cloud/dataflow/contrib/hadoop/HadoopFileSourceTest.java b/contrib/hadoop/src/test/java/com/google/cloud/dataflow/contrib/hadoop/HadoopFileSourceTest.java new file mode 100644 index 000000000000..cef3c0834852 --- /dev/null +++ b/contrib/hadoop/src/test/java/com/google/cloud/dataflow/contrib/hadoop/HadoopFileSourceTest.java @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow Hadoop Library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.hadoop; + +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.readFromSource; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.SourceTestUtils; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.Writer; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Tests for HadoopFileSource. + */ +public class HadoopFileSourceTest { + + Random random = new Random(0L); + + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testFullyReadSingleFile() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + List> expectedResults = createRandomRecords(3, 10, 0); + File file = createFileWithData("tmp.seq", expectedResults); + + HadoopFileSource source = + HadoopFileSource.from(file.toString(), SequenceFileInputFormat.class, + IntWritable.class, Text.class); + + assertEquals(file.length(), source.getEstimatedSizeBytes(null)); + + assertThat(expectedResults, containsInAnyOrder(readFromSource(source, options).toArray())); + } + + @Test + public void testFullyReadFilePattern() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List> data1 = createRandomRecords(3, 10, 0); + File file1 = createFileWithData("file1", data1); + + List> data2 = createRandomRecords(3, 10, 10); + createFileWithData("file2", data2); + + List> data3 = createRandomRecords(3, 10, 20); + createFileWithData("file3", data3); + + List> data4 = createRandomRecords(3, 10, 30); + createFileWithData("otherfile", data4); + + HadoopFileSource source = + HadoopFileSource.from(new File(file1.getParent(), "file*").toString(), + SequenceFileInputFormat.class, IntWritable.class, Text.class); + List> expectedResults = new ArrayList<>(); + expectedResults.addAll(data1); + expectedResults.addAll(data2); + expectedResults.addAll(data3); + assertThat(expectedResults, containsInAnyOrder(readFromSource(source, options).toArray())); + } + + @Test + public void testCloseUnstartedFilePatternReader() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List> data1 = createRandomRecords(3, 10, 0); + File file1 = createFileWithData("file1", data1); + + List> data2 = createRandomRecords(3, 10, 10); + createFileWithData("file2", data2); + + List> data3 = createRandomRecords(3, 10, 20); + createFileWithData("file3", data3); + + List> data4 = createRandomRecords(3, 10, 30); + createFileWithData("otherfile", data4); + + HadoopFileSource source = + HadoopFileSource.from(new File(file1.getParent(), "file*").toString(), + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Source.Reader> reader = source.createReader(options); + // Closing an unstarted FilePatternReader should not throw an exception. + try { + reader.close(); + } catch (Exception e) { + fail("Closing an unstarted FilePatternReader should not throw an exception"); + } + } + + @Test + public void testSplits() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + + List> expectedResults = createRandomRecords(3, 10000, 0); + File file = createFileWithData("tmp.avro", expectedResults); + + HadoopFileSource source = + HadoopFileSource.from(file.toString(), SequenceFileInputFormat.class, + IntWritable.class, Text.class); + + // Assert that the source produces the expected records + assertEquals(expectedResults, readFromSource(source, options)); + + // Split with a small bundle size (has to be at least size of sync interval) + List>> splits = source + .splitIntoBundles(SequenceFile.SYNC_INTERVAL, options); + assertTrue(splits.size() > 2); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + int nonEmptySplits = 0; + for (BoundedSource> subSource : splits) { + if (readFromSource(subSource, options).size() > 0) { + nonEmptySplits += 1; + } + } + assertTrue(nonEmptySplits > 2); + } + + private File createFileWithData(String filename, List> records) + throws IOException { + File tmpFile = tmpFolder.newFile(filename); + try (Writer writer = SequenceFile.createWriter(new Configuration(), + Writer.keyClass(IntWritable.class), Writer.valueClass(Text.class), + Writer.file(new Path(tmpFile.toURI())))) { + + for (KV record : records) { + writer.append(record.getKey(), record.getValue()); + } + } + return tmpFile; + } + + private List> createRandomRecords(int dataItemLength, + int numItems, int offset) { + List> records = new ArrayList<>(); + for (int i = 0; i < numItems; i++) { + IntWritable key = new IntWritable(i + offset); + Text value = new Text(createRandomString(dataItemLength)); + records.add(KV.of(key, value)); + } + return records; + } + + private String createRandomString(int length) { + char[] chars = "abcdefghijklmnopqrstuvwxyz".toCharArray(); + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < length; i++) { + builder.append(chars[random.nextInt(chars.length)]); + } + return builder.toString(); + } + +} diff --git a/contrib/hadoop/src/test/java/com/google/cloud/dataflow/contrib/hadoop/WritableCoderTest.java b/contrib/hadoop/src/test/java/com/google/cloud/dataflow/contrib/hadoop/WritableCoderTest.java new file mode 100644 index 000000000000..8eeb5e5167ad --- /dev/null +++ b/contrib/hadoop/src/test/java/com/google/cloud/dataflow/contrib/hadoop/WritableCoderTest.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow Hadoop Library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.hadoop; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; + +import org.apache.hadoop.io.IntWritable; +import org.junit.Test; + +/** + * Tests for WritableCoder. + */ +public class WritableCoderTest { + + @Test + public void testIntWritableEncoding() throws Exception { + IntWritable value = new IntWritable(42); + WritableCoder coder = WritableCoder.of(IntWritable.class); + + CoderProperties.coderDecodeEncodeEqual(coder, value); + } +} diff --git a/contrib/join-library/AUTHORS.md b/contrib/join-library/AUTHORS.md new file mode 100644 index 000000000000..d32b6a7ebaed --- /dev/null +++ b/contrib/join-library/AUTHORS.md @@ -0,0 +1,6 @@ +# Authors of join-library + +The following is the official list of authors for copyright purposes of this community-contributed module. + + Google Inc. + Magnus Runesson, M.Runesson [at] gmail [dot] com diff --git a/contrib/join-library/README.md b/contrib/join-library/README.md new file mode 100644 index 000000000000..8e2a011c3402 --- /dev/null +++ b/contrib/join-library/README.md @@ -0,0 +1,33 @@ +Join-library +============ + +Join-library provides inner join, outer left and right join functions to +Google Cloud Dataflow. The aim is to simplify the most common cases of join to a +simple function call. + +The functions are generic so it supports join of any types supported by +Dataflow. Input to the join functions are PCollections of Key/Values. Both the +left and right PCollections need the same type for the key. All the join +functions return a Key/Value where Key is the join key and value is +a Key/Value where the key is the left value and right is the value. + +In the cases of outer join, since null cannot be serialized the user have +to provide a value that represent null for that particular use case. + +Example how to use join-library: + + PCollection> leftPcollection = ... + PCollection> rightPcollection = ... + + PCollection>> joinedPcollection = + Join.innerJoin(leftPcollection, rightPcollection); + +Join-library can be found on maven-central: + + + org.linuxalert.dataflow + google-cloud-dataflow-java-contrib-joinlibrary + 0.0.3 + + +Questions or comments: `M.Runesson [at] gmail [dot] com` diff --git a/contrib/join-library/pom.xml b/contrib/join-library/pom.xml new file mode 100644 index 000000000000..df39545dacb7 --- /dev/null +++ b/contrib/join-library/pom.xml @@ -0,0 +1,185 @@ + + + + 4.0.0 + + org.linuxalert.dataflow + google-cloud-dataflow-java-contrib-joinlibrary + Google Cloud Dataflow Join Library + Library with generic join functions for Dataflow. + https://github.com/GoogleCloudPlatform/DataflowJavaSDK/tree/master/contrib/join-library + + + Google Inc. + http://www.google.com + + + Magnus Runesson + M (dot) Runesson (at) gmail (dot) com + + Developer + + +1 + + + + + Magnus Runesson + M (dot) Runesson (at) gmail (dot) com + https://github.com/mrunesson + + + 0.0.4 + jar + + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + scm:git:git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + scm:git:git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + https://github.com/GoogleCloudPlatform/DataflowJavaSDK/tree/master/contrib/join-library + + + + UTF-8 + [1.0.0, 2.0.0) + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.2 + + 1.7 + 1.7 + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + ../../checkstyle.xml + true + true + true + + + + validate + + check + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.4 + + + attach-sources + + jar + + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.3 + + + attach-javadocs + + jar + + + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.6.3 + true + + ossrh + https://oss.sonatype.org/ + true + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.5 + + + sign-artifacts + verify + + sign + + + + + + + + + + ossrh + https://oss.sonatype.org/content/repositories/snapshots + + + ossrh + https://oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + ${google-cloud-dataflow-version} + + + + com.google.guava + guava + 19.0 + + + + + org.hamcrest + hamcrest-all + 1.3 + test + + + + junit + junit + 4.12 + test + + + + diff --git a/contrib/join-library/src/main/java/com/google/cloud/dataflow/contrib/joinlibrary/Join.java b/contrib/join-library/src/main/java/com/google/cloud/dataflow/contrib/joinlibrary/Join.java new file mode 100644 index 000000000000..81d8a7f9af75 --- /dev/null +++ b/contrib/join-library/src/main/java/com/google/cloud/dataflow/contrib/joinlibrary/Join.java @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow join-library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.joinlibrary; + +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Preconditions; + +/** + * Utility class with different versions of joins. All methods join two collections of + * key/value pairs (KV). + */ +public class Join { + + /** + * Inner join of two collections of KV elements. + * @param leftCollection Left side collection to join. + * @param rightCollection Right side collection to join. + * @param Type of the key for both collections + * @param Type of the values for the left collection. + * @param Type of the values for the right collection. + * @return A joined collection of KV where Key is the key and value is a + * KV where Key is of type V1 and Value is type V2. + */ + public static PCollection>> innerJoin( + final PCollection> leftCollection, final PCollection> rightCollection) { + Preconditions.checkNotNull(leftCollection); + Preconditions.checkNotNull(rightCollection); + + final TupleTag v1Tuple = new TupleTag<>(); + final TupleTag v2Tuple = new TupleTag<>(); + + PCollection> coGbkResultCollection = + KeyedPCollectionTuple.of(v1Tuple, leftCollection) + .and(v2Tuple, rightCollection) + .apply(CoGroupByKey.create()); + + return coGbkResultCollection.apply(ParDo.of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + + Iterable leftValuesIterable = e.getValue().getAll(v1Tuple); + Iterable rightValuesIterable = e.getValue().getAll(v2Tuple); + + for (V1 leftValue : leftValuesIterable) { + for (V2 rightValue : rightValuesIterable) { + c.output(KV.of(e.getKey(), KV.of(leftValue, rightValue))); + } + } + } + })) + .setCoder(KvCoder.of(((KvCoder) leftCollection.getCoder()).getKeyCoder(), + KvCoder.of(((KvCoder) leftCollection.getCoder()).getValueCoder(), + ((KvCoder) rightCollection.getCoder()).getValueCoder()))); + } + + /** + * Left Outer Join of two collections of KV elements. + * @param leftCollection Left side collection to join. + * @param rightCollection Right side collection to join. + * @param nullValue Value to use as null value when right side do not match left side. + * @param Type of the key for both collections + * @param Type of the values for the left collection. + * @param Type of the values for the right collection. + * @return A joined collection of KV where Key is the key and value is a + * KV where Key is of type V1 and Value is type V2. Values that + * should be null or empty is replaced with nullValue. + */ + public static PCollection>> leftOuterJoin( + final PCollection> leftCollection, + final PCollection> rightCollection, + final V2 nullValue) { + Preconditions.checkNotNull(leftCollection); + Preconditions.checkNotNull(rightCollection); + Preconditions.checkNotNull(nullValue); + + final TupleTag v1Tuple = new TupleTag<>(); + final TupleTag v2Tuple = new TupleTag<>(); + + PCollection> coGbkResultCollection = + KeyedPCollectionTuple.of(v1Tuple, leftCollection) + .and(v2Tuple, rightCollection) + .apply(CoGroupByKey.create()); + + return coGbkResultCollection.apply(ParDo.of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + + Iterable leftValuesIterable = e.getValue().getAll(v1Tuple); + Iterable rightValuesIterable = e.getValue().getAll(v2Tuple); + + for (V1 leftValue : leftValuesIterable) { + if (rightValuesIterable.iterator().hasNext()) { + for (V2 rightValue : rightValuesIterable) { + c.output(KV.of(e.getKey(), KV.of(leftValue, rightValue))); + } + } else { + c.output(KV.of(e.getKey(), KV.of(leftValue, nullValue))); + } + } + } + })) + .setCoder(KvCoder.of(((KvCoder) leftCollection.getCoder()).getKeyCoder(), + KvCoder.of(((KvCoder) leftCollection.getCoder()).getValueCoder(), + ((KvCoder) rightCollection.getCoder()).getValueCoder()))); + } + + /** + * Right Outer Join of two collections of KV elements. + * @param leftCollection Left side collection to join. + * @param rightCollection Right side collection to join. + * @param nullValue Value to use as null value when left side do not match right side. + * @param Type of the key for both collections + * @param Type of the values for the left collection. + * @param Type of the values for the right collection. + * @return A joined collection of KV where Key is the key and value is a + * KV where Key is of type V1 and Value is type V2. Keys that + * should be null or empty is replaced with nullValue. + */ + public static PCollection>> rightOuterJoin( + final PCollection> leftCollection, + final PCollection> rightCollection, + final V1 nullValue) { + Preconditions.checkNotNull(leftCollection); + Preconditions.checkNotNull(rightCollection); + Preconditions.checkNotNull(nullValue); + + final TupleTag v1Tuple = new TupleTag<>(); + final TupleTag v2Tuple = new TupleTag<>(); + + PCollection> coGbkResultCollection = + KeyedPCollectionTuple.of(v1Tuple, leftCollection) + .and(v2Tuple, rightCollection) + .apply(CoGroupByKey.create()); + + return coGbkResultCollection.apply(ParDo.of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + + Iterable leftValuesIterable = e.getValue().getAll(v1Tuple); + Iterable rightValuesIterable = e.getValue().getAll(v2Tuple); + + for (V2 rightValue : rightValuesIterable) { + if (leftValuesIterable.iterator().hasNext()) { + for (V1 leftValue : leftValuesIterable) { + c.output(KV.of(e.getKey(), KV.of(leftValue, rightValue))); + } + } else { + c.output(KV.of(e.getKey(), KV.of(nullValue, rightValue))); + } + } + } + })) + .setCoder(KvCoder.of(((KvCoder) leftCollection.getCoder()).getKeyCoder(), + KvCoder.of(((KvCoder) leftCollection.getCoder()).getValueCoder(), + ((KvCoder) rightCollection.getCoder()).getValueCoder()))); + } +} diff --git a/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/InnerJoinTest.java b/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/InnerJoinTest.java new file mode 100644 index 000000000000..839c4519508e --- /dev/null +++ b/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/InnerJoinTest.java @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow join-library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.joinlibrary; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * This test Inner Join functionality. + */ +public class InnerJoinTest { + + Pipeline p; + List> leftListOfKv; + List> listRightOfKv; + List>> expectedResult; + + @Before + public void setup() { + + p = TestPipeline.create(); + leftListOfKv = new ArrayList<>(); + listRightOfKv = new ArrayList<>(); + + expectedResult = new ArrayList<>(); + } + + @Test + public void testJoinOneToOneMapping() { + leftListOfKv.add(KV.of("Key1", 5L)); + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = + p.apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key1", "foo")); + listRightOfKv.add(KV.of("Key2", "bar")); + PCollection> rightCollection = + p.apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.innerJoin( + leftCollection, rightCollection); + + expectedResult.add(KV.of("Key1", KV.of(5L, "foo"))); + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinOneToManyMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key2", "bar")); + listRightOfKv.add(KV.of("Key2", "gazonk")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.innerJoin( + leftCollection, rightCollection); + + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + expectedResult.add(KV.of("Key2", KV.of(4L, "gazonk"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinManyToOneMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + leftListOfKv.add(KV.of("Key2", 6L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key2", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.innerJoin( + leftCollection, rightCollection); + + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + expectedResult.add(KV.of("Key2", KV.of(6L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinNoneToNoneMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key3", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.innerJoin( + leftCollection, rightCollection); + + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + p.run(); + } + + @Test(expected = NullPointerException.class) + public void testJoinLeftCollectionNull() { + Join.innerJoin(null, p.apply(Create.of(listRightOfKv))); + } + + @Test(expected = NullPointerException.class) + public void testJoinRightCollectionNull() { + Join.innerJoin(p.apply(Create.of(leftListOfKv)), null); + } +} diff --git a/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/OuterLeftJoinTest.java b/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/OuterLeftJoinTest.java new file mode 100644 index 000000000000..3dd67422d7f6 --- /dev/null +++ b/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/OuterLeftJoinTest.java @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow join-library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.joinlibrary; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + + +/** + * This test Outer Left Join functionality. + */ +public class OuterLeftJoinTest { + + Pipeline p; + List> leftListOfKv; + List> listRightOfKv; + List>> expectedResult; + + @Before + public void setup() { + + p = TestPipeline.create(); + leftListOfKv = new ArrayList<>(); + listRightOfKv = new ArrayList<>(); + + expectedResult = new ArrayList<>(); + } + + @Test + public void testJoinOneToOneMapping() { + leftListOfKv.add(KV.of("Key1", 5L)); + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key1", "foo")); + listRightOfKv.add(KV.of("Key2", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.leftOuterJoin( + leftCollection, rightCollection, ""); + + expectedResult.add(KV.of("Key1", KV.of(5L, "foo"))); + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinOneToManyMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key2", "bar")); + listRightOfKv.add(KV.of("Key2", "gazonk")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.leftOuterJoin( + leftCollection, rightCollection, ""); + + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + expectedResult.add(KV.of("Key2", KV.of(4L, "gazonk"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinManyToOneMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + leftListOfKv.add(KV.of("Key2", 6L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key2", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.leftOuterJoin( + leftCollection, rightCollection, ""); + + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + expectedResult.add(KV.of("Key2", KV.of(6L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinOneToNoneMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key3", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.leftOuterJoin( + leftCollection, rightCollection, ""); + + expectedResult.add(KV.of("Key2", KV.of(4L, ""))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + p.run(); + } + + @Test(expected = NullPointerException.class) + public void testJoinLeftCollectionNull() { + Join.leftOuterJoin(null, p.apply(Create.of(listRightOfKv)), ""); + } + + @Test(expected = NullPointerException.class) + public void testJoinRightCollectionNull() { + Join.leftOuterJoin(p.apply(Create.of(leftListOfKv)), null, ""); + } + + @Test(expected = NullPointerException.class) + public void testJoinNullValueIsNull() { + Join.leftOuterJoin( + p.apply("CreateLeft", Create.of(leftListOfKv)), + p.apply("CreateRight", Create.of(listRightOfKv)), + null); + } +} diff --git a/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/OuterRightJoinTest.java b/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/OuterRightJoinTest.java new file mode 100644 index 000000000000..9f3b80b8c93e --- /dev/null +++ b/contrib/join-library/src/test/java/com/google/cloud/dataflow/contrib/joinlibrary/OuterRightJoinTest.java @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2015 The Google Cloud Dataflow join-library Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.contrib.joinlibrary; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + + +/** + * This test Outer Right Join functionality. + */ +public class OuterRightJoinTest { + + Pipeline p; + List> leftListOfKv; + List> listRightOfKv; + List>> expectedResult; + + @Before + public void setup() { + + p = TestPipeline.create(); + leftListOfKv = new ArrayList<>(); + listRightOfKv = new ArrayList<>(); + + expectedResult = new ArrayList<>(); + } + + @Test + public void testJoinOneToOneMapping() { + leftListOfKv.add(KV.of("Key1", 5L)); + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key1", "foo")); + listRightOfKv.add(KV.of("Key2", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.rightOuterJoin( + leftCollection, rightCollection, -1L); + + expectedResult.add(KV.of("Key1", KV.of(5L, "foo"))); + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinOneToManyMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key2", "bar")); + listRightOfKv.add(KV.of("Key2", "gazonk")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.rightOuterJoin( + leftCollection, rightCollection, -1L); + + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + expectedResult.add(KV.of("Key2", KV.of(4L, "gazonk"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinManyToOneMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + leftListOfKv.add(KV.of("Key2", 6L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key2", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.rightOuterJoin( + leftCollection, rightCollection, -1L); + + expectedResult.add(KV.of("Key2", KV.of(4L, "bar"))); + expectedResult.add(KV.of("Key2", KV.of(6L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + + p.run(); + } + + @Test + public void testJoinNoneToOneMapping() { + leftListOfKv.add(KV.of("Key2", 4L)); + PCollection> leftCollection = p + .apply("CreateLeft", Create.of(leftListOfKv)); + + listRightOfKv.add(KV.of("Key3", "bar")); + PCollection> rightCollection = p + .apply("CreateRight", Create.of(listRightOfKv)); + + PCollection>> output = Join.rightOuterJoin( + leftCollection, rightCollection, -1L); + + expectedResult.add(KV.of("Key3", KV.of(-1L, "bar"))); + DataflowAssert.that(output).containsInAnyOrder(expectedResult); + p.run(); + } + + @Test(expected = NullPointerException.class) + public void testJoinLeftCollectionNull() { + Join.rightOuterJoin(null, p.apply(Create.of(listRightOfKv)), ""); + } + + @Test(expected = NullPointerException.class) + public void testJoinRightCollectionNull() { + Join.rightOuterJoin(p.apply(Create.of(leftListOfKv)), null, -1L); + } + + @Test(expected = NullPointerException.class) + public void testJoinNullValueIsNull() { + Join.rightOuterJoin( + p.apply("CreateLeft", Create.of(leftListOfKv)), + p.apply("CreateRight", Create.of(listRightOfKv)), + null); + } +} diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000000..cbcd01fc0f1c --- /dev/null +++ b/examples/README.md @@ -0,0 +1,95 @@ +# Example Pipelines + +The examples included in this module serve to demonstrate the basic +functionality of Google Cloud Dataflow, and act as starting points for +the development of more complex pipelines. + +## Word Count + +A good starting point for new users is our set of +[word count](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples) examples, which computes word frequencies. This series of four successively more detailed pipelines is described in detail in the accompanying [walkthrough](https://cloud.google.com/dataflow/examples/wordcount-example). + +1. [`MinimalWordCount`](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/MinimalWordCount.java) is the simplest word count pipeline and introduces basic concepts like [Pipelines](https://cloud.google.com/dataflow/model/pipelines), +[PCollections](https://cloud.google.com/dataflow/model/pcollection), +[ParDo](https://cloud.google.com/dataflow/model/par-do), +and [reading and writing data](https://cloud.google.com/dataflow/model/reading-and-writing-data) from external storage. + +1. [`WordCount`](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java) introduces Dataflow best practices like [PipelineOptions](https://cloud.google.com/dataflow/pipelines/constructing-your-pipeline#Creating) and custom [PTransforms](https://cloud.google.com/dataflow/model/composite-transforms). + +1. [`DebuggingWordCount`](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/DebuggingWordCount.java) +shows how to view live aggregators in the [Dataflow Monitoring Interface](https://cloud.google.com/dataflow/pipelines/dataflow-monitoring-intf), get the most out of +[Cloud Logging](https://cloud.google.com/dataflow/pipelines/logging) integration, and start writing +[good tests](https://cloud.google.com/dataflow/pipelines/testing-your-pipeline). + +1. [`WindowedWordCount`](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/WindowedWordCount.java) shows how to run the same pipeline over either unbounded PCollections in streaming mode or bounded PCollections in batch mode. + +## Building and Running + +The examples in this repository can be built and executed from the root directory by running: + + mvn compile exec:java -pl examples \ + -Dexec.mainClass=
\ + -Dexec.args="" + +For example, you can execute the `WordCount` pipeline on your local machine as follows: + + mvn compile exec:java -pl examples \ + -Dexec.mainClass=com.google.cloud.dataflow.examples.WordCount \ + -Dexec.args="--inputFile= --output=" + +Once you have followed the general Cloud Dataflow +[Getting Started](https://cloud.google.com/dataflow/getting-started) instructions, you can execute +the same pipeline on fully managed resources in Google Cloud Platform: + + mvn compile exec:java -pl examples \ + -Dexec.mainClass=com.google.cloud.dataflow.examples.WordCount \ + -Dexec.args="--project= \ + --stagingLocation= \ + --runner=BlockingDataflowPipelineRunner" + +Make sure to use your project id, not the project number or the descriptive name. +The Cloud Storage location should be entered in the form of +`gs://bucket/path/to/staging/directory`. + +Alternatively, you may choose to bundle all dependencies into a single JAR and +execute it outside of the Maven environment. For example, you can execute the +following commands to create the +bundled JAR of the examples and execute it both locally and in Cloud +Platform: + + mvn package + + java -cp examples/target/google-cloud-dataflow-java-examples-all-bundled-.jar \ + com.google.cloud.dataflow.examples.WordCount \ + --inputFile= --output= + + java -cp examples/target/google-cloud-dataflow-java-examples-all-bundled-.jar \ + com.google.cloud.dataflow.examples.WordCount \ + --project= \ + --stagingLocation= \ + --runner=BlockingDataflowPipelineRunner + +Other examples can be run similarly by replacing the `WordCount` class path with the example classpath, e.g. +`com.google.cloud.dataflow.examples.cookbook.BigQueryTornadoes`, +and adjusting runtime options under the `Dexec.args` parameter, as specified in +the example itself. + +Note that when running Maven on Microsoft Windows platform, backslashes (`\`) +under the `Dexec.args` parameter should be escaped with another backslash. For +example, input file pattern of `c:\*.txt` should be entered as `c:\\*.txt`. + +## Beyond Word Count + +After you've finished running your first few word count pipelines, take a look at the [`cookbook`](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook) +directory for some common and useful patterns like joining, filtering, and combining. + +The [`complete`](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/complete) +directory contains a few realistic end-to-end pipelines. + +See the +[Java 8](https://github.com/GoogleCloudPlatform/DataflowJavaSDK/tree/master/examples/src/main/java8/com/google/cloud/dataflow/examples) +examples as well. This directory includes a Java 8 version of the +MinimalWordCount example, as well as series of examples in a simple 'mobile +gaming' domain. This series introduces some advanced concepts and provides +additional examples of using Java 8 syntax. Other than usage of Java 8 lambda +expressions, the concepts that are used apply equally well in Java 7. diff --git a/examples/pom.xml b/examples/pom.xml new file mode 100644 index 000000000000..d7834cb3ef70 --- /dev/null +++ b/examples/pom.xml @@ -0,0 +1,521 @@ + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + 1.5.0-SNAPSHOT + + + com.google.cloud.dataflow + google-cloud-dataflow-java-examples-all + Google Cloud Dataflow Java Examples - All + Google Cloud Dataflow Java SDK provides a simple, Java-based + interface for processing virtually any size data using Google cloud + resources. This artifact includes all Dataflow Java SDK + examples. + http://cloud.google.com/dataflow + + jar + + + + DataflowPipelineTests + + true + com.google.cloud.dataflow.sdk.testing.RunnableOnService + both + + + + + java8 + + [1.8,) + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-java8-main-source + initialize + + add-source + + + + ${project.basedir}/src/main/java8 + + + + + + add-java8-test-source + initialize + + add-test-source + + + + ${project.basedir}/src/test/java8 + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + + default-testCompile + test-compile + + testCompile + + + 1.7 + 1.7 + + + **/*Java8Test.java + **/game/**/*.java + + + + + + + java8-test-compile + test-compile + + testCompile + + + 1.8 + 1.8 + + + **/*Java8Test.java + **/game/**/*.java + + + + + + + default-compile + compile + + compile + + + 1.7 + 1.7 + + + **/*Java8*.java + **/game/**/*.java + + + + + + + java8-compile + compile + + compile + + + 1.8 + 1.8 + + + **/*Java8*.java + **/game/**/*.java + + + + + + + + + + + + + + maven-compiler-plugin + + + + org.apache.maven.plugins + maven-dependency-plugin + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.12 + + + com.puppycrawl.tools + checkstyle + 6.6 + + + + ../checkstyle.xml + true + true + true + false + + + + + check + + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.4 + + + attach-sources + compile + + jar + + + + attach-test-sources + test-compile + + test-jar + + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + Google Cloud Dataflow Examples + Google Cloud Dataflow Examples + + com.google.cloud.dataflow.examples + -exclude com.google.cloud.dataflow.sdk.runners.worker:com.google.cloud.dataflow.sdk.runners.dataflow:com.google.cloud.dataflow.sdk.util ${dataflow.javadoc_opts} + false + true + ]]> + + + + + https://cloud.google.com/dataflow/java-sdk/JavaDoc/ + ${basedir}/../javadoc/dataflow-sdk-docs + + + + https://developers.google.com/api-client-library/java/google-api-java-client/reference/1.20.0/ + ${basedir}/../javadoc/apiclient-docs + + + http://avro.apache.org/docs/1.7.7/api/java/ + ${basedir}/../javadoc/avro-docs + + + https://developers.google.com/resources/api-libraries/documentation/bigquery/v2/java/latest/ + ${basedir}/../javadoc/bq-docs + + + https://cloud.google.com/datastore/docs/apis/javadoc/ + ${basedir}/../javadoc/datastore-docs + + + http://docs.guava-libraries.googlecode.com/git-history/release18/javadoc/ + ${basedir}/../javadoc/guava-docs + + + http://fasterxml.github.io/jackson-annotations/javadoc/2.7/ + ${basedir}/../javadoc/jackson-annotations-docs + + + http://fasterxml.github.io/jackson-databind/javadoc/2.7/ + ${basedir}/../javadoc/jackson-databind-docs + + + http://www.joda.org/joda-time/apidocs + ${basedir}/../javadoc/joda-docs + + + https://developers.google.com/api-client-library/java/google-oauth-java-client/reference/1.20.0/ + ${basedir}/../javadoc/oauth-docs + + + + + + + jar + + package + + + + + + org.apache.maven.plugins + maven-shade-plugin + 2.4.1 + + + package + + shade + + + ${project.artifactId}-bundled-${project.version} + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + default-jar + + jar + + + + default-test-jar + + test-jar + + + + + + + + org.jacoco + jacoco-maven-plugin + + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + ${project.version} + + + + com.google.api-client + google-api-client + ${google-clients.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-dataflow + ${dataflow.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-bigquery + ${bigquery.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.http-client + google-http-client + ${google-clients.version} + + + + com.google.guava + guava-jdk5 + + + + + + org.apache.avro + avro + ${avro.version} + + + + com.google.apis + google-api-services-datastore-protobuf + ${datastore.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-pubsub + ${pubsub.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.guava + guava + ${guava.version} + + + + com.google.code.findbugs + jsr305 + ${jsr305.version} + + + + joda-time + joda-time + ${joda.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + org.slf4j + slf4j-jdk14 + ${slf4j.version} + runtime + + + + javax.servlet + javax.servlet-api + 3.1.0 + + + + + + org.hamcrest + hamcrest-all + ${hamcrest.version} + + + + junit + junit + ${junit.version} + + + + org.mockito + mockito-all + 1.9.5 + test + + + diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/DebuggingWordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/DebuggingWordCount.java new file mode 100644 index 000000000000..8823dbc32327 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/DebuggingWordCount.java @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.examples.WordCount.WordCountOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + + +/** + * An example that verifies word counts in Shakespeare and includes Dataflow best practices. + * + *

This class, {@link DebuggingWordCount}, is the third in a series of four successively more + * detailed 'word count' examples. You may first want to take a look at {@link MinimalWordCount} + * and {@link WordCount}. After you've looked at this example, then see the + * {@link WindowedWordCount} pipeline, for introduction of additional concepts. + * + *

Basic concepts, also in the MinimalWordCount and WordCount examples: + * Reading text files; counting a PCollection; executing a Pipeline both locally + * and using the Dataflow service; defining DoFns. + * + *

New Concepts: + *

+ *   1. Logging to Cloud Logging
+ *   2. Controlling Dataflow worker log levels
+ *   3. Creating a custom aggregator
+ *   4. Testing your Pipeline via DataflowAssert
+ * 
+ * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * + *

To execute this pipeline using the Dataflow service and the additional logging discussed + * below, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ *   --workerLogLevelOverrides={"com.google.cloud.dataflow.examples":"DEBUG"}
+ * }
+ * 
+ * + *

Note that when you run via mvn exec, you may need to escape + * the quotations as appropriate for your shell. For example, in bash: + *

+ * mvn compile exec:java ... \
+ *   -Dexec.args="... \
+ *     --workerLogLevelOverrides={\\\"com.google.cloud.dataflow.examples\\\":\\\"DEBUG\\\"}"
+ * 
+ * + *

Concept #2: Dataflow workers which execute user code are configured to log to Cloud + * Logging by default at "INFO" log level and higher. One may override log levels for specific + * logging namespaces by specifying: + *


+ *   --workerLogLevelOverrides={"Name1":"Level1","Name2":"Level2",...}
+ * 
+ * For example, by specifying: + *

+ *   --workerLogLevelOverrides={"com.google.cloud.dataflow.examples":"DEBUG"}
+ * 
+ * when executing this pipeline using the Dataflow service, Cloud Logging would contain only + * "DEBUG" or higher level logs for the {@code com.google.cloud.dataflow.examples} package in + * addition to the default "INFO" or higher level logs. In addition, the default Dataflow worker + * logging configuration can be overridden by specifying + * {@code --defaultWorkerLogLevel=}. For example, + * by specifying {@code --defaultWorkerLogLevel=DEBUG} when executing this pipeline with + * the Dataflow service, Cloud Logging would contain all "DEBUG" or higher level logs. Note + * that changing the default worker log level to TRACE or DEBUG will significantly increase + * the amount of logs output. + * + *

The input file defaults to {@code gs://dataflow-samples/shakespeare/kinglear.txt} and can be + * overridden with {@code --inputFile}. + */ +public class DebuggingWordCount { + /** A DoFn that filters for a specific key based upon a regular expression. */ + public static class FilterTextFn extends DoFn, KV> { + /** + * Concept #1: The logger below uses the fully qualified class name of FilterTextFn + * as the logger. All log statements emitted by this logger will be referenced by this name + * and will be visible in the Cloud Logging UI. Learn more at https://cloud.google.com/logging + * about the Cloud Logging UI. + */ + private static final Logger LOG = LoggerFactory.getLogger(FilterTextFn.class); + + private final Pattern filter; + public FilterTextFn(String pattern) { + filter = Pattern.compile(pattern); + } + + /** + * Concept #3: A custom aggregator can track values in your pipeline as it runs. Those + * values will be displayed in the Dataflow Monitoring UI when this pipeline is run using the + * Dataflow service. These aggregators below track the number of matched and unmatched words. + * Learn more at https://cloud.google.com/dataflow/pipelines/dataflow-monitoring-intf about + * the Dataflow Monitoring UI. + */ + private final Aggregator matchedWords = + createAggregator("matchedWords", new Sum.SumLongFn()); + private final Aggregator unmatchedWords = + createAggregator("umatchedWords", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (filter.matcher(c.element().getKey()).matches()) { + // Log at the "DEBUG" level each element that we match. When executing this pipeline + // using the Dataflow service, these log lines will appear in the Cloud Logging UI + // only if the log level is set to "DEBUG" or lower. + LOG.debug("Matched: " + c.element().getKey()); + matchedWords.addValue(1L); + c.output(c.element()); + } else { + // Log at the "TRACE" level each element that is not matched. Different log levels + // can be used to control the verbosity of logging providing an effective mechanism + // to filter less important information. + LOG.trace("Did not match: " + c.element().getKey()); + unmatchedWords.addValue(1L); + } + } + } + + public static void main(String[] args) { + WordCountOptions options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(WordCountOptions.class); + Pipeline p = Pipeline.create(options); + + PCollection> filteredWords = + p.apply(TextIO.Read.named("ReadLines").from(options.getInputFile())) + .apply(new WordCount.CountWords()) + .apply(ParDo.of(new FilterTextFn("Flourish|stomach"))); + + /** + * Concept #4: DataflowAssert is a set of convenient PTransforms in the style of + * Hamcrest's collection matchers that can be used when writing Pipeline level tests + * to validate the contents of PCollections. DataflowAssert is best used in unit tests + * with small data sets but is demonstrated here as a teaching tool. + * + *

Below we verify that the set of filtered words matches our expected counts. Note + * that DataflowAssert does not provide any output and that successful completion of the + * Pipeline implies that the expectations were met. Learn more at + * https://cloud.google.com/dataflow/pipelines/testing-your-pipeline on how to test + * your Pipeline and see {@link DebuggingWordCountTest} for an example unit test. + */ + List> expectedResults = Arrays.asList( + KV.of("Flourish", 3L), + KV.of("stomach", 1L)); + DataflowAssert.that(filteredWords).containsInAnyOrder(expectedResults); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/MinimalWordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/MinimalWordCount.java new file mode 100644 index 000000000000..4ed05207c461 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/MinimalWordCount.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SimpleFunction; +import com.google.cloud.dataflow.sdk.values.KV; + + +/** + * An example that counts words in Shakespeare. + * + *

This class, {@link MinimalWordCount}, is the first in a series of four successively more + * detailed 'word count' examples. Here, for simplicity, we don't show any error-checking or + * argument processing, and focus on construction of the pipeline, which chains together the + * application of core transforms. + * + *

Next, see the {@link WordCount} pipeline, then the {@link DebuggingWordCount}, and finally + * the {@link WindowedWordCount} pipeline, for more detailed examples that introduce additional + * concepts. + * + *

Concepts: + *

+ *   1. Reading data from text files
+ *   2. Specifying 'inline' transforms
+ *   3. Counting a PCollection
+ *   4. Writing data to Cloud Storage as text files
+ * 
+ * + *

To execute this pipeline, first edit the code to set your project ID, the staging + * location, and the output location. The specified GCS bucket(s) must already exist. + * + *

Then, run the pipeline as described in the README. It will be deployed and run using the + * Dataflow service. No args are required to run the pipeline. You can see the results in your + * output bucket in the GCS browser. + */ +public class MinimalWordCount { + + public static void main(String[] args) { + // Create a DataflowPipelineOptions object. This object lets us set various execution + // options for our pipeline, such as the associated Cloud Platform project and the location + // in Google Cloud Storage to stage files. + DataflowPipelineOptions options = PipelineOptionsFactory.create() + .as(DataflowPipelineOptions.class); + options.setRunner(BlockingDataflowPipelineRunner.class); + // CHANGE 1/3: Your project ID is required in order to run your pipeline on the Google Cloud. + options.setProject("SET_YOUR_PROJECT_ID_HERE"); + // CHANGE 2/3: Your Google Cloud Storage path is required for staging local files. + options.setStagingLocation("gs://SET_YOUR_BUCKET_NAME_HERE/AND_STAGING_DIRECTORY"); + + // Create the Pipeline object with the options we defined above. + Pipeline p = Pipeline.create(options); + + // Apply the pipeline's transforms. + + // Concept #1: Apply a root transform to the pipeline; in this case, TextIO.Read to read a set + // of input text files. TextIO.Read returns a PCollection where each element is one line from + // the input text (a set of Shakespeare's texts). + p.apply(TextIO.Read.from("gs://dataflow-samples/shakespeare/*")) + // Concept #2: Apply a ParDo transform to our PCollection of text lines. This ParDo invokes a + // DoFn (defined in-line) on each element that tokenizes the text line into individual words. + // The ParDo returns a PCollection, where each element is an individual word in + // Shakespeare's collected texts. + .apply(ParDo.named("ExtractWords").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (String word : c.element().split("[^a-zA-Z']+")) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + })) + // Concept #3: Apply the Count transform to our PCollection of individual words. The Count + // transform returns a new PCollection of key/value pairs, where each key represents a unique + // word in the text. The associated value is the occurrence count for that word. + .apply(Count.perElement()) + // Apply a MapElements transform that formats our PCollection of word counts into a printable + // string, suitable for writing to an output file. + .apply("FormatResults", MapElements.via(new SimpleFunction, String>() { + @Override + public String apply(KV input) { + return input.getKey() + ": " + input.getValue(); + } + })) + // Concept #4: Apply a write transform, TextIO.Write, at the end of the pipeline. + // TextIO.Write writes the contents of a PCollection (in this case, our PCollection of + // formatted strings) to a series of text files in Google Cloud Storage. + // CHANGE 3/3: The Google Cloud Storage path is required for outputting the results to. + .apply(TextIO.Write.to("gs://YOUR_OUTPUT_BUCKET/AND_OUTPUT_PREFIX")); + + // Run the pipeline. + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/WindowedWordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/WindowedWordCount.java new file mode 100644 index 000000000000..2adac5562731 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/WindowedWordCount.java @@ -0,0 +1,269 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.examples.common.DataflowExampleOptions; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.common.ExampleBigQueryTableOptions; +import com.google.cloud.dataflow.examples.common.ExamplePubsubTopicOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +/** + * An example that counts words in text, and can run over either unbounded or bounded input + * collections. + * + *

This class, {@link WindowedWordCount}, is the last in a series of four successively more + * detailed 'word count' examples. First take a look at {@link MinimalWordCount}, + * {@link WordCount}, and {@link DebuggingWordCount}. + * + *

Basic concepts, also in the MinimalWordCount, WordCount, and DebuggingWordCount examples: + * Reading text files; counting a PCollection; writing to GCS; executing a Pipeline both locally + * and using the Dataflow service; defining DoFns; creating a custom aggregator; + * user-defined PTransforms; defining PipelineOptions. + * + *

New Concepts: + *

+ *   1. Unbounded and bounded pipeline input modes
+ *   2. Adding timestamps to data
+ *   3. PubSub topics as sources
+ *   4. Windowing
+ *   5. Re-using PTransforms over windowed PCollections
+ *   6. Writing to BigQuery
+ * 
+ * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * + *

Optionally specify the input file path via: + * {@code --inputFile=gs://INPUT_PATH}, + * which defaults to {@code gs://dataflow-samples/shakespeare/kinglear.txt}. + * + *

Specify an output BigQuery dataset and optionally, a table for the output. If you don't + * specify the table, one will be created for you using the job name. If you don't specify the + * dataset, a dataset called {@code dataflow-examples} must already exist in your project. + * {@code --bigQueryDataset=YOUR-DATASET --bigQueryTable=YOUR-NEW-TABLE-NAME}. + * + *

Decide whether you want your pipeline to run with 'bounded' (such as files in GCS) or + * 'unbounded' input (such as a PubSub topic). To run with unbounded input, set + * {@code --unbounded=true}. Then, optionally specify the Google Cloud PubSub topic to read from + * via {@code --pubsubTopic=projects/PROJECT_ID/topics/YOUR_TOPIC_NAME}. If the topic does not + * exist, the pipeline will create one for you. It will delete this topic when it terminates. + * The pipeline will automatically launch an auxiliary batch pipeline to populate the given PubSub + * topic with the contents of the {@code --inputFile}, in order to make the example easy to run. + * If you want to use an independently-populated PubSub topic, indicate this by setting + * {@code --inputFile=""}. In that case, the auxiliary pipeline will not be started. + * + *

By default, the pipeline will do fixed windowing, on 1-minute windows. You can + * change this interval by setting the {@code --windowSize} parameter, e.g. {@code --windowSize=10} + * for 10-minute windows. + */ +public class WindowedWordCount { + private static final Logger LOG = LoggerFactory.getLogger(WindowedWordCount.class); + static final int WINDOW_SIZE = 1; // Default window duration in minutes + + /** + * Concept #2: A DoFn that sets the data element timestamp. This is a silly method, just for + * this example, for the bounded data case. + * + *

Imagine that many ghosts of Shakespeare are all typing madly at the same time to recreate + * his masterworks. Each line of the corpus will get a random associated timestamp somewhere in a + * 2-hour period. + */ + static class AddTimestampFn extends DoFn { + private static final long RAND_RANGE = 7200000; // 2 hours in ms + + @Override + public void processElement(ProcessContext c) { + // Generate a timestamp that falls somewhere in the past two hours. + long randomTimestamp = System.currentTimeMillis() + - (int) (Math.random() * RAND_RANGE); + /** + * Concept #2: Set the data element with that timestamp. + */ + c.outputWithTimestamp(c.element(), new Instant(randomTimestamp)); + } + } + + /** A DoFn that converts a Word and Count into a BigQuery table row. */ + static class FormatAsTableRowFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + TableRow row = new TableRow() + .set("word", c.element().getKey()) + .set("count", c.element().getValue()) + // include a field for the window timestamp + .set("window_timestamp", c.timestamp().toString()); + c.output(row); + } + } + + /** + * Helper method that defines the BigQuery schema used for the output. + */ + private static TableSchema getSchema() { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("word").setType("STRING")); + fields.add(new TableFieldSchema().setName("count").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("window_timestamp").setType("TIMESTAMP")); + TableSchema schema = new TableSchema().setFields(fields); + return schema; + } + + /** + * Concept #6: We'll stream the results to a BigQuery table. The BigQuery output source is one + * that supports both bounded and unbounded data. This is a helper method that creates a + * TableReference from input options, to tell the pipeline where to write its BigQuery results. + */ + private static TableReference getTableReference(Options options) { + TableReference tableRef = new TableReference(); + tableRef.setProjectId(options.getProject()); + tableRef.setDatasetId(options.getBigQueryDataset()); + tableRef.setTableId(options.getBigQueryTable()); + return tableRef; + } + + /** + * Options supported by {@link WindowedWordCount}. + * + *

Inherits standard example configuration options, which allow specification of the BigQuery + * table and the PubSub topic, as well as the {@link WordCount.WordCountOptions} support for + * specification of the input file. + */ + public static interface Options extends WordCount.WordCountOptions, + DataflowExampleOptions, ExamplePubsubTopicOptions, ExampleBigQueryTableOptions { + @Description("Fixed window duration, in minutes") + @Default.Integer(WINDOW_SIZE) + Integer getWindowSize(); + void setWindowSize(Integer value); + + @Description("Whether to run the pipeline with unbounded input") + boolean isUnbounded(); + void setUnbounded(boolean value); + } + + public static void main(String[] args) throws IOException { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + options.setBigQuerySchema(getSchema()); + // DataflowExampleUtils creates the necessary input sources to simplify execution of this + // Pipeline. + DataflowExampleUtils exampleDataflowUtils = new DataflowExampleUtils(options, + options.isUnbounded()); + + Pipeline pipeline = Pipeline.create(options); + + /** + * Concept #1: the Dataflow SDK lets us run the same pipeline with either a bounded or + * unbounded input source. + */ + PCollection input; + if (options.isUnbounded()) { + LOG.info("Reading from PubSub."); + /** + * Concept #3: Read from the PubSub topic. A topic will be created if it wasn't + * specified as an argument. The data elements' timestamps will come from the pubsub + * injection. + */ + input = pipeline + .apply(PubsubIO.Read.topic(options.getPubsubTopic())); + } else { + /** Else, this is a bounded pipeline. Read from the GCS file. */ + input = pipeline + .apply(TextIO.Read.from(options.getInputFile())) + // Concept #2: Add an element timestamp, using an artificial time just to show windowing. + // See AddTimestampFn for more detail on this. + .apply(ParDo.of(new AddTimestampFn())); + } + + /** + * Concept #4: Window into fixed windows. The fixed window size for this example defaults to 1 + * minute (you can change this with a command-line option). See the documentation for more + * information on how fixed windows work, and for information on the other types of windowing + * available (e.g., sliding windows). + */ + PCollection windowedWords = input + .apply(Window.into( + FixedWindows.of(Duration.standardMinutes(options.getWindowSize())))); + + /** + * Concept #5: Re-use our existing CountWords transform that does not have knowledge of + * windows over a PCollection containing windowed values. + */ + PCollection> wordCounts = windowedWords.apply(new WordCount.CountWords()); + + /** + * Concept #6: Format the results for a BigQuery table, then write to BigQuery. + * The BigQuery output source supports both bounded and unbounded data. + */ + wordCounts.apply(ParDo.of(new FormatAsTableRowFn())) + .apply(BigQueryIO.Write + .to(getTableReference(options)) + .withSchema(getSchema()) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND)); + + PipelineResult result = pipeline.run(); + + /** + * To mock unbounded input from PubSub, we'll now start an auxiliary 'injector' pipeline that + * runs for a limited time, and publishes to the input PubSub topic. + * + * With an unbounded input source, you will need to explicitly shut down this pipeline when you + * are done with it, so that you do not continue to be charged for the instances. You can do + * this via a ctrl-C from the command line, or from the developer's console UI for Dataflow + * pipelines. The PubSub topic will also be deleted at this time. + */ + exampleDataflowUtils.mockUnboundedSource(options.getInputFile(), result); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java new file mode 100644 index 000000000000..1086106f0498 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SimpleFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + + +/** + * An example that counts words in Shakespeare and includes Dataflow best practices. + * + *

This class, {@link WordCount}, is the second in a series of four successively more detailed + * 'word count' examples. You may first want to take a look at {@link MinimalWordCount}. + * After you've looked at this example, then see the {@link DebuggingWordCount} + * pipeline, for introduction of additional concepts. + * + *

For a detailed walkthrough of this example, see + * + * https://cloud.google.com/dataflow/java-sdk/wordcount-example + * + * + *

Basic concepts, also in the MinimalWordCount example: + * Reading text files; counting a PCollection; writing to GCS. + * + *

New Concepts: + *

+ *   1. Executing a Pipeline both locally and using the Dataflow service
+ *   2. Using ParDo with static DoFns defined out-of-line
+ *   3. Building a composite transform
+ *   4. Defining your own pipeline options
+ * 
+ * + *

Concept #1: you can execute this pipeline either locally or using the Dataflow service. + * These are now command-line options and not hard-coded as they were in the MinimalWordCount + * example. + * To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and a local output file or output prefix on GCS: + *
{@code
+ *   --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PREFIX]
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and an output prefix on GCS: + *
{@code
+ *   --output=gs://YOUR_OUTPUT_PREFIX
+ * }
+ * + *

The input file defaults to {@code gs://dataflow-samples/shakespeare/kinglear.txt} and can be + * overridden with {@code --inputFile}. + */ +public class WordCount { + + /** + * Concept #2: You can make your pipeline code less verbose by defining your DoFns statically out- + * of-line. This DoFn tokenizes lines of text into individual words; we pass it to a ParDo in the + * pipeline. + */ + static class ExtractWordsFn extends DoFn { + private final Aggregator emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.addValue(1L); + } + + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + /** A SimpleFunction that converts a Word and Count into a printable string. */ + public static class FormatAsTextFn extends SimpleFunction, String> { + @Override + public String apply(KV input) { + return input.getKey() + ": " + input.getValue(); + } + } + + /** + * A PTransform that converts a PCollection containing lines of text into a PCollection of + * formatted word counts. + * + *

Concept #3: This is a custom composite transform that bundles two transforms (ParDo and + * Count) as a reusable PTransform subclass. Using composite transforms allows for easy reuse, + * modular testing, and an improved monitoring experience. + */ + public static class CountWords extends PTransform, + PCollection>> { + @Override + public PCollection> apply(PCollection lines) { + + // Convert lines of text into individual words. + PCollection words = lines.apply( + ParDo.of(new ExtractWordsFn())); + + // Count the number of times each word occurs. + PCollection> wordCounts = + words.apply(Count.perElement()); + + return wordCounts; + } + } + + /** + * Options supported by {@link WordCount}. + * + *

Concept #4: Defining your own configuration options. Here, you can add your own arguments + * to be processed by the command-line parser, and specify default values for them. You can then + * access the options values in your pipeline code. + * + *

Inherits standard configuration options. + */ + public static interface WordCountOptions extends PipelineOptions { + @Description("Path of the file to read from") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInputFile(); + void setInputFile(String value); + + @Description("Path of the file to write to") + @Default.InstanceFactory(OutputFactory.class) + String getOutput(); + void setOutput(String value); + + /** + * Returns "gs://${YOUR_STAGING_DIRECTORY}/counts.txt" as the default destination. + */ + public static class OutputFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + if (dataflowOptions.getStagingLocation() != null) { + return GcsPath.fromUri(dataflowOptions.getStagingLocation()) + .resolve("counts.txt").toString(); + } else { + throw new IllegalArgumentException("Must specify --output or --stagingLocation"); + } + } + } + + } + + public static void main(String[] args) { + WordCountOptions options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(WordCountOptions.class); + Pipeline p = Pipeline.create(options); + + // Concepts #2 and #3: Our pipeline applies the composite CountWords transform, and passes the + // static FormatAsTextFn() to the ParDo transform. + p.apply(TextIO.Read.named("ReadLines").from(options.getInputFile())) + .apply(new CountWords()) + .apply(MapElements.via(new FormatAsTextFn())) + .apply(TextIO.Write.named("WriteCounts").to(options.getOutput())); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/common/DataflowExampleOptions.java b/examples/src/main/java/com/google/cloud/dataflow/examples/common/DataflowExampleOptions.java new file mode 100644 index 000000000000..606bfb4c03e9 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/common/DataflowExampleOptions.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.common; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; + +/** + * Options that can be used to configure the Dataflow examples. + */ +public interface DataflowExampleOptions extends DataflowPipelineOptions { + @Description("Whether to keep jobs running on the Dataflow service after local process exit") + @Default.Boolean(false) + boolean getKeepJobsRunning(); + void setKeepJobsRunning(boolean keepJobsRunning); + + @Description("Number of workers to use when executing the injector pipeline") + @Default.Integer(1) + int getInjectorNumWorkers(); + void setInjectorNumWorkers(int numWorkers); +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/common/DataflowExampleUtils.java b/examples/src/main/java/com/google/cloud/dataflow/examples/common/DataflowExampleUtils.java new file mode 100644 index 000000000000..4dfdd85b803a --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/common/DataflowExampleUtils.java @@ -0,0 +1,485 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.common; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.googleapis.services.AbstractGoogleClientRequest; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.Sleeper; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.Bigquery.Datasets; +import com.google.api.services.bigquery.Bigquery.Tables; +import com.google.api.services.bigquery.model.Dataset; +import com.google.api.services.bigquery.model.DatasetReference; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.model.Subscription; +import com.google.api.services.pubsub.model.Topic; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineJob; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.IntraBundleParallelization; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.AttemptBoundedExponentialBackOff; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Strings; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import javax.servlet.http.HttpServletResponse; + +/** + * The utility class that sets up and tears down external resources, starts the Google Cloud Pub/Sub + * injector, and cancels the streaming and the injector pipelines once the program terminates. + * + *

It is used to run Dataflow examples, such as TrafficMaxLaneFlow and TrafficRoutes. + */ +public class DataflowExampleUtils { + + private final DataflowPipelineOptions options; + private Bigquery bigQueryClient = null; + private Pubsub pubsubClient = null; + private Dataflow dataflowClient = null; + private Set jobsToCancel = Sets.newHashSet(); + private List pendingMessages = Lists.newArrayList(); + + public DataflowExampleUtils(DataflowPipelineOptions options) { + this.options = options; + } + + /** + * Do resources and runner options setup. + */ + public DataflowExampleUtils(DataflowPipelineOptions options, boolean isUnbounded) + throws IOException { + this.options = options; + setupResourcesAndRunner(isUnbounded); + } + + /** + * Sets up external resources that are required by the example, + * such as Pub/Sub topics and BigQuery tables. + * + * @throws IOException if there is a problem setting up the resources + */ + public void setup() throws IOException { + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backOff = new AttemptBoundedExponentialBackOff(3, 200); + Throwable lastException = null; + try { + do { + try { + setupPubsub(); + setupBigQueryTable(); + return; + } catch (GoogleJsonResponseException e) { + lastException = e; + } + } while (BackOffUtils.next(sleeper, backOff)); + } catch (InterruptedException e) { + // Ignore InterruptedException + } + Throwables.propagate(lastException); + } + + /** + * Set up external resources, and configure the runner appropriately. + */ + public void setupResourcesAndRunner(boolean isUnbounded) throws IOException { + if (isUnbounded) { + options.setStreaming(true); + } + setup(); + setupRunner(); + } + + /** + * Sets up the Google Cloud Pub/Sub topic. + * + *

If the topic doesn't exist, a new topic with the given name will be created. + * + * @throws IOException if there is a problem setting up the Pub/Sub topic + */ + public void setupPubsub() throws IOException { + ExamplePubsubTopicAndSubscriptionOptions pubsubOptions = + options.as(ExamplePubsubTopicAndSubscriptionOptions.class); + if (!pubsubOptions.getPubsubTopic().isEmpty()) { + pendingMessages.add("**********************Set Up Pubsub************************"); + setupPubsubTopic(pubsubOptions.getPubsubTopic()); + pendingMessages.add("The Pub/Sub topic has been set up for this example: " + + pubsubOptions.getPubsubTopic()); + + if (!pubsubOptions.getPubsubSubscription().isEmpty()) { + setupPubsubSubscription( + pubsubOptions.getPubsubTopic(), pubsubOptions.getPubsubSubscription()); + pendingMessages.add("The Pub/Sub subscription has been set up for this example: " + + pubsubOptions.getPubsubSubscription()); + } + } + } + + /** + * Sets up the BigQuery table with the given schema. + * + *

If the table already exists, the schema has to match the given one. Otherwise, the example + * will throw a RuntimeException. If the table doesn't exist, a new table with the given schema + * will be created. + * + * @throws IOException if there is a problem setting up the BigQuery table + */ + public void setupBigQueryTable() throws IOException { + ExampleBigQueryTableOptions bigQueryTableOptions = + options.as(ExampleBigQueryTableOptions.class); + if (bigQueryTableOptions.getBigQueryDataset() != null + && bigQueryTableOptions.getBigQueryTable() != null + && bigQueryTableOptions.getBigQuerySchema() != null) { + pendingMessages.add("******************Set Up Big Query Table*******************"); + setupBigQueryTable(bigQueryTableOptions.getProject(), + bigQueryTableOptions.getBigQueryDataset(), + bigQueryTableOptions.getBigQueryTable(), + bigQueryTableOptions.getBigQuerySchema()); + pendingMessages.add("The BigQuery table has been set up for this example: " + + bigQueryTableOptions.getProject() + + ":" + bigQueryTableOptions.getBigQueryDataset() + + "." + bigQueryTableOptions.getBigQueryTable()); + } + } + + /** + * Tears down external resources that can be deleted upon the example's completion. + */ + private void tearDown() { + pendingMessages.add("*************************Tear Down*************************"); + ExamplePubsubTopicAndSubscriptionOptions pubsubOptions = + options.as(ExamplePubsubTopicAndSubscriptionOptions.class); + if (!pubsubOptions.getPubsubTopic().isEmpty()) { + try { + deletePubsubTopic(pubsubOptions.getPubsubTopic()); + pendingMessages.add("The Pub/Sub topic has been deleted: " + + pubsubOptions.getPubsubTopic()); + } catch (IOException e) { + pendingMessages.add("Failed to delete the Pub/Sub topic : " + + pubsubOptions.getPubsubTopic()); + } + if (!pubsubOptions.getPubsubSubscription().isEmpty()) { + try { + deletePubsubSubscription(pubsubOptions.getPubsubSubscription()); + pendingMessages.add("The Pub/Sub subscription has been deleted: " + + pubsubOptions.getPubsubSubscription()); + } catch (IOException e) { + pendingMessages.add("Failed to delete the Pub/Sub subscription : " + + pubsubOptions.getPubsubSubscription()); + } + } + } + + ExampleBigQueryTableOptions bigQueryTableOptions = + options.as(ExampleBigQueryTableOptions.class); + if (bigQueryTableOptions.getBigQueryDataset() != null + && bigQueryTableOptions.getBigQueryTable() != null + && bigQueryTableOptions.getBigQuerySchema() != null) { + pendingMessages.add("The BigQuery table might contain the example's output, " + + "and it is not deleted automatically: " + + bigQueryTableOptions.getProject() + + ":" + bigQueryTableOptions.getBigQueryDataset() + + "." + bigQueryTableOptions.getBigQueryTable()); + pendingMessages.add("Please go to the Developers Console to delete it manually." + + " Otherwise, you may be charged for its usage."); + } + } + + private void setupBigQueryTable(String projectId, String datasetId, String tableId, + TableSchema schema) throws IOException { + if (bigQueryClient == null) { + bigQueryClient = Transport.newBigQueryClient(options.as(BigQueryOptions.class)).build(); + } + + Datasets datasetService = bigQueryClient.datasets(); + if (executeNullIfNotFound(datasetService.get(projectId, datasetId)) == null) { + Dataset newDataset = new Dataset().setDatasetReference( + new DatasetReference().setProjectId(projectId).setDatasetId(datasetId)); + datasetService.insert(projectId, newDataset).execute(); + } + + Tables tableService = bigQueryClient.tables(); + Table table = executeNullIfNotFound(tableService.get(projectId, datasetId, tableId)); + if (table == null) { + Table newTable = new Table().setSchema(schema).setTableReference( + new TableReference().setProjectId(projectId).setDatasetId(datasetId).setTableId(tableId)); + tableService.insert(projectId, datasetId, newTable).execute(); + } else if (!table.getSchema().equals(schema)) { + throw new RuntimeException( + "Table exists and schemas do not match, expecting: " + schema.toPrettyString() + + ", actual: " + table.getSchema().toPrettyString()); + } + } + + private void setupPubsubTopic(String topic) throws IOException { + if (pubsubClient == null) { + pubsubClient = Transport.newPubsubClient(options).build(); + } + if (executeNullIfNotFound(pubsubClient.projects().topics().get(topic)) == null) { + pubsubClient.projects().topics().create(topic, new Topic().setName(topic)).execute(); + } + } + + private void setupPubsubSubscription(String topic, String subscription) throws IOException { + if (pubsubClient == null) { + pubsubClient = Transport.newPubsubClient(options).build(); + } + if (executeNullIfNotFound(pubsubClient.projects().subscriptions().get(subscription)) == null) { + Subscription subInfo = new Subscription() + .setAckDeadlineSeconds(60) + .setTopic(topic); + pubsubClient.projects().subscriptions().create(subscription, subInfo).execute(); + } + } + + /** + * Deletes the Google Cloud Pub/Sub topic. + * + * @throws IOException if there is a problem deleting the Pub/Sub topic + */ + private void deletePubsubTopic(String topic) throws IOException { + if (pubsubClient == null) { + pubsubClient = Transport.newPubsubClient(options).build(); + } + if (executeNullIfNotFound(pubsubClient.projects().topics().get(topic)) != null) { + pubsubClient.projects().topics().delete(topic).execute(); + } + } + + /** + * Deletes the Google Cloud Pub/Sub subscription. + * + * @throws IOException if there is a problem deleting the Pub/Sub subscription + */ + private void deletePubsubSubscription(String subscription) throws IOException { + if (pubsubClient == null) { + pubsubClient = Transport.newPubsubClient(options).build(); + } + if (executeNullIfNotFound(pubsubClient.projects().subscriptions().get(subscription)) != null) { + pubsubClient.projects().subscriptions().delete(subscription).execute(); + } + } + + /** + * If this is an unbounded (streaming) pipeline, and both inputFile and pubsub topic are defined, + * start an 'injector' pipeline that publishes the contents of the file to the given topic, first + * creating the topic if necessary. + */ + public void startInjectorIfNeeded(String inputFile) { + ExamplePubsubTopicOptions pubsubTopicOptions = options.as(ExamplePubsubTopicOptions.class); + if (pubsubTopicOptions.isStreaming() + && !Strings.isNullOrEmpty(inputFile) + && !Strings.isNullOrEmpty(pubsubTopicOptions.getPubsubTopic())) { + runInjectorPipeline(inputFile, pubsubTopicOptions.getPubsubTopic()); + } + } + + /** + * Do some runner setup: check that the DirectPipelineRunner is not used in conjunction with + * streaming, and if streaming is specified, use the DataflowPipelineRunner. Return the streaming + * flag value. + */ + public void setupRunner() { + if (options.isStreaming() && options.getRunner() != DirectPipelineRunner.class) { + // In order to cancel the pipelines automatically, + // {@literal DataflowPipelineRunner} is forced to be used. + options.setRunner(DataflowPipelineRunner.class); + } + } + + /** + * Runs a batch pipeline to inject data into the PubSubIO input topic. + * + *

The injector pipeline will read from the given text file, and inject data + * into the Google Cloud Pub/Sub topic. + */ + public void runInjectorPipeline(String inputFile, String topic) { + runInjectorPipeline(TextIO.Read.from(inputFile), topic, null); + } + + /** + * Runs a batch pipeline to inject data into the PubSubIO input topic. + * + *

The injector pipeline will read from the given source, and inject data + * into the Google Cloud Pub/Sub topic. + */ + public void runInjectorPipeline(PTransform> readSource, + String topic, + String pubsubTimestampTabelKey) { + PubsubFileInjector.Bound injector; + if (Strings.isNullOrEmpty(pubsubTimestampTabelKey)) { + injector = PubsubFileInjector.publish(topic); + } else { + injector = PubsubFileInjector.withTimestampLabelKey(pubsubTimestampTabelKey).publish(topic); + } + DataflowPipelineOptions copiedOptions = options.cloneAs(DataflowPipelineOptions.class); + if (options.getServiceAccountName() != null) { + copiedOptions.setServiceAccountName(options.getServiceAccountName()); + } + if (options.getServiceAccountKeyfile() != null) { + copiedOptions.setServiceAccountKeyfile(options.getServiceAccountKeyfile()); + } + copiedOptions.setStreaming(false); + copiedOptions.setNumWorkers(options.as(DataflowExampleOptions.class).getInjectorNumWorkers()); + copiedOptions.setJobName(options.getJobName() + "-injector"); + Pipeline injectorPipeline = Pipeline.create(copiedOptions); + injectorPipeline.apply(readSource) + .apply(IntraBundleParallelization + .of(injector) + .withMaxParallelism(20)); + PipelineResult result = injectorPipeline.run(); + if (result instanceof DataflowPipelineJob) { + jobsToCancel.add(((DataflowPipelineJob) result)); + } + } + + /** + * Runs the provided pipeline to inject data into the PubSubIO input topic. + */ + public void runInjectorPipeline(Pipeline injectorPipeline) { + PipelineResult result = injectorPipeline.run(); + if (result instanceof DataflowPipelineJob) { + jobsToCancel.add(((DataflowPipelineJob) result)); + } + } + + /** + * Start the auxiliary injector pipeline, then wait for this pipeline to finish. + */ + public void mockUnboundedSource(String inputFile, PipelineResult result) { + startInjectorIfNeeded(inputFile); + waitToFinish(result); + } + + /** + * If {@literal DataflowPipelineRunner} or {@literal BlockingDataflowPipelineRunner} is used, + * waits for the pipeline to finish and cancels it (and the injector) before the program exists. + */ + public void waitToFinish(PipelineResult result) { + if (result instanceof DataflowPipelineJob) { + final DataflowPipelineJob job = (DataflowPipelineJob) result; + jobsToCancel.add(job); + if (!options.as(DataflowExampleOptions.class).getKeepJobsRunning()) { + addShutdownHook(jobsToCancel); + } + try { + job.waitToFinish(-1, TimeUnit.SECONDS, new MonitoringUtil.PrintHandler(System.out)); + } catch (Exception e) { + throw new RuntimeException("Failed to wait for job to finish: " + job.getJobId()); + } + } else { + // Do nothing if the given PipelineResult doesn't support waitToFinish(), + // such as EvaluationResults returned by DirectPipelineRunner. + tearDown(); + printPendingMessages(); + } + } + + private void addShutdownHook(final Collection jobs) { + if (dataflowClient == null) { + dataflowClient = options.getDataflowClient(); + } + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + tearDown(); + printPendingMessages(); + for (DataflowPipelineJob job : jobs) { + System.out.println("Canceling example pipeline: " + job.getJobId()); + try { + job.cancel(); + } catch (IOException e) { + System.out.println("Failed to cancel the job," + + " please go to the Developers Console to cancel it manually"); + System.out.println( + MonitoringUtil.getJobMonitoringPageURL(job.getProjectId(), job.getJobId())); + } + } + + for (DataflowPipelineJob job : jobs) { + boolean cancellationVerified = false; + for (int retryAttempts = 6; retryAttempts > 0; retryAttempts--) { + if (job.getState().isTerminal()) { + cancellationVerified = true; + System.out.println("Canceled example pipeline: " + job.getJobId()); + break; + } else { + System.out.println( + "The example pipeline is still running. Verifying the cancellation."); + } + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + // Ignore + } + } + if (!cancellationVerified) { + System.out.println("Failed to verify the cancellation for job: " + job.getJobId()); + System.out.println("Please go to the Developers Console to verify manually:"); + System.out.println( + MonitoringUtil.getJobMonitoringPageURL(job.getProjectId(), job.getJobId())); + } + } + } + }); + } + + private void printPendingMessages() { + System.out.println(); + System.out.println("***********************************************************"); + System.out.println("***********************************************************"); + for (String message : pendingMessages) { + System.out.println(message); + } + System.out.println("***********************************************************"); + System.out.println("***********************************************************"); + } + + private static T executeNullIfNotFound( + AbstractGoogleClientRequest request) throws IOException { + try { + return request.execute(); + } catch (GoogleJsonResponseException e) { + if (e.getStatusCode() == HttpServletResponse.SC_NOT_FOUND) { + return null; + } else { + throw e; + } + } + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExampleBigQueryTableOptions.java b/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExampleBigQueryTableOptions.java new file mode 100644 index 000000000000..7c213b59d681 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExampleBigQueryTableOptions.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.common; + +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Options that can be used to configure BigQuery tables in Dataflow examples. + * The project defaults to the project being used to run the example. + */ +public interface ExampleBigQueryTableOptions extends DataflowPipelineOptions { + @Description("BigQuery dataset name") + @Default.String("dataflow_examples") + String getBigQueryDataset(); + void setBigQueryDataset(String dataset); + + @Description("BigQuery table name") + @Default.InstanceFactory(BigQueryTableFactory.class) + String getBigQueryTable(); + void setBigQueryTable(String table); + + @Description("BigQuery table schema") + TableSchema getBigQuerySchema(); + void setBigQuerySchema(TableSchema schema); + + /** + * Returns the job name as the default BigQuery table name. + */ + static class BigQueryTableFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + return options.as(DataflowPipelineOptions.class).getJobName() + .replace('-', '_'); + } + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExamplePubsubTopicAndSubscriptionOptions.java b/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExamplePubsubTopicAndSubscriptionOptions.java new file mode 100644 index 000000000000..d7bd4b8edc3d --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExamplePubsubTopicAndSubscriptionOptions.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.common; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Options that can be used to configure Pub/Sub topic/subscription in Dataflow examples. + */ +public interface ExamplePubsubTopicAndSubscriptionOptions extends ExamplePubsubTopicOptions { + @Description("Pub/Sub subscription") + @Default.InstanceFactory(PubsubSubscriptionFactory.class) + String getPubsubSubscription(); + void setPubsubSubscription(String subscription); + + /** + * Returns a default Pub/Sub subscription based on the project and the job names. + */ + static class PubsubSubscriptionFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowPipelineOptions = + options.as(DataflowPipelineOptions.class); + return "projects/" + dataflowPipelineOptions.getProject() + + "/subscriptions/" + dataflowPipelineOptions.getJobName(); + } + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExamplePubsubTopicOptions.java b/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExamplePubsubTopicOptions.java new file mode 100644 index 000000000000..4bedf318ef5a --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/common/ExamplePubsubTopicOptions.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.common; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Options that can be used to configure Pub/Sub topic in Dataflow examples. + */ +public interface ExamplePubsubTopicOptions extends DataflowPipelineOptions { + @Description("Pub/Sub topic") + @Default.InstanceFactory(PubsubTopicFactory.class) + String getPubsubTopic(); + void setPubsubTopic(String topic); + + /** + * Returns a default Pub/Sub topic based on the project and the job names. + */ + static class PubsubTopicFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowPipelineOptions = + options.as(DataflowPipelineOptions.class); + return "projects/" + dataflowPipelineOptions.getProject() + + "/topics/" + dataflowPipelineOptions.getJobName(); + } + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/common/PubsubFileInjector.java b/examples/src/main/java/com/google/cloud/dataflow/examples/common/PubsubFileInjector.java new file mode 100644 index 000000000000..4a82ae612ae7 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/common/PubsubFileInjector.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.common; + +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.model.PublishRequest; +import com.google.api.services.pubsub.model.PubsubMessage; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.IntraBundleParallelization; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.common.collect.ImmutableMap; + +import java.io.IOException; +import java.util.Arrays; + +/** + * A batch Dataflow pipeline for injecting a set of GCS files into + * a PubSub topic line by line. Empty lines are skipped. + * + *

This is useful for testing streaming + * pipelines. Note that since batch pipelines might retry chunks, this + * does _not_ guarantee exactly-once injection of file data. Some lines may + * be published multiple times. + *

+ */ +public class PubsubFileInjector { + + /** + * An incomplete {@code PubsubFileInjector} transform with unbound output topic. + */ + public static class Unbound { + private final String timestampLabelKey; + + Unbound() { + this.timestampLabelKey = null; + } + + Unbound(String timestampLabelKey) { + this.timestampLabelKey = timestampLabelKey; + } + + Unbound withTimestampLabelKey(String timestampLabelKey) { + return new Unbound(timestampLabelKey); + } + + public Bound publish(String outputTopic) { + return new Bound(outputTopic, timestampLabelKey); + } + } + + /** A DoFn that publishes non-empty lines to Google Cloud PubSub. */ + public static class Bound extends DoFn { + private final String outputTopic; + private final String timestampLabelKey; + public transient Pubsub pubsub; + + public Bound(String outputTopic, String timestampLabelKey) { + this.outputTopic = outputTopic; + this.timestampLabelKey = timestampLabelKey; + } + + @Override + public void startBundle(Context context) { + this.pubsub = + Transport.newPubsubClient(context.getPipelineOptions().as(DataflowPipelineOptions.class)) + .build(); + } + + @Override + public void processElement(ProcessContext c) throws IOException { + if (c.element().isEmpty()) { + return; + } + PubsubMessage pubsubMessage = new PubsubMessage(); + pubsubMessage.encodeData(c.element().getBytes()); + if (timestampLabelKey != null) { + pubsubMessage.setAttributes( + ImmutableMap.of(timestampLabelKey, Long.toString(c.timestamp().getMillis()))); + } + PublishRequest publishRequest = new PublishRequest(); + publishRequest.setMessages(Arrays.asList(pubsubMessage)); + this.pubsub.projects().topics().publish(outputTopic, publishRequest).execute(); + } + } + + /** + * Creates a {@code PubsubFileInjector} transform with the given timestamp label key. + */ + public static Unbound withTimestampLabelKey(String timestampLabelKey) { + return new Unbound(timestampLabelKey); + } + + /** + * Creates a {@code PubsubFileInjector} transform that publishes to the given output topic. + */ + public static Bound publish(String outputTopic) { + return new Unbound().publish(outputTopic); + } + + /** + * Command line parameter options. + */ + private interface PubsubFileInjectorOptions extends PipelineOptions { + @Description("GCS location of files.") + @Validation.Required + String getInput(); + void setInput(String value); + + @Description("Topic to publish on.") + @Validation.Required + String getOutputTopic(); + void setOutputTopic(String value); + } + + /** + * Sets up and starts streaming pipeline. + */ + public static void main(String[] args) { + PubsubFileInjectorOptions options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(PubsubFileInjectorOptions.class); + + Pipeline pipeline = Pipeline.create(options); + + pipeline + .apply(TextIO.Read.from(options.getInput())) + .apply(IntraBundleParallelization.of(PubsubFileInjector.publish(options.getOutputTopic())) + .withMaxParallelism(20)); + + pipeline.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/AutoComplete.java b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/AutoComplete.java new file mode 100644 index 000000000000..1bccc4ace278 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/AutoComplete.java @@ -0,0 +1,510 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.Key; +import com.google.api.services.datastore.DatastoreV1.Value; +import com.google.api.services.datastore.client.DatastoreHelper; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.common.ExampleBigQueryTableOptions; +import com.google.cloud.dataflow.examples.common.ExamplePubsubTopicOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.DatastoreIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Partition; +import com.google.cloud.dataflow.sdk.transforms.Partition.PartitionFn; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Top; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.common.base.Preconditions; + +import org.joda.time.Duration; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * An example that computes the most popular hash tags + * for every prefix, which can be used for auto-completion. + * + *

Concepts: Using the same pipeline in both streaming and batch, combiners, + * composite transforms. + * + *

To execute this pipeline using the Dataflow service in batch mode, + * specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=DataflowPipelineRunner
+ *   --inputFile=gs://path/to/input*.txt
+ * }
+ * + *

To execute this pipeline using the Dataflow service in streaming mode, + * specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=DataflowPipelineRunner
+ *   --inputFile=gs://YOUR_INPUT_DIRECTORY/*.txt
+ *   --streaming
+ * }
+ * + *

This will update the datastore every 10 seconds based on the last + * 30 minutes of data received. + */ +public class AutoComplete { + + /** + * A PTransform that takes as input a list of tokens and returns + * the most common tokens per prefix. + */ + public static class ComputeTopCompletions + extends PTransform, PCollection>>> { + private final int candidatesPerPrefix; + private final boolean recursive; + + protected ComputeTopCompletions(int candidatesPerPrefix, boolean recursive) { + this.candidatesPerPrefix = candidatesPerPrefix; + this.recursive = recursive; + } + + public static ComputeTopCompletions top(int candidatesPerPrefix, boolean recursive) { + return new ComputeTopCompletions(candidatesPerPrefix, recursive); + } + + @Override + public PCollection>> apply(PCollection input) { + PCollection candidates = input + // First count how often each token appears. + .apply(new Count.PerElement()) + + // Map the KV outputs of Count into our own CompletionCandiate class. + .apply(ParDo.named("CreateCompletionCandidates").of( + new DoFn, CompletionCandidate>() { + @Override + public void processElement(ProcessContext c) { + c.output(new CompletionCandidate(c.element().getKey(), c.element().getValue())); + } + })); + + // Compute the top via either a flat or recursive algorithm. + if (recursive) { + return candidates + .apply(new ComputeTopRecursive(candidatesPerPrefix, 1)) + .apply(Flatten.>>pCollections()); + } else { + return candidates + .apply(new ComputeTopFlat(candidatesPerPrefix, 1)); + } + } + } + + /** + * Lower latency, but more expensive. + */ + private static class ComputeTopFlat + extends PTransform, + PCollection>>> { + private final int candidatesPerPrefix; + private final int minPrefix; + + public ComputeTopFlat(int candidatesPerPrefix, int minPrefix) { + this.candidatesPerPrefix = candidatesPerPrefix; + this.minPrefix = minPrefix; + } + + @Override + public PCollection>> apply( + PCollection input) { + return input + // For each completion candidate, map it to all prefixes. + .apply(ParDo.of(new AllPrefixes(minPrefix))) + + // Find and return the top candiates for each prefix. + .apply(Top.largestPerKey(candidatesPerPrefix) + .withHotKeyFanout(new HotKeyFanout())); + } + + private static class HotKeyFanout implements SerializableFunction { + @Override + public Integer apply(String input) { + return (int) Math.pow(4, 5 - input.length()); + } + } + } + + /** + * Cheaper but higher latency. + * + *

Returns two PCollections, the first is top prefixes of size greater + * than minPrefix, and the second is top prefixes of size exactly + * minPrefix. + */ + private static class ComputeTopRecursive + extends PTransform, + PCollectionList>>> { + private final int candidatesPerPrefix; + private final int minPrefix; + + public ComputeTopRecursive(int candidatesPerPrefix, int minPrefix) { + this.candidatesPerPrefix = candidatesPerPrefix; + this.minPrefix = minPrefix; + } + + private class KeySizePartitionFn implements PartitionFn>> { + @Override + public int partitionFor(KV> elem, int numPartitions) { + return elem.getKey().length() > minPrefix ? 0 : 1; + } + } + + private static class FlattenTops + extends DoFn>, CompletionCandidate> { + @Override + public void processElement(ProcessContext c) { + for (CompletionCandidate cc : c.element().getValue()) { + c.output(cc); + } + } + } + + @Override + public PCollectionList>> apply( + PCollection input) { + if (minPrefix > 10) { + // Base case, partitioning to return the output in the expected format. + return input + .apply(new ComputeTopFlat(candidatesPerPrefix, minPrefix)) + .apply(Partition.of(2, new KeySizePartitionFn())); + } else { + // If a candidate is in the top N for prefix a...b, it must also be in the top + // N for a...bX for every X, which is typlically a much smaller set to consider. + // First, compute the top candidate for prefixes of size at least minPrefix + 1. + PCollectionList>> larger = input + .apply(new ComputeTopRecursive(candidatesPerPrefix, minPrefix + 1)); + // Consider the top candidates for each prefix of length minPrefix + 1... + PCollection>> small = + PCollectionList + .of(larger.get(1).apply(ParDo.of(new FlattenTops()))) + // ...together with those (previously excluded) candidates of length + // exactly minPrefix... + .and(input.apply(Filter.byPredicate( + new SerializableFunction() { + @Override + public Boolean apply(CompletionCandidate c) { + return c.getValue().length() == minPrefix; + } + }))) + .apply("FlattenSmall", Flatten.pCollections()) + // ...set the key to be the minPrefix-length prefix... + .apply(ParDo.of(new AllPrefixes(minPrefix, minPrefix))) + // ...and (re)apply the Top operator to all of them together. + .apply(Top.largestPerKey(candidatesPerPrefix)); + + PCollection>> flattenLarger = larger + .apply("FlattenLarge", Flatten.>>pCollections()); + + return PCollectionList.of(flattenLarger).and(small); + } + } + } + + /** + * A DoFn that keys each candidate by all its prefixes. + */ + private static class AllPrefixes + extends DoFn> { + private final int minPrefix; + private final int maxPrefix; + public AllPrefixes(int minPrefix) { + this(minPrefix, Integer.MAX_VALUE); + } + public AllPrefixes(int minPrefix, int maxPrefix) { + this.minPrefix = minPrefix; + this.maxPrefix = maxPrefix; + } + @Override + public void processElement(ProcessContext c) { + String word = c.element().value; + for (int i = minPrefix; i <= Math.min(word.length(), maxPrefix); i++) { + c.output(KV.of(word.substring(0, i), c.element())); + } + } + } + + /** + * Class used to store tag-count pairs. + */ + @DefaultCoder(AvroCoder.class) + static class CompletionCandidate implements Comparable { + private long count; + private String value; + + public CompletionCandidate(String value, long count) { + this.value = value; + this.count = count; + } + + public long getCount() { + return count; + } + + public String getValue() { + return value; + } + + // Empty constructor required for Avro decoding. + public CompletionCandidate() {} + + @Override + public int compareTo(CompletionCandidate o) { + if (this.count < o.count) { + return -1; + } else if (this.count == o.count) { + return this.value.compareTo(o.value); + } else { + return 1; + } + } + + @Override + public boolean equals(Object other) { + if (other instanceof CompletionCandidate) { + CompletionCandidate that = (CompletionCandidate) other; + return this.count == that.count && this.value.equals(that.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Long.valueOf(count).hashCode() ^ value.hashCode(); + } + + @Override + public String toString() { + return "CompletionCandidate[" + value + ", " + count + "]"; + } + } + + /** + * Takes as input a set of strings, and emits each #hashtag found therein. + */ + static class ExtractHashtags extends DoFn { + @Override + public void processElement(ProcessContext c) { + Matcher m = Pattern.compile("#\\S+").matcher(c.element()); + while (m.find()) { + c.output(m.group().substring(1)); + } + } + } + + static class FormatForBigquery extends DoFn>, TableRow> { + @Override + public void processElement(ProcessContext c) { + List completions = new ArrayList<>(); + for (CompletionCandidate cc : c.element().getValue()) { + completions.add(new TableRow() + .set("count", cc.getCount()) + .set("tag", cc.getValue())); + } + TableRow row = new TableRow() + .set("prefix", c.element().getKey()) + .set("tags", completions); + c.output(row); + } + + /** + * Defines the BigQuery schema used for the output. + */ + static TableSchema getSchema() { + List tagFields = new ArrayList<>(); + tagFields.add(new TableFieldSchema().setName("count").setType("INTEGER")); + tagFields.add(new TableFieldSchema().setName("tag").setType("STRING")); + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("prefix").setType("STRING")); + fields.add(new TableFieldSchema() + .setName("tags").setType("RECORD").setMode("REPEATED").setFields(tagFields)); + return new TableSchema().setFields(fields); + } + } + + /** + * Takes as input a the top candidates per prefix, and emits an entity + * suitable for writing to Datastore. + */ + static class FormatForDatastore extends DoFn>, Entity> { + private String kind; + + public FormatForDatastore(String kind) { + this.kind = kind; + } + + @Override + public void processElement(ProcessContext c) { + Entity.Builder entityBuilder = Entity.newBuilder(); + Key key = DatastoreHelper.makeKey(kind, c.element().getKey()).build(); + + entityBuilder.setKey(key); + List candidates = new ArrayList<>(); + for (CompletionCandidate tag : c.element().getValue()) { + Entity.Builder tagEntity = Entity.newBuilder(); + tagEntity.addProperty( + DatastoreHelper.makeProperty("tag", DatastoreHelper.makeValue(tag.value))); + tagEntity.addProperty( + DatastoreHelper.makeProperty("count", DatastoreHelper.makeValue(tag.count))); + candidates.add(DatastoreHelper.makeValue(tagEntity).setIndexed(false).build()); + } + entityBuilder.addProperty( + DatastoreHelper.makeProperty("candidates", DatastoreHelper.makeValue(candidates))); + c.output(entityBuilder.build()); + } + } + + /** + * Options supported by this class. + * + *

Inherits standard Dataflow configuration options. + */ + private static interface Options extends ExamplePubsubTopicOptions, ExampleBigQueryTableOptions { + @Description("Input text file") + String getInputFile(); + void setInputFile(String value); + + @Description("Whether to use the recursive algorithm") + @Default.Boolean(true) + Boolean getRecursive(); + void setRecursive(Boolean value); + + @Description("Dataset entity kind") + @Default.String("autocomplete-demo") + String getKind(); + void setKind(String value); + + @Description("Whether output to BigQuery") + @Default.Boolean(true) + Boolean getOutputToBigQuery(); + void setOutputToBigQuery(Boolean value); + + @Description("Whether output to Datastoree") + @Default.Boolean(false) + Boolean getOutputToDatastore(); + void setOutputToDatastore(Boolean value); + } + + public static void main(String[] args) throws IOException { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + + if (options.isStreaming()) { + // In order to cancel the pipelines automatically, + // {@literal DataflowPipelineRunner} is forced to be used. + options.setRunner(DataflowPipelineRunner.class); + } + + options.setBigQuerySchema(FormatForBigquery.getSchema()); + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options); + + // We support running the same pipeline in either + // batch or windowed streaming mode. + PTransform> readSource; + WindowFn windowFn; + if (options.isStreaming()) { + Preconditions.checkArgument( + !options.getOutputToDatastore(), "DatastoreIO is not supported in streaming."); + dataflowUtils.setupPubsub(); + + readSource = PubsubIO.Read.topic(options.getPubsubTopic()); + windowFn = SlidingWindows.of(Duration.standardMinutes(30)).every(Duration.standardSeconds(5)); + } else { + readSource = TextIO.Read.from(options.getInputFile()); + windowFn = new GlobalWindows(); + } + + // Create the pipeline. + Pipeline p = Pipeline.create(options); + PCollection>> toWrite = p + .apply(readSource) + .apply(ParDo.of(new ExtractHashtags())) + .apply(Window.into(windowFn)) + .apply(ComputeTopCompletions.top(10, options.getRecursive())); + + if (options.getOutputToDatastore()) { + toWrite + .apply(ParDo.named("FormatForDatastore").of(new FormatForDatastore(options.getKind()))) + .apply(DatastoreIO.writeTo(options.getProject())); + } + if (options.getOutputToBigQuery()) { + dataflowUtils.setupBigQueryTable(); + + TableReference tableRef = new TableReference(); + tableRef.setProjectId(options.getProject()); + tableRef.setDatasetId(options.getBigQueryDataset()); + tableRef.setTableId(options.getBigQueryTable()); + + toWrite + .apply(ParDo.of(new FormatForBigquery())) + .apply(BigQueryIO.Write + .to(tableRef) + .withSchema(FormatForBigquery.getSchema()) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); + } + + // Run the pipeline. + PipelineResult result = p.run(); + + if (options.isStreaming() && !options.getInputFile().isEmpty()) { + // Inject the data into the Pub/Sub topic with a Dataflow batch pipeline. + dataflowUtils.runInjectorPipeline(options.getInputFile(), options.getPubsubTopic()); + } + + // dataflowUtils will try to cancel the pipeline and the injector before the program exists. + dataflowUtils.waitToFinish(result); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/README.md b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/README.md new file mode 100644 index 000000000000..5fba15494e9b --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/README.md @@ -0,0 +1,44 @@ + +# "Complete" Examples + +This directory contains end-to-end example pipelines that perform complex data processing tasks. They include: + +

    +
  • AutoComplete + — An example that computes the most popular hash tags for every + prefix, which can be used for auto-completion. Demonstrates how to use the + same pipeline in both streaming and batch, combiners, and composite + transforms.
  • +
  • StreamingWordExtract + — A streaming pipeline example that inputs lines of text from a Cloud + Pub/Sub topic, splits each line into individual words, capitalizes those + words, and writes the output to a BigQuery table. +
  • +
  • TfIdf + — An example that computes a basic TF-IDF search table for a directory or + Cloud Storage prefix. Demonstrates joining data, side inputs, and logging. +
  • +
  • TopWikipediaSessions + — An example that reads Wikipedia edit data from Cloud Storage and + computes the user with the longest string of edits separated by no more than + an hour within each month. Demonstrates using Cloud Dataflow + Windowing to perform time-based aggregations of data. +
  • +
  • TrafficMaxLaneFlow + — A streaming Cloud Dataflow example using BigQuery output in the + traffic sensor domain. Demonstrates the Cloud Dataflow streaming + runner, sliding windows, Cloud Pub/Sub topic ingestion, the use of the + AvroCoder to encode a custom class, and custom + Combine transforms. +
  • +
  • TrafficRoutes + — A streaming Cloud Dataflow example using BigQuery output in the + traffic sensor domain. Demonstrates the Cloud Dataflow streaming + runner, GroupByKey, keyed state, sliding windows, and Cloud + Pub/Sub topic ingestion. +
  • +
+ +See the [documentation](https://cloud.google.com/dataflow/getting-started) and the [Examples +README](../../../../../../../../../README.md) for +information about how to run these examples. diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/StreamingWordExtract.java b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/StreamingWordExtract.java new file mode 100644 index 000000000000..99c524936245 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/StreamingWordExtract.java @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.common.ExampleBigQueryTableOptions; +import com.google.cloud.dataflow.examples.common.ExamplePubsubTopicOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +import java.io.IOException; +import java.util.ArrayList; + +/** + * A streaming Dataflow Example using BigQuery output. + * + *

This pipeline example reads lines of text from a PubSub topic, splits each line + * into individual words, capitalizes those words, and writes the output to + * a BigQuery table. + * + *

By default, the example will run a separate pipeline to inject the data from the default + * {@literal --inputFile} to the Pub/Sub {@literal --pubsubTopic}. It will make it available for + * the streaming pipeline to process. You may override the default {@literal --inputFile} with the + * file of your choosing. You may also set {@literal --inputFile} to an empty string, which will + * disable the automatic Pub/Sub injection, and allow you to use separate tool to control the input + * to this example. + * + *

The example is configured to use the default Pub/Sub topic and the default BigQuery table + * from the example common package (there are no defaults for a general Dataflow pipeline). + * You can override them by using the {@literal --pubsubTopic}, {@literal --bigQueryDataset}, and + * {@literal --bigQueryTable} options. If the Pub/Sub topic or the BigQuery table do not exist, + * the example will try to create them. + * + *

The example will try to cancel the pipelines on the signal to terminate the process (CTRL-C) + * and then exits. + */ +public class StreamingWordExtract { + + /** A DoFn that tokenizes lines of text into individual words. */ + static class ExtractWords extends DoFn { + @Override + public void processElement(ProcessContext c) { + String[] words = c.element().split("[^a-zA-Z']+"); + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + /** A DoFn that uppercases a word. */ + static class Uppercase extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().toUpperCase()); + } + } + + /** + * Converts strings into BigQuery rows. + */ + static class StringToRowConverter extends DoFn { + /** + * In this example, put the whole string into single BigQuery field. + */ + @Override + public void processElement(ProcessContext c) { + c.output(new TableRow().set("string_field", c.element())); + } + + static TableSchema getSchema() { + return new TableSchema().setFields(new ArrayList() { + // Compose the list of TableFieldSchema from tableSchema. + { + add(new TableFieldSchema().setName("string_field").setType("STRING")); + } + }); + } + } + + /** + * Options supported by {@link StreamingWordExtract}. + * + *

Inherits standard configuration options. + */ + private interface StreamingWordExtractOptions + extends ExamplePubsubTopicOptions, ExampleBigQueryTableOptions { + @Description("Input file to inject to Pub/Sub topic") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInputFile(); + void setInputFile(String value); + } + + /** + * Sets up and starts streaming pipeline. + * + * @throws IOException if there is a problem setting up resources + */ + public static void main(String[] args) throws IOException { + StreamingWordExtractOptions options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(StreamingWordExtractOptions.class); + options.setStreaming(true); + // In order to cancel the pipelines automatically, + // {@literal DataflowPipelineRunner} is forced to be used. + options.setRunner(DataflowPipelineRunner.class); + + options.setBigQuerySchema(StringToRowConverter.getSchema()); + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options); + dataflowUtils.setup(); + + Pipeline pipeline = Pipeline.create(options); + + String tableSpec = new StringBuilder() + .append(options.getProject()).append(":") + .append(options.getBigQueryDataset()).append(".") + .append(options.getBigQueryTable()) + .toString(); + pipeline + .apply(PubsubIO.Read.topic(options.getPubsubTopic())) + .apply(ParDo.of(new ExtractWords())) + .apply(ParDo.of(new Uppercase())) + .apply(ParDo.of(new StringToRowConverter())) + .apply(BigQueryIO.Write.to(tableSpec) + .withSchema(StringToRowConverter.getSchema())); + + PipelineResult result = pipeline.run(); + + if (!options.getInputFile().isEmpty()) { + // Inject the data into the Pub/Sub topic with a Dataflow batch pipeline. + dataflowUtils.runInjectorPipeline(options.getInputFile(), options.getPubsubTopic()); + } + + // dataflowUtils will try to cancel the pipeline and the injector before the program exists. + dataflowUtils.waitToFinish(result); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TfIdf.java b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TfIdf.java new file mode 100644 index 000000000000..65ac7539876f --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TfIdf.java @@ -0,0 +1,431 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashSet; +import java.util.Set; + +/** + * An example that computes a basic TF-IDF search table for a directory or GCS prefix. + * + *

Concepts: joining data; side inputs; logging + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * and a local output file or output prefix on GCS: + *
{@code
+ *   --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PREFIX]
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * and an output prefix on GCS:
+ *   --output=gs://YOUR_OUTPUT_PREFIX
+ * }
+ * + *

The default input is {@code gs://dataflow-samples/shakespeare/} and can be overridden with + * {@code --input}. + */ +public class TfIdf { + /** + * Options supported by {@link TfIdf}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Path to the directory or GCS prefix containing files to read from") + @Default.String("gs://dataflow-samples/shakespeare/") + String getInput(); + void setInput(String value); + + @Description("Prefix of output URI to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + /** + * Lists documents contained beneath the {@code options.input} prefix/directory. + */ + public static Set listInputDocuments(Options options) + throws URISyntaxException, IOException { + URI baseUri = new URI(options.getInput()); + + // List all documents in the directory or GCS prefix. + URI absoluteUri; + if (baseUri.getScheme() != null) { + absoluteUri = baseUri; + } else { + absoluteUri = new URI( + "file", + baseUri.getAuthority(), + baseUri.getPath(), + baseUri.getQuery(), + baseUri.getFragment()); + } + + Set uris = new HashSet<>(); + if (absoluteUri.getScheme().equals("file")) { + File directory = new File(absoluteUri); + for (String entry : directory.list()) { + File path = new File(directory, entry); + uris.add(path.toURI()); + } + } else if (absoluteUri.getScheme().equals("gs")) { + GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); + URI gcsUriGlob = new URI( + absoluteUri.getScheme(), + absoluteUri.getAuthority(), + absoluteUri.getPath() + "*", + absoluteUri.getQuery(), + absoluteUri.getFragment()); + for (GcsPath entry : gcsUtil.expand(GcsPath.fromUri(gcsUriGlob))) { + uris.add(entry.toUri()); + } + } + + return uris; + } + + /** + * Reads the documents at the provided uris and returns all lines + * from the documents tagged with which document they are from. + */ + public static class ReadDocuments + extends PTransform>> { + private Iterable uris; + + public ReadDocuments(Iterable uris) { + this.uris = uris; + } + + @Override + public Coder getDefaultOutputCoder() { + return KvCoder.of(StringDelegateCoder.of(URI.class), StringUtf8Coder.of()); + } + + @Override + public PCollection> apply(PInput input) { + Pipeline pipeline = input.getPipeline(); + + // Create one TextIO.Read transform for each document + // and add its output to a PCollectionList + PCollectionList> urisToLines = + PCollectionList.empty(pipeline); + + // TextIO.Read supports: + // - file: URIs and paths locally + // - gs: URIs on the service + for (final URI uri : uris) { + String uriString; + if (uri.getScheme().equals("file")) { + uriString = new File(uri).getPath(); + } else { + uriString = uri.toString(); + } + + PCollection> oneUriToLines = pipeline + .apply(TextIO.Read.from(uriString) + .named("TextIO.Read(" + uriString + ")")) + .apply("WithKeys(" + uriString + ")", WithKeys.of(uri)); + + urisToLines = urisToLines.and(oneUriToLines); + } + + return urisToLines.apply(Flatten.>pCollections()); + } + } + + /** + * A transform containing a basic TF-IDF pipeline. The input consists of KV objects + * where the key is the document's URI and the value is a piece + * of the document's content. The output is mapping from terms to + * scores for each document URI. + */ + public static class ComputeTfIdf + extends PTransform>, PCollection>>> { + public ComputeTfIdf() { } + + @Override + public PCollection>> apply( + PCollection> uriToContent) { + + // Compute the total number of documents, and + // prepare this singleton PCollectionView for + // use as a side input. + final PCollectionView totalDocuments = + uriToContent + .apply("GetURIs", Keys.create()) + .apply("RemoveDuplicateDocs", RemoveDuplicates.create()) + .apply(Count.globally()) + .apply(View.asSingleton()); + + // Create a collection of pairs mapping a URI to each + // of the words in the document associated with that that URI. + PCollection> uriToWords = uriToContent + .apply(ParDo.named("SplitWords").of( + new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey(); + String line = c.element().getValue(); + for (String word : line.split("\\W+")) { + // Log INFO messages when the word “love” is found. + if (word.toLowerCase().equals("love")) { + LOG.info("Found {}", word.toLowerCase()); + } + + if (!word.isEmpty()) { + c.output(KV.of(uri, word.toLowerCase())); + } + } + } + })); + + // Compute a mapping from each word to the total + // number of documents in which it appears. + PCollection> wordToDocCount = uriToWords + .apply("RemoveDuplicateWords", RemoveDuplicates.>create()) + .apply(Values.create()) + .apply("CountDocs", Count.perElement()); + + // Compute a mapping from each URI to the total + // number of words in the document associated with that URI. + PCollection> uriToWordTotal = uriToWords + .apply("GetURIs2", Keys.create()) + .apply("CountWords", Count.perElement()); + + // Count, for each (URI, word) pair, the number of + // occurrences of that word in the document associated + // with the URI. + PCollection, Long>> uriAndWordToCount = uriToWords + .apply("CountWordDocPairs", Count.>perElement()); + + // Adjust the above collection to a mapping from + // (URI, word) pairs to counts into an isomorphic mapping + // from URI to (word, count) pairs, to prepare for a join + // by the URI key. + PCollection>> uriToWordAndCount = uriAndWordToCount + .apply(ParDo.named("ShiftKeys").of( + new DoFn, Long>, KV>>() { + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey().getKey(); + String word = c.element().getKey().getValue(); + Long occurrences = c.element().getValue(); + c.output(KV.of(uri, KV.of(word, occurrences))); + } + })); + + // Prepare to join the mapping of URI to (word, count) pairs with + // the mapping of URI to total word counts, by associating + // each of the input PCollection> with + // a tuple tag. Each input must have the same key type, URI + // in this case. The type parameter of the tuple tag matches + // the types of the values for each collection. + final TupleTag wordTotalsTag = new TupleTag(); + final TupleTag> wordCountsTag = new TupleTag>(); + KeyedPCollectionTuple coGbkInput = KeyedPCollectionTuple + .of(wordTotalsTag, uriToWordTotal) + .and(wordCountsTag, uriToWordAndCount); + + // Perform a CoGroupByKey (a sort of pre-join) on the prepared + // inputs. This yields a mapping from URI to a CoGbkResult + // (CoGroupByKey Result). The CoGbkResult is a mapping + // from the above tuple tags to the values in each input + // associated with a particular URI. In this case, each + // KV group a URI with the total number of + // words in that document as well as all the (word, count) + // pairs for particular words. + PCollection> uriToWordAndCountAndTotal = coGbkInput + .apply("CoGroupByUri", CoGroupByKey.create()); + + // Compute a mapping from each word to a (URI, term frequency) + // pair for each URI. A word's term frequency for a document + // is simply the number of times that word occurs in the document + // divided by the total number of words in the document. + PCollection>> wordToUriAndTf = uriToWordAndCountAndTotal + .apply(ParDo.named("ComputeTermFrequencies").of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey(); + Long wordTotal = c.element().getValue().getOnly(wordTotalsTag); + + for (KV wordAndCount + : c.element().getValue().getAll(wordCountsTag)) { + String word = wordAndCount.getKey(); + Long wordCount = wordAndCount.getValue(); + Double termFrequency = wordCount.doubleValue() / wordTotal.doubleValue(); + c.output(KV.of(word, KV.of(uri, termFrequency))); + } + } + })); + + // Compute a mapping from each word to its document frequency. + // A word's document frequency in a corpus is the number of + // documents in which the word appears divided by the total + // number of documents in the corpus. Note how the total number of + // documents is passed as a side input; the same value is + // presented to each invocation of the DoFn. + PCollection> wordToDf = wordToDocCount + .apply(ParDo + .named("ComputeDocFrequencies") + .withSideInputs(totalDocuments) + .of(new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + String word = c.element().getKey(); + Long documentCount = c.element().getValue(); + Long documentTotal = c.sideInput(totalDocuments); + Double documentFrequency = documentCount.doubleValue() + / documentTotal.doubleValue(); + + c.output(KV.of(word, documentFrequency)); + } + })); + + // Join the term frequency and document frequency + // collections, each keyed on the word. + final TupleTag> tfTag = new TupleTag>(); + final TupleTag dfTag = new TupleTag(); + PCollection> wordToUriAndTfAndDf = KeyedPCollectionTuple + .of(tfTag, wordToUriAndTf) + .and(dfTag, wordToDf) + .apply(CoGroupByKey.create()); + + // Compute a mapping from each word to a (URI, TF-IDF) score + // for each URI. There are a variety of definitions of TF-IDF + // ("term frequency - inverse document frequency") score; + // here we use a basic version that is the term frequency + // divided by the log of the document frequency. + PCollection>> wordToUriAndTfIdf = wordToUriAndTfAndDf + .apply(ParDo.named("ComputeTfIdf").of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + String word = c.element().getKey(); + Double df = c.element().getValue().getOnly(dfTag); + + for (KV uriAndTf : c.element().getValue().getAll(tfTag)) { + URI uri = uriAndTf.getKey(); + Double tf = uriAndTf.getValue(); + Double tfIdf = tf * Math.log(1 / df); + c.output(KV.of(word, KV.of(uri, tfIdf))); + } + } + })); + + return wordToUriAndTfIdf; + } + + // Instantiate Logger. + // It is suggested that the user specify the class name of the containing class + // (in this case ComputeTfIdf). + private static final Logger LOG = LoggerFactory.getLogger(ComputeTfIdf.class); + } + + /** + * A {@link PTransform} to write, in CSV format, a mapping from term and URI + * to score. + */ + public static class WriteTfIdf + extends PTransform>>, PDone> { + private String output; + + public WriteTfIdf(String output) { + this.output = output; + } + + @Override + public PDone apply(PCollection>> wordToUriAndTfIdf) { + return wordToUriAndTfIdf + .apply(ParDo.named("Format").of(new DoFn>, String>() { + @Override + public void processElement(ProcessContext c) { + c.output(String.format("%s,\t%s,\t%f", + c.element().getKey(), + c.element().getValue().getKey(), + c.element().getValue().getValue())); + } + })) + .apply(TextIO.Write + .to(output) + .withSuffix(".csv")); + } + } + + public static void main(String[] args) throws Exception { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline pipeline = Pipeline.create(options); + pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); + + pipeline + .apply(new ReadDocuments(listInputDocuments(options))) + .apply(new ComputeTfIdf()) + .apply(new WriteTfIdf(options.getOutput())); + + pipeline.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TopWikipediaSessions.java b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TopWikipediaSessions.java new file mode 100644 index 000000000000..c57a5f2aa602 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TopWikipediaSessions.java @@ -0,0 +1,223 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableComparator; +import com.google.cloud.dataflow.sdk.transforms.Top; +import com.google.cloud.dataflow.sdk.transforms.windowing.CalendarWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.List; + +/** + * An example that reads Wikipedia edit data from Cloud Storage and computes the user with + * the longest string of edits separated by no more than an hour within each month. + * + *

Concepts: Using Windowing to perform time-based aggregations of data. + * + *

It is not recommended to execute this pipeline locally, given the size of the default input + * data. + * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and an output prefix on GCS: + *
{@code
+ *   --output=gs://YOUR_OUTPUT_PREFIX
+ * }
+ * + *

The default input is {@code gs://dataflow-samples/wikipedia_edits/*.json} and can be + * overridden with {@code --input}. + * + *

The input for this example is large enough that it's a good place to enable (experimental) + * autoscaling: + *

{@code
+ *   --autoscalingAlgorithm=BASIC
+ *   --maxNumWorkers=20
+ * }
+ * 
+ * This will automatically scale the number of workers up over time until the job completes. + */ +public class TopWikipediaSessions { + private static final String EXPORTED_WIKI_TABLE = "gs://dataflow-samples/wikipedia_edits/*.json"; + + /** + * Extracts user and timestamp from a TableRow representing a Wikipedia edit. + */ + static class ExtractUserAndTimestamp extends DoFn { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + int timestamp = (Integer) row.get("timestamp"); + String userName = (String) row.get("contributor_username"); + if (userName != null) { + // Sets the implicit timestamp field to be used in windowing. + c.outputWithTimestamp(userName, new Instant(timestamp * 1000L)); + } + } + } + + /** + * Computes the number of edits in each user session. A session is defined as + * a string of edits where each is separated from the next by less than an hour. + */ + static class ComputeSessions + extends PTransform, PCollection>> { + @Override + public PCollection> apply(PCollection actions) { + return actions + .apply(Window.into(Sessions.withGapDuration(Duration.standardHours(1)))) + + .apply(Count.perElement()); + } + } + + /** + * Computes the longest session ending in each month. + */ + private static class TopPerMonth + extends PTransform>, PCollection>>> { + @Override + public PCollection>> apply(PCollection> sessions) { + return sessions + .apply(Window.>into(CalendarWindows.months(1))) + + .apply(Top.of(1, new SerializableComparator>() { + @Override + public int compare(KV o1, KV o2) { + return Long.compare(o1.getValue(), o2.getValue()); + } + }).withoutDefaults()); + } + } + + static class SessionsToStringsDoFn extends DoFn, KV> + implements RequiresWindowAccess { + + @Override + public void processElement(ProcessContext c) { + c.output(KV.of( + c.element().getKey() + " : " + c.window(), c.element().getValue())); + } + } + + static class FormatOutputDoFn extends DoFn>, String> + implements RequiresWindowAccess { + @Override + public void processElement(ProcessContext c) { + for (KV item : c.element()) { + String session = item.getKey(); + long count = item.getValue(); + c.output(session + " : " + count + " : " + ((IntervalWindow) c.window()).start()); + } + } + } + + static class ComputeTopSessions extends PTransform, PCollection> { + + private final double samplingThreshold; + + public ComputeTopSessions(double samplingThreshold) { + this.samplingThreshold = samplingThreshold; + } + + @Override + public PCollection apply(PCollection input) { + return input + .apply(ParDo.of(new ExtractUserAndTimestamp())) + + .apply(ParDo.named("SampleUsers").of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (Math.abs(c.element().hashCode()) <= Integer.MAX_VALUE * samplingThreshold) { + c.output(c.element()); + } + } + })) + + .apply(new ComputeSessions()) + + .apply(ParDo.named("SessionsToStrings").of(new SessionsToStringsDoFn())) + .apply(new TopPerMonth()) + .apply(ParDo.named("FormatOutput").of(new FormatOutputDoFn())); + } + } + + /** + * Options supported by this class. + * + *

Inherits standard Dataflow configuration options. + */ + private static interface Options extends PipelineOptions { + @Description( + "Input specified as a GCS path containing a BigQuery table exported as json") + @Default.String(EXPORTED_WIKI_TABLE) + String getInput(); + void setInput(String value); + + @Description("File to output results to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(Options.class); + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + + Pipeline p = Pipeline.create(dataflowOptions); + + double samplingThreshold = 0.1; + + p.apply(TextIO.Read + .from(options.getInput()) + .withCoder(TableRowJsonCoder.of())) + .apply(new ComputeTopSessions(samplingThreshold)) + .apply(TextIO.Write.named("Write").withoutSharding().to(options.getOutput())); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TrafficMaxLaneFlow.java b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TrafficMaxLaneFlow.java new file mode 100644 index 000000000000..2d5425208bdb --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TrafficMaxLaneFlow.java @@ -0,0 +1,425 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.examples.common.DataflowExampleOptions; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.common.ExampleBigQueryTableOptions; +import com.google.cloud.dataflow.examples.common.ExamplePubsubTopicAndSubscriptionOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Strings; + +import org.apache.avro.reflect.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * A Dataflow Example that runs in both batch and streaming modes with traffic sensor data. + * You can configure the running mode by setting {@literal --streaming} to true or false. + * + *

Concepts: The batch and streaming runners, sliding windows, Google Cloud Pub/Sub + * topic injection, use of the AvroCoder to encode a custom class, and custom Combine transforms. + * + *

This example analyzes traffic sensor data using SlidingWindows. For each window, + * it finds the lane that had the highest flow recorded, for each sensor station. It writes + * those max values along with auxiliary info to a BigQuery table. + * + *

In batch mode, the pipeline reads traffic sensor data from {@literal --inputFile}. + * + *

In streaming mode, the pipeline reads the data from a Pub/Sub topic. + * By default, the example will run a separate pipeline to inject the data from the default + * {@literal --inputFile} to the Pub/Sub {@literal --pubsubTopic}. It will make it available for + * the streaming pipeline to process. You may override the default {@literal --inputFile} with the + * file of your choosing. You may also set {@literal --inputFile} to an empty string, which will + * disable the automatic Pub/Sub injection, and allow you to use separate tool to control the input + * to this example. An example code, which publishes traffic sensor data to a Pub/Sub topic, + * is provided in + * . + * + *

The example is configured to use the default Pub/Sub topic and the default BigQuery table + * from the example common package (there are no defaults for a general Dataflow pipeline). + * You can override them by using the {@literal --pubsubTopic}, {@literal --bigQueryDataset}, and + * {@literal --bigQueryTable} options. If the Pub/Sub topic or the BigQuery table do not exist, + * the example will try to create them. + * + *

The example will try to cancel the pipelines on the signal to terminate the process (CTRL-C) + * and then exits. + */ +public class TrafficMaxLaneFlow { + + private static final String PUBSUB_TIMESTAMP_LABEL_KEY = "timestamp_ms"; + private static final Integer VALID_INPUTS = 4999; + + static final int WINDOW_DURATION = 60; // Default sliding window duration in minutes + static final int WINDOW_SLIDE_EVERY = 5; // Default window 'slide every' setting in minutes + + /** + * This class holds information about each lane in a station reading, along with some general + * information from the reading. + */ + @DefaultCoder(AvroCoder.class) + static class LaneInfo { + @Nullable String stationId; + @Nullable String lane; + @Nullable String direction; + @Nullable String freeway; + @Nullable String recordedTimestamp; + @Nullable Integer laneFlow; + @Nullable Integer totalFlow; + @Nullable Double laneAO; + @Nullable Double laneAS; + + public LaneInfo() {} + + public LaneInfo(String stationId, String lane, String direction, String freeway, + String timestamp, Integer laneFlow, Double laneAO, + Double laneAS, Integer totalFlow) { + this.stationId = stationId; + this.lane = lane; + this.direction = direction; + this.freeway = freeway; + this.recordedTimestamp = timestamp; + this.laneFlow = laneFlow; + this.laneAO = laneAO; + this.laneAS = laneAS; + this.totalFlow = totalFlow; + } + + public String getStationId() { + return this.stationId; + } + public String getLane() { + return this.lane; + } + public String getDirection() { + return this.direction; + } + public String getFreeway() { + return this.freeway; + } + public String getRecordedTimestamp() { + return this.recordedTimestamp; + } + public Integer getLaneFlow() { + return this.laneFlow; + } + public Double getLaneAO() { + return this.laneAO; + } + public Double getLaneAS() { + return this.laneAS; + } + public Integer getTotalFlow() { + return this.totalFlow; + } + } + + /** + * Extract the timestamp field from the input string, and use it as the element timestamp. + */ + static class ExtractTimestamps extends DoFn { + private static final DateTimeFormatter dateTimeFormat = + DateTimeFormat.forPattern("MM/dd/yyyy HH:mm:ss"); + + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + String[] items = c.element().split(","); + if (items.length > 0) { + try { + String timestamp = items[0]; + c.outputWithTimestamp(c.element(), new Instant(dateTimeFormat.parseMillis(timestamp))); + } catch (IllegalArgumentException e) { + // Skip the invalid input. + } + } + } + } + + /** + * Extract flow information for each of the 8 lanes in a reading, and output as separate tuples. + * This will let us determine which lane has the max flow for that station over the span of the + * window, and output not only the max flow from that calculation, but other associated + * information. The number of lanes for which data is present depends upon which freeway the data + * point comes from. + */ + static class ExtractFlowInfoFn extends DoFn> { + + @Override + public void processElement(ProcessContext c) { + String[] items = c.element().split(","); + if (items.length < 48) { + // Skip the invalid input. + return; + } + // extract the sensor information for the lanes from the input string fields. + String timestamp = items[0]; + String stationId = items[1]; + String freeway = items[2]; + String direction = items[3]; + Integer totalFlow = tryIntParse(items[7]); + for (int i = 1; i <= 8; ++i) { + Integer laneFlow = tryIntParse(items[6 + 5 * i]); + Double laneAvgOccupancy = tryDoubleParse(items[7 + 5 * i]); + Double laneAvgSpeed = tryDoubleParse(items[8 + 5 * i]); + if (laneFlow == null || laneAvgOccupancy == null || laneAvgSpeed == null) { + return; + } + LaneInfo laneInfo = new LaneInfo(stationId, "lane" + i, direction, freeway, timestamp, + laneFlow, laneAvgOccupancy, laneAvgSpeed, totalFlow); + c.output(KV.of(stationId, laneInfo)); + } + } + } + + /** + * A custom 'combine function' used with the Combine.perKey transform. Used to find the max lane + * flow over all the data points in the Window. Extracts the lane flow from the input string and + * determines whether it's the max seen so far. We're using a custom combiner instead of the Max + * transform because we want to retain the additional information we've associated with the flow + * value. + */ + public static class MaxFlow implements SerializableFunction, LaneInfo> { + @Override + public LaneInfo apply(Iterable input) { + Integer max = 0; + LaneInfo maxInfo = new LaneInfo(); + for (LaneInfo item : input) { + Integer flow = item.getLaneFlow(); + if (flow != null && (flow >= max)) { + max = flow; + maxInfo = item; + } + } + return maxInfo; + } + } + + /** + * Format the results of the Max Lane flow calculation to a TableRow, to save to BigQuery. + * Add the timestamp from the window context. + */ + static class FormatMaxesFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + + LaneInfo laneInfo = c.element().getValue(); + TableRow row = new TableRow() + .set("station_id", c.element().getKey()) + .set("direction", laneInfo.getDirection()) + .set("freeway", laneInfo.getFreeway()) + .set("lane_max_flow", laneInfo.getLaneFlow()) + .set("lane", laneInfo.getLane()) + .set("avg_occ", laneInfo.getLaneAO()) + .set("avg_speed", laneInfo.getLaneAS()) + .set("total_flow", laneInfo.getTotalFlow()) + .set("recorded_timestamp", laneInfo.getRecordedTimestamp()) + .set("window_timestamp", c.timestamp().toString()); + c.output(row); + } + + /** Defines the BigQuery schema used for the output. */ + static TableSchema getSchema() { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("station_id").setType("STRING")); + fields.add(new TableFieldSchema().setName("direction").setType("STRING")); + fields.add(new TableFieldSchema().setName("freeway").setType("STRING")); + fields.add(new TableFieldSchema().setName("lane_max_flow").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("lane").setType("STRING")); + fields.add(new TableFieldSchema().setName("avg_occ").setType("FLOAT")); + fields.add(new TableFieldSchema().setName("avg_speed").setType("FLOAT")); + fields.add(new TableFieldSchema().setName("total_flow").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("window_timestamp").setType("TIMESTAMP")); + fields.add(new TableFieldSchema().setName("recorded_timestamp").setType("STRING")); + TableSchema schema = new TableSchema().setFields(fields); + return schema; + } + } + + /** + * This PTransform extracts lane info, calculates the max lane flow found for a given station (for + * the current Window) using a custom 'combiner', and formats the results for BigQuery. + */ + static class MaxLaneFlow + extends PTransform>, PCollection> { + @Override + public PCollection apply(PCollection> flowInfo) { + // stationId, LaneInfo => stationId + max lane flow info + PCollection> flowMaxes = + flowInfo.apply(Combine.perKey( + new MaxFlow())); + + // ... => row... + PCollection results = flowMaxes.apply( + ParDo.of(new FormatMaxesFn())); + + return results; + } + } + + static class ReadFileAndExtractTimestamps extends PTransform> { + private final String inputFile; + + public ReadFileAndExtractTimestamps(String inputFile) { + this.inputFile = inputFile; + } + + @Override + public PCollection apply(PBegin begin) { + return begin + .apply(TextIO.Read.from(inputFile)) + .apply(ParDo.of(new ExtractTimestamps())); + } + } + + /** + * Options supported by {@link TrafficMaxLaneFlow}. + * + *

Inherits standard configuration options. + */ + private interface TrafficMaxLaneFlowOptions extends DataflowExampleOptions, + ExamplePubsubTopicAndSubscriptionOptions, ExampleBigQueryTableOptions { + @Description("Input file to inject to Pub/Sub topic") + @Default.String("gs://dataflow-samples/traffic_sensor/" + + "Freeways-5Minaa2010-01-01_to_2010-02-15_test2.csv") + String getInputFile(); + void setInputFile(String value); + + @Description("Numeric value of sliding window duration, in minutes") + @Default.Integer(WINDOW_DURATION) + Integer getWindowDuration(); + void setWindowDuration(Integer value); + + @Description("Numeric value of window 'slide every' setting, in minutes") + @Default.Integer(WINDOW_SLIDE_EVERY) + Integer getWindowSlideEvery(); + void setWindowSlideEvery(Integer value); + + @Description("Whether to run the pipeline with unbounded input") + @Default.Boolean(false) + boolean isUnbounded(); + void setUnbounded(boolean value); + } + + /** + * Sets up and starts streaming pipeline. + * + * @throws IOException if there is a problem setting up resources + */ + public static void main(String[] args) throws IOException { + TrafficMaxLaneFlowOptions options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(TrafficMaxLaneFlowOptions.class); + options.setBigQuerySchema(FormatMaxesFn.getSchema()); + // Using DataflowExampleUtils to set up required resources. + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options, options.isUnbounded()); + + Pipeline pipeline = Pipeline.create(options); + TableReference tableRef = new TableReference(); + tableRef.setProjectId(options.getProject()); + tableRef.setDatasetId(options.getBigQueryDataset()); + tableRef.setTableId(options.getBigQueryTable()); + + PCollection input; + if (options.isUnbounded()) { + // Read unbounded PubSubIO. + input = pipeline.apply(PubsubIO.Read + .timestampLabel(PUBSUB_TIMESTAMP_LABEL_KEY) + .subscription(options.getPubsubSubscription())); + } else { + // Read bounded PubSubIO. + input = pipeline.apply(PubsubIO.Read + .timestampLabel(PUBSUB_TIMESTAMP_LABEL_KEY) + .subscription(options.getPubsubSubscription()).maxNumRecords(VALID_INPUTS)); + + // To read bounded TextIO files, use: + // input = pipeline.apply(new ReadFileAndExtractTimestamps(options.getInputFile())); + } + input + // row... => ... + .apply(ParDo.of(new ExtractFlowInfoFn())) + // map the incoming data stream into sliding windows. The default window duration values + // work well if you're running the accompanying Pub/Sub generator script with the + // --replay flag, which simulates pauses in the sensor data publication. You may want to + // adjust them otherwise. + .apply(Window.>into(SlidingWindows.of( + Duration.standardMinutes(options.getWindowDuration())). + every(Duration.standardMinutes(options.getWindowSlideEvery())))) + .apply(new MaxLaneFlow()) + .apply(BigQueryIO.Write.to(tableRef) + .withSchema(FormatMaxesFn.getSchema())); + + // Inject the data into the Pub/Sub topic with a Dataflow batch pipeline. + if (!Strings.isNullOrEmpty(options.getInputFile()) + && !Strings.isNullOrEmpty(options.getPubsubTopic())) { + dataflowUtils.runInjectorPipeline( + new ReadFileAndExtractTimestamps(options.getInputFile()), + options.getPubsubTopic(), + PUBSUB_TIMESTAMP_LABEL_KEY); + } + + // Run the pipeline. + PipelineResult result = pipeline.run(); + + // dataflowUtils will try to cancel the pipeline and the injector before the program exists. + dataflowUtils.waitToFinish(result); + } + + private static Integer tryIntParse(String number) { + try { + return Integer.parseInt(number); + } catch (NumberFormatException e) { + return null; + } + } + + private static Double tryDoubleParse(String number) { + try { + return Double.parseDouble(number); + } catch (NumberFormatException e) { + return null; + } + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TrafficRoutes.java b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TrafficRoutes.java new file mode 100644 index 000000000000..e3e88c22da57 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/complete/TrafficRoutes.java @@ -0,0 +1,459 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.examples.common.DataflowExampleOptions; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.common.ExampleBigQueryTableOptions; +import com.google.cloud.dataflow.examples.common.ExamplePubsubTopicAndSubscriptionOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Strings; +import com.google.common.collect.Lists; + +import org.apache.avro.reflect.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Hashtable; +import java.util.List; +import java.util.Map; + +/** + * A Dataflow Example that runs in both batch and streaming modes with traffic sensor data. + * You can configure the running mode by setting {@literal --streaming} to true or false. + * + *

Concepts: The batch and streaming runners, GroupByKey, sliding windows, and + * Google Cloud Pub/Sub topic injection. + * + *

This example analyzes traffic sensor data using SlidingWindows. For each window, + * it calculates the average speed over the window for some small set of predefined 'routes', + * and looks for 'slowdowns' in those routes. It writes its results to a BigQuery table. + * + *

In batch mode, the pipeline reads traffic sensor data from {@literal --inputFile}. + * + *

In streaming mode, the pipeline reads the data from a Pub/Sub topic. + * By default, the example will run a separate pipeline to inject the data from the default + * {@literal --inputFile} to the Pub/Sub {@literal --pubsubTopic}. It will make it available for + * the streaming pipeline to process. You may override the default {@literal --inputFile} with the + * file of your choosing. You may also set {@literal --inputFile} to an empty string, which will + * disable the automatic Pub/Sub injection, and allow you to use separate tool to control the input + * to this example. An example code, which publishes traffic sensor data to a Pub/Sub topic, + * is provided in + * . + * + *

The example is configured to use the default Pub/Sub topic and the default BigQuery table + * from the example common package (there are no defaults for a general Dataflow pipeline). + * You can override them by using the {@literal --pubsubTopic}, {@literal --bigQueryDataset}, and + * {@literal --bigQueryTable} options. If the Pub/Sub topic or the BigQuery table do not exist, + * the example will try to create them. + * + *

The example will try to cancel the pipelines on the signal to terminate the process (CTRL-C) + * and then exits. + */ + +public class TrafficRoutes { + + private static final String PUBSUB_TIMESTAMP_LABEL_KEY = "timestamp_ms"; + private static final Integer VALID_INPUTS = 4999; + + // Instantiate some small predefined San Diego routes to analyze + static Map sdStations = buildStationInfo(); + static final int WINDOW_DURATION = 3; // Default sliding window duration in minutes + static final int WINDOW_SLIDE_EVERY = 1; // Default window 'slide every' setting in minutes + + /** + * This class holds information about a station reading's average speed. + */ + @DefaultCoder(AvroCoder.class) + static class StationSpeed implements Comparable { + @Nullable String stationId; + @Nullable Double avgSpeed; + @Nullable Long timestamp; + + public StationSpeed() {} + + public StationSpeed(String stationId, Double avgSpeed, Long timestamp) { + this.stationId = stationId; + this.avgSpeed = avgSpeed; + this.timestamp = timestamp; + } + + public String getStationId() { + return this.stationId; + } + public Double getAvgSpeed() { + return this.avgSpeed; + } + + @Override + public int compareTo(StationSpeed other) { + return Long.compare(this.timestamp, other.timestamp); + } + } + + /** + * This class holds information about a route's speed/slowdown. + */ + @DefaultCoder(AvroCoder.class) + static class RouteInfo { + @Nullable String route; + @Nullable Double avgSpeed; + @Nullable Boolean slowdownEvent; + + + public RouteInfo() {} + + public RouteInfo(String route, Double avgSpeed, Boolean slowdownEvent) { + this.route = route; + this.avgSpeed = avgSpeed; + this.slowdownEvent = slowdownEvent; + } + + public String getRoute() { + return this.route; + } + public Double getAvgSpeed() { + return this.avgSpeed; + } + public Boolean getSlowdownEvent() { + return this.slowdownEvent; + } + } + + /** + * Extract the timestamp field from the input string, and use it as the element timestamp. + */ + static class ExtractTimestamps extends DoFn { + private static final DateTimeFormatter dateTimeFormat = + DateTimeFormat.forPattern("MM/dd/yyyy HH:mm:ss"); + + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + String[] items = c.element().split(","); + String timestamp = tryParseTimestamp(items); + if (timestamp != null) { + try { + c.outputWithTimestamp(c.element(), new Instant(dateTimeFormat.parseMillis(timestamp))); + } catch (IllegalArgumentException e) { + // Skip the invalid input. + } + } + } + } + + /** + * Filter out readings for the stations along predefined 'routes', and output + * (station, speed info) keyed on route. + */ + static class ExtractStationSpeedFn extends DoFn> { + + @Override + public void processElement(ProcessContext c) { + String[] items = c.element().split(","); + String stationType = tryParseStationType(items); + // For this analysis, use only 'main line' station types + if (stationType != null && stationType.equals("ML")) { + Double avgSpeed = tryParseAvgSpeed(items); + String stationId = tryParseStationId(items); + // For this simple example, filter out everything but some hardwired routes. + if (avgSpeed != null && stationId != null && sdStations.containsKey(stationId)) { + StationSpeed stationSpeed = + new StationSpeed(stationId, avgSpeed, c.timestamp().getMillis()); + // The tuple key is the 'route' name stored in the 'sdStations' hash. + KV outputValue = KV.of(sdStations.get(stationId), stationSpeed); + c.output(outputValue); + } + } + } + } + + /** + * For a given route, track average speed for the window. Calculate whether + * traffic is currently slowing down, via a predefined threshold. If a supermajority of + * speeds in this sliding window are less than the previous reading we call this a 'slowdown'. + * Note: these calculations are for example purposes only, and are unrealistic and oversimplified. + */ + static class GatherStats + extends DoFn>, KV> { + @Override + public void processElement(ProcessContext c) throws IOException { + String route = c.element().getKey(); + double speedSum = 0.0; + int speedCount = 0; + int speedups = 0; + int slowdowns = 0; + List infoList = Lists.newArrayList(c.element().getValue()); + // StationSpeeds sort by embedded timestamp. + Collections.sort(infoList); + Map prevSpeeds = new HashMap<>(); + // For all stations in the route, sum (non-null) speeds. Keep a count of the non-null speeds. + for (StationSpeed item : infoList) { + Double speed = item.getAvgSpeed(); + if (speed != null) { + speedSum += speed; + speedCount++; + Double lastSpeed = prevSpeeds.get(item.getStationId()); + if (lastSpeed != null) { + if (lastSpeed < speed) { + speedups += 1; + } else { + slowdowns += 1; + } + } + prevSpeeds.put(item.getStationId(), speed); + } + } + if (speedCount == 0) { + // No average to compute. + return; + } + double speedAvg = speedSum / speedCount; + boolean slowdownEvent = slowdowns >= 2 * speedups; + RouteInfo routeInfo = new RouteInfo(route, speedAvg, slowdownEvent); + c.output(KV.of(route, routeInfo)); + } + } + + /** + * Format the results of the slowdown calculations to a TableRow, to save to BigQuery. + */ + static class FormatStatsFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + RouteInfo routeInfo = c.element().getValue(); + TableRow row = new TableRow() + .set("avg_speed", routeInfo.getAvgSpeed()) + .set("slowdown_event", routeInfo.getSlowdownEvent()) + .set("route", c.element().getKey()) + .set("window_timestamp", c.timestamp().toString()); + c.output(row); + } + + /** + * Defines the BigQuery schema used for the output. + */ + static TableSchema getSchema() { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("route").setType("STRING")); + fields.add(new TableFieldSchema().setName("avg_speed").setType("FLOAT")); + fields.add(new TableFieldSchema().setName("slowdown_event").setType("BOOLEAN")); + fields.add(new TableFieldSchema().setName("window_timestamp").setType("TIMESTAMP")); + TableSchema schema = new TableSchema().setFields(fields); + return schema; + } + } + + /** + * This PTransform extracts speed info from traffic station readings. + * It groups the readings by 'route' and analyzes traffic slowdown for that route. + * Lastly, it formats the results for BigQuery. + */ + static class TrackSpeed extends + PTransform>, PCollection> { + @Override + public PCollection apply(PCollection> stationSpeed) { + // Apply a GroupByKey transform to collect a list of all station + // readings for a given route. + PCollection>> timeGroup = stationSpeed.apply( + GroupByKey.create()); + + // Analyze 'slowdown' over the route readings. + PCollection> stats = timeGroup.apply(ParDo.of(new GatherStats())); + + // Format the results for writing to BigQuery + PCollection results = stats.apply( + ParDo.of(new FormatStatsFn())); + + return results; + } + } + + static class ReadFileAndExtractTimestamps extends PTransform> { + private final String inputFile; + + public ReadFileAndExtractTimestamps(String inputFile) { + this.inputFile = inputFile; + } + + @Override + public PCollection apply(PBegin begin) { + return begin + .apply(TextIO.Read.from(inputFile)) + .apply(ParDo.of(new ExtractTimestamps())); + } + } + + /** + * Options supported by {@link TrafficRoutes}. + * + *

Inherits standard configuration options. + */ + private interface TrafficRoutesOptions extends DataflowExampleOptions, + ExamplePubsubTopicAndSubscriptionOptions, ExampleBigQueryTableOptions { + @Description("Input file to inject to Pub/Sub topic") + @Default.String("gs://dataflow-samples/traffic_sensor/" + + "Freeways-5Minaa2010-01-01_to_2010-02-15_test2.csv") + String getInputFile(); + void setInputFile(String value); + + @Description("Numeric value of sliding window duration, in minutes") + @Default.Integer(WINDOW_DURATION) + Integer getWindowDuration(); + void setWindowDuration(Integer value); + + @Description("Numeric value of window 'slide every' setting, in minutes") + @Default.Integer(WINDOW_SLIDE_EVERY) + Integer getWindowSlideEvery(); + void setWindowSlideEvery(Integer value); + + @Description("Whether to run the pipeline with unbounded input") + @Default.Boolean(false) + boolean isUnbounded(); + void setUnbounded(boolean value); + } + + /** + * Sets up and starts streaming pipeline. + * + * @throws IOException if there is a problem setting up resources + */ + public static void main(String[] args) throws IOException { + TrafficRoutesOptions options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(TrafficRoutesOptions.class); + + options.setBigQuerySchema(FormatStatsFn.getSchema()); + // Using DataflowExampleUtils to set up required resources. + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options, options.isUnbounded()); + + Pipeline pipeline = Pipeline.create(options); + TableReference tableRef = new TableReference(); + tableRef.setProjectId(options.getProject()); + tableRef.setDatasetId(options.getBigQueryDataset()); + tableRef.setTableId(options.getBigQueryTable()); + + PCollection input; + if (options.isUnbounded()) { + // Read unbounded PubSubIO. + input = pipeline.apply(PubsubIO.Read + .timestampLabel(PUBSUB_TIMESTAMP_LABEL_KEY) + .subscription(options.getPubsubSubscription())); + } else { + // Read bounded PubSubIO. + input = pipeline.apply(PubsubIO.Read + .timestampLabel(PUBSUB_TIMESTAMP_LABEL_KEY) + .subscription(options.getPubsubSubscription()).maxNumRecords(VALID_INPUTS)); + + // To read bounded TextIO files, use: + // input = pipeline.apply(TextIO.Read.from(options.getInputFile())) + // .apply(ParDo.of(new ExtractTimestamps())); + } + input + // row... => ... + .apply(ParDo.of(new ExtractStationSpeedFn())) + // map the incoming data stream into sliding windows. + // The default window duration values work well if you're running the accompanying Pub/Sub + // generator script without the --replay flag, so that there are no simulated pauses in + // the sensor data publication. You may want to adjust the values otherwise. + .apply(Window.>into(SlidingWindows.of( + Duration.standardMinutes(options.getWindowDuration())). + every(Duration.standardMinutes(options.getWindowSlideEvery())))) + .apply(new TrackSpeed()) + .apply(BigQueryIO.Write.to(tableRef) + .withSchema(FormatStatsFn.getSchema())); + + // Inject the data into the Pub/Sub topic with a Dataflow batch pipeline. + if (!Strings.isNullOrEmpty(options.getInputFile()) + && !Strings.isNullOrEmpty(options.getPubsubTopic())) { + dataflowUtils.runInjectorPipeline( + new ReadFileAndExtractTimestamps(options.getInputFile()), + options.getPubsubTopic(), + PUBSUB_TIMESTAMP_LABEL_KEY); + } + + // Run the pipeline. + PipelineResult result = pipeline.run(); + + // dataflowUtils will try to cancel the pipeline and the injector before the program exists. + dataflowUtils.waitToFinish(result); + } + + private static Double tryParseAvgSpeed(String[] inputItems) { + try { + return Double.parseDouble(tryParseString(inputItems, 9)); + } catch (NumberFormatException e) { + return null; + } catch (NullPointerException e) { + return null; + } + } + + private static String tryParseStationType(String[] inputItems) { + return tryParseString(inputItems, 4); + } + + private static String tryParseStationId(String[] inputItems) { + return tryParseString(inputItems, 1); + } + + private static String tryParseTimestamp(String[] inputItems) { + return tryParseString(inputItems, 0); + } + + private static String tryParseString(String[] inputItems, int index) { + return inputItems.length >= index ? inputItems[index] : null; + } + + /** + * Define some small hard-wired San Diego 'routes' to track based on sensor station ID. + */ + private static Map buildStationInfo() { + Map stations = new Hashtable(); + stations.put("1108413", "SDRoute1"); // from freeway 805 S + stations.put("1108699", "SDRoute2"); // from freeway 78 E + stations.put("1108702", "SDRoute2"); + return stations; + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/BigQueryTornadoes.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/BigQueryTornadoes.java new file mode 100644 index 000000000000..503bcadf5332 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/BigQueryTornadoes.java @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.util.ArrayList; +import java.util.List; + +/** + * An example that reads the public samples of weather data from BigQuery, counts the number of + * tornadoes that occur in each month, and writes the results to BigQuery. + * + *

Concepts: Reading/writing BigQuery; counting a PCollection; user-defined PTransforms + * + *

Note: Before running this example, you must create a BigQuery dataset to contain your output + * table. + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and the BigQuery table for the output, with the form + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and the BigQuery table for the output: + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

The BigQuery input table defaults to {@code clouddataflow-readonly:samples.weather_stations} + * and can be overridden with {@code --input}. + */ +public class BigQueryTornadoes { + // Default to using a 1000 row subset of the public weather station table publicdata:samples.gsod. + private static final String WEATHER_SAMPLES_TABLE = + "clouddataflow-readonly:samples.weather_stations"; + + /** + * Examines each row in the input table. If a tornado was recorded + * in that sample, the month in which it occurred is output. + */ + static class ExtractTornadoesFn extends DoFn { + @Override + public void processElement(ProcessContext c){ + TableRow row = c.element(); + if ((Boolean) row.get("tornado")) { + c.output(Integer.parseInt((String) row.get("month"))); + } + } + } + + /** + * Prepares the data for writing to BigQuery by building a TableRow object containing an + * integer representation of month and the number of tornadoes that occurred in each month. + */ + static class FormatCountsFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + TableRow row = new TableRow() + .set("month", c.element().getKey()) + .set("tornado_count", c.element().getValue()); + c.output(row); + } + } + + /** + * Takes rows from a table and generates a table of counts. + * + *

The input schema is described by + * https://developers.google.com/bigquery/docs/dataset-gsod . + * The output contains the total number of tornadoes found in each month in + * the following schema: + *

    + *
  • month: integer
  • + *
  • tornado_count: integer
  • + *
+ */ + static class CountTornadoes + extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection rows) { + + // row... => month... + PCollection tornadoes = rows.apply( + ParDo.of(new ExtractTornadoesFn())); + + // month... => ... + PCollection> tornadoCounts = + tornadoes.apply(Count.perElement()); + + // ... => row... + PCollection results = tornadoCounts.apply( + ParDo.of(new FormatCountsFn())); + + return results; + } + } + + /** + * Options supported by {@link BigQueryTornadoes}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Table to read from, specified as " + + ":.") + @Default.String(WEATHER_SAMPLES_TABLE) + String getInput(); + void setInput(String value); + + @Description("BigQuery table to write to, specified as " + + ":.. The dataset must already exist.") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + + Pipeline p = Pipeline.create(options); + + // Build the table schema for the output table. + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("month").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("tornado_count").setType("INTEGER")); + TableSchema schema = new TableSchema().setFields(fields); + + p.apply(BigQueryIO.Read.from(options.getInput())) + .apply(new CountTornadoes()) + .apply(BigQueryIO.Write + .to(options.getOutput()) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/CombinePerKeyExamples.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/CombinePerKeyExamples.java new file mode 100644 index 000000000000..9540dd448226 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/CombinePerKeyExamples.java @@ -0,0 +1,223 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.util.ArrayList; +import java.util.List; + +/** + * An example that reads the public 'Shakespeare' data, and for each word in + * the dataset that is over a given length, generates a string containing the + * list of play names in which that word appears, and saves this information + * to a bigquery table. + * + *

Concepts: the Combine.perKey transform, which lets you combine the values in a + * key-grouped Collection, and how to use an Aggregator to track information in the + * Monitoring UI. + * + *

Note: Before running this example, you must create a BigQuery dataset to contain your output + * table. + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and the BigQuery table for the output: + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and the BigQuery table for the output: + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

The BigQuery input table defaults to {@code publicdata:samples.shakespeare} and can + * be overridden with {@code --input}. + */ +public class CombinePerKeyExamples { + // Use the shakespeare public BigQuery sample + private static final String SHAKESPEARE_TABLE = + "publicdata:samples.shakespeare"; + // We'll track words >= this word length across all plays in the table. + private static final int MIN_WORD_LENGTH = 9; + + /** + * Examines each row in the input table. If the word is greater than or equal to MIN_WORD_LENGTH, + * outputs word, play_name. + */ + static class ExtractLargeWordsFn extends DoFn> { + private final Aggregator smallerWords = + createAggregator("smallerWords", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c){ + TableRow row = c.element(); + String playName = (String) row.get("corpus"); + String word = (String) row.get("word"); + if (word.length() >= MIN_WORD_LENGTH) { + c.output(KV.of(word, playName)); + } else { + // Track how many smaller words we're not including. This information will be + // visible in the Monitoring UI. + smallerWords.addValue(1L); + } + } + } + + + /** + * Prepares the data for writing to BigQuery by building a TableRow object + * containing a word with a string listing the plays in which it appeared. + */ + static class FormatShakespeareOutputFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + TableRow row = new TableRow() + .set("word", c.element().getKey()) + .set("all_plays", c.element().getValue()); + c.output(row); + } + } + + /** + * Reads the public 'Shakespeare' data, and for each word in the dataset + * over a given length, generates a string containing the list of play names + * in which that word appears. It does this via the Combine.perKey + * transform, with the ConcatWords combine function. + * + *

Combine.perKey is similar to a GroupByKey followed by a ParDo, but + * has more restricted semantics that allow it to be executed more + * efficiently. These records are then formatted as BQ table rows. + */ + static class PlaysForWord + extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection rows) { + + // row... => ... + PCollection> words = rows.apply( + ParDo.of(new ExtractLargeWordsFn())); + + // word, play_name => word, all_plays ... + PCollection> wordAllPlays = + words.apply(Combine.perKey( + new ConcatWords())); + + // ... => row... + PCollection results = wordAllPlays.apply( + ParDo.of(new FormatShakespeareOutputFn())); + + return results; + } + } + + /** + * A 'combine function' used with the Combine.perKey transform. Builds a + * comma-separated string of all input items. So, it will build a string + * containing all the different Shakespeare plays in which the given input + * word has appeared. + */ + public static class ConcatWords implements SerializableFunction, String> { + @Override + public String apply(Iterable input) { + StringBuilder all = new StringBuilder(); + for (String item : input) { + if (!item.isEmpty()) { + if (all.length() == 0) { + all.append(item); + } else { + all.append(","); + all.append(item); + } + } + } + return all.toString(); + } + } + + /** + * Options supported by {@link CombinePerKeyExamples}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Table to read from, specified as " + + ":.") + @Default.String(SHAKESPEARE_TABLE) + String getInput(); + void setInput(String value); + + @Description("Table to write to, specified as " + + ":.. " + + "The dataset_id must already exist") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) + throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + + // Build the table schema for the output table. + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("word").setType("STRING")); + fields.add(new TableFieldSchema().setName("all_plays").setType("STRING")); + TableSchema schema = new TableSchema().setFields(fields); + + p.apply(BigQueryIO.Read.from(options.getInput())) + .apply(new PlaysForWord()) + .apply(BigQueryIO.Write + .to(options.getOutput()) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/DatastoreWordCount.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/DatastoreWordCount.java new file mode 100644 index 000000000000..eaf1e2053d55 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/DatastoreWordCount.java @@ -0,0 +1,269 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import static com.google.api.services.datastore.client.DatastoreHelper.getPropertyMap; +import static com.google.api.services.datastore.client.DatastoreHelper.getString; +import static com.google.api.services.datastore.client.DatastoreHelper.makeFilter; +import static com.google.api.services.datastore.client.DatastoreHelper.makeKey; +import static com.google.api.services.datastore.client.DatastoreHelper.makeValue; + +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.Key; +import com.google.api.services.datastore.DatastoreV1.Property; +import com.google.api.services.datastore.DatastoreV1.PropertyFilter; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.api.services.datastore.DatastoreV1.Value; +import com.google.cloud.dataflow.examples.WordCount; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.DatastoreIO; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +import java.util.Map; +import java.util.UUID; + +import javax.annotation.Nullable; + +/** + * A WordCount example using DatastoreIO. + * + *

This example shows how to use DatastoreIO to read from Datastore and + * write the results to Cloud Storage. Note that this example will write + * data to Datastore, which may incur charge for Datastore operations. + * + *

To run this example, users need to use gcloud to get credential for Datastore: + *

{@code
+ * $ gcloud auth login
+ * }
+ * + *

To run this pipeline locally, the following options must be provided: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --dataset=YOUR_DATASET_ID
+ *   --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PATH]
+ * }
+ * + *

To run this example using Dataflow service, you must additionally + * provide either {@literal --stagingLocation} or {@literal --tempLocation}, and + * select one of the Dataflow pipeline runners, eg + * {@literal --runner=BlockingDataflowPipelineRunner}. + * + *

Note: this example creates entities with Ancestor keys to ensure that all + * entities created are in the same entity group. Similarly, the query used to read from the Cloud + * Datastore uses an Ancestor filter. Ancestors are used to ensure strongly consistent + * results in Cloud Datastore. For more information, see the Cloud Datastore documentation on + * + * Structing Data for Strong Consistency. + */ +public class DatastoreWordCount { + + /** + * A DoFn that gets the content of an entity (one line in a + * Shakespeare play) and converts it to a string. + */ + static class GetContentFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Map props = getPropertyMap(c.element()); + Value value = props.get("content"); + if (value != null) { + c.output(getString(value)); + } + } + } + + /** + * A helper function to create the ancestor key for all created and queried entities. + * + *

We use ancestor keys and ancestor queries for strong consistency. See + * {@link DatastoreWordCount} javadoc for more information. + */ + static Key makeAncestorKey(@Nullable String namespace, String kind) { + Key.Builder keyBuilder = makeKey(kind, "root"); + if (namespace != null) { + keyBuilder.getPartitionIdBuilder().setNamespace(namespace); + } + return keyBuilder.build(); + } + + /** + * A DoFn that creates entity for every line in Shakespeare. + */ + static class CreateEntityFn extends DoFn { + private final String namespace; + private final String kind; + private final Key ancestorKey; + + CreateEntityFn(String namespace, String kind) { + this.namespace = namespace; + this.kind = kind; + + // Build the ancestor key for all created entities once, including the namespace. + ancestorKey = makeAncestorKey(namespace, kind); + } + + public Entity makeEntity(String content) { + Entity.Builder entityBuilder = Entity.newBuilder(); + + // All created entities have the same ancestor Key. + Key.Builder keyBuilder = makeKey(ancestorKey, kind, UUID.randomUUID().toString()); + // NOTE: Namespace is not inherited between keys created with DatastoreHelper.makeKey, so + // we must set the namespace on keyBuilder. TODO: Once partitionId inheritance is added, + // we can simplify this code. + if (namespace != null) { + keyBuilder.getPartitionIdBuilder().setNamespace(namespace); + } + + entityBuilder.setKey(keyBuilder.build()); + entityBuilder.addProperty(Property.newBuilder().setName("content") + .setValue(Value.newBuilder().setStringValue(content))); + return entityBuilder.build(); + } + + @Override + public void processElement(ProcessContext c) { + c.output(makeEntity(c.element())); + } + } + + /** + * Options supported by {@link DatastoreWordCount}. + * + *

Inherits standard configuration options. + */ + public static interface Options extends PipelineOptions { + @Description("Path of the file to read from and store to Datastore") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInput(); + void setInput(String value); + + @Description("Path of the file to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + + @Description("Dataset ID to read from datastore") + @Validation.Required + String getDataset(); + void setDataset(String value); + + @Description("Dataset entity kind") + @Default.String("shakespeare-demo") + String getKind(); + void setKind(String value); + + @Description("Dataset namespace") + String getNamespace(); + void setNamespace(@Nullable String value); + + @Description("Read an existing dataset, do not write first") + boolean isReadOnly(); + void setReadOnly(boolean value); + + @Description("Number of output shards") + @Default.Integer(0) // If the system should choose automatically. + int getNumShards(); + void setNumShards(int value); + } + + /** + * An example that creates a pipeline to populate DatastoreIO from a + * text input. Forces use of DirectPipelineRunner for local execution mode. + */ + public static void writeDataToDatastore(Options options) { + Pipeline p = Pipeline.create(options); + p.apply(TextIO.Read.named("ReadLines").from(options.getInput())) + .apply(ParDo.of(new CreateEntityFn(options.getNamespace(), options.getKind()))) + .apply(DatastoreIO.writeTo(options.getDataset())); + + p.run(); + } + + /** + * Build a Cloud Datastore ancestor query for the specified {@link Options#getNamespace} and + * {@link Options#getKind}. + * + *

We use ancestor keys and ancestor queries for strong consistency. See + * {@link DatastoreWordCount} javadoc for more information. + * + * @see Ancestor filters + */ + static Query makeAncestorKindQuery(Options options) { + Query.Builder q = Query.newBuilder(); + q.addKindBuilder().setName(options.getKind()); + q.setFilter(makeFilter( + "__key__", + PropertyFilter.Operator.HAS_ANCESTOR, + makeValue(makeAncestorKey(options.getNamespace(), options.getKind())))); + return q.build(); + } + + /** + * An example that creates a pipeline to do DatastoreIO.Read from Datastore. + */ + public static void readDataFromDatastore(Options options) { + Query query = makeAncestorKindQuery(options); + + // For Datastore sources, the read namespace can be set on the entire query. + DatastoreIO.Source source = DatastoreIO.source() + .withDataset(options.getDataset()) + .withQuery(query) + .withNamespace(options.getNamespace()); + + Pipeline p = Pipeline.create(options); + p.apply("ReadShakespeareFromDatastore", Read.from(source)) + .apply("StringifyEntity", ParDo.of(new GetContentFn())) + .apply("CountWords", new WordCount.CountWords()) + .apply("PrintWordCount", MapElements.via(new WordCount.FormatAsTextFn())) + .apply("WriteLines", TextIO.Write.to(options.getOutput()) + .withNumShards(options.getNumShards())); + p.run(); + } + + /** + * An example to demo how to use {@link DatastoreIO}. The runner here is + * customizable, which means users could pass either {@code DirectPipelineRunner} + * or {@code DataflowPipelineRunner} in the pipeline options. + */ + public static void main(String args[]) { + // The options are used in two places, for Dataflow service, and + // building DatastoreIO.Read object + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + + if (!options.isReadOnly()) { + // First example: write data to Datastore for reading later. + // + // NOTE: this write does not delete any existing Entities in the Datastore, so if run + // multiple times with the same output dataset, there may be duplicate entries. The + // Datastore Query tool in the Google Developers Console can be used to inspect or erase all + // entries with a particular namespace and/or kind. + DatastoreWordCount.writeDataToDatastore(options); + } + + // Second example: do parallel read from Datastore. + DatastoreWordCount.readDataFromDatastore(options); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/DeDupExample.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/DeDupExample.java new file mode 100644 index 000000000000..9873561e5e72 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/DeDupExample.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; + +/** + * This example uses as input Shakespeare's plays as plaintext files, and will remove any + * duplicate lines across all the files. (The output does not preserve any input order). + * + *

Concepts: the RemoveDuplicates transform, and how to wire transforms together. + * Demonstrates {@link com.google.cloud.dataflow.sdk.io.TextIO.Read}/ + * {@link RemoveDuplicates}/{@link com.google.cloud.dataflow.sdk.io.TextIO.Write}. + * + *

To execute this pipeline locally, specify general pipeline configuration: + * --project=YOUR_PROJECT_ID + * and a local output file or output prefix on GCS: + * --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PREFIX] + * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + * --project=YOUR_PROJECT_ID + * --stagingLocation=gs://YOUR_STAGING_DIRECTORY + * --runner=BlockingDataflowPipelineRunner + * and an output prefix on GCS: + * --output=gs://YOUR_OUTPUT_PREFIX + * + *

The input defaults to {@code gs://dataflow-samples/shakespeare/*} and can be + * overridden with {@code --input}. + */ +public class DeDupExample { + + /** + * Options supported by {@link DeDupExample}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Path to the directory or GCS prefix containing files to read from") + @Default.String("gs://dataflow-samples/shakespeare/*") + String getInput(); + void setInput(String value); + + @Description("Path of the file to write to") + @Default.InstanceFactory(OutputFactory.class) + String getOutput(); + void setOutput(String value); + + /** Returns gs://${STAGING_LOCATION}/"deduped.txt". */ + public static class OutputFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + if (dataflowOptions.getStagingLocation() != null) { + return GcsPath.fromUri(dataflowOptions.getStagingLocation()) + .resolve("deduped.txt").toString(); + } else { + throw new IllegalArgumentException("Must specify --output or --stagingLocation"); + } + } + } + } + + + public static void main(String[] args) + throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + + p.apply(TextIO.Read.named("ReadLines").from(options.getInput())) + .apply(RemoveDuplicates.create()) + .apply(TextIO.Write.named("DedupedShakespeare") + .to(options.getOutput())); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/FilterExamples.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/FilterExamples.java new file mode 100644 index 000000000000..781873a07883 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/FilterExamples.java @@ -0,0 +1,266 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Mean; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; + +/** + * This is an example that demonstrates several approaches to filtering, and use of the Mean + * transform. It shows how to dynamically set parameters by defining and using new pipeline options, + * and how to use a value derived by the pipeline. + * + *

Concepts: The Mean transform; Options configuration; using pipeline-derived data as a side + * input; approaches to filtering, selection, and projection. + * + *

The example reads public samples of weather data from BigQuery. It performs a + * projection on the data, finds the global mean of the temperature readings, filters on readings + * for a single given month, and then outputs only data (for that month) that has a mean temp + * smaller than the derived global mean. +* + *

Note: Before running this example, you must create a BigQuery dataset to contain your output + * table. + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and the BigQuery table for the output: + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ *   [--monthFilter=]
+ * }
+ * 
+ * where optional parameter {@code --monthFilter} is set to a number 1-12. + * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and the BigQuery table for the output: + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ *   [--monthFilter=]
+ * }
+ * 
+ * where optional parameter {@code --monthFilter} is set to a number 1-12. + * + *

The BigQuery input table defaults to {@code clouddataflow-readonly:samples.weather_stations} + * and can be overridden with {@code --input}. + */ +public class FilterExamples { + // Default to using a 1000 row subset of the public weather station table publicdata:samples.gsod. + private static final String WEATHER_SAMPLES_TABLE = + "clouddataflow-readonly:samples.weather_stations"; + static final Logger LOG = Logger.getLogger(FilterExamples.class.getName()); + static final int MONTH_TO_FILTER = 7; + + /** + * Examines each row in the input table. Outputs only the subset of the cells this example + * is interested in-- the mean_temp and year, month, and day-- as a bigquery table row. + */ + static class ProjectionFn extends DoFn { + @Override + public void processElement(ProcessContext c){ + TableRow row = c.element(); + // Grab year, month, day, mean_temp from the row + Integer year = Integer.parseInt((String) row.get("year")); + Integer month = Integer.parseInt((String) row.get("month")); + Integer day = Integer.parseInt((String) row.get("day")); + Double meanTemp = Double.parseDouble(row.get("mean_temp").toString()); + // Prepares the data for writing to BigQuery by building a TableRow object + TableRow outRow = new TableRow() + .set("year", year).set("month", month) + .set("day", day).set("mean_temp", meanTemp); + c.output(outRow); + } + } + + /** + * Implements 'filter' functionality. + * + *

Examines each row in the input table. Outputs only rows from the month + * monthFilter, which is passed in as a parameter during construction of this DoFn. + */ + static class FilterSingleMonthDataFn extends DoFn { + Integer monthFilter; + + public FilterSingleMonthDataFn(Integer monthFilter) { + this.monthFilter = monthFilter; + } + + @Override + public void processElement(ProcessContext c){ + TableRow row = c.element(); + Integer month; + month = (Integer) row.get("month"); + if (month.equals(this.monthFilter)) { + c.output(row); + } + } + } + + /** + * Examines each row (weather reading) in the input table. Output the temperature + * reading for that row ('mean_temp'). + */ + static class ExtractTempFn extends DoFn { + @Override + public void processElement(ProcessContext c){ + TableRow row = c.element(); + Double meanTemp = Double.parseDouble(row.get("mean_temp").toString()); + c.output(meanTemp); + } + } + + + + /* + * Finds the global mean of the mean_temp for each day/record, and outputs + * only data that has a mean temp larger than this global mean. + **/ + static class BelowGlobalMean + extends PTransform, PCollection> { + Integer monthFilter; + + public BelowGlobalMean(Integer monthFilter) { + this.monthFilter = monthFilter; + } + + + @Override + public PCollection apply(PCollection rows) { + + // Extract the mean_temp from each row. + PCollection meanTemps = rows.apply( + ParDo.of(new ExtractTempFn())); + + // Find the global mean, of all the mean_temp readings in the weather data, + // and prepare this singleton PCollectionView for use as a side input. + final PCollectionView globalMeanTemp = + meanTemps.apply(Mean.globally()) + .apply(View.asSingleton()); + + // Rows filtered to remove all but a single month + PCollection monthFilteredRows = rows + .apply(ParDo.of(new FilterSingleMonthDataFn(monthFilter))); + + // Then, use the global mean as a side input, to further filter the weather data. + // By using a side input to pass in the filtering criteria, we can use a value + // that is computed earlier in pipeline execution. + // We'll only output readings with temperatures below this mean. + PCollection filteredRows = monthFilteredRows + .apply(ParDo + .named("ParseAndFilter") + .withSideInputs(globalMeanTemp) + .of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + Double meanTemp = Double.parseDouble(c.element().get("mean_temp").toString()); + Double gTemp = c.sideInput(globalMeanTemp); + if (meanTemp < gTemp) { + c.output(c.element()); + } + } + })); + + return filteredRows; + } + } + + + /** + * Options supported by {@link FilterExamples}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Table to read from, specified as " + + ":.") + @Default.String(WEATHER_SAMPLES_TABLE) + String getInput(); + void setInput(String value); + + @Description("Table to write to, specified as " + + ":.. " + + "The dataset_id must already exist") + @Validation.Required + String getOutput(); + void setOutput(String value); + + @Description("Numeric value of month to filter on") + @Default.Integer(MONTH_TO_FILTER) + Integer getMonthFilter(); + void setMonthFilter(Integer value); + } + + /** + * Helper method to build the table schema for the output table. + */ + private static TableSchema buildWeatherSchemaProjection() { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("year").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("month").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("day").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("mean_temp").setType("FLOAT")); + TableSchema schema = new TableSchema().setFields(fields); + return schema; + } + + public static void main(String[] args) + throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + + TableSchema schema = buildWeatherSchemaProjection(); + + p.apply(BigQueryIO.Read.from(options.getInput())) + .apply(ParDo.of(new ProjectionFn())) + .apply(new BelowGlobalMean(options.getMonthFilter())) + .apply(BigQueryIO.Write + .to(options.getOutput()) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/JoinExamples.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/JoinExamples.java new file mode 100644 index 000000000000..745c5d6719db --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/JoinExamples.java @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +/** + * This example shows how to do a join on two collections. + * It uses a sample of the GDELT 'world event' data (http://goo.gl/OB6oin), joining the event + * 'action' country code against a table that maps country codes to country names. + * + *

Concepts: Join operation; multiple input sources. + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and a local output file or output prefix on GCS: + *
{@code
+ *   --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PREFIX]
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and an output prefix on GCS: + *
{@code
+ *   --output=gs://YOUR_OUTPUT_PREFIX
+ * }
+ */ +public class JoinExamples { + + // A 1000-row sample of the GDELT data here: gdelt-bq:full.events. + private static final String GDELT_EVENTS_TABLE = + "clouddataflow-readonly:samples.gdelt_sample"; + // A table that maps country codes to country names. + private static final String COUNTRY_CODES = + "gdelt-bq:full.crosswalk_geocountrycodetohuman"; + + /** + * Join two collections, using country code as the key. + */ + static PCollection joinEvents(PCollection eventsTable, + PCollection countryCodes) throws Exception { + + final TupleTag eventInfoTag = new TupleTag(); + final TupleTag countryInfoTag = new TupleTag(); + + // transform both input collections to tuple collections, where the keys are country + // codes in both cases. + PCollection> eventInfo = eventsTable.apply( + ParDo.of(new ExtractEventDataFn())); + PCollection> countryInfo = countryCodes.apply( + ParDo.of(new ExtractCountryInfoFn())); + + // country code 'key' -> CGBKR (, ) + PCollection> kvpCollection = KeyedPCollectionTuple + .of(eventInfoTag, eventInfo) + .and(countryInfoTag, countryInfo) + .apply(CoGroupByKey.create()); + + // Process the CoGbkResult elements generated by the CoGroupByKey transform. + // country code 'key' -> string of , + PCollection> finalResultCollection = + kvpCollection.apply(ParDo.named("Process").of( + new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + String countryCode = e.getKey(); + String countryName = "none"; + countryName = e.getValue().getOnly(countryInfoTag); + for (String eventInfo : c.element().getValue().getAll(eventInfoTag)) { + // Generate a string that combines information from both collection values + c.output(KV.of(countryCode, "Country name: " + countryName + + ", Event info: " + eventInfo)); + } + } + })); + + // write to GCS + PCollection formattedResults = finalResultCollection + .apply(ParDo.named("Format").of(new DoFn, String>() { + @Override + public void processElement(ProcessContext c) { + String outputstring = "Country code: " + c.element().getKey() + + ", " + c.element().getValue(); + c.output(outputstring); + } + })); + return formattedResults; + } + + /** + * Examines each row (event) in the input table. Output a KV with the key the country + * code of the event, and the value a string encoding event information. + */ + static class ExtractEventDataFn extends DoFn> { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + String countryCode = (String) row.get("ActionGeo_CountryCode"); + String sqlDate = (String) row.get("SQLDATE"); + String actor1Name = (String) row.get("Actor1Name"); + String sourceUrl = (String) row.get("SOURCEURL"); + String eventInfo = "Date: " + sqlDate + ", Actor1: " + actor1Name + ", url: " + sourceUrl; + c.output(KV.of(countryCode, eventInfo)); + } + } + + + /** + * Examines each row (country info) in the input table. Output a KV with the key the country + * code, and the value the country name. + */ + static class ExtractCountryInfoFn extends DoFn> { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + String countryCode = (String) row.get("FIPSCC"); + String countryName = (String) row.get("HumanName"); + c.output(KV.of(countryCode, countryName)); + } + } + + + /** + * Options supported by {@link JoinExamples}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Path of the file to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) throws Exception { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + // the following two 'applys' create multiple inputs to our pipeline, one for each + // of our two input sources. + PCollection eventsTable = p.apply(BigQueryIO.Read.from(GDELT_EVENTS_TABLE)); + PCollection countryCodes = p.apply(BigQueryIO.Read.from(COUNTRY_CODES)); + PCollection formattedResults = joinEvents(eventsTable, countryCodes); + formattedResults.apply(TextIO.Write.to(options.getOutput())); + p.run(); + } + +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/MaxPerKeyExamples.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/MaxPerKeyExamples.java new file mode 100644 index 000000000000..1c26d0f19e30 --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/MaxPerKeyExamples.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.util.ArrayList; +import java.util.List; + +/** + * An example that reads the public samples of weather data from BigQuery, and finds + * the maximum temperature ('mean_temp') for each month. + * + *

Concepts: The 'Max' statistical combination function, and how to find the max per + * key group. + * + *

Note: Before running this example, you must create a BigQuery dataset to contain your output + * table. + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and the BigQuery table for the output, with the form + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and the BigQuery table for the output: + *
{@code
+ *   --output=YOUR_PROJECT_ID:DATASET_ID.TABLE_ID
+ * }
+ * + *

The BigQuery input table defaults to {@code clouddataflow-readonly:samples.weather_stations } + * and can be overridden with {@code --input}. + */ +public class MaxPerKeyExamples { + // Default to using a 1000 row subset of the public weather station table publicdata:samples.gsod. + private static final String WEATHER_SAMPLES_TABLE = + "clouddataflow-readonly:samples.weather_stations"; + + /** + * Examines each row (weather reading) in the input table. Output the month of the reading, + * and the mean_temp. + */ + static class ExtractTempFn extends DoFn> { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + Integer month = Integer.parseInt((String) row.get("month")); + Double meanTemp = Double.parseDouble(row.get("mean_temp").toString()); + c.output(KV.of(month, meanTemp)); + } + } + + /** + * Format the results to a TableRow, to save to BigQuery. + * + */ + static class FormatMaxesFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + TableRow row = new TableRow() + .set("month", c.element().getKey()) + .set("max_mean_temp", c.element().getValue()); + c.output(row); + } + } + + /** + * Reads rows from a weather data table, and finds the max mean_temp for each + * month via the 'Max' statistical combination function. + */ + static class MaxMeanTemp + extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection rows) { + + // row... => ... + PCollection> temps = rows.apply( + ParDo.of(new ExtractTempFn())); + + // month, mean_temp... => ... + PCollection> tempMaxes = + temps.apply(Max.doublesPerKey()); + + // ... => row... + PCollection results = tempMaxes.apply( + ParDo.of(new FormatMaxesFn())); + + return results; + } + } + + /** + * Options supported by {@link MaxPerKeyExamples}. + * + *

Inherits standard configuration options. + */ + private static interface Options extends PipelineOptions { + @Description("Table to read from, specified as " + + ":.") + @Default.String(WEATHER_SAMPLES_TABLE) + String getInput(); + void setInput(String value); + + @Description("Table to write to, specified as " + + ":.") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) + throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + + // Build the table schema for the output table. + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("month").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("max_mean_temp").setType("FLOAT")); + TableSchema schema = new TableSchema().setFields(fields); + + p.apply(BigQueryIO.Read.from(options.getInput())) + .apply(new MaxMeanTemp()) + .apply(BigQueryIO.Write + .to(options.getOutput()) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); + + p.run(); + } +} diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/README.md b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/README.md new file mode 100644 index 000000000000..99f3080a06ba --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/README.md @@ -0,0 +1,55 @@ + +# "Cookbook" Examples + +This directory holds simple "cookbook" examples, which show how to define +commonly-used data analysis patterns that you would likely incorporate into a +larger Dataflow pipeline. They include: + +

    +
  • BigQueryTornadoes + — An example that reads the public samples of weather data from Google + BigQuery, counts the number of tornadoes that occur in each month, and + writes the results to BigQuery. Demonstrates reading/writing BigQuery, + counting a PCollection, and user-defined PTransforms.
  • +
  • CombinePerKeyExamples + — An example that reads the public "Shakespeare" data, and for + each word in the dataset that exceeds a given length, generates a string + containing the list of play names in which that word appears. + Demonstrates the Combine.perKey + transform, which lets you combine the values in a key-grouped + PCollection. +
  • +
  • DatastoreWordCount + — An example that shows you how to read from Google Cloud Datastore.
  • +
  • DeDupExample + — An example that uses Shakespeare's plays as plain text files, and + removes duplicate lines across all the files. Demonstrates the + RemoveDuplicates, TextIO.Read, + and TextIO.Write transforms, and how to wire transforms together. +
  • +
  • FilterExamples + — An example that shows different approaches to filtering, including + selection and projection. It also shows how to dynamically set parameters + by defining and using new pipeline options, and use how to use a value derived + by a pipeline. Demonstrates the Mean transform, + Options configuration, and using pipeline-derived data as a side + input. +
  • +
  • JoinExamples + — An example that shows how to join two collections. It uses a + sample of the GDELT "world event" + data, joining the event action country code against a table + that maps country codes to country names. Demonstrates the Join + operation, and using multiple input sources. +
  • +
  • MaxPerKeyExamples + — An example that reads the public samples of weather data from BigQuery, + and finds the maximum temperature (mean_temp) for each month. + Demonstrates the Max statistical combination transform, and how to + find the max-per-key group. +
  • +
+ +See the [documentation](https://cloud.google.com/dataflow/getting-started) and the [Examples +README](../../../../../../../../../README.md) for +information about how to run these examples. diff --git a/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/TriggerExample.java b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/TriggerExample.java new file mode 100644 index 000000000000..ce5e08e7d2fa --- /dev/null +++ b/examples/src/main/java/com/google/cloud/dataflow/examples/cookbook/TriggerExample.java @@ -0,0 +1,564 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.examples.common.DataflowExampleOptions; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.common.ExampleBigQueryTableOptions; +import com.google.cloud.dataflow.examples.common.ExamplePubsubTopicOptions; +import com.google.cloud.dataflow.examples.common.PubsubFileInjector; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.IntraBundleParallelization; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterEach; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterProcessingTime; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Repeatedly; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * This example illustrates the basic concepts behind triggering. It shows how to use different + * trigger definitions to produce partial (speculative) results before all the data is processed and + * to control when updated results are produced for late data. The example performs a streaming + * analysis of the data coming in from PubSub and writes the results to BigQuery. It divides the + * data into {@link Window windows} to be processed, and demonstrates using various kinds of {@link + * Trigger triggers} to control when the results for each window are emitted. + * + *

This example uses a portion of real traffic data from San Diego freeways. It contains + * readings from sensor stations set up along each freeway. Each sensor reading includes a + * calculation of the 'total flow' across all lanes in that freeway direction. + * + *

Concepts: + *

+ *   1. The default triggering behavior
+ *   2. Late data with the default trigger
+ *   3. How to get speculative estimates
+ *   4. Combining late data and speculative estimates
+ * 
+ * + *

Before running this example, it will be useful to familiarize yourself with Dataflow triggers + * and understand the concept of 'late data', + * See: + * https://cloud.google.com/dataflow/model/triggers and + * + * https://cloud.google.com/dataflow/model/windowing#Advanced + * + *

The example pipeline reads data from a Pub/Sub topic. By default, running the example will + * also run an auxiliary pipeline to inject data from the default {@code --input} file to the + * {@code --pubsubTopic}. The auxiliary pipeline puts a timestamp on the injected data so that the + * example pipeline can operate on event time (rather than arrival time). The auxiliary + * pipeline also randomly simulates late data, by setting the timestamps of some of the data + * elements to be in the past. You may override the default {@code --input} with the file of your + * choosing or set {@code --input=""} which will disable the automatic Pub/Sub injection, and allow + * you to use a separate tool to publish to the given topic. + * + *

The example is configured to use the default Pub/Sub topic and the default BigQuery table + * from the example common package (there are no defaults for a general Dataflow pipeline). + * You can override them by using the {@code --pubsubTopic}, {@code --bigQueryDataset}, and + * {@code --bigQueryTable} options. If the Pub/Sub topic or the BigQuery table do not exist, + * the example will try to create them. + * + *

The pipeline outputs its results to a BigQuery table. + * Here are some queries you can use to see interesting results: + * Replace {@code } in the query below with the name of the BigQuery table. + * Replace {@code } in the query below with the window interval. + * + *

To see the results of the default trigger, + * Note: When you start up your pipeline, you'll initially see results from 'late' data. Wait after + * the window duration, until the first pane of non-late data has been emitted, to see more + * interesting results. + * {@code SELECT * FROM enter_table_name WHERE trigger_type = "default" ORDER BY window DESC} + * + *

To see the late data i.e. dropped by the default trigger, + * {@code SELECT * FROM WHERE trigger_type = "withAllowedLateness" and + * (timing = "LATE" or timing = "ON_TIME") and freeway = "5" ORDER BY window DESC, processing_time} + * + *

To see the the difference between accumulation mode and discarding mode, + * {@code SELECT * FROM WHERE (timing = "LATE" or timing = "ON_TIME") AND + * (trigger_type = "withAllowedLateness" or trigger_type = "sequential") and freeway = "5" ORDER BY + * window DESC, processing_time} + * + *

To see speculative results every minute, + * {@code SELECT * FROM WHERE trigger_type = "speculative" and freeway = "5" + * ORDER BY window DESC, processing_time} + * + *

To see speculative results every five minutes after the end of the window + * {@code SELECT * FROM WHERE trigger_type = "sequential" and timing != "EARLY" + * and freeway = "5" ORDER BY window DESC, processing_time} + * + *

To see the first and the last pane for a freeway in a window for all the trigger types, + * {@code SELECT * FROM WHERE (isFirst = true or isLast = true) ORDER BY window} + * + *

To reduce the number of results for each query we can add additional where clauses. + * For examples, To see the results of the default trigger, + * {@code SELECT * FROM WHERE trigger_type = "default" AND freeway = "5" AND + * window = ""} + * + *

The example will try to cancel the pipelines on the signal to terminate the process (CTRL-C) + * and then exits. + */ + +public class TriggerExample { + //Numeric value of fixed window duration, in minutes + public static final int WINDOW_DURATION = 30; + // Constants used in triggers. + // Speeding up ONE_MINUTE or FIVE_MINUTES helps you get an early approximation of results. + // ONE_MINUTE is used only with processing time before the end of the window + public static final Duration ONE_MINUTE = Duration.standardMinutes(1); + // FIVE_MINUTES is used only with processing time after the end of the window + public static final Duration FIVE_MINUTES = Duration.standardMinutes(5); + // ONE_DAY is used to specify the amount of lateness allowed for the data elements. + public static final Duration ONE_DAY = Duration.standardDays(1); + + /** + * This transform demonstrates using triggers to control when data is produced for each window + * Consider an example to understand the results generated by each type of trigger. + * The example uses "freeway" as the key. Event time is the timestamp associated with the data + * element and processing time is the time when the data element gets processed in the pipeline. + * For freeway 5, suppose there are 10 elements in the [10:00:00, 10:30:00) window. + * Key (freeway) | Value (total_flow) | event time | processing time + * 5 | 50 | 10:00:03 | 10:00:47 + * 5 | 30 | 10:01:00 | 10:01:03 + * 5 | 30 | 10:02:00 | 11:07:00 + * 5 | 20 | 10:04:10 | 10:05:15 + * 5 | 60 | 10:05:00 | 11:03:00 + * 5 | 20 | 10:05:01 | 11.07:30 + * 5 | 60 | 10:15:00 | 10:27:15 + * 5 | 40 | 10:26:40 | 10:26:43 + * 5 | 60 | 10:27:20 | 10:27:25 + * 5 | 60 | 10:29:00 | 11:11:00 + * + *

Dataflow tracks a watermark which records up to what point in event time the data is + * complete. For the purposes of the example, we'll assume the watermark is approximately 15m + * behind the current processing time. In practice, the actual value would vary over time based + * on the systems knowledge of the current PubSub delay and contents of the backlog (data + * that has not yet been processed). + * + *

If the watermark is 15m behind, then the window [10:00:00, 10:30:00) (in event time) would + * close at 10:44:59, when the watermark passes 10:30:00. + */ + static class CalculateTotalFlow + extends PTransform >, PCollectionList> { + private int windowDuration; + + CalculateTotalFlow(int windowDuration) { + this.windowDuration = windowDuration; + } + + @Override + public PCollectionList apply(PCollection> flowInfo) { + + // Concept #1: The default triggering behavior + // By default Dataflow uses a trigger which fires when the watermark has passed the end of the + // window. This would be written {@code Repeatedly.forever(AfterWatermark.pastEndOfWindow())}. + + // The system also defaults to dropping late data -- data which arrives after the watermark + // has passed the event timestamp of the arriving element. This means that the default trigger + // will only fire once. + + // Each pane produced by the default trigger with no allowed lateness will be the first and + // last pane in the window, and will be ON_TIME. + + // The results for the example above with the default trigger and zero allowed lateness + // would be: + // Key (freeway) | Value (total_flow) | number_of_records | isFirst | isLast | timing + // 5 | 260 | 6 | true | true | ON_TIME + + // At 11:03:00 (processing time) the system watermark may have advanced to 10:54:00. As a + // result, when the data record with event time 10:05:00 arrives at 11:03:00, it is considered + // late, and dropped. + + PCollection defaultTriggerResults = flowInfo + .apply("Default", Window + // The default window duration values work well if you're running the default input + // file. You may want to adjust the window duration otherwise. + .>into(FixedWindows.of(Duration.standardMinutes(windowDuration))) + // The default trigger first emits output when the system's watermark passes the end + // of the window. + .triggering(Repeatedly.forever(AfterWatermark.pastEndOfWindow())) + // Late data is dropped + .withAllowedLateness(Duration.ZERO) + // Discard elements after emitting each pane. + // With no allowed lateness and the specified trigger there will only be a single + // pane, so this doesn't have a noticeable effect. See concept 2 for more details. + .discardingFiredPanes()) + .apply(new TotalFlow("default")); + + // Concept #2: Late data with the default trigger + // This uses the same trigger as concept #1, but allows data that is up to ONE_DAY late. This + // leads to each window staying open for ONE_DAY after the watermark has passed the end of the + // window. Any late data will result in an additional pane being fired for that same window. + + // The first pane produced will be ON_TIME and the remaining panes will be LATE. + // To definitely get the last pane when the window closes, use + // .withAllowedLateness(ONE_DAY, ClosingBehavior.FIRE_ALWAYS). + + // The results for the example above with the default trigger and ONE_DAY allowed lateness + // would be: + // Key (freeway) | Value (total_flow) | number_of_records | isFirst | isLast | timing + // 5 | 260 | 6 | true | false | ON_TIME + // 5 | 60 | 1 | false | false | LATE + // 5 | 30 | 1 | false | false | LATE + // 5 | 20 | 1 | false | false | LATE + // 5 | 60 | 1 | false | false | LATE + PCollection withAllowedLatenessResults = flowInfo + .apply("WithLateData", Window + .>into(FixedWindows.of(Duration.standardMinutes(windowDuration))) + // Late data is emitted as it arrives + .triggering(Repeatedly.forever(AfterWatermark.pastEndOfWindow())) + // Once the output is produced, the pane is dropped and we start preparing the next + // pane for the window + .discardingFiredPanes() + // Late data is handled up to one day + .withAllowedLateness(ONE_DAY)) + .apply(new TotalFlow("withAllowedLateness")); + + // Concept #3: How to get speculative estimates + // We can specify a trigger that fires independent of the watermark, for instance after + // ONE_MINUTE of processing time. This allows us to produce speculative estimates before + // all the data is available. Since we don't have any triggers that depend on the watermark + // we don't get an ON_TIME firing. Instead, all panes are either EARLY or LATE. + + // We also use accumulatingFiredPanes to build up the results across each pane firing. + + // The results for the example above for this trigger would be: + // Key (freeway) | Value (total_flow) | number_of_records | isFirst | isLast | timing + // 5 | 80 | 2 | true | false | EARLY + // 5 | 100 | 3 | false | false | EARLY + // 5 | 260 | 6 | false | false | EARLY + // 5 | 320 | 7 | false | false | LATE + // 5 | 370 | 9 | false | false | LATE + // 5 | 430 | 10 | false | false | LATE + PCollection speculativeResults = flowInfo + .apply("Speculative" , Window + .>into(FixedWindows.of(Duration.standardMinutes(windowDuration))) + // Trigger fires every minute. + .triggering(Repeatedly.forever(AfterProcessingTime.pastFirstElementInPane() + // Speculative every ONE_MINUTE + .plusDelayOf(ONE_MINUTE))) + // After emitting each pane, it will continue accumulating the elements so that each + // approximation includes all of the previous data in addition to the newly arrived + // data. + .accumulatingFiredPanes() + .withAllowedLateness(ONE_DAY)) + .apply(new TotalFlow("speculative")); + + // Concept #4: Combining late data and speculative estimates + // We can put the previous concepts together to get EARLY estimates, an ON_TIME result, + // and LATE updates based on late data. + + // Each time a triggering condition is satisfied it advances to the next trigger. + // If there are new elements this trigger emits a window under following condition: + // > Early approximations every minute till the end of the window. + // > An on-time firing when the watermark has passed the end of the window + // > Every five minutes of late data. + + // Every pane produced will either be EARLY, ON_TIME or LATE. + + // The results for the example above for this trigger would be: + // Key (freeway) | Value (total_flow) | number_of_records | isFirst | isLast | timing + // 5 | 80 | 2 | true | false | EARLY + // 5 | 100 | 3 | false | false | EARLY + // 5 | 260 | 6 | false | false | EARLY + // [First pane fired after the end of the window] + // 5 | 320 | 7 | false | false | ON_TIME + // 5 | 430 | 10 | false | false | LATE + + // For more possibilities of how to build advanced triggers, see {@link Trigger}. + PCollection sequentialResults = flowInfo + .apply("Sequential", Window + .>into(FixedWindows.of(Duration.standardMinutes(windowDuration))) + .triggering(AfterEach.inOrder( + Repeatedly.forever(AfterProcessingTime.pastFirstElementInPane() + // Speculative every ONE_MINUTE + .plusDelayOf(ONE_MINUTE)).orFinally(AfterWatermark.pastEndOfWindow()), + Repeatedly.forever(AfterProcessingTime.pastFirstElementInPane() + // Late data every FIVE_MINUTES + .plusDelayOf(FIVE_MINUTES)))) + .accumulatingFiredPanes() + // For up to ONE_DAY + .withAllowedLateness(ONE_DAY)) + .apply(new TotalFlow("sequential")); + + // Adds the results generated by each trigger type to a PCollectionList. + PCollectionList resultsList = PCollectionList.of(defaultTriggerResults) + .and(withAllowedLatenessResults) + .and(speculativeResults) + .and(sequentialResults); + + return resultsList; + } + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The remaining parts of the pipeline are needed to produce the output for each + // concept above. Not directly relevant to understanding the trigger examples. + + /** + * Calculate total flow and number of records for each freeway and format the results to TableRow + * objects, to save to BigQuery. + */ + static class TotalFlow extends + PTransform >, PCollection> { + private String triggerType; + + public TotalFlow(String triggerType) { + this.triggerType = triggerType; + } + + @Override + public PCollection apply(PCollection> flowInfo) { + PCollection>> flowPerFreeway = flowInfo + .apply(GroupByKey.create()); + + PCollection> results = flowPerFreeway.apply(ParDo.of( + new DoFn >, KV>() { + + @Override + public void processElement(ProcessContext c) throws Exception { + Iterable flows = c.element().getValue(); + Integer sum = 0; + Long numberOfRecords = 0L; + for (Integer value : flows) { + sum += value; + numberOfRecords++; + } + c.output(KV.of(c.element().getKey(), sum + "," + numberOfRecords)); + } + })); + PCollection output = results.apply(ParDo.of(new FormatTotalFlow(triggerType))); + return output; + } + } + + /** + * Format the results of the Total flow calculation to a TableRow, to save to BigQuery. + * Adds the triggerType, pane information, processing time and the window timestamp. + * */ + static class FormatTotalFlow extends DoFn, TableRow> + implements RequiresWindowAccess { + private String triggerType; + + public FormatTotalFlow(String triggerType) { + this.triggerType = triggerType; + } + @Override + public void processElement(ProcessContext c) throws Exception { + String[] values = c.element().getValue().split(","); + TableRow row = new TableRow() + .set("trigger_type", triggerType) + .set("freeway", c.element().getKey()) + .set("total_flow", Integer.parseInt(values[0])) + .set("number_of_records", Long.parseLong(values[1])) + .set("window", c.window().toString()) + .set("isFirst", c.pane().isFirst()) + .set("isLast", c.pane().isLast()) + .set("timing", c.pane().getTiming().toString()) + .set("event_time", c.timestamp().toString()) + .set("processing_time", Instant.now().toString()); + c.output(row); + } + } + + /** + * Extract the freeway and total flow in a reading. + * Freeway is used as key since we are calculating the total flow for each freeway. + */ + static class ExtractFlowInfo extends DoFn> { + @Override + public void processElement(ProcessContext c) throws Exception { + String[] laneInfo = c.element().split(","); + if (laneInfo[0].equals("timestamp")) { + // Header row + return; + } + if (laneInfo.length < 48) { + //Skip the invalid input. + return; + } + String freeway = laneInfo[2]; + Integer totalFlow = tryIntegerParse(laneInfo[7]); + // Ignore the records with total flow 0 to easily understand the working of triggers. + // Skip the records with total flow -1 since they are invalid input. + if (totalFlow == null || totalFlow <= 0) { + return; + } + c.output(KV.of(freeway, totalFlow)); + } + } + + /** + * Inherits standard configuration options. + */ + public interface TrafficFlowOptions + extends ExamplePubsubTopicOptions, ExampleBigQueryTableOptions, DataflowExampleOptions { + + @Description("Input file to inject to Pub/Sub topic") + @Default.String("gs://dataflow-samples/traffic_sensor/" + + "Freeways-5Minaa2010-01-01_to_2010-02-15.csv") + String getInput(); + void setInput(String value); + + @Description("Numeric value of window duration for fixed windows, in minutes") + @Default.Integer(WINDOW_DURATION) + Integer getWindowDuration(); + void setWindowDuration(Integer value); + } + + private static final String PUBSUB_TIMESTAMP_LABEL_KEY = "timestamp_ms"; + + public static void main(String[] args) throws Exception { + TrafficFlowOptions options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(TrafficFlowOptions.class); + options.setStreaming(true); + + // In order to cancel the pipelines automatically, + // {@code DataflowPipelineRunner} is forced to be used. + options.setRunner(DataflowPipelineRunner.class); + options.setBigQuerySchema(getSchema()); + + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options); + dataflowUtils.setup(); + + Pipeline pipeline = Pipeline.create(options); + + TableReference tableRef = getTableReference(options.getProject(), + options.getBigQueryDataset(), options.getBigQueryTable()); + + PCollectionList resultList = pipeline.apply(PubsubIO.Read.named("ReadPubsubInput") + .timestampLabel(PUBSUB_TIMESTAMP_LABEL_KEY) + .topic(options.getPubsubTopic())) + .apply(ParDo.of(new ExtractFlowInfo())) + .apply(new CalculateTotalFlow(options.getWindowDuration())); + + for (int i = 0; i < resultList.size(); i++){ + resultList.get(i).apply(BigQueryIO.Write.to(tableRef).withSchema(getSchema())); + } + + PipelineResult result = pipeline.run(); + if (!options.getInput().isEmpty()){ + //Inject the data into the pubsub topic + dataflowUtils.runInjectorPipeline(runInjector(options)); + } + // dataflowUtils will try to cancel the pipeline and the injector before the program exits. + dataflowUtils.waitToFinish(result); + } + + private static Pipeline runInjector(TrafficFlowOptions options){ + DataflowPipelineOptions copiedOptions = options.cloneAs(DataflowPipelineOptions.class); + copiedOptions.setStreaming(false); + copiedOptions.setNumWorkers(options.as(DataflowExampleOptions.class).getInjectorNumWorkers()); + copiedOptions.setJobName(options.getJobName() + "-injector"); + Pipeline injectorPipeline = Pipeline.create(copiedOptions); + injectorPipeline + .apply(TextIO.Read.named("ReadMyFile").from(options.getInput())) + .apply(ParDo.named("InsertRandomDelays").of(new InsertDelays())) + .apply(IntraBundleParallelization.of(PubsubFileInjector + .withTimestampLabelKey(PUBSUB_TIMESTAMP_LABEL_KEY) + .publish(options.getPubsubTopic())) + .withMaxParallelism(20)); + + return injectorPipeline; + } + + /** + * Add current time to each record. + * Also insert a delay at random to demo the triggers. + */ + public static class InsertDelays extends DoFn { + private static final double THRESHOLD = 0.001; + // MIN_DELAY and MAX_DELAY in minutes. + private static final int MIN_DELAY = 1; + private static final int MAX_DELAY = 100; + + @Override + public void processElement(ProcessContext c) throws Exception { + Instant timestamp = Instant.now(); + if (Math.random() < THRESHOLD){ + int range = MAX_DELAY - MIN_DELAY; + int delayInMinutes = (int) (Math.random() * range) + MIN_DELAY; + long delayInMillis = TimeUnit.MINUTES.toMillis(delayInMinutes); + timestamp = new Instant(timestamp.getMillis() - delayInMillis); + } + c.outputWithTimestamp(c.element(), timestamp); + } + } + + + /**Sets the table reference. **/ + private static TableReference getTableReference(String project, String dataset, String table){ + TableReference tableRef = new TableReference(); + tableRef.setProjectId(project); + tableRef.setDatasetId(dataset); + tableRef.setTableId(table); + return tableRef; + } + + /** Defines the BigQuery schema used for the output. */ + private static TableSchema getSchema() { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("trigger_type").setType("STRING")); + fields.add(new TableFieldSchema().setName("freeway").setType("STRING")); + fields.add(new TableFieldSchema().setName("total_flow").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("number_of_records").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("window").setType("STRING")); + fields.add(new TableFieldSchema().setName("isFirst").setType("BOOLEAN")); + fields.add(new TableFieldSchema().setName("isLast").setType("BOOLEAN")); + fields.add(new TableFieldSchema().setName("timing").setType("STRING")); + fields.add(new TableFieldSchema().setName("event_time").setType("TIMESTAMP")); + fields.add(new TableFieldSchema().setName("processing_time").setType("TIMESTAMP")); + TableSchema schema = new TableSchema().setFields(fields); + return schema; + } + + private static Integer tryIntegerParse(String number) { + try { + return Integer.parseInt(number); + } catch (NumberFormatException e) { + return null; + } + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/MinimalWordCountJava8.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/MinimalWordCountJava8.java new file mode 100644 index 000000000000..c115ea0a3336 --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/MinimalWordCountJava8.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.FlatMapElements; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import java.util.Arrays; + +/** + * An example that counts words in Shakespeare, using Java 8 language features. + * + *

See {@link MinimalWordCount} for a comprehensive explanation. + */ +public class MinimalWordCountJava8 { + + public static void main(String[] args) { + DataflowPipelineOptions options = PipelineOptionsFactory.create() + .as(DataflowPipelineOptions.class); + + options.setRunner(BlockingDataflowPipelineRunner.class); + + // CHANGE 1 of 3: Your project ID is required in order to run your pipeline on the Google Cloud. + options.setProject("SET_YOUR_PROJECT_ID_HERE"); + + // CHANGE 2 of 3: Your Google Cloud Storage path is required for staging local files. + options.setStagingLocation("gs://SET_YOUR_BUCKET_NAME_HERE/AND_STAGING_DIRECTORY"); + + Pipeline p = Pipeline.create(options); + + p.apply(TextIO.Read.from("gs://dataflow-samples/shakespeare/*")) + .apply(FlatMapElements.via((String word) -> Arrays.asList(word.split("[^a-zA-Z']+"))) + .withOutputType(new TypeDescriptor() {})) + .apply(Filter.byPredicate((String word) -> !word.isEmpty())) + .apply(Count.perElement()) + .apply(MapElements + .via((KV wordCount) -> wordCount.getKey() + ": " + wordCount.getValue()) + .withOutputType(new TypeDescriptor() {})) + + // CHANGE 3 of 3: The Google Cloud Storage path is required for outputting the results to. + .apply(TextIO.Write.to("gs://YOUR_OUTPUT_BUCKET/AND_OUTPUT_PREFIX")); + + p.run(); + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/GameStats.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/GameStats.java new file mode 100644 index 000000000000..39d7a760c0a7 --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/GameStats.java @@ -0,0 +1,347 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.complete.game.utils.WriteWindowedToBigQuery; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.Mean; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.joda.time.DateTimeZone; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.TimeZone; + +/** + * This class is the fourth in a series of four pipelines that tell a story in a 'gaming' + * domain, following {@link UserScore}, {@link HourlyTeamScore}, and {@link LeaderBoard}. + * New concepts: session windows and finding session duration; use of both + * singleton and non-singleton side inputs. + * + *

This pipeline builds on the {@link LeaderBoard} functionality, and adds some "business + * intelligence" analysis: abuse detection and usage patterns. The pipeline derives the Mean user + * score sum for a window, and uses that information to identify likely spammers/robots. (The robots + * have a higher click rate than the human users). The 'robot' users are then filtered out when + * calculating the team scores. + * + *

Additionally, user sessions are tracked: that is, we find bursts of user activity using + * session windows. Then, the mean session duration information is recorded in the context of + * subsequent fixed windowing. (This could be used to tell us what games are giving us greater + * user retention). + * + *

Run {@link injector.Injector} to generate pubsub data for this pipeline. The Injector + * documentation provides more detail. + * + *

To execute this pipeline using the Dataflow service, specify the pipeline configuration + * like this: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ *   --dataset=YOUR-DATASET
+ *   --topic=projects/YOUR-PROJECT/topics/YOUR-TOPIC
+ * }
+ * 
+ * where the BigQuery dataset you specify must already exist. The PubSub topic you specify should + * be the same topic to which the Injector is publishing. + */ +public class GameStats extends LeaderBoard { + + private static final String TIMESTAMP_ATTRIBUTE = "timestamp_ms"; + private static final Logger LOG = LoggerFactory.getLogger(GameStats.class); + + private static DateTimeFormatter fmt = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS") + .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))); + static final Duration FIVE_MINUTES = Duration.standardMinutes(5); + static final Duration TEN_MINUTES = Duration.standardMinutes(10); + + + /** + * Filter out all but those users with a high clickrate, which we will consider as 'spammy' uesrs. + * We do this by finding the mean total score per user, then using that information as a side + * input to filter out all but those user scores that are > (mean * SCORE_WEIGHT) + */ + // [START DocInclude_AbuseDetect] + public static class CalculateSpammyUsers + extends PTransform>, PCollection>> { + private static final Logger LOG = LoggerFactory.getLogger(CalculateSpammyUsers.class); + private static final double SCORE_WEIGHT = 2.5; + + @Override + public PCollection> apply(PCollection> userScores) { + + // Get the sum of scores for each user. + PCollection> sumScores = userScores + .apply("UserSum", Sum.integersPerKey()); + + // Extract the score from each element, and use it to find the global mean. + final PCollectionView globalMeanScore = sumScores.apply(Values.create()) + .apply(Mean.globally().asSingletonView()); + + // Filter the user sums using the global mean. + PCollection> filtered = sumScores + .apply(ParDo + .named("ProcessAndFilter") + // use the derived mean total score as a side input + .withSideInputs(globalMeanScore) + .of(new DoFn, KV>() { + private final Aggregator numSpammerUsers = + createAggregator("SpammerUsers", new Sum.SumLongFn()); + @Override + public void processElement(ProcessContext c) { + Integer score = c.element().getValue(); + Double gmc = c.sideInput(globalMeanScore); + if (score > (gmc * SCORE_WEIGHT)) { + LOG.info("user " + c.element().getKey() + " spammer score " + score + + " with mean " + gmc); + numSpammerUsers.addValue(1L); + c.output(c.element()); + } + } + })); + return filtered; + } + } + // [END DocInclude_AbuseDetect] + + /** + * Calculate and output an element's session duration. + */ + private static class UserSessionInfoFn extends DoFn, Integer> + implements RequiresWindowAccess { + + @Override + public void processElement(ProcessContext c) { + IntervalWindow w = (IntervalWindow) c.window(); + int duration = new Duration( + w.start(), w.end()).toPeriod().toStandardMinutes().getMinutes(); + c.output(duration); + } + } + + + /** + * Options supported by {@link GameStats}. + */ + static interface Options extends LeaderBoard.Options { + @Description("Pub/Sub topic to read from") + @Validation.Required + String getTopic(); + void setTopic(String value); + + @Description("Numeric value of fixed window duration for user analysis, in minutes") + @Default.Integer(60) + Integer getFixedWindowDuration(); + void setFixedWindowDuration(Integer value); + + @Description("Numeric value of gap between user sessions, in minutes") + @Default.Integer(5) + Integer getSessionGap(); + void setSessionGap(Integer value); + + @Description("Numeric value of fixed window for finding mean of user session duration, " + + "in minutes") + @Default.Integer(30) + Integer getUserActivityWindowDuration(); + void setUserActivityWindowDuration(Integer value); + + @Description("Prefix used for the BigQuery table names") + @Default.String("game_stats") + String getTablePrefix(); + void setTablePrefix(String value); + } + + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is used to write information about team score sums. + */ + protected static Map>> + configureWindowedWrite() { + Map>> tableConfigure = + new HashMap>>(); + tableConfigure.put("team", + new WriteWindowedToBigQuery.FieldInfo>("STRING", + c -> c.element().getKey())); + tableConfigure.put("total_score", + new WriteWindowedToBigQuery.FieldInfo>("INTEGER", + c -> c.element().getValue())); + tableConfigure.put("window_start", + new WriteWindowedToBigQuery.FieldInfo>("STRING", + c -> { IntervalWindow w = (IntervalWindow) c.window(); + return fmt.print(w.start()); })); + tableConfigure.put("processing_time", + new WriteWindowedToBigQuery.FieldInfo>( + "STRING", c -> fmt.print(Instant.now()))); + return tableConfigure; + } + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is used to write information about mean user session time. + */ + protected static Map> + configureSessionWindowWrite() { + + Map> tableConfigure = + new HashMap>(); + tableConfigure.put("window_start", + new WriteWindowedToBigQuery.FieldInfo("STRING", + c -> { IntervalWindow w = (IntervalWindow) c.window(); + return fmt.print(w.start()); })); + tableConfigure.put("mean_duration", + new WriteWindowedToBigQuery.FieldInfo("FLOAT", c -> c.element())); + return tableConfigure; + } + + + + public static void main(String[] args) throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + // Enforce that this pipeline is always run in streaming mode. + options.setStreaming(true); + // Allow the pipeline to be cancelled automatically. + options.setRunner(DataflowPipelineRunner.class); + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options); + Pipeline pipeline = Pipeline.create(options); + + // Read Events from Pub/Sub using custom timestamps + PCollection rawEvents = pipeline + .apply(PubsubIO.Read.timestampLabel(TIMESTAMP_ATTRIBUTE).topic(options.getTopic())) + .apply(ParDo.named("ParseGameEvent").of(new ParseEventFn())); + + // Extract username/score pairs from the event stream + PCollection> userEvents = + rawEvents.apply("ExtractUserScore", + MapElements.via((GameActionInfo gInfo) -> KV.of(gInfo.getUser(), gInfo.getScore())) + .withOutputType(new TypeDescriptor>() {})); + + // Calculate the total score per user over fixed windows, and + // cumulative updates for late data. + final PCollectionView> spammersView = userEvents + .apply(Window.named("FixedWindowsUser") + .>into(FixedWindows.of( + Duration.standardMinutes(options.getFixedWindowDuration()))) + ) + + // Filter out everyone but those with (SCORE_WEIGHT * avg) clickrate. + // These might be robots/spammers. + .apply("CalculateSpammyUsers", new CalculateSpammyUsers()) + // Derive a view from the collection of spammer users. It will be used as a side input + // in calculating the team score sums, below. + .apply("CreateSpammersView", View.asMap()); + + // [START DocInclude_FilterAndCalc] + // Calculate the total score per team over fixed windows, + // and emit cumulative updates for late data. Uses the side input derived above-- the set of + // suspected robots-- to filter out scores from those users from the sum. + // Write the results to BigQuery. + rawEvents + .apply(Window.named("WindowIntoFixedWindows") + .into(FixedWindows.of( + Duration.standardMinutes(options.getFixedWindowDuration()))) + ) + // Filter out the detected spammer users, using the side input derived above. + .apply(ParDo.named("FilterOutSpammers") + .withSideInputs(spammersView) + .of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + // If the user is not in the spammers Map, output the data element. + if (c.sideInput(spammersView).get(c.element().getUser().trim()) == null) { + c.output(c.element()); + }}})) + // Extract and sum teamname/score pairs from the event data. + .apply("ExtractTeamScore", new ExtractAndSumScore("team")) + // [END DocInclude_FilterAndCalc] + // Write the result to BigQuery + .apply("WriteTeamSums", + new WriteWindowedToBigQuery>( + options.getTablePrefix() + "_team", configureWindowedWrite())); + + + // [START DocInclude_SessionCalc] + // Detect user sessions-- that is, a burst of activity separated by a gap from further + // activity. Find and record the mean session lengths. + // This information could help the game designers track the changing user engagement + // as their set of games changes. + userEvents + .apply(Window.named("WindowIntoSessions") + .>into( + Sessions.withGapDuration(Duration.standardMinutes(options.getSessionGap()))) + .withOutputTimeFn(OutputTimeFns.outputAtEndOfWindow())) + // For this use, we care only about the existence of the session, not any particular + // information aggregated over it, so the following is an efficient way to do that. + .apply(Combine.perKey(x -> 0)) + // Get the duration per session. + .apply("UserSessionActivity", ParDo.of(new UserSessionInfoFn())) + // [END DocInclude_SessionCalc] + // [START DocInclude_Rewindow] + // Re-window to process groups of session sums according to when the sessions complete. + .apply(Window.named("WindowToExtractSessionMean") + .into( + FixedWindows.of(Duration.standardMinutes(options.getUserActivityWindowDuration())))) + // Find the mean session duration in each window. + .apply(Mean.globally().withoutDefaults()) + // Write this info to a BigQuery table. + .apply("WriteAvgSessionLength", + new WriteWindowedToBigQuery( + options.getTablePrefix() + "_sessions", configureSessionWindowWrite())); + // [END DocInclude_Rewindow] + + + // Run the pipeline and wait for the pipeline to finish; capture cancellation requests from the + // command line. + PipelineResult result = pipeline.run(); + dataflowUtils.waitToFinish(result); + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/HourlyTeamScore.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/HourlyTeamScore.java new file mode 100644 index 000000000000..481b9df35b1d --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/HourlyTeamScore.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.cloud.dataflow.examples.complete.game.utils.WriteWindowedToBigQuery; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.WithTimestamps; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.joda.time.DateTimeZone; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +import java.util.HashMap; +import java.util.Map; +import java.util.TimeZone; + +/** + * This class is the second in a series of four pipelines that tell a story in a 'gaming' + * domain, following {@link UserScore}. In addition to the concepts introduced in {@link UserScore}, + * new concepts include: windowing and element timestamps; use of {@code Filter.byPredicate()}. + * + *

This pipeline processes data collected from gaming events in batch, building on {@link + * UserScore} but using fixed windows. It calculates the sum of scores per team, for each window, + * optionally allowing specification of two timestamps before and after which data is filtered out. + * This allows a model where late data collected after the intended analysis window can be included, + * and any late-arriving data prior to the beginning of the analysis window can be removed as well. + * By using windowing and adding element timestamps, we can do finer-grained analysis than with the + * {@link UserScore} pipeline. However, our batch processing is high-latency, in that we don't get + * results from plays at the beginning of the batch's time period until the batch is processed. + * + *

To execute this pipeline using the Dataflow service, specify the pipeline configuration + * like this: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ *   --dataset=YOUR-DATASET
+ * }
+ * 
+ * where the BigQuery dataset you specify must already exist. + * + *

Optionally include {@code --input} to specify the batch input file path. + * To indicate a time after which the data should be filtered out, include the + * {@code --stopMin} arg. E.g., {@code --stopMin=2015-10-18-23-59} indicates that any data + * timestamped after 23:59 PST on 2015-10-18 should not be included in the analysis. + * To indicate a time before which data should be filtered out, include the {@code --startMin} arg. + * If you're using the default input specified in {@link UserScore}, + * "gs://dataflow-samples/game/gaming_data*.csv", then + * {@code --startMin=2015-11-16-16-10 --stopMin=2015-11-17-16-10} are good values. + */ +public class HourlyTeamScore extends UserScore { + + private static DateTimeFormatter fmt = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS") + .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))); + private static DateTimeFormatter minFmt = + DateTimeFormat.forPattern("yyyy-MM-dd-HH-mm") + .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))); + + + /** + * Options supported by {@link HourlyTeamScore}. + */ + static interface Options extends UserScore.Options { + + @Description("Numeric value of fixed window duration, in minutes") + @Default.Integer(60) + Integer getWindowDuration(); + void setWindowDuration(Integer value); + + @Description("String representation of the first minute after which to generate results," + + "in the format: yyyy-MM-dd-HH-mm . This time should be in PST." + + "Any input data timestamped prior to that minute won't be included in the sums.") + @Default.String("1970-01-01-00-00") + String getStartMin(); + void setStartMin(String value); + + @Description("String representation of the first minute for which to not generate results," + + "in the format: yyyy-MM-dd-HH-mm . This time should be in PST." + + "Any input data timestamped after that minute won't be included in the sums.") + @Default.String("2100-01-01-00-00") + String getStopMin(); + void setStopMin(String value); + + @Description("The BigQuery table name. Should not already exist.") + @Default.String("hourly_team_score") + String getTableName(); + void setTableName(String value); + } + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is passed to the {@link WriteWindowedToBigQuery} constructor to write team score sums and + * includes information about window start time. + */ + protected static Map>> + configureWindowedTableWrite() { + Map>> tableConfig = + new HashMap>>(); + tableConfig.put("team", + new WriteWindowedToBigQuery.FieldInfo>("STRING", + c -> c.element().getKey())); + tableConfig.put("total_score", + new WriteWindowedToBigQuery.FieldInfo>("INTEGER", + c -> c.element().getValue())); + tableConfig.put("window_start", + new WriteWindowedToBigQuery.FieldInfo>("STRING", + c -> { IntervalWindow w = (IntervalWindow) c.window(); + return fmt.print(w.start()); })); + return tableConfig; + } + + + /** + * Run a batch pipeline to do windowed analysis of the data. + */ + // [START DocInclude_HTSMain] + public static void main(String[] args) throws Exception { + // Begin constructing a pipeline configured by commandline flags. + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline pipeline = Pipeline.create(options); + + final Instant stopMinTimestamp = new Instant(minFmt.parseMillis(options.getStopMin())); + final Instant startMinTimestamp = new Instant(minFmt.parseMillis(options.getStartMin())); + + // Read 'gaming' events from a text file. + pipeline.apply(TextIO.Read.from(options.getInput())) + // Parse the incoming data. + .apply(ParDo.named("ParseGameEvent").of(new ParseEventFn())) + + // Filter out data before and after the given times so that it is not included + // in the calculations. As we collect data in batches (say, by day), the batch for the day + // that we want to analyze could potentially include some late-arriving data from the previous + // day. If so, we want to weed it out. Similarly, if we include data from the following day + // (to scoop up late-arriving events from the day we're analyzing), we need to weed out events + // that fall after the time period we want to analyze. + // [START DocInclude_HTSFilters] + .apply("FilterStartTime", Filter.byPredicate( + (GameActionInfo gInfo) + -> gInfo.getTimestamp() > startMinTimestamp.getMillis())) + .apply("FilterEndTime", Filter.byPredicate( + (GameActionInfo gInfo) + -> gInfo.getTimestamp() < stopMinTimestamp.getMillis())) + // [END DocInclude_HTSFilters] + + // [START DocInclude_HTSAddTsAndWindow] + // Add an element timestamp based on the event log, and apply fixed windowing. + .apply("AddEventTimestamps", + WithTimestamps.of((GameActionInfo i) -> new Instant(i.getTimestamp()))) + .apply(Window.named("FixedWindowsTeam") + .into(FixedWindows.of( + Duration.standardMinutes(options.getWindowDuration())))) + // [END DocInclude_HTSAddTsAndWindow] + + // Extract and sum teamname/score pairs from the event data. + .apply("ExtractTeamScore", new ExtractAndSumScore("team")) + .apply("WriteTeamScoreSums", + new WriteWindowedToBigQuery>(options.getTableName(), + configureWindowedTableWrite())); + + + pipeline.run(); + } + // [END DocInclude_HTSMain] + +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/LeaderBoard.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/LeaderBoard.java new file mode 100644 index 000000000000..41853768680b --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/LeaderBoard.java @@ -0,0 +1,237 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.cloud.dataflow.examples.common.DataflowExampleOptions; +import com.google.cloud.dataflow.examples.common.DataflowExampleUtils; +import com.google.cloud.dataflow.examples.complete.game.utils.WriteToBigQuery; +import com.google.cloud.dataflow.examples.complete.game.utils.WriteWindowedToBigQuery; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterProcessingTime; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Repeatedly; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.DateTimeZone; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +import java.util.HashMap; +import java.util.Map; +import java.util.TimeZone; + +/** + * This class is the third in a series of four pipelines that tell a story in a 'gaming' domain, + * following {@link UserScore} and {@link HourlyTeamScore}. Concepts include: processing unbounded + * data using fixed windows; use of custom timestamps and event-time processing; generation of + * early/speculative results; using .accumulatingFiredPanes() to do cumulative processing of late- + * arriving data. + * + *

This pipeline processes an unbounded stream of 'game events'. The calculation of the team + * scores uses fixed windowing based on event time (the time of the game play event), not + * processing time (the time that an event is processed by the pipeline). The pipeline calculates + * the sum of scores per team, for each window. By default, the team scores are calculated using + * one-hour windows. + * + *

In contrast-- to demo another windowing option-- the user scores are calculated using a + * global window, which periodically (every ten minutes) emits cumulative user score sums. + * + *

In contrast to the previous pipelines in the series, which used static, finite input data, + * here we're using an unbounded data source, which lets us provide speculative results, and allows + * handling of late data, at much lower latency. We can use the early/speculative results to keep a + * 'leaderboard' updated in near-realtime. Our handling of late data lets us generate correct + * results, e.g. for 'team prizes'. We're now outputing window results as they're + * calculated, giving us much lower latency than with the previous batch examples. + * + *

Run {@link injector.Injector} to generate pubsub data for this pipeline. The Injector + * documentation provides more detail on how to do this. + * + *

To execute this pipeline using the Dataflow service, specify the pipeline configuration + * like this: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ *   --dataset=YOUR-DATASET
+ *   --topic=projects/YOUR-PROJECT/topics/YOUR-TOPIC
+ * }
+ * 
+ * where the BigQuery dataset you specify must already exist. + * The PubSub topic you specify should be the same topic to which the Injector is publishing. + */ +public class LeaderBoard extends HourlyTeamScore { + + private static final String TIMESTAMP_ATTRIBUTE = "timestamp_ms"; + + private static DateTimeFormatter fmt = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS") + .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))); + static final Duration FIVE_MINUTES = Duration.standardMinutes(5); + static final Duration TEN_MINUTES = Duration.standardMinutes(10); + + + /** + * Options supported by {@link LeaderBoard}. + */ + static interface Options extends HourlyTeamScore.Options, DataflowExampleOptions { + + @Description("Pub/Sub topic to read from") + @Validation.Required + String getTopic(); + void setTopic(String value); + + @Description("Numeric value of fixed window duration for team analysis, in minutes") + @Default.Integer(60) + Integer getTeamWindowDuration(); + void setTeamWindowDuration(Integer value); + + @Description("Numeric value of allowed data lateness, in minutes") + @Default.Integer(120) + Integer getAllowedLateness(); + void setAllowedLateness(Integer value); + + @Description("Prefix used for the BigQuery table names") + @Default.String("leaderboard") + String getTableName(); + void setTableName(String value); + } + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is used to write team score sums and includes event timing information. + */ + protected static Map>> + configureWindowedTableWrite() { + + Map>> tableConfigure = + new HashMap>>(); + tableConfigure.put("team", + new WriteWindowedToBigQuery.FieldInfo>("STRING", + c -> c.element().getKey())); + tableConfigure.put("total_score", + new WriteWindowedToBigQuery.FieldInfo>("INTEGER", + c -> c.element().getValue())); + tableConfigure.put("window_start", + new WriteWindowedToBigQuery.FieldInfo>("STRING", + c -> { IntervalWindow w = (IntervalWindow) c.window(); + return fmt.print(w.start()); })); + tableConfigure.put("processing_time", + new WriteWindowedToBigQuery.FieldInfo>( + "STRING", c -> fmt.print(Instant.now()))); + tableConfigure.put("timing", + new WriteWindowedToBigQuery.FieldInfo>( + "STRING", c -> c.pane().getTiming().toString())); + return tableConfigure; + } + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is used to write user score sums. + */ + protected static Map>> + configureGlobalWindowBigQueryWrite() { + + Map>> tableConfigure = + configureBigQueryWrite(); + tableConfigure.put("processing_time", + new WriteToBigQuery.FieldInfo>( + "STRING", c -> fmt.print(Instant.now()))); + return tableConfigure; + } + + + public static void main(String[] args) throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + // Enforce that this pipeline is always run in streaming mode. + options.setStreaming(true); + // For example purposes, allow the pipeline to be easily cancelled instead of running + // continuously. + options.setRunner(DataflowPipelineRunner.class); + DataflowExampleUtils dataflowUtils = new DataflowExampleUtils(options); + Pipeline pipeline = Pipeline.create(options); + + // Read game events from Pub/Sub using custom timestamps, which are extracted from the pubsub + // data elements, and parse the data. + PCollection gameEvents = pipeline + .apply(PubsubIO.Read.timestampLabel(TIMESTAMP_ATTRIBUTE).topic(options.getTopic())) + .apply(ParDo.named("ParseGameEvent").of(new ParseEventFn())); + + // [START DocInclude_WindowAndTrigger] + // Extract team/score pairs from the event stream, using hour-long windows by default. + gameEvents + .apply(Window.named("LeaderboardTeamFixedWindows") + .into(FixedWindows.of( + Duration.standardMinutes(options.getTeamWindowDuration()))) + // We will get early (speculative) results as well as cumulative + // processing of late data. + .triggering( + AfterWatermark.pastEndOfWindow() + .withEarlyFirings(AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(FIVE_MINUTES)) + .withLateFirings(AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(TEN_MINUTES))) + .withAllowedLateness(Duration.standardMinutes(options.getAllowedLateness())) + .accumulatingFiredPanes()) + // Extract and sum teamname/score pairs from the event data. + .apply("ExtractTeamScore", new ExtractAndSumScore("team")) + // Write the results to BigQuery. + .apply("WriteTeamScoreSums", + new WriteWindowedToBigQuery>( + options.getTableName() + "_team", configureWindowedTableWrite())); + // [END DocInclude_WindowAndTrigger] + + // [START DocInclude_ProcTimeTrigger] + // Extract user/score pairs from the event stream using processing time, via global windowing. + // Get periodic updates on all users' running scores. + gameEvents + .apply(Window.named("LeaderboardUserGlobalWindow") + .into(new GlobalWindows()) + // Get periodic results every ten minutes. + .triggering(Repeatedly.forever(AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(TEN_MINUTES))) + .accumulatingFiredPanes() + .withAllowedLateness(Duration.standardMinutes(options.getAllowedLateness()))) + // Extract and sum username/score pairs from the event data. + .apply("ExtractUserScore", new ExtractAndSumScore("user")) + // Write the results to BigQuery. + .apply("WriteUserScoreSums", + new WriteToBigQuery>( + options.getTableName() + "_user", configureGlobalWindowBigQueryWrite())); + // [END DocInclude_ProcTimeTrigger] + + // Run the pipeline and wait for the pipeline to finish; capture cancellation requests from the + // command line. + PipelineResult result = pipeline.run(); + dataflowUtils.waitToFinish(result); + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/README.md b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/README.md new file mode 100644 index 000000000000..4cad16d5f59f --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/README.md @@ -0,0 +1,119 @@ + +# 'Gaming' examples + + +This directory holds a series of example Dataflow pipelines in a simple 'mobile +gaming' domain. They all require Java 8. Each pipeline successively introduces +new concepts, and gives some examples of using Java 8 syntax in constructing +Dataflow pipelines. Other than usage of Java 8 lambda expressions, the concepts +that are used apply equally well in Java 7. + +In the gaming scenario, many users play, as members of different teams, over +the course of a day, and their actions are logged for processing. Some of the +logged game events may be late-arriving, if users play on mobile devices and go +transiently offline for a period. + +The scenario includes not only "regular" users, but "robot users", which have a +higher click rate than the regular users, and may move from team to team. + +The first two pipelines in the series use pre-generated batch data samples. The +second two pipelines read from a [PubSub](https://cloud.google.com/pubsub/) +topic input. For these examples, you will also need to run the +`injector.Injector` program, which generates and publishes the gaming data to +PubSub. The javadocs for each pipeline have more detailed information on how to +run that pipeline. + +All of these pipelines write their results to BigQuery table(s). + + +## The pipelines in the 'gaming' series + +### UserScore + +The first pipeline in the series is `UserScore`. This pipeline does batch +processing of data collected from gaming events. It calculates the sum of +scores per user, over an entire batch of gaming data (collected, say, for each +day). The batch processing will not include any late data that arrives after +the day's cutoff point. + +### HourlyTeamScore + +The next pipeline in the series is `HourlyTeamScore`. This pipeline also +processes data collected from gaming events in batch. It builds on `UserScore`, +but uses [fixed windows](https://cloud.google.com/dataflow/model/windowing), by +default an hour in duration. It calculates the sum of scores per team, for each +window, optionally allowing specification of two timestamps before and after +which data is filtered out. This allows a model where late data collected after +the intended analysis window can be included in the analysis, and any late- +arriving data prior to the beginning of the analysis window can be removed as +well. + +By using windowing and adding element timestamps, we can do finer-grained +analysis than with the `UserScore` pipeline — we're now tracking scores for +each hour rather than over the course of a whole day. However, our batch +processing is high-latency, in that we don't get results from plays at the +beginning of the batch's time period until the complete batch is processed. + +### LeaderBoard + +The third pipeline in the series is `LeaderBoard`. This pipeline processes an +unbounded stream of 'game events' from a PubSub topic. The calculation of the +team scores uses fixed windowing based on event time (the time of the game play +event), not processing time (the time that an event is processed by the +pipeline). The pipeline calculates the sum of scores per team, for each window. +By default, the team scores are calculated using one-hour windows. + +In contrast — to demo another windowing option — the user scores are calculated +using a global window, which periodically (every ten minutes) emits cumulative +user score sums. + +In contrast to the previous pipelines in the series, which used static, finite +input data, here we're using an unbounded data source, which lets us provide +_speculative_ results, and allows handling of late data, at much lower latency. +E.g., we could use the early/speculative results to keep a 'leaderboard' +updated in near-realtime. Our handling of late data lets us generate correct +results, e.g. for 'team prizes'. We're now outputing window results as they're +calculated, giving us much lower latency than with the previous batch examples. + +### GameStats + +The fourth pipeline in the series is `GameStats`. This pipeline builds +on the `LeaderBoard` functionality — supporting output of speculative and late +data — and adds some "business intelligence" analysis: identifying abuse +detection. The pipeline derives the Mean user score sum for a window, and uses +that information to identify likely spammers/robots. (The injector is designed +so that the "robots" have a higher click rate than the "real" users). The robot +users are then filtered out when calculating the team scores. + +Additionally, user sessions are tracked: that is, we find bursts of user +activity using session windows. Then, the mean session duration information is +recorded in the context of subsequent fixed windowing. (This could be used to +tell us what games are giving us greater user retention). + +### Running the PubSub Injector + +The `LeaderBoard` and `GameStats` example pipelines read unbounded data +from a PubSub topic. + +Use the `injector.Injector` program to generate this data and publish to a +PubSub topic. See the `Injector`javadocs for more information on how to run the +injector. Set up the injector before you start one of these pipelines. Then, +when you start the pipeline, pass as an argument the name of that PubSub topic. +See the pipeline javadocs for the details. + +## Viewing the results in BigQuery + +All of the pipelines write their results to BigQuery. `UserScore` and +`HourlyTeamScore` each write one table, and `LeaderBoard` and +`GameStats` each write two. The pipelines have default table names that +you can override when you start up the pipeline if those tables already exist. + +Depending on the windowing intervals defined in a given pipeline, you may have +to wait for a while (more than an hour) before you start to see results written +to the BigQuery tables. + + + + + + diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/UserScore.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/UserScore.java new file mode 100644 index 000000000000..de06ce3aaa58 --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/UserScore.java @@ -0,0 +1,239 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.cloud.dataflow.examples.complete.game.utils.WriteToBigQuery; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.apache.avro.reflect.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; + +/** + * This class is the first in a series of four pipelines that tell a story in a 'gaming' domain. + * Concepts: batch processing; reading input from Google Cloud Storage and writing output to + * BigQuery; using standalone DoFns; use of the sum by key transform; examples of + * Java 8 lambda syntax. + * + *

In this gaming scenario, many users play, as members of different teams, over the course of a + * day, and their actions are logged for processing. Some of the logged game events may be late- + * arriving, if users play on mobile devices and go transiently offline for a period. + * + *

This pipeline does batch processing of data collected from gaming events. It calculates the + * sum of scores per user, over an entire batch of gaming data (collected, say, for each day). The + * batch processing will not include any late data that arrives after the day's cutoff point. + * + *

To execute this pipeline using the Dataflow service and static example input data, specify + * the pipeline configuration like this: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ *   --dataset=YOUR-DATASET
+ * }
+ * 
+ * where the BigQuery dataset you specify must already exist. + * + *

Optionally include the --input argument to specify a batch input file. + * See the --input default value for example batch data file, or use {@link injector.Injector} to + * generate your own batch data. + */ +public class UserScore { + + /** + * Class to hold info about a game event. + */ + @DefaultCoder(AvroCoder.class) + static class GameActionInfo { + @Nullable String user; + @Nullable String team; + @Nullable Integer score; + @Nullable Long timestamp; + + public GameActionInfo() {} + + public GameActionInfo(String user, String team, Integer score, Long timestamp) { + this.user = user; + this.team = team; + this.score = score; + this.timestamp = timestamp; + } + + public String getUser() { + return this.user; + } + public String getTeam() { + return this.team; + } + public Integer getScore() { + return this.score; + } + public String getKey(String keyname) { + if (keyname.equals("team")) { + return this.team; + } else { // return username as default + return this.user; + } + } + public Long getTimestamp() { + return this.timestamp; + } + } + + + /** + * Parses the raw game event info into GameActionInfo objects. Each event line has the following + * format: username,teamname,score,timestamp_in_ms,readable_time + * e.g.: + * user2_AsparagusPig,AsparagusPig,10,1445230923951,2015-11-02 09:09:28.224 + * The human-readable time string is not used here. + */ + static class ParseEventFn extends DoFn { + + // Log and count parse errors. + private static final Logger LOG = LoggerFactory.getLogger(ParseEventFn.class); + private final Aggregator numParseErrors = + createAggregator("ParseErrors", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + String[] components = c.element().split(","); + try { + String user = components[0].trim(); + String team = components[1].trim(); + Integer score = Integer.parseInt(components[2].trim()); + Long timestamp = Long.parseLong(components[3].trim()); + GameActionInfo gInfo = new GameActionInfo(user, team, score, timestamp); + c.output(gInfo); + } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) { + numParseErrors.addValue(1L); + LOG.info("Parse error on " + c.element() + ", " + e.getMessage()); + } + } + } + + /** + * A transform to extract key/score information from GameActionInfo, and sum the scores. The + * constructor arg determines whether 'team' or 'user' info is extracted. + */ + // [START DocInclude_USExtractXform] + public static class ExtractAndSumScore + extends PTransform, PCollection>> { + + private final String field; + + ExtractAndSumScore(String field) { + this.field = field; + } + + @Override + public PCollection> apply( + PCollection gameInfo) { + + return gameInfo + .apply(MapElements + .via((GameActionInfo gInfo) -> KV.of(gInfo.getKey(field), gInfo.getScore())) + .withOutputType(new TypeDescriptor>() {})) + .apply(Sum.integersPerKey()); + } + } + // [END DocInclude_USExtractXform] + + + /** + * Options supported by {@link UserScore}. + */ + public static interface Options extends PipelineOptions { + + @Description("Path to the data file(s) containing game data.") + // The default maps to two large Google Cloud Storage files (each ~12GB) holding two subsequent + // day's worth (roughly) of data. + @Default.String("gs://dataflow-samples/game/gaming_data*.csv") + String getInput(); + void setInput(String value); + + @Description("BigQuery Dataset to write tables to. Must already exist.") + @Validation.Required + String getDataset(); + void setDataset(String value); + + @Description("The BigQuery table name. Should not already exist.") + @Default.String("user_score") + String getTableName(); + void setTableName(String value); + } + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is passed to the {@link WriteToBigQuery} constructor to write user score sums. + */ + protected static Map>> + configureBigQueryWrite() { + Map>> tableConfigure = + new HashMap>>(); + tableConfigure.put("user", + new WriteToBigQuery.FieldInfo>("STRING", c -> c.element().getKey())); + tableConfigure.put("total_score", + new WriteToBigQuery.FieldInfo>("INTEGER", c -> c.element().getValue())); + return tableConfigure; + } + + + /** + * Run a batch pipeline. + */ + // [START DocInclude_USMain] + public static void main(String[] args) throws Exception { + // Begin constructing a pipeline configured by commandline flags. + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline pipeline = Pipeline.create(options); + + // Read events from a text file and parse them. + pipeline.apply(TextIO.Read.from(options.getInput())) + .apply(ParDo.named("ParseGameEvent").of(new ParseEventFn())) + // Extract and sum username/score pairs from the event data. + .apply("ExtractUserScore", new ExtractAndSumScore("user")) + .apply("WriteUserScoreSums", + new WriteToBigQuery>(options.getTableName(), + configureBigQueryWrite())); + + // Run the batch pipeline. + pipeline.run(); + } + // [END DocInclude_USMain] + +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/Injector.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/Injector.java new file mode 100644 index 000000000000..d47886db43ab --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/Injector.java @@ -0,0 +1,417 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game.injector; + +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.model.PublishRequest; +import com.google.api.services.pubsub.model.PubsubMessage; + +import com.google.common.collect.ImmutableMap; +import org.joda.time.DateTimeZone; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; +import java.io.BufferedOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.security.GeneralSecurityException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.TimeZone; + + +/** + * This is a generator that simulates usage data from a mobile game, and either publishes the data + * to a pubsub topic or writes it to a file. + * + *

The general model used by the generator is the following. There is a set of teams with team + * members. Each member is scoring points for their team. After some period, a team will dissolve + * and a new one will be created in its place. There is also a set of 'Robots', or spammer users. + * They hop from team to team. The robots are set to have a higher 'click rate' (generate more + * events) than the regular team members. + * + *

Each generated line of data has the following form: + * username,teamname,score,timestamp_in_ms,readable_time + * e.g.: + * user2_AsparagusPig,AsparagusPig,10,1445230923951,2015-11-02 09:09:28.224 + * + *

The Injector writes either to a PubSub topic, or a file. It will use the PubSub topic if + * specified. It takes the following arguments: + * {@code Injector project-name (topic-name|none) (filename|none)}. + * + *

To run the Injector in the mode where it publishes to PubSub, you will need to authenticate + * locally using project-based service account credentials to avoid running over PubSub + * quota. + * See https://developers.google.com/identity/protocols/application-default-credentials + * for more information on using service account credentials. Set the GOOGLE_APPLICATION_CREDENTIALS + * environment variable to point to your downloaded service account credentials before starting the + * program, e.g.: + * {@code export GOOGLE_APPLICATION_CREDENTIALS=/path/to/your/credentials-key.json}. + * If you do not do this, then your injector will only run for a few minutes on your + * 'user account' credentials before you will start to see quota error messages like: + * "Request throttled due to user QPS limit being reached", and see this exception: + * ".com.google.api.client.googleapis.json.GoogleJsonResponseException: 429 Too Many Requests". + * Once you've set up your credentials, run the Injector like this": + *

{@code
+ * Injector   none
+ * }
+ * 
+ * The pubsub topic will be created if it does not exist. + * + *

To run the injector in write-to-file-mode, set the topic name to "none" and specify the + * filename: + *

{@code
+ * Injector  none 
+ * }
+ * 
+ */ +class Injector { + private static Pubsub pubsub; + private static Random random = new Random(); + private static String topic; + private static String project; + private static final String TIMESTAMP_ATTRIBUTE = "timestamp_ms"; + + // QPS ranges from 800 to 1000. + private static final int MIN_QPS = 800; + private static final int QPS_RANGE = 200; + // How long to sleep, in ms, between creation of the threads that make API requests to PubSub. + private static final int THREAD_SLEEP_MS = 500; + + // Lists used to generate random team names. + private static final ArrayList COLORS = + new ArrayList(Arrays.asList( + "Magenta", "AliceBlue", "Almond", "Amaranth", "Amber", + "Amethyst", "AndroidGreen", "AntiqueBrass", "Fuchsia", "Ruby", "AppleGreen", + "Apricot", "Aqua", "ArmyGreen", "Asparagus", "Auburn", "Azure", "Banana", + "Beige", "Bisque", "BarnRed", "BattleshipGrey")); + + private static final ArrayList ANIMALS = + new ArrayList(Arrays.asList( + "Echidna", "Koala", "Wombat", "Marmot", "Quokka", "Kangaroo", "Dingo", "Numbat", "Emu", + "Wallaby", "CaneToad", "Bilby", "Possum", "Cassowary", "Kookaburra", "Platypus", + "Bandicoot", "Cockatoo", "Antechinus")); + + // The list of live teams. + private static ArrayList liveTeams = new ArrayList(); + + private static DateTimeFormatter fmt = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS") + .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))); + + + // The total number of robots in the system. + private static final int NUM_ROBOTS = 20; + // Determines the chance that a team will have a robot team member. + private static final int ROBOT_PROBABILITY = 3; + private static final int NUM_LIVE_TEAMS = 15; + private static final int BASE_MEMBERS_PER_TEAM = 5; + private static final int MEMBERS_PER_TEAM = 15; + private static final int MAX_SCORE = 20; + private static final int LATE_DATA_RATE = 5 * 60 * 2; // Every 10 minutes + private static final int BASE_DELAY_IN_MILLIS = 5 * 60 * 1000; // 5-10 minute delay + private static final int FUZZY_DELAY_IN_MILLIS = 5 * 60 * 1000; + + // The minimum time a 'team' can live. + private static final int BASE_TEAM_EXPIRATION_TIME_IN_MINS = 20; + private static final int TEAM_EXPIRATION_TIME_IN_MINS = 20; + + + /** + * A class for holding team info: the name of the team, when it started, + * and the current team members. Teams may but need not include one robot team member. + */ + private static class TeamInfo { + String teamName; + long startTimeInMillis; + int expirationPeriod; + // The team might but need not include 1 robot. Will be non-null if so. + String robot; + int numMembers; + + private TeamInfo(String teamName, long startTimeInMillis, String robot) { + this.teamName = teamName; + this.startTimeInMillis = startTimeInMillis; + // How long until this team is dissolved. + this.expirationPeriod = random.nextInt(TEAM_EXPIRATION_TIME_IN_MINS) + + BASE_TEAM_EXPIRATION_TIME_IN_MINS; + this.robot = robot; + // Determine the number of team members. + numMembers = random.nextInt(MEMBERS_PER_TEAM) + BASE_MEMBERS_PER_TEAM; + } + + String getTeamName() { + return teamName; + } + String getRobot() { + return robot; + } + + long getStartTimeInMillis() { + return startTimeInMillis; + } + long getEndTimeInMillis() { + return startTimeInMillis + (expirationPeriod * 60 * 1000); + } + String getRandomUser() { + int userNum = random.nextInt(numMembers); + return "user" + userNum + "_" + teamName; + } + + int numMembers() { + return numMembers; + } + + @Override + public String toString() { + return "(" + teamName + ", num members: " + numMembers() + ", starting at: " + + startTimeInMillis + ", expires in: " + expirationPeriod + ", robot: " + robot + ")"; + } + } + + /** Utility to grab a random element from an array of Strings. */ + private static String randomElement(ArrayList list) { + int index = random.nextInt(list.size()); + return list.get(index); + } + + /** + * Get and return a random team. If the selected team is too old w.r.t its expiration, remove + * it, replacing it with a new team. + */ + private static TeamInfo randomTeam(ArrayList list) { + int index = random.nextInt(list.size()); + TeamInfo team = list.get(index); + // If the selected team is expired, remove it and return a new team. + long currTime = System.currentTimeMillis(); + if ((team.getEndTimeInMillis() < currTime) || team.numMembers() == 0) { + System.out.println("\nteam " + team + " is too old; replacing."); + System.out.println("start time: " + team.getStartTimeInMillis() + + ", end time: " + team.getEndTimeInMillis() + + ", current time:" + currTime); + removeTeam(index); + // Add a new team in its stead. + return (addLiveTeam()); + } else { + return team; + } + } + + /** + * Create and add a team. Possibly add a robot to the team. + */ + private static synchronized TeamInfo addLiveTeam() { + String teamName = randomElement(COLORS) + randomElement(ANIMALS); + String robot = null; + // Decide if we want to add a robot to the team. + if (random.nextInt(ROBOT_PROBABILITY) == 0) { + robot = "Robot-" + random.nextInt(NUM_ROBOTS); + } + long currTime = System.currentTimeMillis(); + // Create the new team. + TeamInfo newTeam = new TeamInfo(teamName, System.currentTimeMillis(), robot); + liveTeams.add(newTeam); + System.out.println("[+" + newTeam + "]"); + return newTeam; + } + + /** + * Remove a specific team. + */ + private static synchronized void removeTeam(int teamIndex) { + TeamInfo removedTeam = liveTeams.remove(teamIndex); + System.out.println("[-" + removedTeam + "]"); + } + + /** Generate a user gaming event. */ + private static String generateEvent(Long currTime, int delayInMillis) { + TeamInfo team = randomTeam(liveTeams); + String teamName = team.getTeamName(); + String user; + int PARSE_ERROR_RATE = 900000; + + String robot = team.getRobot(); + // If the team has an associated robot team member... + if (robot != null) { + // Then use that robot for the message with some probability. + // Set this probability to higher than that used to select any of the 'regular' team + // members, so that if there is a robot on the team, it has a higher click rate. + if (random.nextInt(team.numMembers() / 2) == 0) { + user = robot; + } else { + user = team.getRandomUser(); + } + } else { // No robot. + user = team.getRandomUser(); + } + String event = user + "," + teamName + "," + random.nextInt(MAX_SCORE); + // Randomly introduce occasional parse errors. You can see a custom counter tracking the number + // of such errors in the Dataflow Monitoring UI, as the example pipeline runs. + if (random.nextInt(PARSE_ERROR_RATE) == 0) { + System.out.println("Introducing a parse error."); + event = "THIS LINE REPRESENTS CORRUPT DATA AND WILL CAUSE A PARSE ERROR"; + } + return addTimeInfoToEvent(event, currTime, delayInMillis); + } + + /** + * Add time info to a generated gaming event. + */ + private static String addTimeInfoToEvent(String message, Long currTime, int delayInMillis) { + String eventTimeString = + Long.toString((currTime - delayInMillis) / 1000 * 1000); + // Add a (redundant) 'human-readable' date string to make the data semantics more clear. + String dateString = fmt.print(currTime); + message = message + "," + eventTimeString + "," + dateString; + return message; + } + + /** + * Publish 'numMessages' arbitrary events from live users with the provided delay, to a + * PubSub topic. + */ + public static void publishData(int numMessages, int delayInMillis) + throws IOException { + List pubsubMessages = new ArrayList<>(); + + for (int i = 0; i < Math.max(1, numMessages); i++) { + Long currTime = System.currentTimeMillis(); + String message = generateEvent(currTime, delayInMillis); + PubsubMessage pubsubMessage = new PubsubMessage() + .encodeData(message.getBytes("UTF-8")); + pubsubMessage.setAttributes( + ImmutableMap.of(TIMESTAMP_ATTRIBUTE, + Long.toString((currTime - delayInMillis) / 1000 * 1000))); + if (delayInMillis != 0) { + System.out.println(pubsubMessage.getAttributes()); + System.out.println("late data for: " + message); + } + pubsubMessages.add(pubsubMessage); + } + + PublishRequest publishRequest = new PublishRequest(); + publishRequest.setMessages(pubsubMessages); + pubsub.projects().topics().publish(topic, publishRequest).execute(); + } + + /** + * Publish generated events to a file. + */ + public static void publishDataToFile(String fileName, int numMessages, int delayInMillis) + throws IOException { + List pubsubMessages = new ArrayList<>(); + PrintWriter out = new PrintWriter(new OutputStreamWriter( + new BufferedOutputStream(new FileOutputStream(fileName, true)), "UTF-8")); + + try { + for (int i = 0; i < Math.max(1, numMessages); i++) { + Long currTime = System.currentTimeMillis(); + String message = generateEvent(currTime, delayInMillis); + out.println(message); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + if (out != null) { + out.flush(); + out.close(); + } + } + } + + + public static void main(String[] args) + throws GeneralSecurityException, IOException, InterruptedException { + if (args.length < 3) { + System.out.println("Usage: Injector project-name (topic-name|none) (filename|none)"); + System.exit(1); + } + boolean writeToFile = false; + boolean writeToPubsub = true; + project = args[0]; + String topicName = args[1]; + String fileName = args[2]; + // The Injector writes either to a PubSub topic, or a file. It will use the PubSub topic if + // specified; otherwise, it will try to write to a file. + if (topicName.equalsIgnoreCase("none")) { + writeToFile = true; + writeToPubsub = false; + } + if (writeToPubsub) { + // Create the PubSub client. + pubsub = InjectorUtils.getClient(); + // Create the PubSub topic as necessary. + topic = InjectorUtils.getFullyQualifiedTopicName(project, topicName); + InjectorUtils.createTopic(pubsub, topic); + System.out.println("Injecting to topic: " + topic); + } else { + if (fileName.equalsIgnoreCase("none")) { + System.out.println("Filename not specified."); + System.exit(1); + } + System.out.println("Writing to file: " + fileName); + } + System.out.println("Starting Injector"); + + // Start off with some random live teams. + while (liveTeams.size() < NUM_LIVE_TEAMS) { + addLiveTeam(); + } + + // Publish messages at a rate determined by the QPS and Thread sleep settings. + for (int i = 0; true; i++) { + if (Thread.activeCount() > 10) { + System.err.println("I'm falling behind!"); + } + + // Decide if this should be a batch of late data. + final int numMessages; + final int delayInMillis; + if (i % LATE_DATA_RATE == 0) { + // Insert delayed data for one user (one message only) + delayInMillis = BASE_DELAY_IN_MILLIS + random.nextInt(FUZZY_DELAY_IN_MILLIS); + numMessages = 1; + System.out.println("DELAY(" + delayInMillis + ", " + numMessages + ")"); + } else { + System.out.print("."); + delayInMillis = 0; + numMessages = MIN_QPS + random.nextInt(QPS_RANGE); + } + + if (writeToFile) { // Won't use threading for the file write. + publishDataToFile(fileName, numMessages, delayInMillis); + } else { // Write to PubSub. + // Start a thread to inject some data. + new Thread(){ + public void run() { + try { + publishData(numMessages, delayInMillis); + } catch (IOException e) { + System.err.println(e); + } + } + }.start(); + } + + // Wait before creating another injector thread. + Thread.sleep(THREAD_SLEEP_MS); + } + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/InjectorUtils.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/InjectorUtils.java new file mode 100644 index 000000000000..55982df933e3 --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/InjectorUtils.java @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game.injector; + + +import com.google.api.client.googleapis.auth.oauth2.GoogleCredential; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.googleapis.util.Utils; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.http.HttpStatusCodes; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.json.JsonFactory; +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.PubsubScopes; +import com.google.api.services.pubsub.model.Topic; + +import com.google.common.base.Preconditions; + +import java.io.IOException; + +class InjectorUtils { + + private static final String APP_NAME = "injector"; + + /** + * Builds a new Pubsub client and returns it. + */ + public static Pubsub getClient(final HttpTransport httpTransport, + final JsonFactory jsonFactory) + throws IOException { + Preconditions.checkNotNull(httpTransport); + Preconditions.checkNotNull(jsonFactory); + GoogleCredential credential = + GoogleCredential.getApplicationDefault(httpTransport, jsonFactory); + if (credential.createScopedRequired()) { + credential = credential.createScoped(PubsubScopes.all()); + } + if (credential.getClientAuthentication() != null) { + System.out.println("\n***Warning! You are not using service account credentials to " + + "authenticate.\nYou need to use service account credentials for this example," + + "\nsince user-level credentials do not have enough pubsub quota,\nand so you will run " + + "out of PubSub quota very quickly.\nSee " + + "https://developers.google.com/identity/protocols/application-default-credentials."); + System.exit(1); + } + HttpRequestInitializer initializer = + new RetryHttpInitializerWrapper(credential); + return new Pubsub.Builder(httpTransport, jsonFactory, initializer) + .setApplicationName(APP_NAME) + .build(); + } + + /** + * Builds a new Pubsub client with default HttpTransport and + * JsonFactory and returns it. + */ + public static Pubsub getClient() throws IOException { + return getClient(Utils.getDefaultTransport(), + Utils.getDefaultJsonFactory()); + } + + + /** + * Returns the fully qualified topic name for Pub/Sub. + */ + public static String getFullyQualifiedTopicName( + final String project, final String topic) { + return String.format("projects/%s/topics/%s", project, topic); + } + + /** + * Create a topic if it doesn't exist. + */ + public static void createTopic(Pubsub client, String fullTopicName) + throws IOException { + try { + client.projects().topics().get(fullTopicName).execute(); + } catch (GoogleJsonResponseException e) { + if (e.getStatusCode() == HttpStatusCodes.STATUS_CODE_NOT_FOUND) { + Topic topic = client.projects().topics() + .create(fullTopicName, new Topic()) + .execute(); + System.out.printf("Topic %s was created.\n", topic.getName()); + } + } + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/RetryHttpInitializerWrapper.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/RetryHttpInitializerWrapper.java new file mode 100644 index 000000000000..eeeabcef8beb --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/injector/RetryHttpInitializerWrapper.java @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game.injector; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.http.HttpBackOffIOExceptionHandler; +import com.google.api.client.http.HttpBackOffUnsuccessfulResponseHandler; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpUnsuccessfulResponseHandler; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.client.util.Sleeper; +import com.google.common.base.Preconditions; + +import java.io.IOException; +import java.util.logging.Logger; + +/** + * RetryHttpInitializerWrapper will automatically retry upon RPC + * failures, preserving the auto-refresh behavior of the Google + * Credentials. + */ +public class RetryHttpInitializerWrapper implements HttpRequestInitializer { + + /** + * A private logger. + */ + private static final Logger LOG = + Logger.getLogger(RetryHttpInitializerWrapper.class.getName()); + + /** + * One minutes in miliseconds. + */ + private static final int ONEMINITUES = 60000; + + /** + * Intercepts the request for filling in the "Authorization" + * header field, as well as recovering from certain unsuccessful + * error codes wherein the Credential must refresh its token for a + * retry. + */ + private final Credential wrappedCredential; + + /** + * A sleeper; you can replace it with a mock in your test. + */ + private final Sleeper sleeper; + + /** + * A constructor. + * + * @param wrappedCredential Credential which will be wrapped and + * used for providing auth header. + */ + public RetryHttpInitializerWrapper(final Credential wrappedCredential) { + this(wrappedCredential, Sleeper.DEFAULT); + } + + /** + * A protected constructor only for testing. + * + * @param wrappedCredential Credential which will be wrapped and + * used for providing auth header. + * @param sleeper Sleeper for easy testing. + */ + RetryHttpInitializerWrapper( + final Credential wrappedCredential, final Sleeper sleeper) { + this.wrappedCredential = Preconditions.checkNotNull(wrappedCredential); + this.sleeper = sleeper; + } + + /** + * Initializes the given request. + */ + @Override + public final void initialize(final HttpRequest request) { + request.setReadTimeout(2 * ONEMINITUES); // 2 minutes read timeout + final HttpUnsuccessfulResponseHandler backoffHandler = + new HttpBackOffUnsuccessfulResponseHandler( + new ExponentialBackOff()) + .setSleeper(sleeper); + request.setInterceptor(wrappedCredential); + request.setUnsuccessfulResponseHandler( + new HttpUnsuccessfulResponseHandler() { + @Override + public boolean handleResponse( + final HttpRequest request, + final HttpResponse response, + final boolean supportsRetry) throws IOException { + if (wrappedCredential.handleResponse( + request, response, supportsRetry)) { + // If credential decides it can handle it, + // the return code or message indicated + // something specific to authentication, + // and no backoff is desired. + return true; + } else if (backoffHandler.handleResponse( + request, response, supportsRetry)) { + // Otherwise, we defer to the judgement of + // our internal backoff handler. + LOG.info("Retrying " + + request.getUrl().toString()); + return true; + } else { + return false; + } + } + }); + request.setIOExceptionHandler( + new HttpBackOffIOExceptionHandler(new ExponentialBackOff()) + .setSleeper(sleeper)); + } +} + diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/utils/WriteToBigQuery.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/utils/WriteToBigQuery.java new file mode 100644 index 000000000000..2cf719a7eff2 --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/utils/WriteToBigQuery.java @@ -0,0 +1,134 @@ + /* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game.utils; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.examples.complete.game.UserScore; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Generate, format, and write BigQuery table row information. Use provided information about + * the field names and types, as well as lambda functions that describe how to generate their + * values. + */ +public class WriteToBigQuery + extends PTransform, PDone> { + + protected String tableName; + protected Map> fieldInfo; + + public WriteToBigQuery() { + } + + public WriteToBigQuery(String tableName, + Map> fieldInfo) { + this.tableName = tableName; + this.fieldInfo = fieldInfo; + } + + /** Define a class to hold information about output table field definitions. */ + public static class FieldInfo implements Serializable { + // The BigQuery 'type' of the field + private String fieldType; + // A lambda function to generate the field value + private SerializableFunction.ProcessContext, Object> fieldFn; + + public FieldInfo(String fieldType, + SerializableFunction.ProcessContext, Object> fieldFn) { + this.fieldType = fieldType; + this.fieldFn = fieldFn; + } + + String getFieldType() { + return this.fieldType; + } + + SerializableFunction.ProcessContext, Object> getFieldFn() { + return this.fieldFn; + } + } + /** Convert each key/score pair into a BigQuery TableRow as specified by fieldFn. */ + protected class BuildRowFn extends DoFn { + + @Override + public void processElement(ProcessContext c) { + + TableRow row = new TableRow(); + for (Map.Entry> entry : fieldInfo.entrySet()) { + String key = entry.getKey(); + FieldInfo fcnInfo = entry.getValue(); + SerializableFunction.ProcessContext, Object> fcn = + fcnInfo.getFieldFn(); + row.set(key, fcn.apply(c)); + } + c.output(row); + } + } + + /** Build the output table schema. */ + protected TableSchema getSchema() { + List fields = new ArrayList<>(); + for (Map.Entry> entry : fieldInfo.entrySet()) { + String key = entry.getKey(); + FieldInfo fcnInfo = entry.getValue(); + String bqType = fcnInfo.getFieldType(); + fields.add(new TableFieldSchema().setName(key).setType(bqType)); + } + return new TableSchema().setFields(fields); + } + + @Override + public PDone apply(PCollection teamAndScore) { + return teamAndScore + .apply(ParDo.named("ConvertToRow").of(new BuildRowFn())) + .apply(BigQueryIO.Write + .to(getTable(teamAndScore.getPipeline(), + tableName)) + .withSchema(getSchema()) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_APPEND)); + } + + /** Utility to construct an output table reference. */ + static TableReference getTable(Pipeline pipeline, String tableName) { + PipelineOptions options = pipeline.getOptions(); + TableReference table = new TableReference(); + table.setDatasetId(options.as(UserScore.Options.class).getDataset()); + table.setProjectId(options.as(GcpOptions.class).getProject()); + table.setTableId(tableName); + return table; + } +} diff --git a/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/utils/WriteWindowedToBigQuery.java b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/utils/WriteWindowedToBigQuery.java new file mode 100644 index 000000000000..8433021f2ee0 --- /dev/null +++ b/examples/src/main/java8/com/google/cloud/dataflow/examples/complete/game/utils/WriteWindowedToBigQuery.java @@ -0,0 +1,76 @@ + /* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game.utils; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; + +import java.util.Map; + +/** + * Generate, format, and write BigQuery table row information. Subclasses {@link WriteToBigQuery} + * to require windowing; so this subclass may be used for writes that require access to the + * context's window information. + */ +public class WriteWindowedToBigQuery + extends WriteToBigQuery { + + public WriteWindowedToBigQuery(String tableName, + Map> fieldInfo) { + super(tableName, fieldInfo); + } + + /** Convert each key/score pair into a BigQuery TableRow. */ + protected class BuildRowFn extends DoFn + implements RequiresWindowAccess { + + @Override + public void processElement(ProcessContext c) { + + TableRow row = new TableRow(); + for (Map.Entry> entry : fieldInfo.entrySet()) { + String key = entry.getKey(); + FieldInfo fcnInfo = entry.getValue(); + SerializableFunction.ProcessContext, Object> fcn = + fcnInfo.getFieldFn(); + row.set(key, fcn.apply(c)); + } + c.output(row); + } + } + + @Override + public PDone apply(PCollection teamAndScore) { + return teamAndScore + .apply(ParDo.named("ConvertToRow").of(new BuildRowFn())) + .apply(BigQueryIO.Write + .to(getTable(teamAndScore.getPipeline(), + tableName)) + .withSchema(getSchema()) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_APPEND)); + } + +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/DebuggingWordCountTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/DebuggingWordCountTest.java new file mode 100644 index 000000000000..77d7bc878a9c --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/DebuggingWordCountTest.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.common.io.Files; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.nio.charset.StandardCharsets; + +/** + * Tests for {@link DebuggingWordCount}. + */ +@RunWith(JUnit4.class) +public class DebuggingWordCountTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testDebuggingWordCount() throws Exception { + File file = tmpFolder.newFile(); + Files.write("stomach secret Flourish message Flourish here Flourish", file, + StandardCharsets.UTF_8); + DebuggingWordCount.main(new String[]{"--inputFile=" + file.getAbsolutePath()}); + } +} + diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/WordCountTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/WordCountTest.java new file mode 100644 index 000000000000..4542c4854099 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/WordCountTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.examples.WordCount.CountWords; +import com.google.cloud.dataflow.examples.WordCount.ExtractWordsFn; +import com.google.cloud.dataflow.examples.WordCount.FormatAsTextFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests of WordCount. + */ +@RunWith(JUnit4.class) +public class WordCountTest { + + /** Example test that tests a specific DoFn. */ + @Test + public void testExtractWordsFn() { + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractWordsFn()); + + Assert.assertThat(extractWordsFn.processBatch(" some input words "), + CoreMatchers.hasItems("some", "input", "words")); + Assert.assertThat(extractWordsFn.processBatch(" "), + CoreMatchers.hasItems()); + Assert.assertThat(extractWordsFn.processBatch(" some ", " input", " words"), + CoreMatchers.hasItems("some", "input", "words")); + } + + static final String[] WORDS_ARRAY = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final List WORDS = Arrays.asList(WORDS_ARRAY); + + static final String[] COUNTS_ARRAY = new String[] { + "hi: 5", "there: 1", "sue: 2", "bob: 2"}; + + /** Example test that tests a PTransform by using an in-memory input and inspecting the output. */ + @Test + @Category(RunnableOnService.class) + public void testCountWords() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())); + + PCollection output = input.apply(new CountWords()) + .apply(MapElements.via(new FormatAsTextFn())); + + DataflowAssert.that(output).containsInAnyOrder(COUNTS_ARRAY); + p.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/complete/AutoCompleteTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/complete/AutoCompleteTest.java new file mode 100644 index 000000000000..aec1557c28b0 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/complete/AutoCompleteTest.java @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.cloud.dataflow.examples.complete.AutoComplete.CompletionCandidate; +import com.google.cloud.dataflow.examples.complete.AutoComplete.ComputeTopCompletions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * Tests of AutoComplete. + */ +@RunWith(Parameterized.class) +public class AutoCompleteTest implements Serializable { + private boolean recursive; + + public AutoCompleteTest(Boolean recursive) { + this.recursive = recursive; + } + + @Parameterized.Parameters + public static Collection testRecursive() { + return Arrays.asList(new Object[][] { + { true }, + { false } + }); + } + + @Test + public void testAutoComplete() { + List words = Arrays.asList( + "apple", + "apple", + "apricot", + "banana", + "blackberry", + "blackberry", + "blackberry", + "blueberry", + "blueberry", + "cherry"); + + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(words)); + + PCollection>> output = + input.apply(new ComputeTopCompletions(2, recursive)) + .apply(Filter.byPredicate( + new SerializableFunction>, Boolean>() { + @Override + public Boolean apply(KV> element) { + return element.getKey().length() <= 2; + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", parseList("apple:2", "apricot:1")), + KV.of("ap", parseList("apple:2", "apricot:1")), + KV.of("b", parseList("blackberry:3", "blueberry:2")), + KV.of("ba", parseList("banana:1")), + KV.of("bl", parseList("blackberry:3", "blueberry:2")), + KV.of("c", parseList("cherry:1")), + KV.of("ch", parseList("cherry:1"))); + p.run(); + } + + @Test + public void testTinyAutoComplete() { + List words = Arrays.asList("x", "x", "x", "xy", "xy", "xyz"); + + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(words)); + + PCollection>> output = + input.apply(new ComputeTopCompletions(2, recursive)); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("x", parseList("x:3", "xy:2")), + KV.of("xy", parseList("xy:2", "xyz:1")), + KV.of("xyz", parseList("xyz:1"))); + p.run(); + } + + @Test + public void testWindowedAutoComplete() { + List> words = Arrays.asList( + TimestampedValue.of("xA", new Instant(1)), + TimestampedValue.of("xA", new Instant(1)), + TimestampedValue.of("xB", new Instant(1)), + TimestampedValue.of("xB", new Instant(2)), + TimestampedValue.of("xB", new Instant(2))); + + Pipeline p = TestPipeline.create(); + + PCollection input = p + .apply(Create.of(words)) + .apply(new ReifyTimestamps()); + + PCollection>> output = + input.apply(Window.into(SlidingWindows.of(new Duration(2)))) + .apply(new ComputeTopCompletions(2, recursive)); + + DataflowAssert.that(output).containsInAnyOrder( + // Window [0, 2) + KV.of("x", parseList("xA:2", "xB:1")), + KV.of("xA", parseList("xA:2")), + KV.of("xB", parseList("xB:1")), + + // Window [1, 3) + KV.of("x", parseList("xB:3", "xA:2")), + KV.of("xA", parseList("xA:2")), + KV.of("xB", parseList("xB:3")), + + // Window [2, 3) + KV.of("x", parseList("xB:2")), + KV.of("xB", parseList("xB:2"))); + p.run(); + } + + private static List parseList(String... entries) { + List all = new ArrayList<>(); + for (String s : entries) { + String[] countValue = s.split(":"); + all.add(new CompletionCandidate(countValue[0], Integer.valueOf(countValue[1]))); + } + return all; + } + + private static class ReifyTimestamps + extends PTransform>, PCollection> { + @Override + public PCollection apply(PCollection> input) { + return input.apply(ParDo.of(new DoFn, T>() { + @Override + public void processElement(ProcessContext c) { + c.outputWithTimestamp(c.element().getValue(), c.element().getTimestamp()); + } + })); + } + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/complete/TfIdfTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/complete/TfIdfTest.java new file mode 100644 index 000000000000..5ee136cee2a4 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/complete/TfIdfTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.net.URI; +import java.util.Arrays; + +/** + * Tests of {@link TfIdf}. + */ +@RunWith(JUnit4.class) +public class TfIdfTest { + + /** Test that the example runs. */ + @Test + @Category(RunnableOnService.class) + public void testTfIdf() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); + + PCollection>> wordToUriAndTfIdf = pipeline + .apply(Create.of( + KV.of(new URI("x"), "a b c d"), + KV.of(new URI("y"), "a b c"), + KV.of(new URI("z"), "a m n"))) + .apply(new TfIdf.ComputeTfIdf()); + + PCollection words = wordToUriAndTfIdf + .apply(Keys.create()) + .apply(RemoveDuplicates.create()); + + DataflowAssert.that(words).containsInAnyOrder(Arrays.asList("a", "m", "n", "b", "c", "d")); + + pipeline.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/complete/TopWikipediaSessionsTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/complete/TopWikipediaSessionsTest.java new file mode 100644 index 000000000000..ce9de5140996 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/complete/TopWikipediaSessionsTest.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** Unit tests for {@link TopWikipediaSessions}. */ +@RunWith(JUnit4.class) +public class TopWikipediaSessionsTest { + @Test + @Category(RunnableOnService.class) + public void testComputeTopUsers() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(Arrays.asList( + new TableRow().set("timestamp", 0).set("contributor_username", "user1"), + new TableRow().set("timestamp", 1).set("contributor_username", "user1"), + new TableRow().set("timestamp", 2).set("contributor_username", "user1"), + new TableRow().set("timestamp", 0).set("contributor_username", "user2"), + new TableRow().set("timestamp", 1).set("contributor_username", "user2"), + new TableRow().set("timestamp", 3601).set("contributor_username", "user2"), + new TableRow().set("timestamp", 3602).set("contributor_username", "user2"), + new TableRow().set("timestamp", 35 * 24 * 3600).set("contributor_username", "user3")))) + .apply(new TopWikipediaSessions.ComputeTopSessions(1.0)); + + DataflowAssert.that(output).containsInAnyOrder(Arrays.asList( + "user1 : [1970-01-01T00:00:00.000Z..1970-01-01T01:00:02.000Z)" + + " : 3 : 1970-01-01T00:00:00.000Z", + "user3 : [1970-02-05T00:00:00.000Z..1970-02-05T01:00:00.000Z)" + + " : 1 : 1970-02-01T00:00:00.000Z")); + + p.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/BigQueryTornadoesTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/BigQueryTornadoesTest.java new file mode 100644 index 000000000000..6dce4eddfa0c --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/BigQueryTornadoesTest.java @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.cookbook.BigQueryTornadoes.ExtractTornadoesFn; +import com.google.cloud.dataflow.examples.cookbook.BigQueryTornadoes.FormatCountsFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** + * Test case for {@link BigQueryTornadoes}. + */ +@RunWith(JUnit4.class) +public class BigQueryTornadoesTest { + + @Test + public void testExtractTornadoes() throws Exception { + TableRow row = new TableRow() + .set("month", "6") + .set("tornado", true); + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractTornadoesFn()); + Assert.assertThat(extractWordsFn.processBatch(row), + CoreMatchers.hasItems(6)); + } + + @Test + public void testNoTornadoes() throws Exception { + TableRow row = new TableRow() + .set("month", 6) + .set("tornado", false); + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractTornadoesFn()); + Assert.assertTrue(extractWordsFn.processBatch(row).isEmpty()); + } + + @Test + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testFormatCounts() throws Exception { + DoFnTester, TableRow> formatCountsFn = + DoFnTester.of(new FormatCountsFn()); + KV empty[] = {}; + List results = formatCountsFn.processBatch(empty); + Assert.assertTrue(results.size() == 0); + KV input[] = { KV.of(3, 0L), + KV.of(4, Long.MAX_VALUE), + KV.of(5, Long.MIN_VALUE) }; + results = formatCountsFn.processBatch(input); + Assert.assertEquals(results.size(), 3); + Assert.assertEquals(results.get(0).get("month"), 3); + Assert.assertEquals(results.get(0).get("tornado_count"), 0L); + Assert.assertEquals(results.get(1).get("month"), 4); + Assert.assertEquals(results.get(1).get("tornado_count"), Long.MAX_VALUE); + Assert.assertEquals(results.get(2).get("month"), 5); + Assert.assertEquals(results.get(2).get("tornado_count"), Long.MIN_VALUE); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/CombinePerKeyExamplesTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/CombinePerKeyExamplesTest.java new file mode 100644 index 000000000000..fe4823d0994f --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/CombinePerKeyExamplesTest.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.cookbook.CombinePerKeyExamples.ExtractLargeWordsFn; +import com.google.cloud.dataflow.examples.cookbook.CombinePerKeyExamples.FormatShakespeareOutputFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** Unit tests for {@link CombinePerKeyExamples}. */ +@RunWith(JUnit4.class) +public class CombinePerKeyExamplesTest { + + private static final TableRow row1 = new TableRow() + .set("corpus", "king_lear").set("word", "snuffleupaguses"); + private static final TableRow row2 = new TableRow() + .set("corpus", "macbeth").set("word", "antidisestablishmentarianism"); + private static final TableRow row3 = new TableRow() + .set("corpus", "king_lear").set("word", "antidisestablishmentarianism"); + private static final TableRow row4 = new TableRow() + .set("corpus", "macbeth").set("word", "bob"); + private static final TableRow row5 = new TableRow() + .set("corpus", "king_lear").set("word", "hi"); + + static final TableRow[] ROWS_ARRAY = new TableRow[] { + row1, row2, row3, row4, row5 + }; + + private static final KV tuple1 = KV.of("snuffleupaguses", "king_lear"); + private static final KV tuple2 = KV.of("antidisestablishmentarianism", "macbeth"); + private static final KV tuple3 = KV.of("antidisestablishmentarianism", + "king_lear"); + + private static final KV combinedTuple1 = KV.of("antidisestablishmentarianism", + "king_lear,macbeth"); + private static final KV combinedTuple2 = KV.of("snuffleupaguses", "king_lear"); + + @SuppressWarnings({"unchecked", "rawtypes"}) + static final KV[] COMBINED_TUPLES_ARRAY = new KV[] { + combinedTuple1, combinedTuple2 + }; + + private static final TableRow resultRow1 = new TableRow() + .set("word", "snuffleupaguses").set("all_plays", "king_lear"); + private static final TableRow resultRow2 = new TableRow() + .set("word", "antidisestablishmentarianism") + .set("all_plays", "king_lear,macbeth"); + + @Test + public void testExtractLargeWordsFn() { + DoFnTester> extractLargeWordsFn = + DoFnTester.of(new ExtractLargeWordsFn()); + List> results = extractLargeWordsFn.processBatch(ROWS_ARRAY); + Assert.assertThat(results, CoreMatchers.hasItem(tuple1)); + Assert.assertThat(results, CoreMatchers.hasItem(tuple2)); + Assert.assertThat(results, CoreMatchers.hasItem(tuple3)); + } + + @Test + public void testFormatShakespeareOutputFn() { + DoFnTester, TableRow> formatShakespeareOutputFn = + DoFnTester.of(new FormatShakespeareOutputFn()); + List results = formatShakespeareOutputFn.processBatch(COMBINED_TUPLES_ARRAY); + Assert.assertThat(results, CoreMatchers.hasItem(resultRow1)); + Assert.assertThat(results, CoreMatchers.hasItem(resultRow2)); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/DeDupExampleTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/DeDupExampleTest.java new file mode 100644 index 000000000000..bce6b118312b --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/DeDupExampleTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link DeDupExample}. */ +@RunWith(JUnit4.class) +public class DeDupExampleTest { + + @Test + @Category(RunnableOnService.class) + public void testRemoveDuplicates() { + List strings = Arrays.asList( + "k1", + "k5", + "k5", + "k2", + "k1", + "k2", + "k3"); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(strings) + .withCoder(StringUtf8Coder.of())); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + DataflowAssert.that(output) + .containsInAnyOrder("k1", "k5", "k2", "k3"); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testRemoveDuplicatesEmpty() { + List strings = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(strings) + .withCoder(StringUtf8Coder.of())); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + DataflowAssert.that(output).empty(); + p.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/FilterExamplesTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/FilterExamplesTest.java new file mode 100644 index 000000000000..6d822f980519 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/FilterExamplesTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.cookbook.FilterExamples.FilterSingleMonthDataFn; +import com.google.cloud.dataflow.examples.cookbook.FilterExamples.ProjectionFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link FilterExamples}. */ +@RunWith(JUnit4.class) +public class FilterExamplesTest { + + private static final TableRow row1 = new TableRow() + .set("month", "6").set("day", "21") + .set("year", "2014").set("mean_temp", "85.3") + .set("tornado", true); + private static final TableRow row2 = new TableRow() + .set("month", "7").set("day", "20") + .set("year", "2014").set("mean_temp", "75.4") + .set("tornado", false); + private static final TableRow row3 = new TableRow() + .set("month", "6").set("day", "18") + .set("year", "2014").set("mean_temp", "45.3") + .set("tornado", true); + static final TableRow[] ROWS_ARRAY = new TableRow[] { + row1, row2, row3 + }; + static final List ROWS = Arrays.asList(ROWS_ARRAY); + + private static final TableRow outRow1 = new TableRow() + .set("year", 2014).set("month", 6) + .set("day", 21).set("mean_temp", 85.3); + private static final TableRow outRow2 = new TableRow() + .set("year", 2014).set("month", 7) + .set("day", 20).set("mean_temp", 75.4); + private static final TableRow outRow3 = new TableRow() + .set("year", 2014).set("month", 6) + .set("day", 18).set("mean_temp", 45.3); + private static final TableRow[] PROJROWS_ARRAY = new TableRow[] { + outRow1, outRow2, outRow3 + }; + + + @Test + public void testProjectionFn() { + DoFnTester projectionFn = + DoFnTester.of(new ProjectionFn()); + List results = projectionFn.processBatch(ROWS_ARRAY); + Assert.assertThat(results, CoreMatchers.hasItem(outRow1)); + Assert.assertThat(results, CoreMatchers.hasItem(outRow2)); + Assert.assertThat(results, CoreMatchers.hasItem(outRow3)); + } + + @Test + public void testFilterSingleMonthDataFn() { + DoFnTester filterSingleMonthDataFn = + DoFnTester.of(new FilterSingleMonthDataFn(7)); + List results = filterSingleMonthDataFn.processBatch(PROJROWS_ARRAY); + Assert.assertThat(results, CoreMatchers.hasItem(outRow2)); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/JoinExamplesTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/JoinExamplesTest.java new file mode 100644 index 000000000000..db3ae34e7dab --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/JoinExamplesTest.java @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.cookbook.JoinExamples.ExtractCountryInfoFn; +import com.google.cloud.dataflow.examples.cookbook.JoinExamples.ExtractEventDataFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link JoinExamples}. */ +@RunWith(JUnit4.class) +public class JoinExamplesTest { + + private static final TableRow row1 = new TableRow() + .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") + .set("Actor1Name", "BANGKOK").set("SOURCEURL", "http://cnn.com"); + private static final TableRow row2 = new TableRow() + .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") + .set("Actor1Name", "LAOS").set("SOURCEURL", "http://www.chicagotribune.com"); + private static final TableRow row3 = new TableRow() + .set("ActionGeo_CountryCode", "BE").set("SQLDATE", "20141213") + .set("Actor1Name", "AFGHANISTAN").set("SOURCEURL", "http://cnn.com"); + static final TableRow[] EVENTS = new TableRow[] { + row1, row2, row3 + }; + static final List EVENT_ARRAY = Arrays.asList(EVENTS); + + private static final KV kv1 = KV.of("VM", + "Date: 20141212, Actor1: LAOS, url: http://www.chicagotribune.com"); + private static final KV kv2 = KV.of("BE", + "Date: 20141213, Actor1: AFGHANISTAN, url: http://cnn.com"); + private static final KV kv3 = KV.of("BE", "Belgium"); + private static final KV kv4 = KV.of("VM", "Vietnam"); + + private static final TableRow cc1 = new TableRow() + .set("FIPSCC", "VM").set("HumanName", "Vietnam"); + private static final TableRow cc2 = new TableRow() + .set("FIPSCC", "BE").set("HumanName", "Belgium"); + static final TableRow[] CCS = new TableRow[] { + cc1, cc2 + }; + static final List CC_ARRAY = Arrays.asList(CCS); + + static final String[] JOINED_EVENTS = new String[] { + "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: LAOS, " + + "url: http://www.chicagotribune.com", + "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: BANGKOK, " + + "url: http://cnn.com", + "Country code: BE, Country name: Belgium, Event info: Date: 20141213, Actor1: AFGHANISTAN, " + + "url: http://cnn.com" + }; + + @Test + public void testExtractEventDataFn() { + DoFnTester> extractEventDataFn = + DoFnTester.of(new ExtractEventDataFn()); + List> results = extractEventDataFn.processBatch(EVENTS); + Assert.assertThat(results, CoreMatchers.hasItem(kv1)); + Assert.assertThat(results, CoreMatchers.hasItem(kv2)); + } + + @Test + public void testExtractCountryInfoFn() { + DoFnTester> extractCountryInfoFn = + DoFnTester.of(new ExtractCountryInfoFn()); + List> results = extractCountryInfoFn.processBatch(CCS); + Assert.assertThat(results, CoreMatchers.hasItem(kv3)); + Assert.assertThat(results, CoreMatchers.hasItem(kv4)); + } + + + @Test + @Category(RunnableOnService.class) + public void testJoin() throws java.lang.Exception { + Pipeline p = TestPipeline.create(); + PCollection input1 = p.apply("CreateEvent", Create.of(EVENT_ARRAY)); + PCollection input2 = p.apply("CreateCC", Create.of(CC_ARRAY)); + + PCollection output = JoinExamples.joinEvents(input1, input2); + DataflowAssert.that(output).containsInAnyOrder(JOINED_EVENTS); + p.run(); + } +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/MaxPerKeyExamplesTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/MaxPerKeyExamplesTest.java new file mode 100644 index 000000000000..3deff2a2e354 --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/MaxPerKeyExamplesTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.cookbook.MaxPerKeyExamples.ExtractTempFn; +import com.google.cloud.dataflow.examples.cookbook.MaxPerKeyExamples.FormatMaxesFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** Unit tests for {@link MaxPerKeyExamples}. */ +@RunWith(JUnit4.class) +public class MaxPerKeyExamplesTest { + + private static final TableRow row1 = new TableRow() + .set("month", "6").set("day", "21") + .set("year", "2014").set("mean_temp", "85.3") + .set("tornado", true); + private static final TableRow row2 = new TableRow() + .set("month", "7").set("day", "20") + .set("year", "2014").set("mean_temp", "75.4") + .set("tornado", false); + private static final TableRow row3 = new TableRow() + .set("month", "6").set("day", "18") + .set("year", "2014").set("mean_temp", "45.3") + .set("tornado", true); + private static final List TEST_ROWS = ImmutableList.of(row1, row2, row3); + + private static final KV kv1 = KV.of(6, 85.3); + private static final KV kv2 = KV.of(6, 45.3); + private static final KV kv3 = KV.of(7, 75.4); + + private static final List> TEST_KVS = ImmutableList.of(kv1, kv2, kv3); + + private static final TableRow resultRow1 = new TableRow() + .set("month", 6) + .set("max_mean_temp", 85.3); + private static final TableRow resultRow2 = new TableRow() + .set("month", 7) + .set("max_mean_temp", 75.4); + + + @Test + public void testExtractTempFn() { + DoFnTester> extractTempFn = + DoFnTester.of(new ExtractTempFn()); + List> results = extractTempFn.processBatch(TEST_ROWS); + Assert.assertThat(results, CoreMatchers.hasItem(kv1)); + Assert.assertThat(results, CoreMatchers.hasItem(kv2)); + Assert.assertThat(results, CoreMatchers.hasItem(kv3)); + } + + @Test + public void testFormatMaxesFn() { + DoFnTester, TableRow> formatMaxesFnFn = + DoFnTester.of(new FormatMaxesFn()); + List results = formatMaxesFnFn.processBatch(TEST_KVS); + Assert.assertThat(results, CoreMatchers.hasItem(resultRow1)); + Assert.assertThat(results, CoreMatchers.hasItem(resultRow2)); + } + +} diff --git a/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/TriggerExampleTest.java b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/TriggerExampleTest.java new file mode 100644 index 000000000000..209ea521bd9d --- /dev/null +++ b/examples/src/test/java/com/google/cloud/dataflow/examples/cookbook/TriggerExampleTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.cookbook; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.cookbook.TriggerExample.ExtractFlowInfo; +import com.google.cloud.dataflow.examples.cookbook.TriggerExample.TotalFlow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Unit Tests for {@link TriggerExample}. + * The results generated by triggers are by definition non-deterministic and hence hard to test. + * The unit test does not test all aspects of the example. + */ +@RunWith(JUnit4.class) +public class TriggerExampleTest { + + private static final String[] INPUT = + {"01/01/2010 00:00:00,1108302,94,E,ML,36,100,29,0.0065,66,9,1,0.001,74.8,1,9,3,0.0028,71,1,9," + + "12,0.0099,67.4,1,9,13,0.0121,99.0,1,,,,,0,,,,,0,,,,,0,,,,,0", "01/01/2010 00:00:00," + + "1100333,5,N,FR,9,0,39,,,9,,,,0,,,,,0,,,,,0,,,,,0,,,,,0,,,,,0,,,,,0,,,,"}; + + private static final List> TIME_STAMPED_INPUT = Arrays.asList( + TimestampedValue.of("01/01/2010 00:00:00,1108302,5,W,ML,36,100,30,0.0065,66,9,1,0.001," + + "74.8,1,9,3,0.0028,71,1,9,12,0.0099,87.4,1,9,13,0.0121,99.0,1,,,,,0,,,,,0,,,,,0,,," + + ",,0", new Instant(60000)), + TimestampedValue.of("01/01/2010 00:00:00,1108302,110,E,ML,36,100,40,0.0065,66,9,1,0.001," + + "74.8,1,9,3,0.0028,71,1,9,12,0.0099,67.4,1,9,13,0.0121,99.0,1,,,,,0,,,,,0,,,,,0,,," + + ",,0", new Instant(1)), + TimestampedValue.of("01/01/2010 00:00:00,1108302,110,E,ML,36,100,50,0.0065,66,9,1," + + "0.001,74.8,1,9,3,0.0028,71,1,9,12,0.0099,97.4,1,9,13,0.0121,50.0,1,,,,,0,,,,,0" + + ",,,,,0,,,,,0", new Instant(1))); + + private static final TableRow OUT_ROW_1 = new TableRow() + .set("trigger_type", "default") + .set("freeway", "5").set("total_flow", 30) + .set("number_of_records", 1) + .set("isFirst", true).set("isLast", true) + .set("timing", "ON_TIME") + .set("window", "[1970-01-01T00:01:00.000Z..1970-01-01T00:02:00.000Z)"); + + private static final TableRow OUT_ROW_2 = new TableRow() + .set("trigger_type", "default") + .set("freeway", "110").set("total_flow", 90) + .set("number_of_records", 2) + .set("isFirst", true).set("isLast", true) + .set("timing", "ON_TIME") + .set("window", "[1970-01-01T00:00:00.000Z..1970-01-01T00:01:00.000Z)"); + + @Test + public void testExtractTotalFlow() { + DoFnTester> extractFlowInfow = DoFnTester + .of(new ExtractFlowInfo()); + + List> results = extractFlowInfow.processBatch(INPUT); + Assert.assertEquals(results.size(), 1); + Assert.assertEquals(results.get(0).getKey(), "94"); + Assert.assertEquals(results.get(0).getValue(), new Integer(29)); + + List> output = extractFlowInfow.processBatch(""); + Assert.assertEquals(output.size(), 0); + } + + @Test + @Category(RunnableOnService.class) + public void testTotalFlow () { + Pipeline pipeline = TestPipeline.create(); + PCollection> flow = pipeline + .apply(Create.timestamped(TIME_STAMPED_INPUT)) + .apply(ParDo.of(new ExtractFlowInfo())); + + PCollection totalFlow = flow + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(1)))) + .apply(new TotalFlow("default")); + + PCollection results = totalFlow.apply(ParDo.of(new FormatResults())); + + + DataflowAssert.that(results).containsInAnyOrder(OUT_ROW_1, OUT_ROW_2); + pipeline.run(); + + } + + static class FormatResults extends DoFn { + @Override + public void processElement(ProcessContext c) throws Exception { + TableRow element = c.element(); + TableRow row = new TableRow() + .set("trigger_type", element.get("trigger_type")) + .set("freeway", element.get("freeway")) + .set("total_flow", element.get("total_flow")) + .set("number_of_records", element.get("number_of_records")) + .set("isFirst", element.get("isFirst")) + .set("isLast", element.get("isLast")) + .set("timing", element.get("timing")) + .set("window", element.get("window")); + c.output(row); + } + } +} + + diff --git a/examples/src/test/java8/com/google/cloud/dataflow/examples/MinimalWordCountJava8Test.java b/examples/src/test/java8/com/google/cloud/dataflow/examples/MinimalWordCountJava8Test.java new file mode 100644 index 000000000000..fcae41c6bb52 --- /dev/null +++ b/examples/src/test/java8/com/google/cloud/dataflow/examples/MinimalWordCountJava8Test.java @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.FlatMapElements; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.io.Serializable; +import java.nio.channels.FileChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.List; + +/** + * To keep {@link MinimalWordCountJava8} simple, it is not factored or testable. This test + * file should be maintained with a copy of its code for a basic smoke test. + */ +@RunWith(JUnit4.class) +public class MinimalWordCountJava8Test implements Serializable { + + /** + * A basic smoke test that ensures there is no crash at pipeline construction time. + */ + @Test + public void testMinimalWordCountJava8() throws Exception { + Pipeline p = TestPipeline.create(); + p.getOptions().as(GcsOptions.class).setGcsUtil(buildMockGcsUtil()); + + p.apply(TextIO.Read.from("gs://dataflow-samples/shakespeare/*")) + .apply(FlatMapElements.via((String word) -> Arrays.asList(word.split("[^a-zA-Z']+"))) + .withOutputType(new TypeDescriptor() {})) + .apply(Filter.byPredicate((String word) -> !word.isEmpty())) + .apply(Count.perElement()) + .apply(MapElements + .via((KV wordCount) -> wordCount.getKey() + ": " + wordCount.getValue()) + .withOutputType(new TypeDescriptor() {})) + .apply(TextIO.Write.to("gs://YOUR_OUTPUT_BUCKET/AND_OUTPUT_PREFIX")); + } + + private GcsUtil buildMockGcsUtil() throws IOException { + GcsUtil mockGcsUtil = Mockito.mock(GcsUtil.class); + + // Any request to open gets a new bogus channel + Mockito + .when(mockGcsUtil.open(Mockito.any(GcsPath.class))) + .then(new Answer() { + @Override + public SeekableByteChannel answer(InvocationOnMock invocation) throws Throwable { + return FileChannel.open( + Files.createTempFile("channel-", ".tmp"), + StandardOpenOption.CREATE, StandardOpenOption.DELETE_ON_CLOSE); + } + }); + + // Any request for expansion returns a list containing the original GcsPath + // This is required to pass validation that occurs in TextIO during apply() + Mockito + .when(mockGcsUtil.expand(Mockito.any(GcsPath.class))) + .then(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return ImmutableList.of((GcsPath) invocation.getArguments()[0]); + } + }); + + return mockGcsUtil; + } +} diff --git a/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/GameStatsTest.java b/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/GameStatsTest.java new file mode 100644 index 000000000000..4795de2fc32d --- /dev/null +++ b/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/GameStatsTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.complete.game.GameStats.CalculateSpammyUsers; +import com.google.cloud.dataflow.examples.complete.game.UserScore.GameActionInfo; +import com.google.cloud.dataflow.examples.complete.game.UserScore.ParseEventFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.WithTimestamps; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.hamcrest.CoreMatchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Tests of GameStats. + * Because the pipeline was designed for easy readability and explanations, it lacks good + * modularity for testing. See our testing documentation for better ideas: + * https://cloud.google.com/dataflow/pipelines/testing-your-pipeline. + */ +@RunWith(JUnit4.class) +public class GameStatsTest implements Serializable { + + // User scores + static final KV[] USER_SCORES_ARRAY = new KV[] { + KV.of("Robot-2", 66), KV.of("Robot-1", 116), KV.of("user7_AndroidGreenKookaburra", 23), + KV.of("user7_AndroidGreenKookaburra", 1), + KV.of("user19_BisqueBilby", 14), KV.of("user13_ApricotQuokka", 15), + KV.of("user18_BananaEmu", 25), KV.of("user6_AmberEchidna", 8), + KV.of("user2_AmberQuokka", 6), KV.of("user0_MagentaKangaroo", 4), + KV.of("user0_MagentaKangaroo", 3), KV.of("user2_AmberCockatoo", 13), + KV.of("user7_AlmondWallaby", 15), KV.of("user6_AmberNumbat", 11), + KV.of("user6_AmberQuokka", 4) + }; + + static final List> USER_SCORES = Arrays.asList(USER_SCORES_ARRAY); + + // The expected list of 'spammers'. + static final KV[] SPAMMERS = new KV[] { + KV.of("Robot-2", 66), KV.of("Robot-1", 116) + }; + + + /** Test the calculation of 'spammy users'. */ + @Test + @Category(RunnableOnService.class) + public void testCalculateSpammyUsers() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection> input = p.apply(Create.of(USER_SCORES)); + PCollection> output = input.apply(new CalculateSpammyUsers()); + + // Check the set of spammers. + DataflowAssert.that(output).containsInAnyOrder(SPAMMERS); + + p.run(); + } + +} diff --git a/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/HourlyTeamScoreTest.java b/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/HourlyTeamScoreTest.java new file mode 100644 index 000000000000..fe163037f63c --- /dev/null +++ b/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/HourlyTeamScoreTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.examples.complete.game.UserScore.ExtractAndSumScore; +import com.google.cloud.dataflow.examples.complete.game.UserScore.GameActionInfo; +import com.google.cloud.dataflow.examples.complete.game.UserScore.ParseEventFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.WithTimestamps; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.hamcrest.CoreMatchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Tests of HourlyTeamScore. + * Because the pipeline was designed for easy readability and explanations, it lacks good + * modularity for testing. See our testing documentation for better ideas: + * https://cloud.google.com/dataflow/pipelines/testing-your-pipeline. + */ +@RunWith(JUnit4.class) +public class HourlyTeamScoreTest implements Serializable { + + static final String[] GAME_EVENTS_ARRAY = new String[] { + "user0_MagentaKangaroo,MagentaKangaroo,3,1447955630000,2015-11-19 09:53:53.444", + "user13_ApricotQuokka,ApricotQuokka,15,1447955630000,2015-11-19 09:53:53.444", + "user6_AmberNumbat,AmberNumbat,11,1447955630000,2015-11-19 09:53:53.444", + "user7_AlmondWallaby,AlmondWallaby,15,1447955630000,2015-11-19 09:53:53.444", + "user7_AndroidGreenKookaburra,AndroidGreenKookaburra,12,1447955630000,2015-11-19 09:53:53.444", + "user7_AndroidGreenKookaburra,AndroidGreenKookaburra,11,1447955630000,2015-11-19 09:53:53.444", + "user19_BisqueBilby,BisqueBilby,6,1447955630000,2015-11-19 09:53:53.444", + "user19_BisqueBilby,BisqueBilby,8,1447955630000,2015-11-19 09:53:53.444", + // time gap... + "user0_AndroidGreenEchidna,AndroidGreenEchidna,0,1447965690000,2015-11-19 12:41:31.053", + "user0_MagentaKangaroo,MagentaKangaroo,4,1447965690000,2015-11-19 12:41:31.053", + "user2_AmberCockatoo,AmberCockatoo,13,1447965690000,2015-11-19 12:41:31.053", + "user18_BananaEmu,BananaEmu,7,1447965690000,2015-11-19 12:41:31.053", + "user3_BananaEmu,BananaEmu,17,1447965690000,2015-11-19 12:41:31.053", + "user18_BananaEmu,BananaEmu,1,1447965690000,2015-11-19 12:41:31.053", + "user18_ApricotCaneToad,ApricotCaneToad,14,1447965690000,2015-11-19 12:41:31.053" + }; + + + static final List GAME_EVENTS = Arrays.asList(GAME_EVENTS_ARRAY); + + + // Used to check the filtering. + static final KV[] FILTERED_EVENTS = new KV[] { + KV.of("user0_AndroidGreenEchidna", 0), KV.of("user0_MagentaKangaroo", 4), + KV.of("user2_AmberCockatoo", 13), + KV.of("user18_BananaEmu", 7), KV.of("user3_BananaEmu", 17), + KV.of("user18_BananaEmu", 1), KV.of("user18_ApricotCaneToad", 14) + }; + + + /** Test the filtering. */ + @Test + @Category(RunnableOnService.class) + public void testUserScoresFilter() throws Exception { + Pipeline p = TestPipeline.create(); + + final Instant startMinTimestamp = new Instant(1447965680000L); + + PCollection input = p.apply(Create.of(GAME_EVENTS).withCoder(StringUtf8Coder.of())); + + PCollection> output = input + .apply(ParDo.named("ParseGameEvent").of(new ParseEventFn())) + + .apply("FilterStartTime", Filter.byPredicate( + (GameActionInfo gInfo) + -> gInfo.getTimestamp() > startMinTimestamp.getMillis())) + // run a map to access the fields in the result. + .apply(MapElements + .via((GameActionInfo gInfo) -> KV.of(gInfo.getUser(), gInfo.getScore())) + .withOutputType(new TypeDescriptor>() {})); + + DataflowAssert.that(output).containsInAnyOrder(FILTERED_EVENTS); + + p.run(); + } + +} diff --git a/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/UserScoreTest.java b/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/UserScoreTest.java new file mode 100644 index 000000000000..69601be1bb4d --- /dev/null +++ b/examples/src/test/java8/com/google/cloud/dataflow/examples/complete/game/UserScoreTest.java @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.examples.complete.game; + +import com.google.cloud.dataflow.examples.complete.game.UserScore.ExtractAndSumScore; +import com.google.cloud.dataflow.examples.complete.game.UserScore.GameActionInfo; +import com.google.cloud.dataflow.examples.complete.game.UserScore.ParseEventFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Tests of UserScore. + */ +@RunWith(JUnit4.class) +public class UserScoreTest implements Serializable { + + static final String[] GAME_EVENTS_ARRAY = new String[] { + "user0_MagentaKangaroo,MagentaKangaroo,3,1447955630000,2015-11-19 09:53:53.444", + "user13_ApricotQuokka,ApricotQuokka,15,1447955630000,2015-11-19 09:53:53.444", + "user6_AmberNumbat,AmberNumbat,11,1447955630000,2015-11-19 09:53:53.444", + "user7_AlmondWallaby,AlmondWallaby,15,1447955630000,2015-11-19 09:53:53.444", + "user7_AndroidGreenKookaburra,AndroidGreenKookaburra,12,1447955630000,2015-11-19 09:53:53.444", + "user6_AliceBlueDingo,AliceBlueDingo,4,xxxxxxx,2015-11-19 09:53:53.444", + "user7_AndroidGreenKookaburra,AndroidGreenKookaburra,11,1447955630000,2015-11-19 09:53:53.444", + "THIS IS A PARSE ERROR,2015-11-19 09:53:53.444", + "user19_BisqueBilby,BisqueBilby,6,1447955630000,2015-11-19 09:53:53.444", + "user19_BisqueBilby,BisqueBilby,8,1447955630000,2015-11-19 09:53:53.444" + }; + + static final String[] GAME_EVENTS_ARRAY2 = new String[] { + "user6_AliceBlueDingo,AliceBlueDingo,4,xxxxxxx,2015-11-19 09:53:53.444", + "THIS IS A PARSE ERROR,2015-11-19 09:53:53.444", + "user13_BisqueBilby,BisqueBilby,xxx,1447955630000,2015-11-19 09:53:53.444" + }; + + static final List GAME_EVENTS = Arrays.asList(GAME_EVENTS_ARRAY); + static final List GAME_EVENTS2 = Arrays.asList(GAME_EVENTS_ARRAY2); + + static final KV[] USER_SUMS = new KV[] { + KV.of("user0_MagentaKangaroo", 3), KV.of("user13_ApricotQuokka", 15), + KV.of("user6_AmberNumbat", 11), KV.of("user7_AlmondWallaby", 15), + KV.of("user7_AndroidGreenKookaburra", 23), + KV.of("user19_BisqueBilby", 14) }; + + static final KV[] TEAM_SUMS = new KV[] { + KV.of("MagentaKangaroo", 3), KV.of("ApricotQuokka", 15), + KV.of("AmberNumbat", 11), KV.of("AlmondWallaby", 15), + KV.of("AndroidGreenKookaburra", 23), + KV.of("BisqueBilby", 14) }; + + /** Test the ParseEventFn DoFn. */ + @Test + public void testParseEventFn() { + DoFnTester parseEventFn = + DoFnTester.of(new ParseEventFn()); + + List results = parseEventFn.processBatch(GAME_EVENTS_ARRAY); + Assert.assertEquals(results.size(), 8); + Assert.assertEquals(results.get(0).getUser(), "user0_MagentaKangaroo"); + Assert.assertEquals(results.get(0).getTeam(), "MagentaKangaroo"); + Assert.assertEquals(results.get(0).getScore(), new Integer(3)); + } + + /** Tests ExtractAndSumScore("user"). */ + @Test + @Category(RunnableOnService.class) + public void testUserScoreSums() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(GAME_EVENTS).withCoder(StringUtf8Coder.of())); + + PCollection> output = input + .apply(ParDo.of(new ParseEventFn())) + // Extract and sum username/score pairs from the event data. + .apply("ExtractUserScore", new ExtractAndSumScore("user")); + + // Check the user score sums. + DataflowAssert.that(output).containsInAnyOrder(USER_SUMS); + + p.run(); + } + + /** Tests ExtractAndSumScore("team"). */ + @Test + @Category(RunnableOnService.class) + public void testTeamScoreSums() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(GAME_EVENTS).withCoder(StringUtf8Coder.of())); + + PCollection> output = input + .apply(ParDo.of(new ParseEventFn())) + // Extract and sum teamname/score pairs from the event data. + .apply("ExtractTeamScore", new ExtractAndSumScore("team")); + + // Check the team score sums. + DataflowAssert.that(output).containsInAnyOrder(TEAM_SUMS); + + p.run(); + } + + /** Test that bad input data is dropped appropriately. */ + @Test + @Category(RunnableOnService.class) + public void testUserScoresBadInput() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(GAME_EVENTS2).withCoder(StringUtf8Coder.of())); + + PCollection> extract = input + .apply(ParDo.of(new ParseEventFn())) + .apply( + MapElements.via((GameActionInfo gInfo) -> KV.of(gInfo.getUser(), gInfo.getScore())) + .withOutputType(new TypeDescriptor>() {})); + + DataflowAssert.that(extract).empty(); + + p.run(); + } +} diff --git a/javadoc/README.md b/javadoc/README.md new file mode 100644 index 000000000000..8240d3ce00e8 --- /dev/null +++ b/javadoc/README.md @@ -0,0 +1,4 @@ +# SDK Javadoc + +This directory contains package-info files for external javadoc we would like +our javadoc to link to using `-linkoffline`. diff --git a/javadoc/apiclient-docs/package-list b/javadoc/apiclient-docs/package-list new file mode 100644 index 000000000000..3ec1471f3041 --- /dev/null +++ b/javadoc/apiclient-docs/package-list @@ -0,0 +1,34 @@ +com.google.api.client.googleapis +com.google.api.client.googleapis.apache +com.google.api.client.googleapis.auth.clientlogin +com.google.api.client.googleapis.auth.oauth2 +com.google.api.client.googleapis.batch +com.google.api.client.googleapis.batch.json +com.google.api.client.googleapis.compute +com.google.api.client.googleapis.extensions.android.accounts +com.google.api.client.googleapis.extensions.android.gms.auth +com.google.api.client.googleapis.extensions.appengine.auth.oauth2 +com.google.api.client.googleapis.extensions.appengine.notifications +com.google.api.client.googleapis.extensions.appengine.testing.auth.oauth2 +com.google.api.client.googleapis.extensions.java6.auth.oauth2 +com.google.api.client.googleapis.extensions.servlet.notifications +com.google.api.client.googleapis.javanet +com.google.api.client.googleapis.json +com.google.api.client.googleapis.media +com.google.api.client.googleapis.notifications +com.google.api.client.googleapis.notifications.json +com.google.api.client.googleapis.notifications.json.gson +com.google.api.client.googleapis.notifications.json.jackson2 +com.google.api.client.googleapis.services +com.google.api.client.googleapis.services.json +com.google.api.client.googleapis.services.protobuf +com.google.api.client.googleapis.testing +com.google.api.client.googleapis.testing.auth.oauth2 +com.google.api.client.googleapis.testing.compute +com.google.api.client.googleapis.testing.json +com.google.api.client.googleapis.testing.notifications +com.google.api.client.googleapis.testing.services +com.google.api.client.googleapis.testing.services.json +com.google.api.client.googleapis.testing.services.protobuf +com.google.api.client.googleapis.util +com.google.api.client.googleapis.xml.atom diff --git a/javadoc/avro-docs/package-list b/javadoc/avro-docs/package-list new file mode 100644 index 000000000000..319ff01fdec2 --- /dev/null +++ b/javadoc/avro-docs/package-list @@ -0,0 +1,30 @@ +org.apache.avro +org.apache.avro.compiler.idl +org.apache.avro.compiler.specific +org.apache.avro.data +org.apache.avro.file +org.apache.avro.generic +org.apache.avro.hadoop.file +org.apache.avro.hadoop.io +org.apache.avro.hadoop.util +org.apache.avro.io +org.apache.avro.io.parsing +org.apache.avro.ipc +org.apache.avro.ipc.generic +org.apache.avro.ipc.reflect +org.apache.avro.ipc.specific +org.apache.avro.ipc.stats +org.apache.avro.ipc.trace +org.apache.avro.mapred +org.apache.avro.mapred.tether +org.apache.avro.mapreduce +org.apache.avro.mojo +org.apache.avro.protobuf +org.apache.avro.reflect +org.apache.avro.specific +org.apache.avro.thrift +org.apache.avro.tool +org.apache.avro.util +org.apache.trevni +org.apache.trevni.avro +org.apache.trevni.avro.mapreduce diff --git a/javadoc/bq-docs/package-list b/javadoc/bq-docs/package-list new file mode 100644 index 000000000000..384b3fc2d81d --- /dev/null +++ b/javadoc/bq-docs/package-list @@ -0,0 +1,2 @@ +com.google.api.services.bigquery +com.google.api.services.bigquery.model diff --git a/javadoc/dataflow-sdk-docs/package-list b/javadoc/dataflow-sdk-docs/package-list new file mode 100644 index 000000000000..a26f5a35cb84 --- /dev/null +++ b/javadoc/dataflow-sdk-docs/package-list @@ -0,0 +1,11 @@ +com.google.cloud.dataflow.sdk +com.google.cloud.dataflow.sdk.annotations +com.google.cloud.dataflow.sdk.coders +com.google.cloud.dataflow.sdk.io +com.google.cloud.dataflow.sdk.options +com.google.cloud.dataflow.sdk.runners +com.google.cloud.dataflow.sdk.testing +com.google.cloud.dataflow.sdk.transforms +com.google.cloud.dataflow.sdk.transforms.join +com.google.cloud.dataflow.sdk.transforms.windowing +com.google.cloud.dataflow.sdk.values diff --git a/javadoc/datastore-docs/package-list b/javadoc/datastore-docs/package-list new file mode 100644 index 000000000000..ebbafd860d2f --- /dev/null +++ b/javadoc/datastore-docs/package-list @@ -0,0 +1,2 @@ +com.google.api.services.datastore +com.google.api.services.datastore.client diff --git a/javadoc/guava-docs/package-list b/javadoc/guava-docs/package-list new file mode 100644 index 000000000000..f8551784fd3f --- /dev/null +++ b/javadoc/guava-docs/package-list @@ -0,0 +1,15 @@ +com.google.common.annotations +com.google.common.base +com.google.common.cache +com.google.common.collect +com.google.common.escape +com.google.common.eventbus +com.google.common.hash +com.google.common.html +com.google.common.io +com.google.common.math +com.google.common.net +com.google.common.primitives +com.google.common.reflect +com.google.common.util.concurrent +com.google.common.xml diff --git a/javadoc/hamcrest-docs/package-list b/javadoc/hamcrest-docs/package-list new file mode 100644 index 000000000000..3f5e945f7afb --- /dev/null +++ b/javadoc/hamcrest-docs/package-list @@ -0,0 +1,10 @@ +org.hamcrest +org.hamcrest.beans +org.hamcrest.collection +org.hamcrest.core +org.hamcrest.integration +org.hamcrest.internal +org.hamcrest.number +org.hamcrest.object +org.hamcrest.text +org.hamcrest.xml diff --git a/javadoc/jackson-annotations-docs/package-list b/javadoc/jackson-annotations-docs/package-list new file mode 100644 index 000000000000..768b3bab1cda --- /dev/null +++ b/javadoc/jackson-annotations-docs/package-list @@ -0,0 +1 @@ +com.fasterxml.jackson.annotation diff --git a/javadoc/jackson-databind-docs/package-list b/javadoc/jackson-databind-docs/package-list new file mode 100644 index 000000000000..8a2cd8be56f8 --- /dev/null +++ b/javadoc/jackson-databind-docs/package-list @@ -0,0 +1,20 @@ +com.fasterxml.jackson.databind +com.fasterxml.jackson.databind.annotation +com.fasterxml.jackson.databind.cfg +com.fasterxml.jackson.databind.deser +com.fasterxml.jackson.databind.deser.impl +com.fasterxml.jackson.databind.deser.std +com.fasterxml.jackson.databind.exc +com.fasterxml.jackson.databind.ext +com.fasterxml.jackson.databind.introspect +com.fasterxml.jackson.databind.jsonFormatVisitors +com.fasterxml.jackson.databind.jsonschema +com.fasterxml.jackson.databind.jsontype +com.fasterxml.jackson.databind.jsontype.impl +com.fasterxml.jackson.databind.module +com.fasterxml.jackson.databind.node +com.fasterxml.jackson.databind.ser +com.fasterxml.jackson.databind.ser.impl +com.fasterxml.jackson.databind.ser.std +com.fasterxml.jackson.databind.type +com.fasterxml.jackson.databind.util diff --git a/javadoc/joda-docs/package-list b/javadoc/joda-docs/package-list new file mode 100644 index 000000000000..2ab05aa0cf08 --- /dev/null +++ b/javadoc/joda-docs/package-list @@ -0,0 +1,7 @@ +org.joda.time +org.joda.time.base +org.joda.time.chrono +org.joda.time.convert +org.joda.time.field +org.joda.time.format +org.joda.time.tz diff --git a/javadoc/junit-docs/package-list b/javadoc/junit-docs/package-list new file mode 100644 index 000000000000..0735177ae6ed --- /dev/null +++ b/javadoc/junit-docs/package-list @@ -0,0 +1,7 @@ +org.hamcrest.core +org.junit +org.junit.matchers +org.junit.runner +org.junit.runner.manipulation +org.junit.runner.notification +org.junit.runners diff --git a/javadoc/oauth-docs/package-list b/javadoc/oauth-docs/package-list new file mode 100644 index 000000000000..38fc0461af4c --- /dev/null +++ b/javadoc/oauth-docs/package-list @@ -0,0 +1,11 @@ +com.google.api.client.auth.oauth +com.google.api.client.auth.oauth2 +com.google.api.client.auth.openidconnect +com.google.api.client.extensions.appengine.auth +com.google.api.client.extensions.appengine.auth.oauth2 +com.google.api.client.extensions.auth.helpers +com.google.api.client.extensions.auth.helpers.oauth +com.google.api.client.extensions.java6.auth.oauth2 +com.google.api.client.extensions.jetty.auth.oauth2 +com.google.api.client.extensions.servlet.auth +com.google.api.client.extensions.servlet.auth.oauth2 diff --git a/javadoc/overview.html b/javadoc/overview.html new file mode 100644 index 000000000000..4ffd33f22fca --- /dev/null +++ b/javadoc/overview.html @@ -0,0 +1,31 @@ + + + + Google Cloud Dataflow Java SDK + + +

The Google Cloud Dataflow SDK for Java provides a simple and elegant + programming model to express your data processing pipelines; + see our product page + for more information and getting started instructions.

+ +

The easiest way to use the Google Cloud Dataflow SDK for Java is via + one of the released artifacts from the + + Maven Central Repository. + See our + release notes for more information about each released version.

+ +

Version numbers use the form major.minor.incremental + and are incremented as follows:

+

    +
  • major version for incompatible API changes
  • +
  • minor version for new functionality added in a backward-compatible manner
  • +
  • incremental version for forward-compatible bug fixes
  • +
+ +

Please note that APIs marked + {@link com.google.cloud.dataflow.sdk.annotations.Experimental @Experimental} + may change at any point and are not guaranteed to remain compatible across versions.

+ + diff --git a/maven-archetypes/examples/pom.xml b/maven-archetypes/examples/pom.xml new file mode 100644 index 000000000000..6cb1852562a5 --- /dev/null +++ b/maven-archetypes/examples/pom.xml @@ -0,0 +1,56 @@ + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + 1.5.0-SNAPSHOT + ../../pom.xml + + + com.google.cloud.dataflow + google-cloud-dataflow-java-archetypes-examples + Google Cloud Dataflow Java SDK - Examples Archetype + A Maven Archetype to create a project containing all the + example pipelines from the Google Cloud Dataflow Java SDK. + http://cloud.google.com/dataflow + + maven-archetype + + + + + org.apache.maven.archetype + archetype-packaging + 2.4 + + + + + + + maven-archetype-plugin + 2.4 + + + + + diff --git a/maven-archetypes/examples/src/main/resources/META-INF/maven/archetype-metadata.xml b/maven-archetypes/examples/src/main/resources/META-INF/maven/archetype-metadata.xml new file mode 100644 index 000000000000..7742af4e7242 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/META-INF/maven/archetype-metadata.xml @@ -0,0 +1,29 @@ + + + + + + 1.7 + + + + + + src/main/java + + **/*.java + + + + + src/test/java + + **/*.java + + + + diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml b/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml new file mode 100644 index 000000000000..bffa376f5666 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml @@ -0,0 +1,204 @@ + + + + 4.0.0 + + ${groupId} + ${artifactId} + ${version} + + jar + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.3 + + ${targetPlatform} + ${targetPlatform} + + + + + org.apache.maven.plugins + maven-shade-plugin + 2.3 + + + package + + shade + + + ${project.artifactId}-bundled-${project.version} + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.18.1 + + all + 4 + true + + + + org.apache.maven.surefire + surefire-junit47 + 2.18.1 + + + + + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + [1.0.0, 2.0.0) + + + + com.google.api-client + google-api-client + 1.21.0 + + + + com.google.guava + guava-jdk5 + + + + + + + com.google.apis + google-api-services-bigquery + v2-rev248-1.21.0 + + + + com.google.guava + guava-jdk5 + + + + + + com.google.http-client + google-http-client + 1.21.0 + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-pubsub + v1-rev7-1.21.0 + + + + com.google.guava + guava-jdk5 + + + + + + joda-time + joda-time + 2.4 + + + + com.google.guava + guava + 18.0 + + + + javax.servlet + javax.servlet-api + 3.1.0 + + + + + org.slf4j + slf4j-api + 1.7.7 + + + + org.slf4j + slf4j-jdk14 + 1.7.7 + + runtime + + + + + org.hamcrest + hamcrest-all + 1.3 + + + + junit + junit + 4.11 + + + diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/DebuggingWordCount.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/DebuggingWordCount.java new file mode 100644 index 000000000000..3cf2bc0dffa3 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/DebuggingWordCount.java @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import ${package}.WordCount.WordCountOptions; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + + +/** + * An example that verifies word counts in Shakespeare and includes Dataflow best practices. + * + *

This class, {@link DebuggingWordCount}, is the third in a series of four successively more + * detailed 'word count' examples. You may first want to take a look at {@link MinimalWordCount} + * and {@link WordCount}. After you've looked at this example, then see the + * {@link WindowedWordCount} pipeline, for introduction of additional concepts. + * + *

Basic concepts, also in the MinimalWordCount and WordCount examples: + * Reading text files; counting a PCollection; executing a Pipeline both locally + * and using the Dataflow service; defining DoFns. + * + *

New Concepts: + *

+ *   1. Logging to Cloud Logging
+ *   2. Controlling Dataflow worker log levels
+ *   3. Creating a custom aggregator
+ *   4. Testing your Pipeline via DataflowAssert
+ * 
+ * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * + *

To execute this pipeline using the Dataflow service and the additional logging discussed + * below, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ *   --workerLogLevelOverrides={"com.google.cloud.dataflow.examples":"DEBUG"}
+ * }
+ * 
+ * + *

Note that when you run via mvn exec, you may need to escape + * the quotations as appropriate for your shell. For example, in bash: + *

+ * mvn compile exec:java ... \
+ *   -Dexec.args="... \
+ *     --workerLogLevelOverrides={\\\"com.google.cloud.dataflow.examples\\\":\\\"DEBUG\\\"}"
+ * 
+ * + *

Concept #2: Dataflow workers which execute user code are configured to log to Cloud + * Logging by default at "INFO" log level and higher. One may override log levels for specific + * logging namespaces by specifying: + *


+ *   --workerLogLevelOverrides={"Name1":"Level1","Name2":"Level2",...}
+ * 
+ * For example, by specifying: + *

+ *   --workerLogLevelOverrides={"com.google.cloud.dataflow.examples":"DEBUG"}
+ * 
+ * when executing this pipeline using the Dataflow service, Cloud Logging would contain only + * "DEBUG" or higher level logs for the {@code com.google.cloud.dataflow.examples} package in + * addition to the default "INFO" or higher level logs. In addition, the default Dataflow worker + * logging configuration can be overridden by specifying + * {@code --defaultWorkerLogLevel=}. For example, + * by specifying {@code --defaultWorkerLogLevel=DEBUG} when executing this pipeline with + * the Dataflow service, Cloud Logging would contain all "DEBUG" or higher level logs. Note + * that changing the default worker log level to TRACE or DEBUG will significantly increase + * the amount of logs output. + * + *

The input file defaults to {@code gs://dataflow-samples/shakespeare/kinglear.txt} and can be + * overridden with {@code --inputFile}. + */ +public class DebuggingWordCount { + /** A DoFn that filters for a specific key based upon a regular expression. */ + public static class FilterTextFn extends DoFn, KV> { + /** + * Concept #1: The logger below uses the fully qualified class name of FilterTextFn + * as the logger. All log statements emitted by this logger will be referenced by this name + * and will be visible in the Cloud Logging UI. Learn more at https://cloud.google.com/logging + * about the Cloud Logging UI. + */ + private static final Logger LOG = LoggerFactory.getLogger(FilterTextFn.class); + + private final Pattern filter; + public FilterTextFn(String pattern) { + filter = Pattern.compile(pattern); + } + + /** + * Concept #3: A custom aggregator can track values in your pipeline as it runs. Those + * values will be displayed in the Dataflow Monitoring UI when this pipeline is run using the + * Dataflow service. These aggregators below track the number of matched and unmatched words. + * Learn more at https://cloud.google.com/dataflow/pipelines/dataflow-monitoring-intf about + * the Dataflow Monitoring UI. + */ + private final Aggregator matchedWords = + createAggregator("matchedWords", new Sum.SumLongFn()); + private final Aggregator unmatchedWords = + createAggregator("umatchedWords", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (filter.matcher(c.element().getKey()).matches()) { + // Log at the "DEBUG" level each element that we match. When executing this pipeline + // using the Dataflow service, these log lines will appear in the Cloud Logging UI + // only if the log level is set to "DEBUG" or lower. + LOG.debug("Matched: " + c.element().getKey()); + matchedWords.addValue(1L); + c.output(c.element()); + } else { + // Log at the "TRACE" level each element that is not matched. Different log levels + // can be used to control the verbosity of logging providing an effective mechanism + // to filter less important information. + LOG.trace("Did not match: " + c.element().getKey()); + unmatchedWords.addValue(1L); + } + } + } + + public static void main(String[] args) { + WordCountOptions options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(WordCountOptions.class); + Pipeline p = Pipeline.create(options); + + PCollection> filteredWords = + p.apply(TextIO.Read.named("ReadLines").from(options.getInputFile())) + .apply(new WordCount.CountWords()) + .apply(ParDo.of(new FilterTextFn("Flourish|stomach"))); + + /** + * Concept #4: DataflowAssert is a set of convenient PTransforms in the style of + * Hamcrest's collection matchers that can be used when writing Pipeline level tests + * to validate the contents of PCollections. DataflowAssert is best used in unit tests + * with small data sets but is demonstrated here as a teaching tool. + * + *

Below we verify that the set of filtered words matches our expected counts. Note + * that DataflowAssert does not provide any output and that successful completion of the + * Pipeline implies that the expectations were met. Learn more at + * https://cloud.google.com/dataflow/pipelines/testing-your-pipeline on how to test + * your Pipeline and see {@link DebuggingWordCountTest} for an example unit test. + */ + List> expectedResults = Arrays.asList( + KV.of("Flourish", 3L), + KV.of("stomach", 1L)); + DataflowAssert.that(filteredWords).containsInAnyOrder(expectedResults); + + p.run(); + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/MinimalWordCount.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/MinimalWordCount.java new file mode 100644 index 000000000000..035db01e4a80 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/MinimalWordCount.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; + + +/** + * An example that counts words in Shakespeare. + * + *

This class, {@link MinimalWordCount}, is the first in a series of four successively more + * detailed 'word count' examples. Here, for simplicity, we don't show any error-checking or + * argument processing, and focus on construction of the pipeline, which chains together the + * application of core transforms. + * + *

Next, see the {@link WordCount} pipeline, then the {@link DebuggingWordCount}, and finally + * the {@link WindowedWordCount} pipeline, for more detailed examples that introduce additional + * concepts. + * + *

Concepts: + *

+ *   1. Reading data from text files
+ *   2. Specifying 'inline' transforms
+ *   3. Counting a PCollection
+ *   4. Writing data to Cloud Storage as text files
+ * 
+ * + *

To execute this pipeline, first edit the code to set your project ID, the staging + * location, and the output location. The specified GCS bucket(s) must already exist. + * + *

Then, run the pipeline as described in the README. It will be deployed and run using the + * Dataflow service. No args are required to run the pipeline. You can see the results in your + * output bucket in the GCS browser. + */ +public class MinimalWordCount { + + public static void main(String[] args) { + // Create a DataflowPipelineOptions object. This object lets us set various execution + // options for our pipeline, such as the associated Cloud Platform project and the location + // in Google Cloud Storage to stage files. + DataflowPipelineOptions options = PipelineOptionsFactory.create() + .as(DataflowPipelineOptions.class); + options.setRunner(BlockingDataflowPipelineRunner.class); + // CHANGE 1/3: Your project ID is required in order to run your pipeline on the Google Cloud. + options.setProject("SET_YOUR_PROJECT_ID_HERE"); + // CHANGE 2/3: Your Google Cloud Storage path is required for staging local files. + options.setStagingLocation("gs://SET_YOUR_BUCKET_NAME_HERE/AND_STAGING_DIRECTORY"); + + // Create the Pipeline object with the options we defined above. + Pipeline p = Pipeline.create(options); + + // Apply the pipeline's transforms. + + // Concept #1: Apply a root transform to the pipeline; in this case, TextIO.Read to read a set + // of input text files. TextIO.Read returns a PCollection where each element is one line from + // the input text (a set of Shakespeare's texts). + p.apply(TextIO.Read.from("gs://dataflow-samples/shakespeare/*")) + // Concept #2: Apply a ParDo transform to our PCollection of text lines. This ParDo invokes a + // DoFn (defined in-line) on each element that tokenizes the text line into individual words. + // The ParDo returns a PCollection, where each element is an individual word in + // Shakespeare's collected texts. + .apply(ParDo.named("ExtractWords").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (String word : c.element().split("[^a-zA-Z']+")) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + })) + // Concept #3: Apply the Count transform to our PCollection of individual words. The Count + // transform returns a new PCollection of key/value pairs, where each key represents a unique + // word in the text. The associated value is the occurrence count for that word. + .apply(Count.perElement()) + // Apply another ParDo transform that formats our PCollection of word counts into a printable + // string, suitable for writing to an output file. + .apply(ParDo.named("FormatResults").of(new DoFn, String>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey() + ": " + c.element().getValue()); + } + })) + // Concept #4: Apply a write transform, TextIO.Write, at the end of the pipeline. + // TextIO.Write writes the contents of a PCollection (in this case, our PCollection of + // formatted strings) to a series of text files in Google Cloud Storage. + // CHANGE 3/3: The Google Cloud Storage path is required for outputting the results to. + .apply(TextIO.Write.to("gs://YOUR_OUTPUT_BUCKET/AND_OUTPUT_PREFIX")); + + // Run the pipeline. + p.run(); + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/WindowedWordCount.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/WindowedWordCount.java new file mode 100644 index 000000000000..29921e235fb1 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/WindowedWordCount.java @@ -0,0 +1,262 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import ${package}.common.DataflowExampleUtils; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +/** + * An example that counts words in text, and can run over either unbounded or bounded input + * collections. + * + *

This class, {@link WindowedWordCount}, is the last in a series of four successively more + * detailed 'word count' examples. First take a look at {@link MinimalWordCount}, + * {@link WordCount}, and {@link DebuggingWordCount}. + * + *

Basic concepts, also in the MinimalWordCount, WordCount, and DebuggingWordCount examples: + * Reading text files; counting a PCollection; writing to GCS; executing a Pipeline both locally + * and using the Dataflow service; defining DoFns; creating a custom aggregator; + * user-defined PTransforms; defining PipelineOptions. + * + *

New Concepts: + *

+ *   1. Unbounded and bounded pipeline input modes
+ *   2. Adding timestamps to data
+ *   3. PubSub topics as sources
+ *   4. Windowing
+ *   5. Re-using PTransforms over windowed PCollections
+ *   6. Writing to BigQuery
+ * 
+ * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * + *

Optionally specify the input file path via: + * {@code --inputFile=gs://INPUT_PATH}, + * which defaults to {@code gs://dataflow-samples/shakespeare/kinglear.txt}. + * + *

Specify an output BigQuery dataset and optionally, a table for the output. If you don't + * specify the table, one will be created for you using the job name. If you don't specify the + * dataset, a dataset called {@code dataflow-examples} must already exist in your project. + * {@code --bigQueryDataset=YOUR-DATASET --bigQueryTable=YOUR-NEW-TABLE-NAME}. + * + *

Decide whether you want your pipeline to run with 'bounded' (such as files in GCS) or + * 'unbounded' input (such as a PubSub topic). To run with unbounded input, set + * {@code --unbounded=true}. Then, optionally specify the Google Cloud PubSub topic to read from + * via {@code --pubsubTopic=projects/PROJECT_ID/topics/YOUR_TOPIC_NAME}. If the topic does not + * exist, the pipeline will create one for you. It will delete this topic when it terminates. + * The pipeline will automatically launch an auxiliary batch pipeline to populate the given PubSub + * topic with the contents of the {@code --inputFile}, in order to make the example easy to run. + * If you want to use an independently-populated PubSub topic, indicate this by setting + * {@code --inputFile=""}. In that case, the auxiliary pipeline will not be started. + * + *

By default, the pipeline will do fixed windowing, on 1-minute windows. You can + * change this interval by setting the {@code --windowSize} parameter, e.g. {@code --windowSize=10} + * for 10-minute windows. + */ +public class WindowedWordCount { + private static final Logger LOG = LoggerFactory.getLogger(WindowedWordCount.class); + static final int WINDOW_SIZE = 1; // Default window duration in minutes + + /** + * Concept #2: A DoFn that sets the data element timestamp. This is a silly method, just for + * this example, for the bounded data case. + * + *

Imagine that many ghosts of Shakespeare are all typing madly at the same time to recreate + * his masterworks. Each line of the corpus will get a random associated timestamp somewhere in a + * 2-hour period. + */ + static class AddTimestampFn extends DoFn { + private static final long RAND_RANGE = 7200000; // 2 hours in ms + + @Override + public void processElement(ProcessContext c) { + // Generate a timestamp that falls somewhere in the past two hours. + long randomTimestamp = System.currentTimeMillis() + - (int) (Math.random() * RAND_RANGE); + /** + * Concept #2: Set the data element with that timestamp. + */ + c.outputWithTimestamp(c.element(), new Instant(randomTimestamp)); + } + } + + /** A DoFn that converts a Word and Count into a BigQuery table row. */ + static class FormatAsTableRowFn extends DoFn, TableRow> { + @Override + public void processElement(ProcessContext c) { + TableRow row = new TableRow() + .set("word", c.element().getKey()) + .set("count", c.element().getValue()) + // include a field for the window timestamp + .set("window_timestamp", c.timestamp().toString()); + c.output(row); + } + } + + /** + * Helper method that defines the BigQuery schema used for the output. + */ + private static TableSchema getSchema() { + List fields = new ArrayList<>(); + fields.add(new TableFieldSchema().setName("word").setType("STRING")); + fields.add(new TableFieldSchema().setName("count").setType("INTEGER")); + fields.add(new TableFieldSchema().setName("window_timestamp").setType("TIMESTAMP")); + TableSchema schema = new TableSchema().setFields(fields); + return schema; + } + + /** + * Concept #6: We'll stream the results to a BigQuery table. The BigQuery output source is one + * that supports both bounded and unbounded data. This is a helper method that creates a + * TableReference from input options, to tell the pipeline where to write its BigQuery results. + */ + private static TableReference getTableReference(Options options) { + TableReference tableRef = new TableReference(); + tableRef.setProjectId(options.getProject()); + tableRef.setDatasetId(options.getBigQueryDataset()); + tableRef.setTableId(options.getBigQueryTable()); + return tableRef; + } + + /** + * Options supported by {@link WindowedWordCount}. + * + *

Inherits standard example configuration options, which allow specification of the BigQuery + * table and the PubSub topic, as well as the {@link WordCount.WordCountOptions} support for + * specification of the input file. + */ + public static interface Options + extends WordCount.WordCountOptions, DataflowExampleUtils.DataflowExampleUtilsOptions { + @Description("Fixed window duration, in minutes") + @Default.Integer(WINDOW_SIZE) + Integer getWindowSize(); + void setWindowSize(Integer value); + + @Description("Whether to run the pipeline with unbounded input") + boolean isUnbounded(); + void setUnbounded(boolean value); + } + + public static void main(String[] args) throws IOException { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + options.setBigQuerySchema(getSchema()); + // DataflowExampleUtils creates the necessary input sources to simplify execution of this + // Pipeline. + DataflowExampleUtils exampleDataflowUtils = new DataflowExampleUtils(options, + options.isUnbounded()); + + Pipeline pipeline = Pipeline.create(options); + + /** + * Concept #1: the Dataflow SDK lets us run the same pipeline with either a bounded or + * unbounded input source. + */ + PCollection input; + if (options.isUnbounded()) { + LOG.info("Reading from PubSub."); + /** + * Concept #3: Read from the PubSub topic. A topic will be created if it wasn't + * specified as an argument. The data elements' timestamps will come from the pubsub + * injection. + */ + input = pipeline + .apply(PubsubIO.Read.topic(options.getPubsubTopic())); + } else { + /** Else, this is a bounded pipeline. Read from the GCS file. */ + input = pipeline + .apply(TextIO.Read.from(options.getInputFile())) + // Concept #2: Add an element timestamp, using an artificial time just to show windowing. + // See AddTimestampFn for more detail on this. + .apply(ParDo.of(new AddTimestampFn())); + } + + /** + * Concept #4: Window into fixed windows. The fixed window size for this example defaults to 1 + * minute (you can change this with a command-line option). See the documentation for more + * information on how fixed windows work, and for information on the other types of windowing + * available (e.g., sliding windows). + */ + PCollection windowedWords = input + .apply(Window.into( + FixedWindows.of(Duration.standardMinutes(options.getWindowSize())))); + + /** + * Concept #5: Re-use our existing CountWords transform that does not have knowledge of + * windows over a PCollection containing windowed values. + */ + PCollection> wordCounts = windowedWords.apply(new WordCount.CountWords()); + + /** + * Concept #6: Format the results for a BigQuery table, then write to BigQuery. + * The BigQuery output source supports both bounded and unbounded data. + */ + wordCounts.apply(ParDo.of(new FormatAsTableRowFn())) + .apply(BigQueryIO.Write.to(getTableReference(options)).withSchema(getSchema())); + + PipelineResult result = pipeline.run(); + + /** + * To mock unbounded input from PubSub, we'll now start an auxiliary 'injector' pipeline that + * runs for a limited time, and publishes to the input PubSub topic. + * + * With an unbounded input source, you will need to explicitly shut down this pipeline when you + * are done with it, so that you do not continue to be charged for the instances. You can do + * this via a ctrl-C from the command line, or from the developer's console UI for Dataflow + * pipelines. The PubSub topic will also be deleted at this time. + */ + exampleDataflowUtils.mockUnboundedSource(options.getInputFile(), result); + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/WordCount.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/WordCount.java new file mode 100644 index 000000000000..150b60d2d25f --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/WordCount.java @@ -0,0 +1,204 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + + +/** + * An example that counts words in Shakespeare and includes Dataflow best practices. + * + *

This class, {@link WordCount}, is the second in a series of four successively more detailed + * 'word count' examples. You may first want to take a look at {@link MinimalWordCount}. + * After you've looked at this example, then see the {@link DebuggingWordCount} + * pipeline, for introduction of additional concepts. + * + *

For a detailed walkthrough of this example, see + * + * https://cloud.google.com/dataflow/java-sdk/wordcount-example + * + * + *

Basic concepts, also in the MinimalWordCount example: + * Reading text files; counting a PCollection; writing to GCS. + * + *

New Concepts: + *

+ *   1. Executing a Pipeline both locally and using the Dataflow service
+ *   2. Using ParDo with static DoFns defined out-of-line
+ *   3. Building a composite transform
+ *   4. Defining your own pipeline options
+ * 
+ * + *

Concept #1: you can execute this pipeline either locally or using the Dataflow service. + * These are now command-line options and not hard-coded as they were in the MinimalWordCount + * example. + * To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * 
+ * and a local output file or output prefix on GCS: + *
{@code
+ *   --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PREFIX]
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * }
+ * 
+ * and an output prefix on GCS: + *
{@code
+ *   --output=gs://YOUR_OUTPUT_PREFIX
+ * }
+ * + *

The input file defaults to {@code gs://dataflow-samples/shakespeare/kinglear.txt} and can be + * overridden with {@code --inputFile}. + */ +public class WordCount { + + /** + * Concept #2: You can make your pipeline code less verbose by defining your DoFns statically out- + * of-line. This DoFn tokenizes lines of text into individual words; we pass it to a ParDo in the + * pipeline. + */ + static class ExtractWordsFn extends DoFn { + private final Aggregator emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.addValue(1L); + } + + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + /** A DoFn that converts a Word and Count into a printable string. */ + public static class FormatAsTextFn extends DoFn, String> { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey() + ": " + c.element().getValue()); + } + } + + /** + * A PTransform that converts a PCollection containing lines of text into a PCollection of + * formatted word counts. + * + *

Concept #3: This is a custom composite transform that bundles two transforms (ParDo and + * Count) as a reusable PTransform subclass. Using composite transforms allows for easy reuse, + * modular testing, and an improved monitoring experience. + */ + public static class CountWords extends PTransform, + PCollection>> { + @Override + public PCollection> apply(PCollection lines) { + + // Convert lines of text into individual words. + PCollection words = lines.apply( + ParDo.of(new ExtractWordsFn())); + + // Count the number of times each word occurs. + PCollection> wordCounts = + words.apply(Count.perElement()); + + return wordCounts; + } + } + + /** + * Options supported by {@link WordCount}. + * + *

Concept #4: Defining your own configuration options. Here, you can add your own arguments + * to be processed by the command-line parser, and specify default values for them. You can then + * access the options values in your pipeline code. + * + *

Inherits standard configuration options. + */ + public static interface WordCountOptions extends PipelineOptions { + @Description("Path of the file to read from") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInputFile(); + void setInputFile(String value); + + @Description("Path of the file to write to") + @Default.InstanceFactory(OutputFactory.class) + String getOutput(); + void setOutput(String value); + + /** + * Returns "gs://${YOUR_STAGING_DIRECTORY}/counts.txt" as the default destination. + */ + public static class OutputFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + if (dataflowOptions.getStagingLocation() != null) { + return GcsPath.fromUri(dataflowOptions.getStagingLocation()) + .resolve("counts.txt").toString(); + } else { + throw new IllegalArgumentException("Must specify --output or --stagingLocation"); + } + } + } + + } + + public static void main(String[] args) { + WordCountOptions options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(WordCountOptions.class); + Pipeline p = Pipeline.create(options); + + // Concepts #2 and #3: Our pipeline applies the composite CountWords transform, and passes the + // static FormatAsTextFn() to the ParDo transform. + p.apply(TextIO.Read.named("ReadLines").from(options.getInputFile())) + .apply(new CountWords()) + .apply(ParDo.of(new FormatAsTextFn())) + .apply(TextIO.Write.named("WriteCounts").to(options.getOutput())); + + p.run(); + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/DataflowExampleOptions.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/DataflowExampleOptions.java new file mode 100644 index 000000000000..e182f4cd2bd5 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/DataflowExampleOptions.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}.common; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; + +/** + * Options that can be used to configure the Dataflow examples. + */ +public interface DataflowExampleOptions extends DataflowPipelineOptions { + @Description("Whether to keep jobs running on the Dataflow service after local process exit") + @Default.Boolean(false) + boolean getKeepJobsRunning(); + void setKeepJobsRunning(boolean keepJobsRunning); +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/DataflowExampleUtils.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/DataflowExampleUtils.java new file mode 100644 index 000000000000..98617699ad88 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/DataflowExampleUtils.java @@ -0,0 +1,398 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}.common; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.googleapis.services.AbstractGoogleClientRequest; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.Bigquery.Datasets; +import com.google.api.services.bigquery.Bigquery.Tables; +import com.google.api.services.bigquery.model.Dataset; +import com.google.api.services.bigquery.model.DatasetReference; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.model.Topic; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineJob; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.IntraBundleParallelization; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import javax.servlet.http.HttpServletResponse; + +/** + * The utility class that sets up and tears down external resources, starts the Google Cloud Pub/Sub + * injector, and cancels the streaming and the injector pipelines once the program terminates. + * + *

It is used to run Dataflow examples, such as TrafficMaxLaneFlow and TrafficRoutes. + */ +public class DataflowExampleUtils { + + private final DataflowPipelineOptions options; + private Bigquery bigQueryClient = null; + private Pubsub pubsubClient = null; + private Dataflow dataflowClient = null; + private Set jobsToCancel = Sets.newHashSet(); + private List pendingMessages = Lists.newArrayList(); + + /** + * Define an interface that supports the PubSub and BigQuery example options. + */ + public static interface DataflowExampleUtilsOptions + extends DataflowExampleOptions, ExamplePubsubTopicOptions, ExampleBigQueryTableOptions { + } + + public DataflowExampleUtils(DataflowPipelineOptions options) { + this.options = options; + } + + /** + * Do resources and runner options setup. + */ + public DataflowExampleUtils(DataflowPipelineOptions options, boolean isUnbounded) + throws IOException { + this.options = options; + setupResourcesAndRunner(isUnbounded); + } + + /** + * Sets up external resources that are required by the example, + * such as Pub/Sub topics and BigQuery tables. + * + * @throws IOException if there is a problem setting up the resources + */ + public void setup() throws IOException { + setupPubsubTopic(); + setupBigQueryTable(); + } + + /** + * Set up external resources, and configure the runner appropriately. + */ + public void setupResourcesAndRunner(boolean isUnbounded) throws IOException { + if (isUnbounded) { + options.setStreaming(true); + } + setup(); + setupRunner(); + } + + /** + * Sets up the Google Cloud Pub/Sub topic. + * + *

If the topic doesn't exist, a new topic with the given name will be created. + * + * @throws IOException if there is a problem setting up the Pub/Sub topic + */ + public void setupPubsubTopic() throws IOException { + ExamplePubsubTopicOptions pubsubTopicOptions = options.as(ExamplePubsubTopicOptions.class); + if (!pubsubTopicOptions.getPubsubTopic().isEmpty()) { + pendingMessages.add("*******************Set Up Pubsub Topic*********************"); + setupPubsubTopic(pubsubTopicOptions.getPubsubTopic()); + pendingMessages.add("The Pub/Sub topic has been set up for this example: " + + pubsubTopicOptions.getPubsubTopic()); + } + } + + /** + * Sets up the BigQuery table with the given schema. + * + *

If the table already exists, the schema has to match the given one. Otherwise, the example + * will throw a RuntimeException. If the table doesn't exist, a new table with the given schema + * will be created. + * + * @throws IOException if there is a problem setting up the BigQuery table + */ + public void setupBigQueryTable() throws IOException { + ExampleBigQueryTableOptions bigQueryTableOptions = + options.as(ExampleBigQueryTableOptions.class); + if (bigQueryTableOptions.getBigQueryDataset() != null + && bigQueryTableOptions.getBigQueryTable() != null + && bigQueryTableOptions.getBigQuerySchema() != null) { + pendingMessages.add("******************Set Up Big Query Table*******************"); + setupBigQueryTable(bigQueryTableOptions.getProject(), + bigQueryTableOptions.getBigQueryDataset(), + bigQueryTableOptions.getBigQueryTable(), + bigQueryTableOptions.getBigQuerySchema()); + pendingMessages.add("The BigQuery table has been set up for this example: " + + bigQueryTableOptions.getProject() + + ":" + bigQueryTableOptions.getBigQueryDataset() + + "." + bigQueryTableOptions.getBigQueryTable()); + } + } + + /** + * Tears down external resources that can be deleted upon the example's completion. + */ + private void tearDown() { + pendingMessages.add("*************************Tear Down*************************"); + ExamplePubsubTopicOptions pubsubTopicOptions = options.as(ExamplePubsubTopicOptions.class); + if (!pubsubTopicOptions.getPubsubTopic().isEmpty()) { + try { + deletePubsubTopic(pubsubTopicOptions.getPubsubTopic()); + pendingMessages.add("The Pub/Sub topic has been deleted: " + + pubsubTopicOptions.getPubsubTopic()); + } catch (IOException e) { + pendingMessages.add("Failed to delete the Pub/Sub topic : " + + pubsubTopicOptions.getPubsubTopic()); + } + } + + ExampleBigQueryTableOptions bigQueryTableOptions = + options.as(ExampleBigQueryTableOptions.class); + if (bigQueryTableOptions.getBigQueryDataset() != null + && bigQueryTableOptions.getBigQueryTable() != null + && bigQueryTableOptions.getBigQuerySchema() != null) { + pendingMessages.add("The BigQuery table might contain the example's output, " + + "and it is not deleted automatically: " + + bigQueryTableOptions.getProject() + + ":" + bigQueryTableOptions.getBigQueryDataset() + + "." + bigQueryTableOptions.getBigQueryTable()); + pendingMessages.add("Please go to the Developers Console to delete it manually." + + " Otherwise, you may be charged for its usage."); + } + } + + private void setupBigQueryTable(String projectId, String datasetId, String tableId, + TableSchema schema) throws IOException { + if (bigQueryClient == null) { + bigQueryClient = Transport.newBigQueryClient(options.as(BigQueryOptions.class)).build(); + } + + Datasets datasetService = bigQueryClient.datasets(); + if (executeNullIfNotFound(datasetService.get(projectId, datasetId)) == null) { + Dataset newDataset = new Dataset().setDatasetReference( + new DatasetReference().setProjectId(projectId).setDatasetId(datasetId)); + datasetService.insert(projectId, newDataset).execute(); + } + + Tables tableService = bigQueryClient.tables(); + Table table = executeNullIfNotFound(tableService.get(projectId, datasetId, tableId)); + if (table == null) { + Table newTable = new Table().setSchema(schema).setTableReference( + new TableReference().setProjectId(projectId).setDatasetId(datasetId).setTableId(tableId)); + tableService.insert(projectId, datasetId, newTable).execute(); + } else if (!table.getSchema().equals(schema)) { + throw new RuntimeException( + "Table exists and schemas do not match, expecting: " + schema.toPrettyString() + + ", actual: " + table.getSchema().toPrettyString()); + } + } + + private void setupPubsubTopic(String topic) throws IOException { + if (pubsubClient == null) { + pubsubClient = Transport.newPubsubClient(options).build(); + } + if (executeNullIfNotFound(pubsubClient.projects().topics().get(topic)) == null) { + pubsubClient.projects().topics().create(topic, new Topic().setName(topic)).execute(); + } + } + + /** + * Deletes the Google Cloud Pub/Sub topic. + * + * @throws IOException if there is a problem deleting the Pub/Sub topic + */ + private void deletePubsubTopic(String topic) throws IOException { + if (pubsubClient == null) { + pubsubClient = Transport.newPubsubClient(options).build(); + } + if (executeNullIfNotFound(pubsubClient.projects().topics().get(topic)) != null) { + pubsubClient.projects().topics().delete(topic).execute(); + } + } + + /** + * If this is an unbounded (streaming) pipeline, and both inputFile and pubsub topic are defined, + * start an 'injector' pipeline that publishes the contents of the file to the given topic, first + * creating the topic if necessary. + */ + public void startInjectorIfNeeded(String inputFile) { + ExamplePubsubTopicOptions pubsubTopicOptions = options.as(ExamplePubsubTopicOptions.class); + if (pubsubTopicOptions.isStreaming() + && inputFile != null && !inputFile.isEmpty() + && pubsubTopicOptions.getPubsubTopic() != null + && !pubsubTopicOptions.getPubsubTopic().isEmpty()) { + runInjectorPipeline(inputFile, pubsubTopicOptions.getPubsubTopic()); + } + } + + /** + * Do some runner setup: check that the DirectPipelineRunner is not used in conjunction with + * streaming, and if streaming is specified, use the DataflowPipelineRunner. Return the streaming + * flag value. + */ + public void setupRunner() { + if (options.isStreaming()) { + if (options.getRunner() == DirectPipelineRunner.class) { + throw new IllegalArgumentException( + "Processing of unbounded input sources is not supported with the DirectPipelineRunner."); + } + // In order to cancel the pipelines automatically, + // {@literal DataflowPipelineRunner} is forced to be used. + options.setRunner(DataflowPipelineRunner.class); + } + } + + /** + * Runs the batch injector for the streaming pipeline. + * + *

The injector pipeline will read from the given text file, and inject data + * into the Google Cloud Pub/Sub topic. + */ + public void runInjectorPipeline(String inputFile, String topic) { + DataflowPipelineOptions copiedOptions = options.cloneAs(DataflowPipelineOptions.class); + copiedOptions.setStreaming(false); + copiedOptions.setNumWorkers( + options.as(ExamplePubsubTopicOptions.class).getInjectorNumWorkers()); + copiedOptions.setJobName(options.getJobName() + "-injector"); + Pipeline injectorPipeline = Pipeline.create(copiedOptions); + injectorPipeline.apply(TextIO.Read.from(inputFile)) + .apply(IntraBundleParallelization + .of(PubsubFileInjector.publish(topic)) + .withMaxParallelism(20)); + DataflowPipelineJob injectorJob = (DataflowPipelineJob) injectorPipeline.run(); + jobsToCancel.add(injectorJob); + } + + /** + * Runs the provided injector pipeline for the streaming pipeline. + */ + public void runInjectorPipeline(Pipeline injectorPipeline) { + DataflowPipelineJob injectorJob = (DataflowPipelineJob) injectorPipeline.run(); + jobsToCancel.add(injectorJob); + } + + /** + * Start the auxiliary injector pipeline, then wait for this pipeline to finish. + */ + public void mockUnboundedSource(String inputFile, PipelineResult result) { + startInjectorIfNeeded(inputFile); + waitToFinish(result); + } + + /** + * If {@literal DataflowPipelineRunner} or {@literal BlockingDataflowPipelineRunner} is used, + * waits for the pipeline to finish and cancels it (and the injector) before the program exists. + */ + public void waitToFinish(PipelineResult result) { + if (result instanceof DataflowPipelineJob) { + final DataflowPipelineJob job = (DataflowPipelineJob) result; + jobsToCancel.add(job); + if (!options.as(DataflowExampleOptions.class).getKeepJobsRunning()) { + addShutdownHook(jobsToCancel); + } + try { + job.waitToFinish(-1, TimeUnit.SECONDS, new MonitoringUtil.PrintHandler(System.out)); + } catch (Exception e) { + throw new RuntimeException("Failed to wait for job to finish: " + job.getJobId()); + } + } else { + // Do nothing if the given PipelineResult doesn't support waitToFinish(), + // such as EvaluationResults returned by DirectPipelineRunner. + } + } + + private void addShutdownHook(final Collection jobs) { + if (dataflowClient == null) { + dataflowClient = options.getDataflowClient(); + } + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + tearDown(); + printPendingMessages(); + for (DataflowPipelineJob job : jobs) { + System.out.println("Canceling example pipeline: " + job.getJobId()); + try { + job.cancel(); + } catch (IOException e) { + System.out.println("Failed to cancel the job," + + " please go to the Developers Console to cancel it manually"); + System.out.println( + MonitoringUtil.getJobMonitoringPageURL(job.getProjectId(), job.getJobId())); + } + } + + for (DataflowPipelineJob job : jobs) { + boolean cancellationVerified = false; + for (int retryAttempts = 6; retryAttempts > 0; retryAttempts--) { + if (job.getState().isTerminal()) { + cancellationVerified = true; + System.out.println("Canceled example pipeline: " + job.getJobId()); + break; + } else { + System.out.println( + "The example pipeline is still running. Verifying the cancellation."); + } + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + // Ignore + } + } + if (!cancellationVerified) { + System.out.println("Failed to verify the cancellation for job: " + job.getJobId()); + System.out.println("Please go to the Developers Console to verify manually:"); + System.out.println( + MonitoringUtil.getJobMonitoringPageURL(job.getProjectId(), job.getJobId())); + } + } + } + }); + } + + private void printPendingMessages() { + System.out.println(); + System.out.println("***********************************************************"); + System.out.println("***********************************************************"); + for (String message : pendingMessages) { + System.out.println(message); + } + System.out.println("***********************************************************"); + System.out.println("***********************************************************"); + } + + private static T executeNullIfNotFound( + AbstractGoogleClientRequest request) throws IOException { + try { + return request.execute(); + } catch (GoogleJsonResponseException e) { + if (e.getStatusCode() == HttpServletResponse.SC_NOT_FOUND) { + return null; + } else { + throw e; + } + } + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/ExampleBigQueryTableOptions.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/ExampleBigQueryTableOptions.java new file mode 100644 index 000000000000..bef5bfdd83bd --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/ExampleBigQueryTableOptions.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}.common; + +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Options that can be used to configure BigQuery tables in Dataflow examples. + * The project defaults to the project being used to run the example. + */ +public interface ExampleBigQueryTableOptions extends DataflowPipelineOptions { + @Description("BigQuery dataset name") + @Default.String("dataflow_examples") + String getBigQueryDataset(); + void setBigQueryDataset(String dataset); + + @Description("BigQuery table name") + @Default.InstanceFactory(BigQueryTableFactory.class) + String getBigQueryTable(); + void setBigQueryTable(String table); + + @Description("BigQuery table schema") + TableSchema getBigQuerySchema(); + void setBigQuerySchema(TableSchema schema); + + /** + * Returns the job name as the default BigQuery table name. + */ + static class BigQueryTableFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + return options.as(DataflowPipelineOptions.class).getJobName() + .replace('-', '_'); + } + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/ExamplePubsubTopicOptions.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/ExamplePubsubTopicOptions.java new file mode 100644 index 000000000000..525de69cdd77 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/ExamplePubsubTopicOptions.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}.common; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Options that can be used to configure Pub/Sub topic in Dataflow examples. + */ +public interface ExamplePubsubTopicOptions extends DataflowPipelineOptions { + @Description("Pub/Sub topic") + @Default.InstanceFactory(PubsubTopicFactory.class) + String getPubsubTopic(); + void setPubsubTopic(String topic); + + @Description("Number of workers to use when executing the injector pipeline") + @Default.Integer(1) + int getInjectorNumWorkers(); + void setInjectorNumWorkers(int numWorkers); + + /** + * Returns a default Pub/Sub topic based on the project and the job names. + */ + static class PubsubTopicFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowPipelineOptions = + options.as(DataflowPipelineOptions.class); + return "projects/" + dataflowPipelineOptions.getProject() + + "/topics/" + dataflowPipelineOptions.getJobName(); + } + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/PubsubFileInjector.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/PubsubFileInjector.java new file mode 100644 index 000000000000..f6f80aec7d64 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/main/java/common/PubsubFileInjector.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}.common; + +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.model.PublishRequest; +import com.google.api.services.pubsub.model.PubsubMessage; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.IntraBundleParallelization; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.common.collect.ImmutableMap; + +import java.io.IOException; +import java.util.Arrays; + +/** + * A batch Dataflow pipeline for injecting a set of GCS files into + * a PubSub topic line by line. Empty lines are skipped. + * + *

This is useful for testing streaming + * pipelines. Note that since batch pipelines might retry chunks, this + * does _not_ guarantee exactly-once injection of file data. Some lines may + * be published multiple times. + *

+ */ +public class PubsubFileInjector { + + /** + * An incomplete {@code PubsubFileInjector} transform with unbound output topic. + */ + public static class Unbound { + private final String timestampLabelKey; + + Unbound() { + this.timestampLabelKey = null; + } + + Unbound(String timestampLabelKey) { + this.timestampLabelKey = timestampLabelKey; + } + + Unbound withTimestampLabelKey(String timestampLabelKey) { + return new Unbound(timestampLabelKey); + } + + public Bound publish(String outputTopic) { + return new Bound(outputTopic, timestampLabelKey); + } + } + + /** A DoFn that publishes non-empty lines to Google Cloud PubSub. */ + public static class Bound extends DoFn { + private final String outputTopic; + private final String timestampLabelKey; + public transient Pubsub pubsub; + + public Bound(String outputTopic, String timestampLabelKey) { + this.outputTopic = outputTopic; + this.timestampLabelKey = timestampLabelKey; + } + + @Override + public void startBundle(Context context) { + this.pubsub = + Transport.newPubsubClient(context.getPipelineOptions().as(DataflowPipelineOptions.class)) + .build(); + } + + @Override + public void processElement(ProcessContext c) throws IOException { + if (c.element().isEmpty()) { + return; + } + PubsubMessage pubsubMessage = new PubsubMessage(); + pubsubMessage.encodeData(c.element().getBytes()); + if (timestampLabelKey != null) { + pubsubMessage.setAttributes( + ImmutableMap.of(timestampLabelKey, Long.toString(c.timestamp().getMillis()))); + } + PublishRequest publishRequest = new PublishRequest(); + publishRequest.setMessages(Arrays.asList(pubsubMessage)); + this.pubsub.projects().topics().publish(outputTopic, publishRequest).execute(); + } + } + + /** + * Creates a {@code PubsubFileInjector} transform with the given timestamp label key. + */ + public static Unbound withTimestampLabelKey(String timestampLabelKey) { + return new Unbound(timestampLabelKey); + } + + /** + * Creates a {@code PubsubFileInjector} transform that publishes to the given output topic. + */ + public static Bound publish(String outputTopic) { + return new Unbound().publish(outputTopic); + } + + /** + * Command line parameter options. + */ + private interface PubsubFileInjectorOptions extends PipelineOptions { + @Description("GCS location of files.") + @Validation.Required + String getInput(); + void setInput(String value); + + @Description("Topic to publish on.") + @Validation.Required + String getOutputTopic(); + void setOutputTopic(String value); + } + + /** + * Sets up and starts streaming pipeline. + */ + public static void main(String[] args) { + PubsubFileInjectorOptions options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(PubsubFileInjectorOptions.class); + + Pipeline pipeline = Pipeline.create(options); + + pipeline + .apply(TextIO.Read.from(options.getInput())) + .apply(IntraBundleParallelization.of(PubsubFileInjector.publish(options.getOutputTopic())) + .withMaxParallelism(20)); + + pipeline.run(); + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/test/java/DebuggingWordCountTest.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/test/java/DebuggingWordCountTest.java new file mode 100644 index 000000000000..7a9aa4c5c67c --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/test/java/DebuggingWordCountTest.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import com.google.common.io.Files; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.nio.charset.StandardCharsets; + +/** + * Tests for {@link DebuggingWordCount}. + */ +@RunWith(JUnit4.class) +public class DebuggingWordCountTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testDebuggingWordCount() throws Exception { + File file = tmpFolder.newFile(); + Files.write("stomach secret Flourish message Flourish here Flourish", file, + StandardCharsets.UTF_8); + DebuggingWordCount.main(new String[]{"--inputFile=" + file.getAbsolutePath()}); + } +} diff --git a/maven-archetypes/examples/src/main/resources/archetype-resources/src/test/java/WordCountTest.java b/maven-archetypes/examples/src/main/resources/archetype-resources/src/test/java/WordCountTest.java new file mode 100644 index 000000000000..45555ce3ce48 --- /dev/null +++ b/maven-archetypes/examples/src/main/resources/archetype-resources/src/test/java/WordCountTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import ${package}.WordCount.CountWords; +import ${package}.WordCount.ExtractWordsFn; +import ${package}.WordCount.FormatAsTextFn; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests of WordCount. + */ +@RunWith(JUnit4.class) +public class WordCountTest { + + /** Example test that tests a specific DoFn. */ + @Test + public void testExtractWordsFn() { + DoFnTester extractWordsFn = + DoFnTester.of(new ExtractWordsFn()); + + Assert.assertThat(extractWordsFn.processBatch(" some input words "), + CoreMatchers.hasItems("some", "input", "words")); + Assert.assertThat(extractWordsFn.processBatch(" "), + CoreMatchers.hasItems()); + Assert.assertThat(extractWordsFn.processBatch(" some ", " input", " words"), + CoreMatchers.hasItems("some", "input", "words")); + } + + static final String[] WORDS_ARRAY = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final List WORDS = Arrays.asList(WORDS_ARRAY); + + static final String[] COUNTS_ARRAY = new String[] { + "hi: 5", "there: 1", "sue: 2", "bob: 2"}; + + /** Example test that tests a PTransform by using an in-memory input and inspecting the output. */ + @Test + @Category(RunnableOnService.class) + public void testCountWords() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())); + + PCollection output = input.apply(new CountWords()) + .apply(ParDo.of(new FormatAsTextFn())); + + DataflowAssert.that(output).containsInAnyOrder(COUNTS_ARRAY); + p.run(); + } +} diff --git a/maven-archetypes/examples/src/test/resources/projects/basic/archetype.properties b/maven-archetypes/examples/src/test/resources/projects/basic/archetype.properties new file mode 100644 index 000000000000..c59e77a9d55b --- /dev/null +++ b/maven-archetypes/examples/src/test/resources/projects/basic/archetype.properties @@ -0,0 +1,5 @@ +package=it.pkg +version=0.1-SNAPSHOT +groupId=archetype.it +artifactId=basic +targetPlatform=1.7 diff --git a/maven-archetypes/examples/src/test/resources/projects/basic/goal.txt b/maven-archetypes/examples/src/test/resources/projects/basic/goal.txt new file mode 100644 index 000000000000..0b5987362fe3 --- /dev/null +++ b/maven-archetypes/examples/src/test/resources/projects/basic/goal.txt @@ -0,0 +1 @@ +verify diff --git a/maven-archetypes/starter/pom.xml b/maven-archetypes/starter/pom.xml new file mode 100644 index 000000000000..b5b32514dcb2 --- /dev/null +++ b/maven-archetypes/starter/pom.xml @@ -0,0 +1,56 @@ + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + 1.5.0-SNAPSHOT + ../../pom.xml + + + com.google.cloud.dataflow + google-cloud-dataflow-java-archetypes-starter + Google Cloud Dataflow Java SDK - Starter Archetype + A Maven archetype to create a simple starter pipeline to + get started using the Google Cloud Dataflow Java SDK. + http://cloud.google.com/dataflow + + maven-archetype + + + + + org.apache.maven.archetype + archetype-packaging + 2.4 + + + + + + + maven-archetype-plugin + 2.4 + + + + + diff --git a/maven-archetypes/starter/src/main/resources/META-INF/maven/archetype-metadata.xml b/maven-archetypes/starter/src/main/resources/META-INF/maven/archetype-metadata.xml new file mode 100644 index 000000000000..bf75798d39b7 --- /dev/null +++ b/maven-archetypes/starter/src/main/resources/META-INF/maven/archetype-metadata.xml @@ -0,0 +1,21 @@ + + + + + 1.7 + + + + + + src/main/java + + **/*.java + + + + diff --git a/maven-archetypes/starter/src/main/resources/archetype-resources/pom.xml b/maven-archetypes/starter/src/main/resources/archetype-resources/pom.xml new file mode 100644 index 000000000000..bb679a00b4dc --- /dev/null +++ b/maven-archetypes/starter/src/main/resources/archetype-resources/pom.xml @@ -0,0 +1,43 @@ + + 4.0.0 + + ${groupId} + ${artifactId} + ${version} + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.3 + + ${targetPlatform} + ${targetPlatform} + + + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + [1.0.0, 2.0.0) + + + + + org.slf4j + slf4j-api + 1.7.7 + + + org.slf4j + slf4j-jdk14 + 1.7.7 + + + diff --git a/maven-archetypes/starter/src/main/resources/archetype-resources/src/main/java/StarterPipeline.java b/maven-archetypes/starter/src/main/resources/archetype-resources/src/main/java/StarterPipeline.java new file mode 100644 index 000000000000..ffabbc01cff6 --- /dev/null +++ b/maven-archetypes/starter/src/main/resources/archetype-resources/src/main/java/StarterPipeline.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package ${package}; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A starter example for writing Google Cloud Dataflow programs. + * + *

The example takes two strings, converts them to their upper-case + * representation and logs them. + * + *

To run this starter example locally using DirectPipelineRunner, just + * execute it without any additional parameters from your favorite development + * environment. + * + *

To run this starter example using managed resource in Google Cloud + * Platform, you should specify the following command-line options: + * --project= + * --stagingLocation= + * --runner=BlockingDataflowPipelineRunner + */ +public class StarterPipeline { + private static final Logger LOG = LoggerFactory.getLogger(StarterPipeline.class); + + public static void main(String[] args) { + Pipeline p = Pipeline.create( + PipelineOptionsFactory.fromArgs(args).withValidation().create()); + + p.apply(Create.of("Hello", "World")) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().toUpperCase()); + } + })) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + LOG.info(c.element()); + } + })); + + p.run(); + } +} diff --git a/maven-archetypes/starter/src/test/resources/projects/basic/archetype.properties b/maven-archetypes/starter/src/test/resources/projects/basic/archetype.properties new file mode 100644 index 000000000000..c59e77a9d55b --- /dev/null +++ b/maven-archetypes/starter/src/test/resources/projects/basic/archetype.properties @@ -0,0 +1,5 @@ +package=it.pkg +version=0.1-SNAPSHOT +groupId=archetype.it +artifactId=basic +targetPlatform=1.7 diff --git a/maven-archetypes/starter/src/test/resources/projects/basic/goal.txt b/maven-archetypes/starter/src/test/resources/projects/basic/goal.txt new file mode 100644 index 000000000000..0b5987362fe3 --- /dev/null +++ b/maven-archetypes/starter/src/test/resources/projects/basic/goal.txt @@ -0,0 +1 @@ +verify diff --git a/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml b/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml new file mode 100644 index 000000000000..d8c563d07fe4 --- /dev/null +++ b/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml @@ -0,0 +1,43 @@ + + 4.0.0 + + archetype.it + basic + 0.1-SNAPSHOT + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.3 + + 1.7 + 1.7 + + + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + [1.0.0, 2.0.0) + + + + + org.slf4j + slf4j-api + 1.7.7 + + + org.slf4j + slf4j-jdk14 + 1.7.7 + + + diff --git a/maven-archetypes/starter/src/test/resources/projects/basic/reference/src/main/java/it/pkg/StarterPipeline.java b/maven-archetypes/starter/src/test/resources/projects/basic/reference/src/main/java/it/pkg/StarterPipeline.java new file mode 100644 index 000000000000..2e7c4e1fc985 --- /dev/null +++ b/maven-archetypes/starter/src/test/resources/projects/basic/reference/src/main/java/it/pkg/StarterPipeline.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package it.pkg; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A starter example for writing Google Cloud Dataflow programs. + * + *

The example takes two strings, converts them to their upper-case + * representation and logs them. + * + *

To run this starter example locally using DirectPipelineRunner, just + * execute it without any additional parameters from your favorite development + * environment. + * + *

To run this starter example using managed resource in Google Cloud + * Platform, you should specify the following command-line options: + * --project= + * --stagingLocation= + * --runner=BlockingDataflowPipelineRunner + */ +public class StarterPipeline { + private static final Logger LOG = LoggerFactory.getLogger(StarterPipeline.class); + + public static void main(String[] args) { + Pipeline p = Pipeline.create( + PipelineOptionsFactory.fromArgs(args).withValidation().create()); + + p.apply(Create.of("Hello", "World")) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().toUpperCase()); + } + })) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + LOG.info(c.element()); + } + })); + + p.run(); + } +} diff --git a/pom.xml b/pom.xml new file mode 100644 index 000000000000..ba130d25a3d2 --- /dev/null +++ b/pom.xml @@ -0,0 +1,331 @@ + + + + 4.0.0 + + + com.google + google + 5 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + Google Cloud Dataflow Java SDK - Parent + Google Cloud Dataflow Java SDK provides a simple, Java-based + interface for processing virtually any size data using Google cloud + resources. This artifact includes the parent POM for other Dataflow + artifacts. + http://cloud.google.com/dataflow + 2013 + + 1.5.0-SNAPSHOT + + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + + Google Inc. + http://www.google.com + + + + + scm:git:git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + scm:git:git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + git@github.com:GoogleCloudPlatform/DataflowJavaSDK.git + + + + 3.0.3 + + + + UTF-8 + + + + 1.7.7 + v2-rev248-1.21.0 + 0.2.3 + v1b3-rev19-1.21.0 + 0.5.160222 + v1beta2-rev1-4.0.0 + 1.21.0 + 19.0 + 1.3 + 2.7.0 + 2.4 + 3.0.1 + 4.11 + 3.0.0-beta-1 + v1-rev7-1.21.0 + 1.7.14 + 3.1.4 + v1-rev53-1.21.0 + 4.4.1 + + + pom + + sdk + examples + maven-archetypes/starter + maven-archetypes/examples + + + + + doclint-java8-disable + + [1.8,) + + + -Xdoclint:-missing + + + + + + + + + maven-compiler-plugin + 3.1 + + 1.7 + 1.7 + + -Xlint:all + -Werror + + -Xlint:-options + + -Xlint:-cast + -Xlint:-deprecation + -Xlint:-processing + -Xlint:-rawtypes + -Xlint:-serial + -Xlint:-try + -Xlint:-unchecked + -Xlint:-varargs + + + + + true + + false + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.5 + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.3 + + + + org.codehaus.mojo + versions-maven-plugin + 2.1 + + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + verify + + java + + + + + + + java.util.logging.config.file + logging.properties + + + + + + + + org.jacoco + jacoco-maven-plugin + 0.7.5.201505241946 + + + + prepare-agent + + + file + true + + + + report + prepare-package + + report + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.18.1 + + ${testParallelValue} + 4 + + ${project.build.directory}/${project.artifactId}-${project.version}.jar + ${project.build.directory}/${project.artifactId}-${project.version}-tests.jar + + ${testGroups} + + ${runIntegrationTestOnService} + ${dataflowProjectName} + + false + false + true + + + + org.apache.maven.surefire + surefire-junit47 + 2.18.1 + + + + + + + org.eclipse.m2e + lifecycle-mapping + 1.0.0 + + + + + + org.apache.avro + avro-maven-plugin + ${avro.version} + + schema + + + + + false + + + + + + org.apache.maven.plugins + maven-jar-plugin + [2.5,) + + jar + test-jar + + + + + + + + + org.jacoco + jacoco-maven-plugin + [0.7.5,) + + report + prepare-agent + + + + + + + + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.10 + + true + + + + + org.codehaus.mojo + build-helper-maven-plugin + 1.10 + + + + + + + + + org.codehaus.mojo + versions-maven-plugin + 2.1 + + + + dependency-updates-report + plugin-updates-report + + + + + + + diff --git a/sdk/pom.xml b/sdk/pom.xml new file mode 100644 index 000000000000..4995da06d3c0 --- /dev/null +++ b/sdk/pom.xml @@ -0,0 +1,755 @@ + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + 1.5.0-SNAPSHOT + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + Google Cloud Dataflow Java SDK - All + Google Cloud Dataflow Java SDK provides a simple, Java-based + interface for processing virtually any size data using Google cloud + resources. This artifact includes entire Dataflow Java SDK. + http://cloud.google.com/dataflow + + jar + + + ${maven.build.timestamp} + yyyy-MM-dd HH:mm + com.google.cloud.dataflow + false + none + + + + + + + DataflowPipelineTests + + true + com.google.cloud.dataflow.sdk.testing.RunnableOnService + both + + + + + java8tests + + [1.8,) + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-java8-test-source + initialize + + add-test-source + + + + ${project.basedir}/src/test/java8 + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + default-testCompile + test-compile + + testCompile + + + 1.7 + 1.7 + + + **/*Java8Test.java + + + + + + java8-testCompile + test-compile + + testCompile + + + 1.8 + 1.8 + + **/*Java8Test.java + + + + + + + + + + + + + + src/main/resources + true + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.apache.maven.plugins + maven-dependency-plugin + + + analyze-only + + true + + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.12 + + + com.puppycrawl.tools + checkstyle + 6.6 + + + + ../checkstyle.xml + true + true + false + true + ${project.build.directory}/generated-test-sources/** + + + + + check + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + default-jar + + jar + + + + default-test-jar + + test-jar + + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.4 + + + attach-sources + compile + + jar + + + + attach-test-sources + test-compile + + test-jar + + + + + + + org.apache.maven.plugins maven-javadoc-plugin + + Google Cloud Dataflow SDK ${project.version} API + Google Cloud Dataflow SDK for Java, version ${project.version} + ../javadoc/overview.html + + com.google.cloud.dataflow.sdk + -exclude com.google.cloud.dataflow.sdk.runners.worker:com.google.cloud.dataflow.sdk.runners.dataflow:com.google.cloud.dataflow.sdk.util:com.google.cloud.dataflow.sdk.runners.inprocess ${dataflow.javadoc_opts} + false + true + ]]> + + + + https://developers.google.com/api-client-library/java/google-api-java-client/reference/1.20.0/ + ${basedir}/../javadoc/apiclient-docs + + + http://avro.apache.org/docs/1.7.7/api/java/ + ${basedir}/../javadoc/avro-docs + + + https://developers.google.com/resources/api-libraries/documentation/bigquery/v2/java/latest/ + ${basedir}/../javadoc/bq-docs + + + https://cloud.google.com/datastore/docs/apis/javadoc/ + ${basedir}/../javadoc/datastore-docs + + + http://docs.guava-libraries.googlecode.com/git-history/release19/javadoc/ + ${basedir}/../javadoc/guava-docs + + + http://hamcrest.org/JavaHamcrest/javadoc/1.3/ + ${basedir}/../javadoc/hamcrest-docs + + + http://fasterxml.github.io/jackson-annotations/javadoc/2.7/ + ${basedir}/../javadoc/jackson-annotations-docs + + + http://fasterxml.github.io/jackson-databind/javadoc/2.7/ + ${basedir}/../javadoc/jackson-databind-docs + + + http://www.joda.org/joda-time/apidocs + ${basedir}/../javadoc/joda-docs + + + http://junit.sourceforge.net/javadoc/ + ${basedir}/../javadoc/junit-docs + + + https://developers.google.com/api-client-library/java/google-oauth-java-client/reference/1.20.0/ + ${basedir}/../javadoc/oauth-docs + + + + + + + jar + + package + + + + + + org.apache.maven.plugins + maven-shade-plugin + 2.4.1 + + + + bundle-and-repackage + package + + shade + + + true + + + com.google.cloud.bigtable:bigtable-client-core + com.google.guava:guava + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + com.google.common + com.google.cloud.dataflow.sdk.repackaged.com.google.common + + + com.google.thirdparty + com.google.cloud.dataflow.sdk.repackaged.com.google.thirdparty + + + com.google.cloud.bigtable + com.google.cloud.dataflow.sdk.repackaged.com.google.cloud.bigtable + + com.google.cloud.bigtable.config.BigtableOptions* + com.google.cloud.bigtable.config.CredentialOptions* + com.google.cloud.bigtable.config.RetryOptions* + com.google.cloud.bigtable.grpc.BigtableClusterName + com.google.cloud.bigtable.grpc.BigtableTableName + + + + + + + + + bundle-rest-without-repackaging + package + + shade + + + true + ${project.artifactId}-bundled-${project.version} + + + com.google.cloud.bigtable:bigtable-client-core + com.google.guava:guava + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + + org.jacoco + jacoco-maven-plugin + + + + + org.apache.avro + avro-maven-plugin + ${avro.version} + + + schemas + generate-test-sources + + schema + + + ${project.basedir}/src/test/ + ${project.build.directory}/generated-test-sources/java + + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 1.9.1 + + + add-test-source + generate-test-sources + + add-test-source + + + + ${project.build.directory}/generated-test-sources/java + + + + + + + + + + + com.google.apis + google-api-services-dataflow + ${dataflow.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.cloud.dataflow + google-cloud-dataflow-java-proto-library-all + 0.5.160222 + + + + io.grpc + grpc-all + 0.12.0 + + + + com.google.cloud.bigtable + bigtable-protos + ${bigtable.version} + + + + com.google.cloud.bigtable + bigtable-client-core + ${bigtable.version} + + + + com.google.api-client + google-api-client + ${google-clients.version} + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-bigquery + ${bigquery.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-pubsub + ${pubsub.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-storage + ${storage.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.http-client + google-http-client + ${google-clients.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.http-client + google-http-client-jackson2 + ${google-clients.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.oauth-client + google-oauth-client-java6 + ${google-clients.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.oauth-client + google-oauth-client + ${google-clients.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.apis + google-api-services-datastore-protobuf + ${datastore.version} + + + + com.google.guava + guava-jdk5 + + + + + + com.google.cloud.bigdataoss + gcsio + 1.4.3 + + + + com.google.cloud.bigdataoss + util + 1.4.3 + + + + com.google.guava + guava + + ${guava.version} + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + com.google.code.findbugs + jsr305 + ${jsr305.version} + + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + org.apache.avro + avro + ${avro.version} + + + + org.apache.commons + commons-compress + 1.9 + + + + joda-time + joda-time + ${joda.version} + + + + + org.codehaus.woodstox + stax2-api + ${stax2.version} + true + + + + org.codehaus.woodstox + woodstox-core-asl + ${woodstox.version} + runtime + true + + + + javax.xml.stream + stax-api + + + + + + + org.tukaani + xz + 1.5 + runtime + true + + + + + com.google.auto.service + auto-service + 1.0-rc2 + true + + + + + org.hamcrest + hamcrest-all + ${hamcrest.version} + provided + + + + junit + junit + ${junit.version} + provided + + + + org.slf4j + slf4j-jdk14 + ${slf4j.version} + test + + + + org.mockito + mockito-all + 1.9.5 + test + + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/Pipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/Pipeline.java new file mode 100644 index 000000000000..b166673e6e9c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/Pipeline.java @@ -0,0 +1,502 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.TransformHierarchy; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A {@link Pipeline} manages a directed acyclic graph of {@link PTransform PTransforms}, and the + * {@link PCollection PCollections} that the {@link PTransform}s consume and produce. + * + *

A {@link Pipeline} is initialized with a {@link PipelineRunner} that will later + * execute the {@link Pipeline}. + * + *

{@link Pipeline Pipelines} are independent, so they can be constructed and executed + * concurrently. + * + *

Each {@link Pipeline} is self-contained and isolated from any other + * {@link Pipeline}. The {@link PValue PValues} that are inputs and outputs of each of a + * {@link Pipeline Pipeline's} {@link PTransform PTransforms} are also owned by that + * {@link Pipeline}. A {@link PValue} owned by one {@link Pipeline} can be read only by + * {@link PTransform PTransforms} also owned by that {@link Pipeline}. + * + *

Here is a typical example of use: + *

 {@code
+ * // Start by defining the options for the pipeline.
+ * PipelineOptions options = PipelineOptionsFactory.create();
+ * // Then create the pipeline. The runner is determined by the options.
+ * Pipeline p = Pipeline.create(options);
+ *
+ * // A root PTransform, like TextIO.Read or Create, gets added
+ * // to the Pipeline by being applied:
+ * PCollection lines =
+ *     p.apply(TextIO.Read.from("gs://bucket/dir/file*.txt"));
+ *
+ * // A Pipeline can have multiple root transforms:
+ * PCollection moreLines =
+ *     p.apply(TextIO.Read.from("gs://bucket/other/dir/file*.txt"));
+ * PCollection yetMoreLines =
+ *     p.apply(Create.of("yet", "more", "lines").withCoder(StringUtf8Coder.of()));
+ *
+ * // Further PTransforms can be applied, in an arbitrary (acyclic) graph.
+ * // Subsequent PTransforms (and intermediate PCollections etc.) are
+ * // implicitly part of the same Pipeline.
+ * PCollection allLines =
+ *     PCollectionList.of(lines).and(moreLines).and(yetMoreLines)
+ *     .apply(new Flatten());
+ * PCollection> wordCounts =
+ *     allLines
+ *     .apply(ParDo.of(new ExtractWords()))
+ *     .apply(new Count());
+ * PCollection formattedWordCounts =
+ *     wordCounts.apply(ParDo.of(new FormatCounts()));
+ * formattedWordCounts.apply(TextIO.Write.to("gs://bucket/dir/counts.txt"));
+ *
+ * // PTransforms aren't executed when they're applied, rather they're
+ * // just added to the Pipeline.  Once the whole Pipeline of PTransforms
+ * // is constructed, the Pipeline's PTransforms can be run using a
+ * // PipelineRunner.  The default PipelineRunner executes the Pipeline
+ * // directly, sequentially, in this one process, which is useful for
+ * // unit tests and simple experiments:
+ * p.run();
+ *
+ * } 
+ */ +public class Pipeline { + private static final Logger LOG = LoggerFactory.getLogger(Pipeline.class); + + /** + * Thrown during execution of a {@link Pipeline}, whenever user code within that + * {@link Pipeline} throws an exception. + * + *

The original exception thrown by user code may be retrieved via {@link #getCause}. + */ + public static class PipelineExecutionException extends RuntimeException { + /** + * Wraps {@code cause} into a {@link PipelineExecutionException}. + */ + public PipelineExecutionException(Throwable cause) { + super(cause); + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Public operations. + + /** + * Constructs a pipeline from the provided options. + * + * @return The newly created pipeline. + */ + public static Pipeline create(PipelineOptions options) { + Pipeline pipeline = new Pipeline(PipelineRunner.fromOptions(options), options); + LOG.debug("Creating {}", pipeline); + return pipeline; + } + + /** + * Returns a {@link PBegin} owned by this Pipeline. This is useful + * as the input of a root PTransform such as {@link Read} or + * {@link Create}. + */ + public PBegin begin() { + return PBegin.in(this); + } + + /** + * Like {@link #apply(String, PTransform)} but the transform node in the {@link Pipeline} + * graph will be named according to {@link PTransform#getName}. + * + * @see #apply(String, PTransform) + */ + public OutputT apply( + PTransform root) { + return begin().apply(root); + } + + /** + * Adds a root {@link PTransform}, such as {@link Read} or {@link Create}, + * to this {@link Pipeline}. + * + *

The node in the {@link Pipeline} graph will use the provided {@code name}. + * This name is used in various places, including the monitoring UI, logging, + * and to stably identify this node in the {@link Pipeline} graph upon update. + * + *

Alias for {@code begin().apply(name, root)}. + */ + public OutputT apply( + String name, PTransform root) { + return begin().apply(name, root); + } + + /** + * Runs the {@link Pipeline} using its {@link PipelineRunner}. + */ + public PipelineResult run() { + LOG.debug("Running {} via {}", this, runner); + try { + return runner.run(this); + } catch (UserCodeException e) { + // This serves to replace the stack with one that ends here and + // is caused by the caught UserCodeException, thereby splicing + // out all the stack frames in between the PipelineRunner itself + // and where the worker calls into the user's code. + throw new PipelineExecutionException(e.getCause()); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + // Below here are operations that aren't normally called by users. + + /** + * Returns the {@link CoderRegistry} that this {@link Pipeline} uses. + */ + public CoderRegistry getCoderRegistry() { + if (coderRegistry == null) { + coderRegistry = new CoderRegistry(); + coderRegistry.registerStandardCoders(); + } + return coderRegistry; + } + + /** + * Sets the {@link CoderRegistry} that this {@link Pipeline} uses. + */ + public void setCoderRegistry(CoderRegistry coderRegistry) { + this.coderRegistry = coderRegistry; + } + + /** + * A {@link PipelineVisitor} can be passed into + * {@link Pipeline#traverseTopologically} to be called for each of the + * transforms and values in the {@link Pipeline}. + */ + public interface PipelineVisitor { + /** + * Called for each composite transform after all topological predecessors have been visited + * but before any of its component transforms. + */ + public void enterCompositeTransform(TransformTreeNode node); + + /** + * Called for each composite transform after all of its component transforms and their outputs + * have been visited. + */ + public void leaveCompositeTransform(TransformTreeNode node); + + /** + * Called for each primitive transform after all of its topological predecessors + * and inputs have been visited. + */ + public void visitTransform(TransformTreeNode node); + + /** + * Called for each value after the transform that produced the value has been + * visited. + */ + public void visitValue(PValue value, TransformTreeNode producer); + } + + /** + * Invokes the {@link PipelineVisitor PipelineVisitor's} + * {@link PipelineVisitor#visitTransform} and + * {@link PipelineVisitor#visitValue} operations on each of this + * {@link Pipeline Pipeline's} transform and value nodes, in forward + * topological order. + * + *

Traversal of the {@link Pipeline} causes {@link PTransform PTransforms} and + * {@link PValue PValues} owned by the {@link Pipeline} to be marked as finished, + * at which point they may no longer be modified. + * + *

Typically invoked by {@link PipelineRunner} subclasses. + */ + public void traverseTopologically(PipelineVisitor visitor) { + Set visitedValues = new HashSet<>(); + // Visit all the transforms, which should implicitly visit all the values. + transforms.visit(visitor, visitedValues); + if (!visitedValues.containsAll(values)) { + throw new RuntimeException( + "internal error: should have visited all the values " + + "after visiting all the transforms"); + } + } + + /** + * Like {@link #applyTransform(String, PInput, PTransform)} but defaulting to the name + * provided by the {@link PTransform}. + */ + public static + OutputT applyTransform(InputT input, + PTransform transform) { + return input.getPipeline().applyInternal(transform.getName(), input, transform); + } + + /** + * Applies the given {@code PTransform} to this input {@code InputT} and returns + * its {@code OutputT}. This uses {@code name} to identify this specific application + * of the transform. This name is used in various places, including the monitoring UI, + * logging, and to stably identify this application node in the {@link Pipeline} graph during + * update. + * + *

Each {@link PInput} subclass that provides an {@code apply} method should delegate to + * this method to ensure proper registration with the {@link PipelineRunner}. + */ + public static + OutputT applyTransform(String name, InputT input, + PTransform transform) { + return input.getPipeline().applyInternal(name, input, transform); + } + + ///////////////////////////////////////////////////////////////////////////// + // Below here are internal operations, never called by users. + + private final PipelineRunner runner; + private final PipelineOptions options; + private final TransformHierarchy transforms = new TransformHierarchy(); + private Collection values = new ArrayList<>(); + private Set usedFullNames = new HashSet<>(); + private CoderRegistry coderRegistry; + private Multimap, AppliedPTransform> transformApplicationsForTesting = + HashMultimap.create(); + + /** + * @deprecated replaced by {@link #Pipeline(PipelineRunner, PipelineOptions)} + */ + @Deprecated + protected Pipeline(PipelineRunner runner) { + this(runner, PipelineOptionsFactory.create()); + } + + protected Pipeline(PipelineRunner runner, PipelineOptions options) { + this.runner = runner; + this.options = options; + } + + @Override + public String toString() { + return "Pipeline#" + hashCode(); + } + + /** + * Applies a {@link PTransform} to the given {@link PInput}. + * + * @see Pipeline#apply + */ + private + OutputT applyInternal(String name, InputT input, + PTransform transform) { + input.finishSpecifying(); + + TransformTreeNode parent = transforms.getCurrent(); + String namePrefix = parent.getFullName(); + String fullName = uniquifyInternal(namePrefix, name); + + boolean nameIsUnique = fullName.equals(buildName(namePrefix, name)); + + if (!nameIsUnique) { + switch (getOptions().getStableUniqueNames()) { + case OFF: + break; + case WARNING: + LOG.warn("Transform {} does not have a stable unique name. " + + "This will prevent updating of pipelines.", fullName); + break; + case ERROR: + throw new IllegalStateException( + "Transform " + fullName + " does not have a stable unique name. " + + "This will prevent updating of pipelines."); + default: + throw new IllegalArgumentException( + "Unrecognized value for stable unique names: " + getOptions().getStableUniqueNames()); + } + } + + TransformTreeNode child = + new TransformTreeNode(parent, transform, fullName, input); + parent.addComposite(child); + + transforms.addInput(child, input); + + LOG.debug("Adding {} to {}", transform, this); + try { + transforms.pushNode(child); + transform.validate(input); + OutputT output = runner.apply(transform, input); + transforms.setOutput(child, output); + + AppliedPTransform applied = AppliedPTransform.of( + child.getFullName(), input, output, transform); + transformApplicationsForTesting.put(transform, applied); + // recordAsOutput is a NOOP if already called; + output.recordAsOutput(applied); + verifyOutputState(output, child); + return output; + } finally { + transforms.popNode(); + } + } + + /** + * Returns all producing transforms for the {@link PValue PValues} contained + * in {@code output}. + */ + private List> getProducingTransforms(POutput output) { + List> producingTransforms = new ArrayList<>(); + for (PValue value : output.expand()) { + AppliedPTransform transform = value.getProducingTransformInternal(); + if (transform != null) { + producingTransforms.add(transform); + } + } + return producingTransforms; + } + + /** + * Verifies that the output of a {@link PTransform} is correctly configured in its + * {@link TransformTreeNode} in the {@link Pipeline} graph. + * + *

A non-composite {@link PTransform} must have all + * of its outputs registered as produced by that {@link PTransform}. + * + *

A composite {@link PTransform} must have all of its outputs + * registered as produced by the contained primitive {@link PTransform PTransforms}. + * They have each had the above check performed already, when + * they were applied, so the only possible failure state is + * that the composite {@link PTransform} has returned a primitive output. + */ + private void verifyOutputState(POutput output, TransformTreeNode node) { + if (!node.isCompositeNode()) { + PTransform thisTransform = node.getTransform(); + List> producingTransforms = getProducingTransforms(output); + for (AppliedPTransform producingTransform : producingTransforms) { + // Using != because object identity indicates that the transforms + // are the same node in the pipeline + if (thisTransform != producingTransform.getTransform()) { + throw new IllegalArgumentException("Output of non-composite transform " + + thisTransform + " is registered as being produced by" + + " a different transform: " + producingTransform); + } + } + } else { + PTransform thisTransform = node.getTransform(); + List> producingTransforms = getProducingTransforms(output); + for (AppliedPTransform producingTransform : producingTransforms) { + // Using == because object identity indicates that the transforms + // are the same node in the pipeline + if (thisTransform == producingTransform.getTransform()) { + throw new IllegalStateException("Output of composite transform " + + thisTransform + " is registered as being produced by it," + + " but the output of every composite transform should be" + + " produced by a primitive transform contained therein."); + } + } + } + } + + /** + * Returns the configured {@link PipelineRunner}. + */ + public PipelineRunner getRunner() { + return runner; + } + + /** + * Returns the configured {@link PipelineOptions}. + */ + public PipelineOptions getOptions() { + return options; + } + + /** + * @deprecated this method is no longer compatible with the design of {@link Pipeline}, + * as {@link PTransform PTransforms} can be applied multiple times, with different names + * each time. + */ + @Deprecated + public String getFullNameForTesting(PTransform transform) { + Collection> uses = + transformApplicationsForTesting.get(transform); + Preconditions.checkState(uses.size() > 0, "Unknown transform: " + transform); + Preconditions.checkState(uses.size() <= 1, "Transform used multiple times: " + transform); + return Iterables.getOnlyElement(uses).getFullName(); + } + + /** + * Returns a unique name for a transform with the given prefix (from + * enclosing transforms) and initial name. + * + *

For internal use only. + */ + private String uniquifyInternal(String namePrefix, String origName) { + String name = origName; + int suffixNum = 2; + while (true) { + String candidate = buildName(namePrefix, name); + if (usedFullNames.add(candidate)) { + return candidate; + } + // A duplicate! Retry. + name = origName + suffixNum++; + } + } + + /** + * Builds a name from a "/"-delimited prefix and a name. + */ + private String buildName(String namePrefix, String name) { + return namePrefix.isEmpty() ? name : namePrefix + "/" + name; + } + + /** + * Adds the given {@link PValue} to this {@link Pipeline}. + * + *

For internal use only. + */ + public void addValueInternal(PValue value) { + this.values.add(value); + LOG.debug("Adding {} to {}", value, this); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/PipelineResult.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/PipelineResult.java new file mode 100644 index 000000000000..6b9a36b728e6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/PipelineResult.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; + +/** + * Result of {@link Pipeline#run()}. + */ +public interface PipelineResult { + + /** + * Retrieves the current state of the pipeline execution. + * + * @return the {@link State} representing the state of this pipeline. + */ + State getState(); + + /** + * Retrieves the current value of the provided {@link Aggregator}. + * + * @param aggregator the {@link Aggregator} to retrieve values for. + * @return the current values of the {@link Aggregator}, + * which may be empty if there are no values yet. + * @throws AggregatorRetrievalException if the {@link Aggregator} values could not be retrieved. + */ + AggregatorValues getAggregatorValues(Aggregator aggregator) + throws AggregatorRetrievalException; + + // TODO: method to retrieve error messages. + + /** Named constants for common values for the job state. */ + public enum State { + + /** The job state could not be obtained or was not specified. */ + UNKNOWN(false, false), + + /** The job has been paused, or has not yet started. */ + STOPPED(false, false), + + /** The job is currently running. */ + RUNNING(false, false), + + /** The job has successfully completed. */ + DONE(true, false), + + /** The job has failed. */ + FAILED(true, false), + + /** The job has been explicitly cancelled. */ + CANCELLED(true, false), + + /** The job has been updated. */ + UPDATED(true, true); + + private final boolean terminal; + + private final boolean hasReplacement; + + private State(boolean terminal, boolean hasReplacement) { + this.terminal = terminal; + this.hasReplacement = hasReplacement; + } + + /** + * @return {@code true} if the job state can no longer complete work. + */ + public final boolean isTerminal() { + return terminal; + } + + /** + * @return {@code true} if this job state indicates that a replacement job exists. + */ + public final boolean hasReplacementJob() { + return hasReplacement; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/annotations/Experimental.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/annotations/Experimental.java new file mode 100644 index 000000000000..cac2aa8435db --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/annotations/Experimental.java @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Signifies that a public API (public class, method or field) is subject to + * incompatible changes, or even removal, in a future release. An API bearing + * this annotation is exempt from any compatibility guarantees made by its + * containing library. Note that the presence of this annotation implies nothing + * about the quality or performance of the API in question, only the fact that + * it is not "API-frozen." + * + *

It is generally safe for applications to depend on experimental + * APIs, at the cost of some extra work during upgrades. However, it is + * generally inadvisable for libraries (which get included on users' + * class paths, outside the library developers' control) to do so. + */ +@Retention(RetentionPolicy.CLASS) +@Target({ + ElementType.ANNOTATION_TYPE, + ElementType.CONSTRUCTOR, + ElementType.FIELD, + ElementType.METHOD, + ElementType.TYPE}) +@Documented +public @interface Experimental { + public Kind value() default Kind.UNSPECIFIED; + + /** + * An enumeration of various kinds of experimental APIs. + */ + public enum Kind { + /** Generic group of experimental APIs. This is the default value. */ + UNSPECIFIED, + + /** Sources and sinks related experimental APIs. */ + SOURCE_SINK, + + /** Auto-scaling related experimental APIs. */ + AUTOSCALING, + + /** Trigger-related experimental APIs. */ + TRIGGER, + + /** Aggregator-related experimental APIs. */ + AGGREGATOR, + + /** Experimental APIs for Coder binary format identifiers. */ + CODER_ENCODING_ID, + + /** State-related experimental APIs. */ + STATE, + + /** Timer-related experimental APIs. */ + TIMERS, + + /** Experimental APIs related to customizing the output time for computed values. */ + OUTPUT_TIME + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/annotations/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/annotations/package-info.java new file mode 100644 index 000000000000..6c224a6a8e8d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/annotations/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines annotations used across the SDK. + */ +package com.google.cloud.dataflow.sdk.annotations; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AtomicCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AtomicCoder.java new file mode 100644 index 000000000000..c4951b40041d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AtomicCoder.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import java.util.Collections; +import java.util.List; + +/** + * A {@link Coder} that has no component {@link Coder Coders} or other state. + * + *

Note that, unless the behavior is overridden, atomic coders are presumed to be deterministic + * and all instances are considered equal. + * + * @param the type of the values being transcoded + */ +public abstract class AtomicCoder extends DeterministicStandardCoder { + protected AtomicCoder() { } + + @Override + public List> getCoderArguments() { + return null; + } + + /** + * Returns a list of values contained in the provided example + * value, one per type parameter. If there are no type parameters, + * returns an empty list. + * + *

Because {@link AtomicCoder} has no components, always returns an empty list. + * + * @param exampleValue unused, but part of the latent interface expected by + * {@link CoderFactories#fromStaticMethods} + */ + public static List getInstanceComponents(T exampleValue) { + return Collections.emptyList(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AvroCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AvroCoder.java new file mode 100644 index 000000000000..91efb43f35ac --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/AvroCoder.java @@ -0,0 +1,714 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.IndexedRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.reflect.AvroEncode; +import org.apache.avro.reflect.AvroName; +import org.apache.avro.reflect.AvroSchema; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.avro.reflect.ReflectDatumWriter; +import org.apache.avro.reflect.Union; +import org.apache.avro.specific.SpecificData; +import org.apache.avro.util.ClassUtils; +import org.apache.avro.util.Utf8; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; + +import javax.annotation.Nullable; + +/** + * A {@link Coder} using Avro binary format. + * + *

Each instance of {@code AvroCoder} encapsulates an Avro schema for objects of type + * {@code T}. + * + *

The Avro schema may be provided explicitly via {@link AvroCoder#of(Class, Schema)} or + * omitted via {@link AvroCoder#of(Class)}, in which case it will be inferred + * using Avro's {@link org.apache.avro.reflect.ReflectData}. + * + *

For complete details about schema generation and how it can be controlled please see + * the {@link org.apache.avro.reflect} package. + * Only concrete classes with a no-argument constructor can be mapped to Avro records. + * All inherited fields that are not static or transient are included. Fields are not permitted to + * be null unless annotated by {@link Nullable} or a {@link Union} schema + * containing {@code "null"}. + * + *

To use, specify the {@code Coder} type on a PCollection: + *

+ * {@code
+ * PCollection records =
+ *     input.apply(...)
+ *          .setCoder(AvroCoder.of(MyCustomElement.class);
+ * }
+ * 
+ * + *

or annotate the element class using {@code @DefaultCoder}. + *


+ * {@literal @}DefaultCoder(AvroCoder.class)
+ * public class MyCustomElement {
+ *   ...
+ * }
+ * 
+ * + *

The implementation attempts to determine if the Avro encoding of the given type will satisfy + * the criteria of {@link Coder#verifyDeterministic} by inspecting both the type and the + * Schema provided or generated by Avro. Only coders that are deterministic can be used in + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} operations. + * + * @param the type of elements handled by this coder + */ +public class AvroCoder extends StandardCoder { + + /** + * Returns an {@code AvroCoder} instance for the provided element type. + * @param the element type + */ + public static AvroCoder of(TypeDescriptor type) { + @SuppressWarnings("unchecked") + Class clazz = (Class) type.getRawType(); + return of(clazz); + } + + /** + * Returns an {@code AvroCoder} instance for the provided element class. + * @param the element type + */ + public static AvroCoder of(Class clazz) { + return new AvroCoder<>(clazz, ReflectData.get().getSchema(clazz)); + } + + /** + * Returns an {@code AvroCoder} instance for the Avro schema. The implicit + * type is GenericRecord. + */ + public static AvroCoder of(Schema schema) { + return new AvroCoder<>(GenericRecord.class, schema); + } + + /** + * Returns an {@code AvroCoder} instance for the provided element type + * using the provided Avro schema. + * + *

If the type argument is GenericRecord, the schema may be arbitrary. + * Otherwise, the schema must correspond to the type provided. + * + * @param the element type + */ + public static AvroCoder of(Class type, Schema schema) { + return new AvroCoder<>(type, schema); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @JsonCreator + public static AvroCoder of( + @JsonProperty("type") String classType, + @JsonProperty("schema") String schema) throws ClassNotFoundException { + Schema.Parser parser = new Schema.Parser(); + return new AvroCoder(Class.forName(classType), parser.parse(schema)); + } + + public static final CoderProvider PROVIDER = new CoderProvider() { + @Override + public Coder getCoder(TypeDescriptor typeDescriptor) { + // This is a downcast from `? super T` to T. However, because + // it comes from a TypeDescriptor, the class object itself + // is the same so the supertype in question shares the same + // generated AvroCoder schema. + @SuppressWarnings("unchecked") + Class rawType = (Class) typeDescriptor.getRawType(); + return AvroCoder.of(rawType); + } + }; + + private final Class type; + private final Schema schema; + + private final List nonDeterministicReasons; + + // Factories allocated by .get() are thread-safe and immutable. + private static final EncoderFactory ENCODER_FACTORY = EncoderFactory.get(); + private static final DecoderFactory DECODER_FACTORY = DecoderFactory.get(); + // Cache the old encoder/decoder and let the factories reuse them when possible. To be threadsafe, + // these are ThreadLocal. This code does not need to be re-entrant as AvroCoder does not use + // an inner coder. + private final ThreadLocal decoder; + private final ThreadLocal encoder; + private final ThreadLocal> writer; + private final ThreadLocal> reader; + + protected AvroCoder(Class type, Schema schema) { + this.type = type; + this.schema = schema; + + nonDeterministicReasons = new AvroDeterminismChecker().check(TypeDescriptor.of(type), schema); + + // Decoder and Encoder start off null for each thread. They are allocated and potentially + // reused inside encode/decode. + this.decoder = new ThreadLocal<>(); + this.encoder = new ThreadLocal<>(); + + // Reader and writer are allocated once per thread and are "final" for thread-local Coder + // instance. + this.reader = new ThreadLocal>() { + @Override + public DatumReader initialValue() { + return createDatumReader(); + } + }; + this.writer = new ThreadLocal>() { + @Override + public DatumWriter initialValue() { + return createDatumWriter(); + } + }; + } + + /** + * The encoding identifier is designed to support evolution as per the design of Avro + * In order to use this class effectively, carefully read the Avro + * documentation at + * Schema Resolution + * to ensure that the old and new schema match. + * + *

In particular, this encoding identifier is guaranteed to be the same for {@code AvroCoder} + * instances of the same principal class, and otherwise distinct. The schema is not included + * in the identifier. + * + *

When modifying a class to be encoded as Avro, here are some guidelines; see the above link + * for greater detail. + * + *

    + *
  • Avoid changing field names. + *
  • Never remove a required field. + *
  • Only add optional fields, with sensible defaults. + *
  • When changing the type of a field, consult the Avro documentation to ensure the new and + * old types are interchangeable. + *
+ * + *

Code consuming this message class should be prepared to support all versions of + * the class until it is certain that no remaining serialized instances exist. + * + *

If backwards incompatible changes must be made, the best recourse is to change the name + * of your class. + */ + @Override + public String getEncodingId() { + return type.getName(); + } + + /** + * Returns the type this coder encodes/decodes. + */ + public Class getType() { + return type; + } + + private Object writeReplace() { + // When serialized by Java, instances of AvroCoder should be replaced by + // a SerializedAvroCoderProxy. + return new SerializedAvroCoderProxy<>(type, schema.toString()); + } + + @Override + public void encode(T value, OutputStream outStream, Context context) throws IOException { + // Get a BinaryEncoder instance from the ThreadLocal cache and attempt to reuse it. + BinaryEncoder encoderInstance = ENCODER_FACTORY.directBinaryEncoder(outStream, encoder.get()); + // Save the potentially-new instance for reuse later. + encoder.set(encoderInstance); + writer.get().write(value, encoderInstance); + // Direct binary encoder does not buffer any data and need not be flushed. + } + + @Override + public T decode(InputStream inStream, Context context) throws IOException { + // Get a BinaryDecoder instance from the ThreadLocal cache and attempt to reuse it. + BinaryDecoder decoderInstance = DECODER_FACTORY.directBinaryDecoder(inStream, decoder.get()); + // Save the potentially-new instance for later. + decoder.set(decoderInstance); + return reader.get().read(null, decoderInstance); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addString(result, "type", type.getName()); + addString(result, "schema", schema.toString()); + return result; + } + + /** + * @throws NonDeterministicException when the type may not be deterministically + * encoded using the given {@link Schema}, the {@code directBinaryEncoder}, and the + * {@link ReflectDatumWriter} or {@link GenericDatumWriter}. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + if (!nonDeterministicReasons.isEmpty()) { + throw new NonDeterministicException(this, nonDeterministicReasons); + } + } + + /** + * Returns a new {@link DatumReader} that can be used to read from an Avro file directly. Assumes + * the schema used to read is the same as the schema that was used when writing. + * + * @deprecated For {@code AvroCoder} internal use only. + */ + // TODO: once we can remove this deprecated function, inline in constructor. + @Deprecated + public DatumReader createDatumReader() { + if (type.equals(GenericRecord.class)) { + return new GenericDatumReader<>(schema); + } else { + return new ReflectDatumReader<>(schema); + } + } + + /** + * Returns a new {@link DatumWriter} that can be used to write to an Avro file directly. + * + * @deprecated For {@code AvroCoder} internal use only. + */ + // TODO: once we can remove this deprecated function, inline in constructor. + @Deprecated + public DatumWriter createDatumWriter() { + if (type.equals(GenericRecord.class)) { + return new GenericDatumWriter<>(schema); + } else { + return new ReflectDatumWriter<>(schema); + } + } + + /** + * Returns the schema used by this coder. + */ + public Schema getSchema() { + return schema; + } + + /** + * Proxy to use in place of serializing the {@link AvroCoder}. This allows the fields + * to remain final. + */ + private static class SerializedAvroCoderProxy implements Serializable { + private final Class type; + private final String schemaStr; + + public SerializedAvroCoderProxy(Class type, String schemaStr) { + this.type = type; + this.schemaStr = schemaStr; + } + + private Object readResolve() { + // When deserialized, instances of this object should be replaced by + // constructing an AvroCoder. + Schema.Parser parser = new Schema.Parser(); + return new AvroCoder(type, parser.parse(schemaStr)); + } + } + + /** + * Helper class encapsulating the various pieces of state maintained by the + * recursive walk used for checking if the encoding will be deterministic. + */ + private static class AvroDeterminismChecker { + + // Reasons that the original type are not deterministic. This accumulates + // the actual output. + private List reasons = new ArrayList<>(); + + // Types that are currently "open". Used to make sure we don't have any + // recursive types. Note that we assume that all occurrences of a given type + // are equal, rather than tracking pairs of type + schema. + private Set> activeTypes = new HashSet<>(); + + // Similarly to how we record active types, we record the schemas we visit + // to make sure we don't encounter recursive fields. + private Set activeSchemas = new HashSet<>(); + + /** + * Report an error in the current context. + */ + private void reportError(String context, String fmt, Object... args) { + String message = String.format(fmt, args); + reasons.add(context + ": " + message); + } + + /** + * Classes that are serialized by Avro as a String include + *

    + *
  • Subtypes of CharSequence (including String, Avro's mutable Utf8, etc.) + *
  • Several predefined classes (BigDecimal, BigInteger, URI, URL) + *
  • Classes annotated with @Stringable (uses their #toString() and a String constructor) + *
+ * + *

Rather than determine which of these cases are deterministic, we list some classes + * that definitely are, and treat any others as non-deterministic. + */ + private static final Set> DETERMINISTIC_STRINGABLE_CLASSES = new HashSet<>(); + static { + // CharSequences: + DETERMINISTIC_STRINGABLE_CLASSES.add(String.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(Utf8.class); + + // Explicitly Stringable: + DETERMINISTIC_STRINGABLE_CLASSES.add(java.math.BigDecimal.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(java.math.BigInteger.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(java.net.URI.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(java.net.URL.class); + + // Classes annotated with @Stringable: + } + + /** + * Return true if the given type token is a subtype of *any* of the listed parents. + */ + private static boolean isSubtypeOf(TypeDescriptor type, Class... parents) { + for (Class parent : parents) { + if (type.isSubtypeOf(TypeDescriptor.of(parent))) { + return true; + } + } + return false; + } + + protected AvroDeterminismChecker() {} + + // The entry point for the check. Should not be recursively called. + public List check(TypeDescriptor type, Schema schema) { + recurse(type.getRawType().getName(), type, schema); + return reasons; + } + + // This is the method that should be recursively called. It sets up the path + // and visited types correctly. + private void recurse(String context, TypeDescriptor type, Schema schema) { + if (type.getRawType().isAnnotationPresent(AvroSchema.class)) { + reportError(context, "Custom schemas are not supported -- remove @AvroSchema."); + return; + } + + if (!activeTypes.add(type)) { + reportError(context, "%s appears recursively", type); + return; + } + + // If the the record isn't a true class, but rather a GenericRecord, SpecificRecord, etc. + // with a specified schema, then we need to make the decision based on the generated + // implementations. + if (isSubtypeOf(type, IndexedRecord.class)) { + checkIndexedRecord(context, schema, null); + } else { + doCheck(context, type, schema); + } + + activeTypes.remove(type); + } + + private void doCheck(String context, TypeDescriptor type, Schema schema) { + switch (schema.getType()) { + case ARRAY: + checkArray(context, type, schema); + break; + case ENUM: + // Enums should be deterministic, since they depend only on the ordinal. + break; + case FIXED: + // Depending on the implementation of GenericFixed, we don't know how + // the given field will be encoded. So, we assume that it isn't + // deterministic. + reportError(context, "FIXED encodings are not guaranteed to be deterministic"); + break; + case MAP: + checkMap(context, type, schema); + break; + case RECORD: + checkRecord(type, schema); + break; + case UNION: + checkUnion(context, type, schema); + break; + case STRING: + checkString(context, type); + break; + case BOOLEAN: + case BYTES: + case DOUBLE: + case INT: + case FLOAT: + case LONG: + case NULL: + // For types that Avro encodes using one of the above primitives, we assume they are + // deterministic. + break; + default: + // In any other case (eg., new types added to Avro) we cautiously return + // false. + reportError(context, "Unknown schema type %s may be non-deterministic", schema.getType()); + break; + } + } + + private void checkString(String context, TypeDescriptor type) { + // For types that are encoded as strings, we need to make sure they're in an approved + // whitelist. For other types that are annotated @Stringable, Avro will just use the + // #toString() methods, which has no guarantees of determinism. + if (!DETERMINISTIC_STRINGABLE_CLASSES.contains(type.getRawType())) { + reportError(context, "%s may not have deterministic #toString()", type); + } + } + + private static final Schema AVRO_NULL_SCHEMA = Schema.create(Schema.Type.NULL); + + private void checkUnion(String context, TypeDescriptor type, Schema schema) { + final List unionTypes = schema.getTypes(); + + if (!type.getRawType().isAnnotationPresent(Union.class)) { + // First check for @Nullable field, which shows up as a union of field type and null. + if (unionTypes.size() == 2 && unionTypes.contains(AVRO_NULL_SCHEMA)) { + // Find the Schema that is not NULL and recursively check that it is deterministic. + Schema nullableFieldSchema = unionTypes.get(0).equals(AVRO_NULL_SCHEMA) + ? unionTypes.get(1) : unionTypes.get(0); + doCheck(context, type, nullableFieldSchema); + return; + } + + // Otherwise report a schema error. + reportError(context, "Expected type %s to have @Union annotation", type); + return; + } + + // Errors associated with this union will use the base class as their context. + String baseClassContext = type.getRawType().getName(); + + // For a union, we need to make sure that each possible instantiation is deterministic. + for (Schema concrete : unionTypes) { + @SuppressWarnings("unchecked") + TypeDescriptor unionType = TypeDescriptor.of(ReflectData.get().getClass(concrete)); + + recurse(baseClassContext, unionType, concrete); + } + } + + private void checkRecord(TypeDescriptor type, Schema schema) { + // For a record, we want to make sure that all the fields are deterministic. + Class clazz = type.getRawType(); + for (org.apache.avro.Schema.Field fieldSchema : schema.getFields()) { + Field field = getField(clazz, fieldSchema.name()); + String fieldContext = field.getDeclaringClass().getName() + "#" + field.getName(); + + if (field.isAnnotationPresent(AvroEncode.class)) { + reportError(fieldContext, + "Custom encoders may be non-deterministic -- remove @AvroEncode"); + continue; + } + + if (!IndexedRecord.class.isAssignableFrom(field.getType()) + && field.isAnnotationPresent(AvroSchema.class)) { + // TODO: We should be able to support custom schemas on POJO fields, but we shouldn't + // need to, so we just allow it in the case of IndexedRecords. + reportError(fieldContext, + "Custom schemas are only supported for subtypes of IndexedRecord."); + continue; + } + + TypeDescriptor fieldType = type.resolveType(field.getGenericType()); + recurse(fieldContext, fieldType, fieldSchema.schema()); + } + } + + private void checkIndexedRecord(String context, Schema schema, + @Nullable String specificClassStr) { + + if (!activeSchemas.add(schema)) { + reportError(context, "%s appears recursively", schema.getName()); + return; + } + + switch (schema.getType()) { + case ARRAY: + // Generic Records use GenericData.Array to implement arrays, which is + // essentially an ArrayList, and therefore ordering is deterministic. + // The array is thus deterministic if the elements are deterministic. + checkIndexedRecord(context, schema.getElementType(), null); + break; + case ENUM: + // Enums are deterministic because they encode as a single integer. + break; + case FIXED: + // In the case of GenericRecords, FIXED is deterministic because it + // encodes/decodes as a Byte[]. + break; + case MAP: + reportError(context, + "GenericRecord and SpecificRecords use a HashMap to represent MAPs," + + " so it is non-deterministic"); + break; + case RECORD: + for (org.apache.avro.Schema.Field field : schema.getFields()) { + checkIndexedRecord( + schema.getName() + "." + field.name(), + field.schema(), + field.getProp(SpecificData.CLASS_PROP)); + } + break; + case STRING: + // GenericDatumWriter#findStringClass will use a CharSequence or a String + // for each string, so it is deterministic. + + // SpecificCompiler#getStringType will use java.lang.String, org.apache.avro.util.Utf8, + // or java.lang.CharSequence, unless SpecificData.CLASS_PROP overrides that. + if (specificClassStr != null) { + Class specificClass; + try { + specificClass = ClassUtils.forName(specificClassStr); + if (!DETERMINISTIC_STRINGABLE_CLASSES.contains(specificClass)) { + reportError(context, "Specific class %s is not known to be deterministic", + specificClassStr); + } + } catch (ClassNotFoundException e) { + reportError(context, "Specific class %s is not known to be deterministic", + specificClassStr); + } + } + break; + case UNION: + for (org.apache.avro.Schema subschema : schema.getTypes()) { + checkIndexedRecord(subschema.getName(), subschema, null); + } + break; + case BOOLEAN: + case BYTES: + case DOUBLE: + case INT: + case FLOAT: + case LONG: + case NULL: + // For types that Avro encodes using one of the above primitives, we assume they are + // deterministic. + break; + default: + reportError(context, "Unknown schema type %s may be non-deterministic", schema.getType()); + break; + } + + activeSchemas.remove(schema); + } + + private void checkMap(String context, TypeDescriptor type, Schema schema) { + if (!isSubtypeOf(type, SortedMap.class)) { + reportError(context, "%s may not be deterministically ordered", type); + } + + // Avro (currently) asserts that all keys are strings. + // In case that changes, we double check that the key was a string: + Class keyType = type.resolveType(Map.class.getTypeParameters()[0]).getRawType(); + if (!String.class.equals(keyType)) { + reportError(context, "map keys should be Strings, but was %s", keyType); + } + + recurse(context, + type.resolveType(Map.class.getTypeParameters()[1]), + schema.getValueType()); + } + + private void checkArray(String context, TypeDescriptor type, Schema schema) { + TypeDescriptor elementType = null; + if (type.isArray()) { + // The type is an array (with ordering)-> deterministic iff the element is deterministic. + elementType = type.getComponentType(); + } else if (isSubtypeOf(type, Collection.class)) { + if (isSubtypeOf(type, List.class, SortedSet.class)) { + // Ordered collection -> deterministic iff the element is deterministic + elementType = type.resolveType(Collection.class.getTypeParameters()[0]); + } else { + // Not an ordered collection -> not deterministic + reportError(context, "%s may not be deterministically ordered", type); + return; + } + } else { + // If it was an unknown type encoded as an array, be conservative and assume + // that we don't know anything about the order. + reportError(context, "encoding %s as an ARRAY was unexpected"); + return; + } + + // If we get here, it's either a deterministically-ordered Collection, or + // an array. Either way, the type is deterministic iff the element type is + // deterministic. + recurse(context, elementType, schema.getElementType()); + } + + /** + * Extract a field from a class. We need to look at the declared fields so that we can + * see private fields. We may need to walk up to the parent to get classes from the parent. + */ + private static Field getField(Class clazz, String name) { + while (clazz != null) { + for (Field field : clazz.getDeclaredFields()) { + AvroName avroName = field.getAnnotation(AvroName.class); + if (avroName != null && name.equals(avroName.value())) { + return field; + } else if (avroName == null && name.equals(field.getName())) { + return field; + } + } + clazz = clazz.getSuperclass(); + } + + throw new IllegalArgumentException( + "Unable to get field " + name + " from class " + clazz); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoder.java new file mode 100644 index 000000000000..24f6a45433ad --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoder.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A {@link BigEndianIntegerCoder} encodes {@link Integer Integers} in 4 bytes, big-endian. + */ +public class BigEndianIntegerCoder extends AtomicCoder { + + @JsonCreator + public static BigEndianIntegerCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final BigEndianIntegerCoder INSTANCE = new BigEndianIntegerCoder(); + + private BigEndianIntegerCoder() {} + + @Override + public void encode(Integer value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + new DataOutputStream(outStream).writeInt(value); + } + + @Override + public Integer decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new DataInputStream(inStream).readInt(); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}, because {@link #getEncodedElementByteSize} runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Integer value, Context context) { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code 4}, the size in bytes of an integer's big endian encoding. + */ + @Override + protected long getEncodedElementByteSize(Integer value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + return 4; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoder.java new file mode 100644 index 000000000000..4196608b54e4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoder.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A {@link BigEndianLongCoder} encodes {@link Long}s in 8 bytes, big-endian. + */ +public class BigEndianLongCoder extends AtomicCoder { + + @JsonCreator + public static BigEndianLongCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final BigEndianLongCoder INSTANCE = new BigEndianLongCoder(); + + private BigEndianLongCoder() {} + + @Override + public void encode(Long value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + new DataOutputStream(outStream).writeLong(value); + } + + @Override + public Long decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new DataInputStream(inStream).readLong(); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}, since {@link #getEncodedElementByteSize} returns a constant. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Long value, Context context) { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code 8}, the byte size of a big-endian encoded {@code Long}. + */ + @Override + protected long getEncodedElementByteSize(Long value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + return 8; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoder.java new file mode 100644 index 000000000000..1e555c67bf62 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoder.java @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.ExposedByteArrayOutputStream; +import com.google.cloud.dataflow.sdk.util.StreamUtils; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.io.ByteStreams; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} for {@code byte[]}. + * + *

The encoding format is as follows: + *

    + *
  • If in a non-nested context (the {@code byte[]} is the only value in the stream), the + * bytes are read/written directly.
  • + *
  • If in a nested context, the bytes are prefixed with the length of the array, + * encoded via a {@link VarIntCoder}.
  • + *
+ */ +public class ByteArrayCoder extends AtomicCoder { + + @JsonCreator + public static ByteArrayCoder of() { + return INSTANCE; + } + + + ///////////////////////////////////////////////////////////////////////////// + + private static final ByteArrayCoder INSTANCE = new ByteArrayCoder(); + + private ByteArrayCoder() {} + + @Override + public void encode(byte[] value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null byte[]"); + } + if (!context.isWholeStream) { + VarInt.encode(value.length, outStream); + outStream.write(value); + } else { + outStream.write(value); + } + } + + /** + * Encodes the provided {@code value} with the identical encoding to {@link #encode}, but with + * optimizations that take ownership of the value. + * + *

Once passed to this method, {@code value} should never be observed or mutated again. + */ + public void encodeAndOwn(byte[] value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (!context.isWholeStream) { + VarInt.encode(value.length, outStream); + outStream.write(value); + } else { + if (outStream instanceof ExposedByteArrayOutputStream) { + ((ExposedByteArrayOutputStream) outStream).writeAndOwn(value); + } else { + outStream.write(value); + } + } + } + + @Override + public byte[] decode(InputStream inStream, Context context) + throws IOException, CoderException { + if (context.isWholeStream) { + return StreamUtils.getBytes(inStream); + } else { + int length = VarInt.decodeInt(inStream); + if (length < 0) { + throw new IOException("invalid length " + length); + } + byte[] value = new byte[length]; + ByteStreams.readFully(inStream, value); + return value; + } + } + + /** + * {@inheritDoc} + * + * @return objects that are equal if the two arrays contain the same bytes. + */ + @Override + public Object structuralValue(byte[] value) { + return new StructuralByteArray(value); + } + + /** + * {@inheritDoc} + * + * @return {@code true} since {@link #getEncodedElementByteSize} runs in + * constant time using the {@code length} of the provided array. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(byte[] value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(byte[] value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null byte[]"); + } + long size = 0; + if (!context.isWholeStream) { + size += VarInt.getLength(value.length); + } + return size + value.length; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteCoder.java new file mode 100644 index 000000000000..9f17497d8dc4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteCoder.java @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A {@link ByteCoder} encodes {@link Byte} values in 1 byte using Java serialization. + */ +public class ByteCoder extends AtomicCoder { + + @JsonCreator + public static ByteCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final ByteCoder INSTANCE = new ByteCoder(); + + private ByteCoder() {} + + @Override + public void encode(Byte value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Byte"); + } + outStream.write(value.byteValue()); + } + + @Override + public Byte decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + // value will be between 0-255, -1 for EOF + int value = inStream.read(); + if (value == -1) { + throw new EOFException("EOF encountered decoding 1 byte from input stream"); + } + return (byte) value; + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * {@inheritDoc} + * + * {@link ByteCoder} will never throw a {@link Coder.NonDeterministicException}; bytes can always + * be encoded deterministically. + */ + @Override + public void verifyDeterministic() {} + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link ByteCoder#getEncodedElementByteSize} returns a constant. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Byte value, Context context) { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code 1}, the byte size of a {@link Byte} encoded using Java serialization. + */ + @Override + protected long getEncodedElementByteSize(Byte value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot estimate size for unsupported null value"); + } + return 1; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteStringCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteStringCoder.java new file mode 100644 index 000000000000..b7c1a3cd0ade --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ByteStringCoder.java @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.io.ByteStreams; +import com.google.protobuf.ByteString; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} for {@link ByteString} objects based on their encoded Protocol Buffer form. + * + *

When this code is used in a nested {@link Coder.Context}, the serialized {@link ByteString} + * objects are first delimited by their size. + */ +public class ByteStringCoder extends AtomicCoder { + + @JsonCreator + public static ByteStringCoder of() { + return INSTANCE; + } + + /***************************/ + + private static final ByteStringCoder INSTANCE = new ByteStringCoder(); + + private ByteStringCoder() {} + + @Override + public void encode(ByteString value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null ByteString"); + } + + if (!context.isWholeStream) { + // ByteString is not delimited, so write its size before its contents. + VarInt.encode(value.size(), outStream); + } + value.writeTo(outStream); + } + + @Override + public ByteString decode(InputStream inStream, Context context) throws IOException { + if (context.isWholeStream) { + return ByteString.readFrom(inStream); + } + + int size = VarInt.decodeInt(inStream); + // ByteString reads to the end of the input stream, so give it a limited stream of exactly + // the right length. Also set its chunk size so that the ByteString will contain exactly + // one chunk. + return ByteString.readFrom(ByteStreams.limit(inStream, size), size); + } + + @Override + protected long getEncodedElementByteSize(ByteString value, Context context) throws Exception { + int size = value.size(); + + if (context.isWholeStream) { + return size; + } + return VarInt.getLength(size) + size; + } + + /** + * {@inheritDoc} + * + *

Returns true; the encoded output of two invocations of {@link ByteStringCoder} in the same + * {@link Coder.Context} will be identical if and only if the original {@link ByteString} objects + * are equal according to {@link Object#equals}. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + *

Returns true. {@link ByteString#size} returns the size of an array and a {@link VarInt}. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(ByteString value, Context context) { + return true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CannotProvideCoderException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CannotProvideCoderException.java new file mode 100644 index 000000000000..97b5e238b0c4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CannotProvideCoderException.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +/** + * The exception thrown when a {@link CoderProvider} cannot + * provide a {@link Coder} that has been requested. + */ +public class CannotProvideCoderException extends Exception { + private final ReasonCode reason; + + public CannotProvideCoderException(String message) { + this(message, ReasonCode.UNKNOWN); + } + + public CannotProvideCoderException(String message, ReasonCode reason) { + super(message); + this.reason = reason; + } + + public CannotProvideCoderException(String message, Throwable cause) { + this(message, cause, ReasonCode.UNKNOWN); + } + + public CannotProvideCoderException(String message, Throwable cause, ReasonCode reason) { + super(message, cause); + this.reason = reason; + } + + public CannotProvideCoderException(Throwable cause) { + this(cause, ReasonCode.UNKNOWN); + } + + public CannotProvideCoderException(Throwable cause, ReasonCode reason) { + super(cause); + this.reason = reason; + } + + /** + * @return the reason that Coder inference failed. + */ + public ReasonCode getReason() { + return reason; + } + + /** + * Returns the inner-most {@link CannotProvideCoderException} when they are deeply nested. + * + *

For example, if a coder for {@code List>} cannot be provided because + * there is no known coder for {@code Whatsit}, the root cause of the exception should be a + * CannotProvideCoderException with details pertinent to {@code Whatsit}, suppressing the + * intermediate layers. + */ + public Throwable getRootCause() { + Throwable cause = getCause(); + if (cause == null) { + return this; + } else if (!(cause instanceof CannotProvideCoderException)) { + return cause; + } else { + return ((CannotProvideCoderException) cause).getRootCause(); + } + } + + /** + * Indicates the reason that {@link Coder} inference failed. + */ + public static enum ReasonCode { + /** + * The reason a coder could not be provided is unknown or does have an established + * {@link ReasonCode}. + */ + UNKNOWN, + + /** + * The reason a coder could not be provided is type erasure, for example when requesting + * coder inference for a {@code List} where {@code T} is unknown. + */ + TYPE_ERASURE + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Coder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Coder.java new file mode 100644 index 000000000000..f3a8bec4a620 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Coder.java @@ -0,0 +1,298 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.base.Joiner; +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A {@link Coder Coder<T>} defines how to encode and decode values of type {@code T} into + * byte streams. + * + *

{@link Coder} instances are serialized during job creation and deserialized + * before use, via JSON serialization. See {@link SerializableCoder} for an example of a + * {@link Coder} that adds a custom field to + * the {@link Coder} serialization. It provides a constructor annotated with + * {@link com.fasterxml.jackson.annotation.JsonCreator}, which is a factory method used when + * deserializing a {@link Coder} instance. + * + *

{@link Coder} classes for compound types are often composed from coder classes for types + * contains therein. The composition of {@link Coder} instances into a coder for the compound + * class is the subject of the {@link CoderFactory} type, which enables automatic generic + * composition of {@link Coder} classes within the {@link CoderRegistry}. With particular + * static methods on a compound {@link Coder} class, a {@link CoderFactory} can be automatically + * inferred. See {@link KvCoder} for an example of a simple compound {@link Coder} that supports + * automatic composition in the {@link CoderRegistry}. + * + *

The binary format of a {@link Coder} is identified by {@link #getEncodingId()}; be sure to + * understand the requirements for evolving coder formats. + * + *

All methods of a {@link Coder} are required to be thread safe. + * + * @param the type of the values being transcoded + */ +public interface Coder extends Serializable { + /** The context in which encoding or decoding is being done. */ + public static class Context { + /** + * The outer context: the value being encoded or decoded takes + * up the remainder of the record/stream contents. + */ + public static final Context OUTER = new Context(true); + + /** + * The nested context: the value being encoded or decoded is + * (potentially) a part of a larger record/stream contents, and + * may have other parts encoded or decoded after it. + */ + public static final Context NESTED = new Context(false); + + /** + * Whether the encoded or decoded value fills the remainder of the + * output or input (resp.) record/stream contents. If so, then + * the size of the decoded value can be determined from the + * remaining size of the record/stream contents, and so explicit + * lengths aren't required. + */ + public final boolean isWholeStream; + + public Context(boolean isWholeStream) { + this.isWholeStream = isWholeStream; + } + + public Context nested() { + return NESTED; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Context)) { + return false; + } + return Objects.equal(isWholeStream, ((Context) obj).isWholeStream); + } + + @Override + public int hashCode() { + return Objects.hashCode(isWholeStream); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(Context.class) + .addValue(isWholeStream ? "OUTER" : "NESTED").toString(); + } + } + + /** + * Encodes the given value of type {@code T} onto the given output stream + * in the given context. + * + * @throws IOException if writing to the {@code OutputStream} fails + * for some reason + * @throws CoderException if the value could not be encoded for some reason + */ + public void encode(T value, OutputStream outStream, Context context) + throws CoderException, IOException; + + /** + * Decodes a value of type {@code T} from the given input stream in + * the given context. Returns the decoded value. + * + * @throws IOException if reading from the {@code InputStream} fails + * for some reason + * @throws CoderException if the value could not be decoded for some reason + */ + public T decode(InputStream inStream, Context context) + throws CoderException, IOException; + + /** + * If this is a {@code Coder} for a parameterized type, returns the + * list of {@code Coder}s being used for each of the parameters, or + * returns {@code null} if this cannot be done or this is not a + * parameterized type. + */ + public List> getCoderArguments(); + + /** + * Returns the {@link CloudObject} that represents this {@code Coder}. + */ + public CloudObject asCloudObject(); + + /** + * Throw {@link NonDeterministicException} if the coding is not deterministic. + * + *

In order for a {@code Coder} to be considered deterministic, + * the following must be true: + *

    + *
  • two values that compare as equal (via {@code Object.equals()} + * or {@code Comparable.compareTo()}, if supported) have the same + * encoding. + *
  • the {@code Coder} always produces a canonical encoding, which is the + * same for an instance of an object even if produced on different + * computers at different times. + *
+ * + * @throws Coder.NonDeterministicException if this coder is not deterministic. + */ + public void verifyDeterministic() throws Coder.NonDeterministicException; + + /** + * Returns {@code true} if this {@link Coder} is injective with respect to {@link Objects#equals}. + * + *

Whenever the encoded bytes of two values are equal, then the original values are equal + * according to {@code Objects.equals()}. Note that this is well-defined for {@code null}. + * + *

This condition is most notably false for arrays. More generally, this condition is false + * whenever {@code equals()} compares object identity, rather than performing a + * semantic/structural comparison. + */ + public boolean consistentWithEquals(); + + /** + * Returns an object with an {@code Object.equals()} method that represents structural equality + * on the argument. + * + *

For any two values {@code x} and {@code y} of type {@code T}, if their encoded bytes are the + * same, then it must be the case that {@code structuralValue(x).equals(@code structuralValue(y)}. + * + *

Most notably: + *

    + *
  • The structural value for an array coder should perform a structural comparison of the + * contents of the arrays, rather than the default behavior of comparing according to object + * identity. + *
  • The structural value for a coder accepting {@code null} should be a proper object with + * an {@code equals()} method, even if the input value is {@code null}. + *
+ * + *

See also {@link #consistentWithEquals()}. + */ + public Object structuralValue(T value) throws Exception; + + /** + * Returns whether {@link #registerByteSizeObserver} cheap enough to + * call for every element, that is, if this {@code Coder} can + * calculate the byte size of the element to be coded in roughly + * constant time (or lazily). + * + *

Not intended to be called by user code, but instead by + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * implementations. + */ + public boolean isRegisterByteSizeObserverCheap(T value, Context context); + + /** + * Notifies the {@code ElementByteSizeObserver} about the byte size + * of the encoded value using this {@code Coder}. + * + *

Not intended to be called by user code, but instead by + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * implementations. + */ + public void registerByteSizeObserver( + T value, ElementByteSizeObserver observer, Context context) + throws Exception; + + /** + * An identifier for the binary format written by {@link #encode}. + * + *

This value, along with the fully qualified class name, forms an identifier for the + * binary format of this coder. Whenever this value changes, the new encoding is considered + * incompatible with the prior format: It is presumed that the prior version of the coder will + * be unable to correctly read the new format and the new version of the coder will be unable to + * correctly read the old format. + * + *

If the format is changed in a backwards-compatible way (the Coder can still accept data from + * the prior format), such as by adding optional fields to a Protocol Buffer or Avro definition, + * and you want Dataflow to understand that the new coder is compatible with the prior coder, + * this value must remain unchanged. It is then the responsibility of {@link #decode} to correctly + * read data from the prior format. + */ + @Experimental(Kind.CODER_ENCODING_ID) + public String getEncodingId(); + + /** + * A collection of encodings supported by {@link #decode} in addition to the encoding + * from {@link #getEncodingId()} (which is assumed supported). + * + *

This information is not currently used for any purpose. It is descriptive only, + * and this method is subject to change. + * + * @see #getEncodingId() + */ + @Experimental(Kind.CODER_ENCODING_ID) + public Collection getAllowedEncodings(); + + /** + * Exception thrown by {@link Coder#verifyDeterministic()} if the encoding is + * not deterministic, including details of why the encoding is not deterministic. + */ + public static class NonDeterministicException extends Throwable { + private Coder coder; + private List reasons; + + public NonDeterministicException( + Coder coder, String reason, @Nullable NonDeterministicException e) { + this(coder, Arrays.asList(reason), e); + } + + public NonDeterministicException(Coder coder, String reason) { + this(coder, Arrays.asList(reason), null); + } + + public NonDeterministicException(Coder coder, List reasons) { + this(coder, reasons, null); + } + + public NonDeterministicException( + Coder coder, + List reasons, + @Nullable NonDeterministicException cause) { + super(cause); + Preconditions.checkArgument(reasons.size() > 0, + "Reasons must not be empty."); + this.reasons = reasons; + this.coder = coder; + } + + public Iterable getReasons() { + return reasons; + } + + @Override + public String getMessage() { + return String.format("%s is not deterministic because:\n %s", + coder, Joiner.on("\n ").join(reasons)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderException.java new file mode 100644 index 000000000000..8ff8571e5c27 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderException.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import java.io.IOException; + +/** + * An {@link Exception} thrown if there is a problem encoding or decoding a value. + */ +public class CoderException extends IOException { + public CoderException(String message) { + super(message); + } + + public CoderException(String message, Throwable cause) { + super(message, cause); + } + + public CoderException(Throwable cause) { + super(cause); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderFactories.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderFactories.java new file mode 100644 index 000000000000..82b40a489fd8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderFactories.java @@ -0,0 +1,274 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Static utility methods for creating and working with {@link Coder}s. + */ +public final class CoderFactories { + private CoderFactories() { } // Static utility class + + /** + * Creates a {@link CoderFactory} built from particular static methods of a class that + * implements {@link Coder}. + * + *

The class must have the following static methods: + * + *

    + *
  • {@code + * public static Coder of(Coder argCoder1, Coder argCoder2, ...) + * } + *
  • {@code + * public static List getInstanceComponents(T exampleValue); + * } + * + * + *

    The {@code of(...)} method will be used to construct a + * {@code Coder} from component {@link Coder}s. + * It must accept one {@link Coder} argument for each + * generic type parameter of {@code T}. If {@code T} takes no generic + * type parameters, then the {@code of()} factory method should take + * no arguments. + * + *

    The {@code getInstanceComponents} method will be used to + * decompose a value during the {@link Coder} inference process, + * to automatically choose coders for the components. + * + *

    Note that the class {@code T} to be coded may be a + * not-yet-specialized generic class. + * For a generic class {@code MyClass} and an actual type parameter + * {@code Foo}, the {@link CoderFactoryFromStaticMethods} will + * accept any {@code Coder} and produce a {@code Coder>}. + * + *

    For example, the {@link CoderFactory} returned by + * {@code fromStaticMethods(ListCoder.class)} + * will produce a {@code Coder>} for any {@code Coder Coder}. + */ + public static CoderFactory fromStaticMethods(Class clazz) { + return new CoderFactoryFromStaticMethods(clazz); + } + + /** + * Creates a {@link CoderFactory} that always returns the + * given coder. + * + *

    The {@code getInstanceComponents} method of this + * {@link CoderFactory} always returns an empty list. + */ + public static CoderFactory forCoder(Coder coder) { + return new CoderFactoryForCoder<>(coder); + } + + /** + * See {@link #fromStaticMethods} for a detailed description + * of the characteristics of this {@link CoderFactory}. + */ + private static class CoderFactoryFromStaticMethods implements CoderFactory { + + @Override + @SuppressWarnings("rawtypes") + public Coder create(List> componentCoders) { + try { + return (Coder) factoryMethod.invoke( + null /* static */, componentCoders.toArray()); + } catch (IllegalAccessException | + IllegalArgumentException | + InvocationTargetException | + NullPointerException | + ExceptionInInitializerError exn) { + throw new IllegalStateException( + "error when invoking Coder factory method " + factoryMethod, + exn); + } + } + + @Override + public List getInstanceComponents(Object value) { + try { + @SuppressWarnings("unchecked") + List components = (List) getComponentsMethod.invoke( + null /* static */, value); + return components; + } catch (IllegalAccessException + | IllegalArgumentException + | InvocationTargetException + | NullPointerException + | ExceptionInInitializerError exn) { + throw new IllegalStateException( + "error when invoking Coder getComponents method " + getComponentsMethod, + exn); + } + } + + //////////////////////////////////////////////////////////////////////////////// + + // Method to create a coder given component coders + // For a Coder class of kind * -> * -> ... n times ... -> * + // this has type Coder -> Coder -> ... n times ... -> Coder + private Method factoryMethod; + + // Method to decompose a value of type T into its parts. + // For a Coder class of kind * -> * -> ... n times ... -> * + // this has type T -> List + // where the list has n elements. + private Method getComponentsMethod; + + /** + * Returns a CoderFactory that invokes the given static factory method + * to create the Coder. + */ + private CoderFactoryFromStaticMethods(Class coderClazz) { + this.factoryMethod = getFactoryMethod(coderClazz); + this.getComponentsMethod = getInstanceComponentsMethod(coderClazz); + } + + /** + * Returns the static {@code of} constructor method on {@code coderClazz} + * if it exists. It is assumed to have one {@link Coder} parameter for + * each type parameter of {@code coderClazz}. + */ + private Method getFactoryMethod(Class coderClazz) { + Method factoryMethodCandidate; + + // Find the static factory method of coderClazz named 'of' with + // the appropriate number of type parameters. + int numTypeParameters = coderClazz.getTypeParameters().length; + Class[] factoryMethodArgTypes = new Class[numTypeParameters]; + Arrays.fill(factoryMethodArgTypes, Coder.class); + try { + factoryMethodCandidate = + coderClazz.getDeclaredMethod("of", factoryMethodArgTypes); + } catch (NoSuchMethodException | SecurityException exn) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "does not have an accessible method named 'of' with " + + numTypeParameters + " arguments of Coder type", + exn); + } + if (!Modifier.isStatic(factoryMethodCandidate.getModifiers())) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "method named 'of' with " + numTypeParameters + + " arguments of Coder type is not static"); + } + if (!coderClazz.isAssignableFrom(factoryMethodCandidate.getReturnType())) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "method named 'of' with " + numTypeParameters + + " arguments of Coder type does not return a " + coderClazz); + } + try { + if (!factoryMethodCandidate.isAccessible()) { + factoryMethodCandidate.setAccessible(true); + } + } catch (SecurityException exn) { + throw new IllegalArgumentException( + "cannot register Coder " + coderClazz + ": " + + "method named 'of' with " + numTypeParameters + + " arguments of Coder type is not accessible", + exn); + } + + return factoryMethodCandidate; + } + + /** + * Finds the static method on {@code coderType} to use + * to decompose a value of type {@code T} into components, + * each corresponding to an argument of the {@code of} + * method. + */ + private Method getInstanceComponentsMethod(Class coderClazz) { + TypeDescriptor coderType = TypeDescriptor.of(coderClazz); + TypeDescriptor argumentType = getCodedType(coderType); + + // getInstanceComponents may be implemented in a superclass, + // so we search them all for an applicable method. We do not + // try to be clever about finding the best overload. It may + // be in a generic superclass, erased to accept an Object. + // However, subtypes are listed before supertypes (it is a + // topological ordering) so probably the best one will be chosen + // if there are more than one (which should be rare) + for (TypeDescriptor supertype : coderType.getClasses()) { + for (Method method : supertype.getRawType().getDeclaredMethods()) { + if (method.getName().equals("getInstanceComponents")) { + TypeDescriptor formalArgumentType = supertype.getArgumentTypes(method).get(0); + if (formalArgumentType.getRawType().isAssignableFrom(argumentType.getRawType())) { + return method; + } + } + } + } + + throw new IllegalArgumentException( + "cannot create a CoderFactory from " + coderType + ": " + + "does not have an accessible method " + + "'getInstanceComponents'"); + } + + /** + * If {@code coderType} is a subclass of {@link Coder} for a specific + * type {@code T}, returns {@code T.class}. Otherwise, raises IllegalArgumentException. + */ + private TypeDescriptor getCodedType(TypeDescriptor coderType) { + for (TypeDescriptor ifaceType : coderType.getInterfaces()) { + if (ifaceType.getRawType().equals(Coder.class)) { + ParameterizedType coderIface = (ParameterizedType) ifaceType.getType(); + @SuppressWarnings("unchecked") + TypeDescriptor token = + (TypeDescriptor) TypeDescriptor.of(coderIface.getActualTypeArguments()[0]); + return token; + } + } + throw new IllegalArgumentException( + "cannot build CoderFactory from class " + coderType + + ": does not implement Coder for any T."); + } + } + + /** + * See {@link #forCoder} for a detailed description of this + * {@link CoderFactory}. + */ + private static class CoderFactoryForCoder implements CoderFactory { + private Coder coder; + + public CoderFactoryForCoder(Coder coder) { + this.coder = coder; + } + + @Override + public Coder create(List> componentCoders) { + return this.coder; + } + + @Override + public List getInstanceComponents(Object value) { + return Collections.emptyList(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderFactory.java new file mode 100644 index 000000000000..541256c97814 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import java.util.List; + +/** + * A {@link CoderFactory} creates coders and decomposes values. + * It may operate on a parameterized type, such as {@link List}, + * in which case the {@link #create} method accepts a list of + * coders to use for the type parameters. + */ +public interface CoderFactory { + + /** + * Returns a {@code Coder}, given argument coder to use for + * values of a particular type, given the Coders for each of + * the type's generic parameter types. + */ + public Coder create(List> componentCoders); + + /** + * Returns a list of objects contained in {@code value}, one per + * type argument, or {@code null} if none can be determined. + * The list of returned objects should be the same size as the + * list of coders required by {@link #create}. + */ + public List getInstanceComponents(Object value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderProvider.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderProvider.java new file mode 100644 index 000000000000..a3e6ec4e7cad --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderProvider.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +/** + * A {@link CoderProvider} may create a {@link Coder} for + * any concrete class. + */ +public interface CoderProvider { + + /** + * Provides a coder for a given class, if possible. + * + * @throws CannotProvideCoderException if no coder can be provided + */ + public Coder getCoder(TypeDescriptor type) throws CannotProvideCoderException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderProviders.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderProviders.java new file mode 100644 index 000000000000..8b0aedd2b446 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderProviders.java @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import java.lang.reflect.InvocationTargetException; +import java.util.List; + +/** + * Static utility methods for working with {@link CoderProvider CoderProviders}. + */ +public final class CoderProviders { + + // Static utility class + private CoderProviders() { } + + /** + * Creates a {@link CoderProvider} built from particular static methods of a class that + * implements {@link Coder}. The requirements for this method are precisely the requirements + * for a {@link Coder} class to be usable with {@link DefaultCoder} annotations. + * + *

    The class must have the following static method: + * + *

    {@code
    +   * public static Coder of(TypeDescriptor type)
    +   * }
    +   * 
    + */ + public static CoderProvider fromStaticMethods(Class clazz) { + return new CoderProviderFromStaticMethods(clazz); + } + + + /** + * Returns a {@link CoderProvider} that consults each of the provider {@code coderProviders} + * and returns the first {@link Coder} provided. + * + *

    Note that the order in which the providers are listed matters: While the set of types + * handled will be the union of those handled by all of the providers in the list, the actual + * {@link Coder} provided by the first successful provider may differ, and may have inferior + * properties. For example, not all {@link Coder Coders} are deterministic, handle {@code null} + * values, or have comparable performance. + */ + public static CoderProvider firstOf(CoderProvider... coderProviders) { + return new FirstOf(ImmutableList.copyOf(coderProviders)); + } + + /////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * @see #firstOf + */ + private static class FirstOf implements CoderProvider { + + private Iterable providers; + + public FirstOf(Iterable providers) { + this.providers = providers; + } + + @Override + public Coder getCoder(TypeDescriptor type) throws CannotProvideCoderException { + List messages = Lists.newArrayList(); + for (CoderProvider provider : providers) { + try { + return provider.getCoder(type); + } catch (CannotProvideCoderException exc) { + messages.add(String.format("%s could not provide a Coder for type %s: %s", + provider, type, exc.getMessage())); + } + } + throw new CannotProvideCoderException( + String.format("Cannot provide coder for type %s: %s.", + type, Joiner.on("; ").join(messages))); + } + } + + private static class CoderProviderFromStaticMethods implements CoderProvider { + + /** If true, then clazz has {@code of(TypeDescriptor)}. If false, {@code of(Class)}. */ + private final boolean takesTypeDescriptor; + private final Class clazz; + + public CoderProviderFromStaticMethods(Class clazz) { + // Note that the second condition supports older classes, which only needed to provide + // of(Class), not of(TypeDescriptor). Our own classes have updated to accept a + // TypeDescriptor. Hence the error message points only to the current specification, + // not both acceptable conditions. + checkArgument(classTakesTypeDescriptor(clazz) || classTakesClass(clazz), + "Class " + clazz.getCanonicalName() + + " is missing required static method of(TypeDescriptor)."); + + this.takesTypeDescriptor = classTakesTypeDescriptor(clazz); + this.clazz = clazz; + } + + @Override + public Coder getCoder(TypeDescriptor type) throws CannotProvideCoderException { + try { + if (takesTypeDescriptor) { + @SuppressWarnings("unchecked") + Coder result = InstanceBuilder.ofType(Coder.class) + .fromClass(clazz) + .fromFactoryMethod("of") + .withArg(TypeDescriptor.class, type) + .build(); + return result; + } else { + @SuppressWarnings("unchecked") + Coder result = InstanceBuilder.ofType(Coder.class) + .fromClass(clazz) + .fromFactoryMethod("of") + .withArg(Class.class, type.getRawType()) + .build(); + return result; + } + } catch (RuntimeException exc) { + if (exc.getCause() instanceof InvocationTargetException) { + throw new CannotProvideCoderException(exc.getCause().getCause()); + } + throw exc; + } + } + + private boolean classTakesTypeDescriptor(Class clazz) { + try { + clazz.getDeclaredMethod("of", TypeDescriptor.class); + return true; + } catch (NoSuchMethodException | SecurityException exc) { + return false; + } + } + + private boolean classTakesClass(Class clazz) { + try { + clazz.getDeclaredMethod("of", Class.class); + return true; + } catch (NoSuchMethodException | SecurityException exc) { + return false; + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderRegistry.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderRegistry.java new file mode 100644 index 000000000000..00982e64ffd3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CoderRegistry.java @@ -0,0 +1,843 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException.ReasonCode; +import com.google.cloud.dataflow.sdk.coders.protobuf.ProtoCoder; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.lang.reflect.WildcardType; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * A {@link CoderRegistry} allows registering the default {@link Coder} to use for a Java class, + * and looking up and instantiating the default {@link Coder} for a Java type. + * + *

    {@link CoderRegistry} uses the following mechanisms to determine a default {@link Coder} for a + * Java class, in order of precedence: + *

      + *
    1. Registration: + *
        + *
      • A {@link CoderFactory} can be registered to handle a particular class via + * {@link #registerCoder(Class, CoderFactory)}.
      • + *
      • A {@link Coder} class with the static methods to satisfy + * {@link CoderFactories#fromStaticMethods} can be registered via + * {@link #registerCoder(Class, Class)}.
      • + *
      • Built-in types are registered via + * {@link #registerStandardCoders()}.
      • + *
      + *
    2. Annotations: {@link DefaultCoder} can be used to annotate a type with + * the default {@code Coder} type. The {@link Coder} class must satisfy the requirements + * of {@link CoderProviders#fromStaticMethods}. + *
    3. Fallback: A fallback {@link CoderProvider} is used to attempt to provide a {@link Coder} + * for any type. By default, this is {@link SerializableCoder#PROVIDER}, which can provide + * a {@link Coder} for any type that is serializable via Java serialization. The fallback + * {@link CoderProvider} can be get and set via {@link #getFallbackCoderProvider()} + * and {@link #setFallbackCoderProvider}. Multiple fallbacks can be chained together using + * {@link CoderProviders#firstOf}. + *
    + */ +public class CoderRegistry implements CoderProvider { + + private static final Logger LOG = LoggerFactory.getLogger(CoderRegistry.class); + + public CoderRegistry() { + setFallbackCoderProvider( + CoderProviders.firstOf(ProtoCoder.coderProvider(), SerializableCoder.PROVIDER)); + } + + /** + * Registers standard Coders with this CoderRegistry. + */ + public void registerStandardCoders() { + registerCoder(Byte.class, ByteCoder.class); + registerCoder(ByteString.class, ByteStringCoder.class); + registerCoder(Double.class, DoubleCoder.class); + registerCoder(Instant.class, InstantCoder.class); + registerCoder(Integer.class, VarIntCoder.class); + registerCoder(Iterable.class, IterableCoder.class); + registerCoder(KV.class, KvCoder.class); + registerCoder(List.class, ListCoder.class); + registerCoder(Long.class, VarLongCoder.class); + registerCoder(Map.class, MapCoder.class); + registerCoder(Set.class, SetCoder.class); + registerCoder(String.class, StringUtf8Coder.class); + registerCoder(TableRow.class, TableRowJsonCoder.class); + registerCoder(TimestampedValue.class, TimestampedValue.TimestampedValueCoder.class); + registerCoder(Void.class, VoidCoder.class); + registerCoder(byte[].class, ByteArrayCoder.class); + } + + /** + * Registers {@code coderClazz} as the default {@link Coder} class to handle encoding and + * decoding instances of {@code clazz}, overriding prior registrations if any exist. + * + *

    Supposing {@code T} is the static type corresponding to the {@code clazz}, then + * {@code coderClazz} should have a static factory method with the following signature: + * + *

     {@code
    +   * public static Coder of(Coder argCoder1, Coder argCoder2, ...)
    +   * } 
    + * + *

    This method will be called to create instances of {@code Coder} for values of type + * {@code T}, passing Coders for each of the generic type parameters of {@code T}. If {@code T} + * takes no generic type parameters, then the {@code of()} factory method should have no + * arguments. + * + *

    If {@code T} is a parameterized type, then it should additionally have a method with the + * following signature: + * + *

     {@code
    +   * public static List getInstanceComponents(T exampleValue);
    +   * } 
    +   *
    +   * 

    This method will be called to decompose a value during the {@link Coder} inference process, + * to automatically choose {@link Coder Coders} for the components. + * + * @param clazz the class of objects to be encoded + * @param coderClazz a class with static factory methods to provide {@link Coder Coders} + */ + public void registerCoder(Class clazz, Class coderClazz) { + registerCoder(clazz, CoderFactories.fromStaticMethods(coderClazz)); + } + + /** + * Registers {@code coderFactory} as the default {@link CoderFactory} to produce {@code Coder} + * instances to decode and encode instances of {@code clazz}. This will override prior + * registrations if any exist. + */ + public void registerCoder(Class clazz, CoderFactory coderFactory) { + coderFactoryMap.put(clazz, coderFactory); + } + + /** + * Register the provided {@link Coder} for encoding all values of the specified {@code Class}. + * This will override prior registrations if any exist. + * + *

    Not for use with generic rawtypes. Instead, register a {@link CoderFactory} via + * {@link #registerCoder(Class, CoderFactory)} or ensure your {@code Coder} class has the + * appropriate static methods and register it directly via {@link #registerCoder(Class, Class)}. + */ + public void registerCoder(Class rawClazz, Coder coder) { + Preconditions.checkArgument( + rawClazz.getTypeParameters().length == 0, + "CoderRegistry.registerCoder(Class, Coder) may not be used " + + "with unspecialized generic classes"); + + CoderFactory factory = CoderFactories.forCoder(coder); + registerCoder(rawClazz, factory); + } + + /** + * Returns the {@link Coder} to use by default for values of the given type. + * + * @throws CannotProvideCoderException if there is no default Coder. + */ + public Coder getDefaultCoder(TypeDescriptor typeDescriptor) + throws CannotProvideCoderException { + return getDefaultCoder(typeDescriptor, Collections.>emptyMap()); + } + + /** + * See {@link #getDefaultCoder(TypeDescriptor)}. + */ + @Override + public Coder getCoder(TypeDescriptor typeDescriptor) + throws CannotProvideCoderException { + return getDefaultCoder(typeDescriptor); + } + + /** + * Returns the {@link Coder} to use by default for values of the given type, where the given input + * type uses the given {@link Coder}. + * + * @throws CannotProvideCoderException if there is no default Coder. + */ + public Coder getDefaultCoder( + TypeDescriptor typeDescriptor, + TypeDescriptor inputTypeDescriptor, + Coder inputCoder) + throws CannotProvideCoderException { + return getDefaultCoder( + typeDescriptor, getTypeToCoderBindings(inputTypeDescriptor.getType(), inputCoder)); + } + + /** + * Returns the {@link Coder} to use on elements produced by this function, given the {@link Coder} + * used for its input elements. + */ + public Coder getDefaultOutputCoder( + SerializableFunction fn, Coder inputCoder) + throws CannotProvideCoderException { + + ParameterizedType fnType = (ParameterizedType) + TypeDescriptor.of(fn.getClass()).getSupertype(SerializableFunction.class).getType(); + + return getDefaultCoder( + fn.getClass(), + SerializableFunction.class, + ImmutableMap.of(fnType.getActualTypeArguments()[0], inputCoder), + SerializableFunction.class.getTypeParameters()[1]); + } + + /** + * Returns the {@link Coder} to use for the specified type parameter specialization of the + * subclass, given {@link Coder Coders} to use for all other type parameters (if any). + * + * @throws CannotProvideCoderException if there is no default Coder. + */ + public Coder getDefaultCoder( + Class subClass, + Class baseClass, + Map> knownCoders, + TypeVariable param) + throws CannotProvideCoderException { + + Map> inferredCoders = getDefaultCoders(subClass, baseClass, knownCoders); + + @SuppressWarnings("unchecked") + Coder paramCoderOrNull = (Coder) inferredCoders.get(param); + if (paramCoderOrNull != null) { + return paramCoderOrNull; + } else { + throw new CannotProvideCoderException( + "Cannot infer coder for type parameter " + param.getName()); + } + } + + /** + * Returns the {@link Coder} to use for the provided example value, if it can be determined. + * + * @throws CannotProvideCoderException if there is no default {@link Coder} or + * more than one {@link Coder} matches + */ + public Coder getDefaultCoder(T exampleValue) throws CannotProvideCoderException { + Class clazz = exampleValue == null ? Void.class : exampleValue.getClass(); + + if (clazz.getTypeParameters().length == 0) { + // Trust that getDefaultCoder returns a valid + // Coder for non-generic clazz. + @SuppressWarnings("unchecked") + Coder coder = (Coder) getDefaultCoder(clazz); + return coder; + } else { + CoderFactory factory = getDefaultCoderFactory(clazz); + + List components = factory.getInstanceComponents(exampleValue); + if (components == null) { + throw new CannotProvideCoderException(String.format( + "Cannot provide coder based on value with class %s: The registered CoderFactory with " + + "class %s failed to decompose the value, which is required in order to provide " + + "Coders for the components.", + clazz.getCanonicalName(), factory.getClass().getCanonicalName())); + } + + // componentcoders = components.map(this.getDefaultCoder) + List> componentCoders = new ArrayList<>(); + for (Object component : components) { + try { + Coder componentCoder = getDefaultCoder(component); + componentCoders.add(componentCoder); + } catch (CannotProvideCoderException exc) { + throw new CannotProvideCoderException( + String.format("Cannot provide coder based on value with class %s", + clazz.getCanonicalName()), + exc); + } + } + + // Trust that factory.create maps from valid component Coders + // to a valid Coder. + @SuppressWarnings("unchecked") + Coder coder = (Coder) factory.create(componentCoders); + return coder; + } + } + + /** + * Returns the {@link Coder} to use by default for values of the given class. The following three + * sources for a {@link Coder} will be attempted, in order: + * + *
      + *
    1. A {@link Coder} class registered explicitly via a call to {@link #registerCoder}, + *
    2. A {@link DefaultCoder} annotation on the class, + *
    3. This registry's fallback {@link CoderProvider}, which may be able to generate a + * {@link Coder} for an arbitrary class. + *
    + * + * @throws CannotProvideCoderException if a {@link Coder} cannot be provided + */ + public Coder getDefaultCoder(Class clazz) throws CannotProvideCoderException { + + CannotProvideCoderException factoryException; + try { + CoderFactory coderFactory = getDefaultCoderFactory(clazz); + LOG.debug("Default coder for {} found by factory", clazz); + @SuppressWarnings("unchecked") + Coder coder = (Coder) coderFactory.create(Collections.>emptyList()); + return coder; + } catch (CannotProvideCoderException exc) { + factoryException = exc; + } + + CannotProvideCoderException annotationException; + try { + return getDefaultCoderFromAnnotation(clazz); + } catch (CannotProvideCoderException exc) { + annotationException = exc; + } + + CannotProvideCoderException fallbackException; + if (getFallbackCoderProvider() != null) { + try { + return getFallbackCoderProvider().getCoder(TypeDescriptor.of(clazz)); + } catch (CannotProvideCoderException exc) { + fallbackException = exc; + } + } else { + fallbackException = new CannotProvideCoderException("no fallback CoderProvider configured"); + } + + // Build up the error message and list of causes. + StringBuilder messageBuilder = new StringBuilder() + .append("Unable to provide a default Coder for ").append(clazz.getCanonicalName()) + .append(". Correct one of the following root causes:"); + + messageBuilder + .append("\n Building a Coder using a registered CoderFactory failed: ") + .append(factoryException.getMessage()); + + messageBuilder + .append("\n Building a Coder from the @DefaultCoder annotation failed: ") + .append(annotationException.getMessage()); + + messageBuilder + .append("\n Building a Coder from the fallback CoderProvider failed: ") + .append(fallbackException.getMessage()); + + throw new CannotProvideCoderException(messageBuilder.toString()); + } + + /** + * Sets the fallback {@link CoderProvider} for this registry. If no other method succeeds in + * providing a {@code Coder} for a type {@code T}, then the registry will attempt to create + * a {@link Coder} using this {@link CoderProvider}. + * + *

    By default, this is set to {@link SerializableCoder#PROVIDER}. + * + *

    See {@link #getFallbackCoderProvider}. + */ + public void setFallbackCoderProvider(CoderProvider coderProvider) { + fallbackCoderProvider = coderProvider; + } + + /** + * Returns the fallback {@link CoderProvider} for this registry. + * + *

    See {@link #setFallbackCoderProvider}. + */ + public CoderProvider getFallbackCoderProvider() { + return fallbackCoderProvider; + } + + /** + * Returns a {@code Map} from each of {@code baseClass}'s type parameters to the {@link Coder} to + * use by default for it, in the context of {@code subClass}'s specialization of + * {@code baseClass}. + * + *

    If no {@link Coder} can be inferred for a particular type parameter, then that type variable + * will be absent from the returned {@code Map}. + * + *

    For example, if {@code baseClass} is {@code Map.class}, where {@code Map} has type + * parameters {@code K} and {@code V}, and {@code subClass} extends {@code Map} + * then the result will map the type variable {@code K} to a {@code Coder} and the + * type variable {@code V} to a {@code Coder}. + * + *

    The {@code knownCoders} parameter can be used to provide known {@link Coder Coders} for any + * of the parameters; these will be used to infer the others. + * + *

    Note that inference is attempted for every type variable. For a type + * {@code MyType} inference will be attempted for all of {@code One}, + * {@code Two}, {@code Three}, even if the requester only wants a {@link Coder} for {@code Two}. + * + *

    For this reason {@code getDefaultCoders} (plural) does not throw an exception if a + * {@link Coder} for a particular type variable cannot be inferred, but merely omits the entry + * from the returned {@code Map}. It is the responsibility of the caller (usually + * {@link #getDefaultCoder} to extract the desired coder or throw a + * {@link CannotProvideCoderException} when appropriate. + * + * @param subClass the concrete type whose specializations are being inferred + * @param baseClass the base type, a parameterized class + * @param knownCoders a map corresponding to the set of known {@link Coder Coders} indexed by + * parameter name + * + * @deprecated this method is not part of the public interface and will be made private + */ + @Deprecated + public Map> getDefaultCoders( + Class subClass, + Class baseClass, + Map> knownCoders) { + TypeVariable>[] typeParams = baseClass.getTypeParameters(); + Coder[] knownCodersArray = new Coder[typeParams.length]; + for (int i = 0; i < typeParams.length; i++) { + knownCodersArray[i] = knownCoders.get(typeParams[i]); + } + Coder[] resultArray = getDefaultCoders( + subClass, baseClass, knownCodersArray); + Map> result = new HashMap<>(); + for (int i = 0; i < typeParams.length; i++) { + if (resultArray[i] != null) { + result.put(typeParams[i], resultArray[i]); + } + } + return result; + } + + /** + * Returns an array listing, for each of {@code baseClass}'s type parameters, the {@link Coder} to + * use by default for it, in the context of {@code subClass}'s specialization of + * {@code baseClass}. + * + *

    If a {@link Coder} cannot be inferred for a type variable, its slot in the resulting array + * will be {@code null}. + * + *

    For example, if {@code baseClass} is {@code Map.class}, where {@code Map} has type + * parameters {@code K} and {@code V} in that order, and {@code subClass} extends + * {@code Map} then the result will contain a {@code Coder} and a + * {@code Coder}, in that order. + * + *

    The {@code knownCoders} parameter can be used to provide known {@link Coder Coders} for any + * of the type parameters. These will be used to infer the others. If non-null, the length of this + * array must match the number of type parameters of {@code baseClass}, and simply be filled with + * {@code null} values for each type parameters without a known {@link Coder}. + * + *

    Note that inference is attempted for every type variable. For a type + * {@code MyType} inference will will be attempted for all of {@code One}, + * {@code Two}, {@code Three}, even if the requester only wants a {@link Coder} for {@code Two}. + * + *

    For this reason {@code getDefaultCoders} (plural) does not throw an exception if a + * {@link Coder} for a particular type variable cannot be inferred. Instead, it results in a + * {@code null} in the array. It is the responsibility of the caller (usually + * {@link #getDefaultCoder} to extract the desired coder or throw a + * {@link CannotProvideCoderException} when appropriate. + * + * @param subClass the concrete type whose specializations are being inferred + * @param baseClass the base type, a parameterized class + * @param knownCoders an array corresponding to the set of base class type parameters. Each entry + * can be either a {@link Coder} (in which case it will be used for inference) or + * {@code null} (in which case it will be inferred). May be {@code null} to indicate the + * entire set of parameters should be inferred. + * @throws IllegalArgumentException if baseClass doesn't have type parameters or if the length of + * {@code knownCoders} is not equal to the number of type parameters of {@code baseClass}. + */ + private Coder[] getDefaultCoders( + Class subClass, + Class baseClass, + @Nullable Coder[] knownCoders) { + Type type = TypeDescriptor.of(subClass).getSupertype(baseClass).getType(); + if (!(type instanceof ParameterizedType)) { + throw new IllegalArgumentException(type + " is not a ParameterizedType"); + } + ParameterizedType parameterizedType = (ParameterizedType) type; + Type[] typeArgs = parameterizedType.getActualTypeArguments(); + if (knownCoders == null) { + knownCoders = new Coder[typeArgs.length]; + } else if (typeArgs.length != knownCoders.length) { + throw new IllegalArgumentException( + String.format("Class %s has %d parameters, but %d coders are requested.", + baseClass.getCanonicalName(), typeArgs.length, knownCoders.length)); + } + + Map> context = new HashMap<>(); + for (int i = 0; i < knownCoders.length; i++) { + if (knownCoders[i] != null) { + try { + verifyCompatible(knownCoders[i], typeArgs[i]); + } catch (IncompatibleCoderException exn) { + throw new IllegalArgumentException( + String.format("Provided coders for type arguments of %s contain incompatibilities:" + + " Cannot encode elements of type %s with coder %s", + baseClass, + typeArgs[i], knownCoders[i]), + exn); + } + context.putAll(getTypeToCoderBindings(typeArgs[i], knownCoders[i])); + } + } + + Coder[] result = new Coder[typeArgs.length]; + for (int i = 0; i < knownCoders.length; i++) { + if (knownCoders[i] != null) { + result[i] = knownCoders[i]; + } else { + try { + result[i] = getDefaultCoder(typeArgs[i], context); + } catch (CannotProvideCoderException exc) { + result[i] = null; + } + } + } + return result; + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Thrown when a {@link Coder} cannot possibly encode a type, yet has been proposed as a + * {@link Coder} for that type. + */ + @VisibleForTesting static class IncompatibleCoderException extends RuntimeException { + private Coder coder; + private Type type; + + public IncompatibleCoderException(String message, Coder coder, Type type) { + super(message); + this.coder = coder; + this.type = type; + } + + public IncompatibleCoderException(String message, Coder coder, Type type, Throwable cause) { + super(message, cause); + this.coder = coder; + this.type = type; + } + + public Coder getCoder() { + return coder; + } + + public Type getType() { + return type; + } + } + + /** + * Returns {@code true} if the given {@link Coder} can possibly encode elements + * of the given type. + */ + @VisibleForTesting static , CandidateT> + void verifyCompatible(CoderT coder, Type candidateType) throws IncompatibleCoderException { + + // Various representations of the coder's class + @SuppressWarnings("unchecked") + Class coderClass = (Class) coder.getClass(); + TypeDescriptor coderDescriptor = TypeDescriptor.of(coderClass); + + // Various representations of the actual coded type + @SuppressWarnings("unchecked") + TypeDescriptor codedDescriptor = CoderUtils.getCodedType(coderDescriptor); + @SuppressWarnings("unchecked") + Class codedClass = (Class) codedDescriptor.getRawType(); + Type codedType = codedDescriptor.getType(); + + // Various representations of the candidate type + @SuppressWarnings("unchecked") + TypeDescriptor candidateDescriptor = + (TypeDescriptor) TypeDescriptor.of(candidateType); + @SuppressWarnings("unchecked") + Class candidateClass = (Class) candidateDescriptor.getRawType(); + + // If coder has type Coder where the actual value of T is lost + // to erasure, then we cannot rule it out. + if (candidateType instanceof TypeVariable) { + return; + } + + // If the raw types are not compatible, we can certainly rule out + // coder compatibility + if (!codedClass.isAssignableFrom(candidateClass)) { + throw new IncompatibleCoderException( + String.format("Cannot encode elements of type %s with coder %s because the" + + " coded type %s is not assignable from %s", + candidateType, coder, codedClass, candidateType), + coder, candidateType); + } + // we have established that this is a covariant upcast... though + // coders are invariant, we are just checking one direction + @SuppressWarnings("unchecked") + TypeDescriptor candidateOkDescriptor = (TypeDescriptor) candidateDescriptor; + + // If the coded type is a parameterized type where any of the actual + // type parameters are not compatible, then the whole thing is certainly not + // compatible. + if ((codedType instanceof ParameterizedType) && !isNullOrEmpty(coder.getCoderArguments())) { + ParameterizedType parameterizedSupertype = ((ParameterizedType) + candidateOkDescriptor.getSupertype(codedClass).getType()); + Type[] typeArguments = parameterizedSupertype.getActualTypeArguments(); + List> typeArgumentCoders = coder.getCoderArguments(); + if (typeArguments.length < typeArgumentCoders.size()) { + throw new IncompatibleCoderException( + String.format("Cannot encode elements of type %s with coder %s:" + + " the generic supertype %s has %s type parameters, which is less than the" + + " number of coder arguments %s has (%s).", + candidateOkDescriptor, coder, + parameterizedSupertype, typeArguments.length, + coder, typeArgumentCoders.size()), + coder, candidateOkDescriptor.getType()); + } + for (int i = 0; i < typeArgumentCoders.size(); i++) { + try { + verifyCompatible( + typeArgumentCoders.get(i), + candidateDescriptor.resolveType(typeArguments[i]).getType()); + } catch (IncompatibleCoderException exn) { + throw new IncompatibleCoderException( + String.format("Cannot encode elements of type %s with coder %s" + + " because some component coder is incompatible", + candidateType, coder), + coder, candidateType, exn); + } + } + } + } + + private static boolean isNullOrEmpty(Collection c) { + return c == null || c.size() == 0; + } + + /** + * The map of classes to the CoderFactories to use to create their + * default Coders. + */ + private Map, CoderFactory> coderFactoryMap = new HashMap<>(); + + /** + * A provider of coders for types where no coder is registered. + */ + private CoderProvider fallbackCoderProvider; + + /** + * Returns the {@link CoderFactory} to use to create default {@link Coder Coders} for instances of + * the given class, or {@code null} if there is no default {@link CoderFactory} registered. + */ + private CoderFactory getDefaultCoderFactory(Class clazz) throws CannotProvideCoderException { + CoderFactory coderFactoryOrNull = coderFactoryMap.get(clazz); + if (coderFactoryOrNull != null) { + return coderFactoryOrNull; + } else { + throw new CannotProvideCoderException( + String.format("Cannot provide coder based on value with class %s: No CoderFactory has " + + "been registered for the class.", clazz.getCanonicalName())); + } + } + + /** + * Returns the {@link Coder} returned according to the {@link CoderProvider} from any + * {@link DefaultCoder} annotation on the given class. + */ + private Coder getDefaultCoderFromAnnotation(Class clazz) + throws CannotProvideCoderException { + DefaultCoder defaultAnnotation = clazz.getAnnotation(DefaultCoder.class); + if (defaultAnnotation == null) { + throw new CannotProvideCoderException( + String.format("Class %s does not have a @DefaultCoder annotation.", + clazz.getCanonicalName())); + } + + LOG.debug("DefaultCoder annotation found for {} with value {}", + clazz, defaultAnnotation.value()); + CoderProvider coderProvider = CoderProviders.fromStaticMethods(defaultAnnotation.value()); + return coderProvider.getCoder(TypeDescriptor.of(clazz)); + } + + /** + * Returns the {@link Coder} to use by default for values of the given type, + * in a context where the given types use the given coders. + * + * @throws CannotProvideCoderException if a coder cannot be provided + */ + private Coder getDefaultCoder( + TypeDescriptor typeDescriptor, + Map> typeCoderBindings) + throws CannotProvideCoderException { + + Coder defaultCoder = getDefaultCoder(typeDescriptor.getType(), typeCoderBindings); + LOG.debug("Default coder for {}: {}", typeDescriptor, defaultCoder); + @SuppressWarnings("unchecked") + Coder result = (Coder) defaultCoder; + return result; + } + + /** + * Returns the {@link Coder} to use by default for values of the given type, + * in a context where the given types use the given coders. + * + * @throws CannotProvideCoderException if a coder cannot be provided + */ + private Coder getDefaultCoder(Type type, Map> typeCoderBindings) + throws CannotProvideCoderException { + Coder coder = typeCoderBindings.get(type); + if (coder != null) { + return coder; + } + if (type instanceof Class) { + Class clazz = (Class) type; + return getDefaultCoder(clazz); + } else if (type instanceof ParameterizedType) { + return getDefaultCoder((ParameterizedType) type, typeCoderBindings); + } else if (type instanceof TypeVariable || type instanceof WildcardType) { + // No default coder for an unknown generic type. + throw new CannotProvideCoderException( + String.format("Cannot provide a coder for type variable %s" + + " (declared by %s) because the actual type is unknown due to erasure.", + type, + ((TypeVariable) type).getGenericDeclaration()), + ReasonCode.TYPE_ERASURE); + } else { + throw new RuntimeException( + "Internal error: unexpected kind of Type: " + type); + } + } + + /** + * Returns the {@link Coder} to use by default for values of the given + * parameterized type, in a context where the given types use the + * given {@link Coder Coders}. + * + * @throws CannotProvideCoderException if no coder can be provided + */ + private Coder getDefaultCoder( + ParameterizedType type, + Map> typeCoderBindings) + throws CannotProvideCoderException { + + CannotProvideCoderException factoryException; + try { + return getDefaultCoderFromFactory(type, typeCoderBindings); + } catch (CannotProvideCoderException exc) { + factoryException = exc; + } + + CannotProvideCoderException annotationException; + try { + Class rawClazz = (Class) type.getRawType(); + return getDefaultCoderFromAnnotation(rawClazz); + } catch (CannotProvideCoderException exc) { + annotationException = exc; + } + + // Build up the error message and list of causes. + StringBuilder messageBuilder = new StringBuilder() + .append("Unable to provide a default Coder for ").append(type) + .append(". Correct one of the following root causes:"); + + messageBuilder + .append("\n Building a Coder using a registered CoderFactory failed: ") + .append(factoryException.getMessage()); + + messageBuilder + .append("\n Building a Coder from the @DefaultCoder annotation failed: ") + .append(annotationException.getMessage()); + + throw new CannotProvideCoderException(messageBuilder.toString()); + } + + private Coder getDefaultCoderFromFactory( + ParameterizedType type, + Map> typeCoderBindings) + throws CannotProvideCoderException { + Class rawClazz = (Class) type.getRawType(); + CoderFactory coderFactory = getDefaultCoderFactory(rawClazz); + List> typeArgumentCoders = new ArrayList<>(); + for (Type typeArgument : type.getActualTypeArguments()) { + try { + Coder typeArgumentCoder = getDefaultCoder(typeArgument, + typeCoderBindings); + typeArgumentCoders.add(typeArgumentCoder); + } catch (CannotProvideCoderException exc) { + throw new CannotProvideCoderException( + String.format("Cannot provide coder for parameterized type %s: %s", + type, + exc.getMessage()), + exc); + } + } + return coderFactory.create(typeArgumentCoders); + } + + /** + * Returns an immutable {@code Map} from each of the type variables + * embedded in the given type to the corresponding types + * in the given {@link Coder}. + */ + private Map> getTypeToCoderBindings(Type type, Coder coder) { + if (type instanceof TypeVariable || type instanceof Class) { + return ImmutableMap.>of(type, coder); + } else if (type instanceof ParameterizedType) { + return getTypeToCoderBindings((ParameterizedType) type, coder); + } else { + return ImmutableMap.of(); + } + } + + /** + * Returns an immutable {@code Map} from the type arguments of the parameterized type to their + * corresponding {@link Coder Coders}, and so on recursively for their type parameters. + * + *

    This method is simply a specialization to break out the most + * elaborate case of {@link #getTypeToCoderBindings(Type, Coder)}. + */ + private Map> getTypeToCoderBindings(ParameterizedType type, Coder coder) { + List typeArguments = Arrays.asList(type.getActualTypeArguments()); + List> coderArguments = coder.getCoderArguments(); + + if ((coderArguments == null) || (typeArguments.size() != coderArguments.size())) { + return ImmutableMap.of(); + } else { + Map> typeToCoder = Maps.newHashMap(); + + typeToCoder.put(type, coder); + + for (int i = 0; i < typeArguments.size(); i++) { + Type typeArgument = typeArguments.get(i); + Coder coderArgument = coderArguments.get(i); + typeToCoder.putAll(getTypeToCoderBindings(typeArgument, coderArgument)); + } + + return ImmutableMap.>builder().putAll(typeToCoder).build(); + } + + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CollectionCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CollectionCoder.java new file mode 100644 index 000000000000..a028317e673d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CollectionCoder.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collection; +import java.util.List; + +/** + * A {@link CollectionCoder} encodes {@link Collection Collections} in the format + * of {@link IterableLikeCoder}. + */ +public class CollectionCoder extends IterableLikeCoder> { + + public static CollectionCoder of(Coder elemCoder) { + return new CollectionCoder<>(elemCoder); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + /** + * {@inheritDoc} + * + * @return the decoded elements directly, since {@link List} is a subtype of + * {@link Collection}. + */ + @Override + protected final Collection decodeToIterable(List decodedElements) { + return decodedElements; + } + + @JsonCreator + public static CollectionCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + /** + * Returns the first element in this collection if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Collection exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + protected CollectionCoder(Coder elemCoder) { + super(elemCoder, "Collection"); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CustomCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CustomCoder.java new file mode 100644 index 000000000000..b34ef8cf6dec --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/CustomCoder.java @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.addStringList; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.common.collect.Lists; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.Collection; + +/** + * An abstract base class for writing a {@link Coder} class that encodes itself via Java + * serialization. + * + *

    To complete an implementation, subclasses must implement {@link Coder#encode} + * and {@link Coder#decode} methods. Anonymous subclasses must furthermore override + * {@link #getEncodingId}. + * + *

    Not to be confused with {@link SerializableCoder} that encodes objects that implement the + * {@link Serializable} interface. + * + * @param the type of elements handled by this coder + */ +public abstract class CustomCoder extends AtomicCoder + implements Serializable { + @JsonCreator + public static CustomCoder of( + // N.B. typeId is a required parameter here, since a field named "@type" + // is presented to the deserializer as an input. + // + // If this method did not consume the field, Jackson2 would observe an + // unconsumed field and a returned value of a derived type. So Jackson2 + // would attempt to update the returned value with the unconsumed field + // data, The standard JsonDeserializer does not implement a mechanism for + // updating constructed values, so it would throw an exception, causing + // deserialization to fail. + @JsonProperty(value = "@type", required = false) String typeId, + @JsonProperty(value = "encoding_id", required = false) String encodingId, + @JsonProperty("type") String type, + @JsonProperty("serialized_coder") String serializedCoder) { + return (CustomCoder) SerializableUtils.deserializeFromByteArray( + StringUtils.jsonStringToByteArray(serializedCoder), + type); + } + + /** + * {@inheritDoc} + * + * @return A thin {@link CloudObject} wrapping of the Java serialization of {@code this}. + */ + @Override + public CloudObject asCloudObject() { + // N.B. We use the CustomCoder class, not the derived class, since during + // deserialization we will be using the CustomCoder's static factory method + // to construct an instance of the derived class. + CloudObject result = CloudObject.forClass(CustomCoder.class); + addString(result, "type", getClass().getName()); + addString(result, "serialized_coder", + StringUtils.byteArrayToJsonString( + SerializableUtils.serializeToByteArray(this))); + + String encodingId = getEncodingId(); + checkNotNull(encodingId, "Coder.getEncodingId() must not return null."); + if (!encodingId.isEmpty()) { + addString(result, PropertyNames.ENCODING_ID, encodingId); + } + + Collection allowedEncodings = getAllowedEncodings(); + if (!allowedEncodings.isEmpty()) { + addStringList(result, PropertyNames.ALLOWED_ENCODINGS, Lists.newArrayList(allowedEncodings)); + } + + return result; + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException a {@link CustomCoder} is presumed + * nondeterministic. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "CustomCoder implementations must override verifyDeterministic," + + " or they are presumed nondeterministic."); + } + + /** + * {@inheritDoc} + * + * @return The canonical class name for this coder. For stable data formats that are independent + * of class name, it is recommended to override this method. + * + * @throws UnsupportedOperationException when an anonymous class is used, since they do not have + * a stable canonical class name. + */ + @Override + public String getEncodingId() { + if (getClass().isAnonymousClass()) { + throw new UnsupportedOperationException( + String.format("Anonymous CustomCoder subclass %s must override getEncodingId()." + + " Otherwise, convert to a named class and getEncodingId() will be automatically" + + " generated from the fully qualified class name.", + getClass())); + } + return getClass().getCanonicalName(); + } + + // This coder inherits isRegisterByteSizeObserverCheap, + // getEncodedElementByteSize and registerByteSizeObserver + // from StandardCoder. Override if we can do better. +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DefaultCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DefaultCoder.java new file mode 100644 index 000000000000..110579b64c38 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DefaultCoder.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * The {@link DefaultCoder} annotation + * specifies a default {@link Coder} class to handle encoding and decoding + * instances of the annotated class. + * + *

    The specified {@link Coder} must satisfy the requirements of + * {@link CoderProviders#fromStaticMethods}. Two classes provided by the SDK that + * are intended for use with this annotation include {@link SerializableCoder} + * and {@link AvroCoder}. + * + *

    To configure the use of Java serialization as the default + * for a class, annotate the class to use + * {@link SerializableCoder} as follows: + * + *

    {@literal @}DefaultCoder(SerializableCoder.class)
    + * public class MyCustomDataType implements Serializable {
    + *   // ...
    + * }
    + * + *

    Similarly, to configure the use of + * {@link AvroCoder} as the default: + *

    {@literal @}DefaultCoder(AvroCoder.class)
    + * public class MyCustomDataType {
    + *   public MyCustomDataType() {}  // Avro requires an empty constructor.
    + *   // ...
    + * }
    + * + *

    Coders specified explicitly via + * {@link PCollection#setCoder} + * take precedence, followed by Coders registered at runtime via + * {@link CoderRegistry#registerCoder}. See {@link CoderRegistry} for a more detailed discussion + * of the precedence rules. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +@SuppressWarnings("rawtypes") +public @interface DefaultCoder { + Class value(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DelegateCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DelegateCoder.java new file mode 100644 index 000000000000..cdd882b07a19 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DelegateCoder.java @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.common.collect.Lists; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Collection; +import java.util.List; + +/** + * A {@code DelegateCoder} wraps a {@link Coder} for {@code IntermediateT} and + * encodes/decodes values of type {@code T} by converting + * to/from {@code IntermediateT} and then encoding/decoding using the underlying + * {@code Coder}. + * + *

    The conversions from {@code T} to {@code IntermediateT} and vice versa + * must be supplied as {@link CodingFunction}, a serializable + * function that may throw any {@code Exception}. If a thrown + * exception is an instance of {@link CoderException} or + * {@link IOException}, it will be re-thrown, otherwise it will be wrapped as + * a {@link CoderException}. + * + * @param The type of objects coded by this Coder. + * @param The type of objects a {@code T} will be converted to for coding. + */ +public class DelegateCoder extends CustomCoder { + /** + * A {@link DelegateCoder.CodingFunction CodingFunction<InputT, OutputT>} is a serializable + * function from {@code InputT} to {@code OutputT} that may throw any {@link Exception}. + */ + public static interface CodingFunction extends Serializable { + public abstract OutputT apply(InputT input) throws Exception; + } + + public static DelegateCoder of(Coder coder, + CodingFunction toFn, + CodingFunction fromFn) { + return new DelegateCoder(coder, toFn, fromFn); + } + + @Override + public void encode(T value, OutputStream outStream, Context context) + throws CoderException, IOException { + coder.encode(applyAndWrapExceptions(toFn, value), outStream, context); + } + + @Override + public T decode(InputStream inStream, Context context) throws CoderException, IOException { + return applyAndWrapExceptions(fromFn, coder.decode(inStream, context)); + } + + /** + * Returns the coder used to encode/decode the intermediate values produced/consumed by the + * coding functions of this {@code DelegateCoder}. + */ + public Coder getCoder() { + return coder; + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException when the underlying coder's {@code verifyDeterministic()} + * throws a {@link Coder.NonDeterministicException}. For this to be safe, the + * intermediate {@code CodingFunction} must also be deterministic. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + coder.verifyDeterministic(); + } + + /** + * {@inheritDoc} + * + * @return a structural for a value of type {@code T} obtained by first converting to + * {@code IntermediateT} and then obtaining a structural value according to the underlying + * coder. + */ + @Override + public Object structuralValue(T value) throws Exception { + return coder.structuralValue(toFn.apply(value)); + } + + @Override + public String toString() { + return "DelegateCoder(" + coder + ")"; + } + + /** + * {@inheritDoc} + * + * @return a {@link String} composed from the underlying coder class name and its encoding id. + * Note that this omits any description of the coding functions. These should be modified + * with care. + */ + @Override + public String getEncodingId() { + return delegateEncodingId(coder.getClass(), coder.getEncodingId()); + } + + /** + * {@inheritDoc} + * + * @return allowed encodings which are composed from the underlying coder class and its allowed + * encoding ids. Note that this omits any description of the coding functions. These + * should be modified with care. + */ + @Override + public Collection getAllowedEncodings() { + List allowedEncodings = Lists.newArrayList(); + for (String allowedEncoding : coder.getAllowedEncodings()) { + allowedEncodings.add(delegateEncodingId(coder.getClass(), allowedEncoding)); + } + return allowedEncodings; + } + + private String delegateEncodingId(Class delegateClass, String encodingId) { + return String.format("%s:%s", delegateClass.getName(), encodingId); + } + + ///////////////////////////////////////////////////////////////////////////// + + private OutputT applyAndWrapExceptions( + CodingFunction fn, + InputT input) throws CoderException, IOException { + try { + return fn.apply(input); + } catch (IOException exc) { + throw exc; + } catch (Exception exc) { + throw new CoderException(exc); + } + } + + private final Coder coder; + private final CodingFunction toFn; + private final CodingFunction fromFn; + + protected DelegateCoder(Coder coder, + CodingFunction toFn, + CodingFunction fromFn) { + this.coder = coder; + this.fromFn = fromFn; + this.toFn = toFn; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DeterministicStandardCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DeterministicStandardCoder.java new file mode 100644 index 000000000000..0e0018afd4d4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DeterministicStandardCoder.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +/** + * A {@link DeterministicStandardCoder} is a {@link StandardCoder} that is + * deterministic, in the sense that for objects considered equal + * according to {@link Object#equals(Object)}, the encoded bytes are + * also equal. + * + * @param the type of the values being transcoded + */ +public abstract class DeterministicStandardCoder extends StandardCoder { + protected DeterministicStandardCoder() {} + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException never, unless overridden. A + * {@link DeterministicStandardCoder} is presumed deterministic. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DoubleCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DoubleCoder.java new file mode 100644 index 000000000000..68d58df10259 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DoubleCoder.java @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A {@link DoubleCoder} encodes {@link Double} values in 8 bytes using Java serialization. + */ +public class DoubleCoder extends AtomicCoder { + + @JsonCreator + public static DoubleCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final DoubleCoder INSTANCE = new DoubleCoder(); + + private DoubleCoder() {} + + @Override + public void encode(Double value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Double"); + } + new DataOutputStream(outStream).writeDouble(value); + } + + @Override + public Double decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new DataInputStream(inStream).readDouble(); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. + * Floating-point operations are not guaranteed to be deterministic, even + * if the storage format might be, so floating point representations are not + * recommended for use in operations that require deterministic inputs. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "Floating point encodings are not guaranteed to be deterministic."); + } + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link DoubleCoder#getEncodedElementByteSize} returns a constant. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Double value, Context context) { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code 8}, the byte size of a {@link Double} encoded using Java serialization. + */ + @Override + protected long getEncodedElementByteSize(Double value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Double"); + } + return 8; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DurationCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DurationCoder.java new file mode 100644 index 000000000000..25527f05df21 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/DurationCoder.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.joda.time.Duration; +import org.joda.time.ReadableDuration; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} that encodes a joda {@link Duration} as a {@link Long} using the format of + * {@link VarLongCoder}. + */ +public class DurationCoder extends AtomicCoder { + + @JsonCreator + public static DurationCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final DurationCoder INSTANCE = new DurationCoder(); + + private final VarLongCoder longCoder = VarLongCoder.of(); + + private DurationCoder() {} + + private Long toLong(ReadableDuration value) { + return value.getMillis(); + } + + private ReadableDuration fromLong(Long decoded) { + return Duration.millis(decoded); + } + + @Override + public void encode(ReadableDuration value, OutputStream outStream, Context context) + throws CoderException, IOException { + if (value == null) { + throw new CoderException("cannot encode a null ReadableDuration"); + } + longCoder.encode(toLong(value), outStream, context); + } + + @Override + public ReadableDuration decode(InputStream inStream, Context context) + throws CoderException, IOException { + return fromLong(longCoder.decode(inStream, context)); + } + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}, because it is cheap to ascertain the byte size of a long. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(ReadableDuration value, Context context) { + return longCoder.isRegisterByteSizeObserverCheap(toLong(value), context); + } + + @Override + public void registerByteSizeObserver( + ReadableDuration value, ElementByteSizeObserver observer, Context context) throws Exception { + longCoder.registerByteSizeObserver(toLong(value), observer, context); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/EntityCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/EntityCoder.java new file mode 100644 index 000000000000..3ae857f065ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/EntityCoder.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.datastore.DatastoreV1.Entity; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} for {@link Entity} objects based on their encoded Protocol Buffer form. + */ +public class EntityCoder extends AtomicCoder { + + @JsonCreator + public static EntityCoder of() { + return INSTANCE; + } + + /***************************/ + + private static final EntityCoder INSTANCE = new EntityCoder(); + + private EntityCoder() {} + + @Override + public void encode(Entity value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Entity"); + } + + // Since Entity implements com.google.protobuf.MessageLite, + // we could directly use writeTo to write to a OutputStream object + outStream.write(java.nio.ByteBuffer.allocate(4).putInt(value.getSerializedSize()).array()); + value.writeTo(outStream); + outStream.flush(); + } + + @Override + public Entity decode(InputStream inStream, Context context) + throws IOException { + byte[] entitySize = new byte[4]; + inStream.read(entitySize, 0, 4); + int size = java.nio.ByteBuffer.wrap(entitySize).getInt(); + byte[] data = new byte[size]; + inStream.read(data, 0, size); + return Entity.parseFrom(data); + } + + @Override + protected long getEncodedElementByteSize(Entity value, Context context) + throws Exception { + return value.getSerializedSize(); + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. + * A datastore kind can hold arbitrary {@link Object} instances, which + * makes the encoding non-deterministic. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "Datastore encodings can hold arbitrary Object instances"); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/InstantCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/InstantCoder.java new file mode 100644 index 000000000000..99b58ce4197b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/InstantCoder.java @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.base.Converter; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} for joda {@link Instant} that encodes it as a big endian {@link Long} + * shifted such that lexicographic ordering of the bytes corresponds to chronological order. + */ +public class InstantCoder extends AtomicCoder { + + @JsonCreator + public static InstantCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final InstantCoder INSTANCE = new InstantCoder(); + + private final BigEndianLongCoder longCoder = BigEndianLongCoder.of(); + + private InstantCoder() {} + + /** + * Converts {@link Instant} to a {@code Long} representing its millis-since-epoch, + * but shifted so that the byte representation of negative values are lexicographically + * ordered before the byte representation of positive values. + * + *

    This deliberately utilizes the well-defined overflow for {@code Long} values. + * See http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.18.2 + */ + private static final Converter ORDER_PRESERVING_CONVERTER = + new Converter() { + + @Override + protected Long doForward(Instant instant) { + return instant.getMillis() - Long.MIN_VALUE; + } + + @Override + protected Instant doBackward(Long shiftedMillis) { + return new Instant(shiftedMillis + Long.MIN_VALUE); + } + }; + + @Override + public void encode(Instant value, OutputStream outStream, Context context) + throws CoderException, IOException { + if (value == null) { + throw new CoderException("cannot encode a null Instant"); + } + longCoder.encode(ORDER_PRESERVING_CONVERTER.convert(value), outStream, context); + } + + @Override + public Instant decode(InputStream inStream, Context context) + throws CoderException, IOException { + return ORDER_PRESERVING_CONVERTER.reverse().convert(longCoder.decode(inStream, context)); + } + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. The byte size for a big endian long is a constant. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Instant value, Context context) { + return longCoder.isRegisterByteSizeObserverCheap( + ORDER_PRESERVING_CONVERTER.convert(value), context); + } + + @Override + public void registerByteSizeObserver( + Instant value, ElementByteSizeObserver observer, Context context) throws Exception { + longCoder.registerByteSizeObserver( + ORDER_PRESERVING_CONVERTER.convert(value), observer, context); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableCoder.java new file mode 100644 index 000000000000..70dcd84e5222 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableCoder.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * An {@link IterableCoder} encodes any {@link Iterable} in the format + * of {@link IterableLikeCoder}. + * + * @param the type of the elements of the iterables being transcoded + */ +public class IterableCoder extends IterableLikeCoder> { + + public static IterableCoder of(Coder elemCoder) { + return new IterableCoder<>(elemCoder); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + @Override + protected final Iterable decodeToIterable(List decodedElements) { + return decodedElements; + } + + @JsonCreator + public static IterableCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of(components.get(0)); + } + + /** + * Returns the first element in this iterable if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Iterable exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + protected IterableCoder(Coder elemCoder) { + super(elemCoder, "Iterable"); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_STREAM_LIKE, true); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableLikeCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableLikeCoder.java new file mode 100644 index 000000000000..7fb573a5c65c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/IterableLikeCoder.java @@ -0,0 +1,278 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.BufferedElementCountingOutputStream; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObservableIterable; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.base.Preconditions; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Observable; +import java.util.Observer; + +/** + * An abstract base class with functionality for assembling a + * {@link Coder} for a class that implements {@code Iterable}. + * + *

    To complete a subclass, implement the {@link #decodeToIterable} method. This superclass + * will decode the elements in the input stream into a {@link List} and then pass them to that + * method to be converted into the appropriate iterable type. Note that this means the input + * iterables must fit into memory. + * + *

    The format of this coder is as follows: + * + *

      + *
    • If the input {@link Iterable} has a known and finite size, then the size is written to the + * output stream in big endian format, followed by all of the encoded elements.
    • + *
    • If the input {@link Iterable} is not known to have a finite size, then each element + * of the input is preceded by {@code true} encoded as a byte (indicating "more data") + * followed by the encoded element, and terminated by {@code false} encoded as a byte.
    • + *
    + * + * @param the type of the elements of the {@code Iterable}s being transcoded + * @param the type of the Iterables being transcoded + */ +public abstract class IterableLikeCoder> + extends StandardCoder { + public Coder getElemCoder() { + return elementCoder; + } + + /** + * Builds an instance of {@code IterableT}, this coder's associated {@link Iterable}-like + * subtype, from a list of decoded elements. + */ + protected abstract IterableT decodeToIterable(List decodedElements); + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + private final Coder elementCoder; + private final String iterableName; + + /** + * Returns the first element in the iterable-like {@code exampleValue} if it is non-empty, + * otherwise returns {@code null}. + */ + protected static > + List getInstanceComponentsHelper(IterableT exampleValue) { + for (T value : exampleValue) { + return Arrays.asList(value); + } + return null; + } + + protected IterableLikeCoder(Coder elementCoder, String iterableName) { + Preconditions.checkArgument(elementCoder != null, + "element Coder for IterableLikeCoder must not be null"); + Preconditions.checkArgument(iterableName != null, + "iterable name for IterableLikeCoder must not be null"); + this.elementCoder = elementCoder; + this.iterableName = iterableName; + } + + @Override + public void encode( + IterableT iterable, OutputStream outStream, Context context) + throws IOException, CoderException { + if (iterable == null) { + throw new CoderException("cannot encode a null " + iterableName); + } + Context nestedContext = context.nested(); + DataOutputStream dataOutStream = new DataOutputStream(outStream); + if (iterable instanceof Collection) { + // We can know the size of the Iterable. Use an encoding with a + // leading size field, followed by that many elements. + Collection collection = (Collection) iterable; + dataOutStream.writeInt(collection.size()); + for (T elem : collection) { + elementCoder.encode(elem, dataOutStream, nestedContext); + } + } else { + // We don't know the size without traversing it so use a fixed size buffer + // and encode as many elements as possible into it before outputting the size followed + // by the elements. + dataOutStream.writeInt(-1); + BufferedElementCountingOutputStream countingOutputStream = + new BufferedElementCountingOutputStream(dataOutStream); + for (T elem : iterable) { + countingOutputStream.markElementStart(); + elementCoder.encode(elem, countingOutputStream, nestedContext); + } + countingOutputStream.finish(); + } + // Make sure all our output gets pushed to the underlying outStream. + dataOutStream.flush(); + } + + @Override + public IterableT decode(InputStream inStream, Context context) + throws IOException, CoderException { + Context nestedContext = context.nested(); + DataInputStream dataInStream = new DataInputStream(inStream); + int size = dataInStream.readInt(); + if (size >= 0) { + List elements = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + elements.add(elementCoder.decode(dataInStream, nestedContext)); + } + return decodeToIterable(elements); + } else { + List elements = new ArrayList<>(); + long count; + // We don't know the size a priori. Check if we're done with + // each block of elements. + while ((count = VarInt.decodeLong(dataInStream)) > 0) { + while (count > 0) { + elements.add(elementCoder.decode(dataInStream, nestedContext)); + count -= 1; + } + } + return decodeToIterable(elements); + } + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(elementCoder); + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. + * Encoding is not deterministic for the general {@link Iterable} case, as it depends + * upon the type of iterable. This may allow two objects to compare as equal + * while the encoding differs. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "IterableLikeCoder can not guarantee deterministic ordering."); + } + + /** + * {@inheritDoc} + * + * @return {@code true} if the iterable is of a known class that supports lazy counting + * of byte size, since that requires minimal extra computation. + */ + @Override + public boolean isRegisterByteSizeObserverCheap( + IterableT iterable, Context context) { + return iterable instanceof ElementByteSizeObservableIterable; + } + + @Override + public void registerByteSizeObserver( + IterableT iterable, ElementByteSizeObserver observer, Context context) + throws Exception { + if (iterable == null) { + throw new CoderException("cannot encode a null Iterable"); + } + Context nestedContext = context.nested(); + + if (iterable instanceof ElementByteSizeObservableIterable) { + observer.setLazy(); + ElementByteSizeObservableIterable observableIterable = + (ElementByteSizeObservableIterable) iterable; + observableIterable.addObserver( + new IteratorObserver(observer, iterable instanceof Collection)); + } else { + if (iterable instanceof Collection) { + // We can know the size of the Iterable. Use an encoding with a + // leading size field, followed by that many elements. + Collection collection = (Collection) iterable; + observer.update(4L); + for (T elem : collection) { + elementCoder.registerByteSizeObserver(elem, observer, nestedContext); + } + } else { + // TODO: Update to use an accurate count depending on size and count, currently we + // are under estimating the size by up to 10 bytes per block of data since we are + // not encoding the count prefix which occurs at most once per 64k of data and is upto + // 10 bytes long. Since we include the total count we can upper bound the underestimate + // to be 10 / 65536 ~= 0.0153% of the actual size. + observer.update(4L); + long count = 0; + for (T elem : iterable) { + count += 1; + elementCoder.registerByteSizeObserver(elem, observer, nestedContext); + } + if (count > 0) { + // Update the length based upon the number of counted elements, this helps + // eliminate the case where all the elements are encoded in the first block and + // it is quite short (e.g. Long.MAX_VALUE nulls encoded with VoidCoder). + observer.update(VarInt.getLength(count)); + } + // Update with the terminator byte. + observer.update(1L); + } + } + } + + /** + * An observer that gets notified when an observable iterator + * returns a new value. This observer just notifies an outerObserver + * about this event. Additionally, the outerObserver is notified + * about additional separators that are transparently added by this + * coder. + */ + private class IteratorObserver implements Observer { + private final ElementByteSizeObserver outerObserver; + private final boolean countable; + + public IteratorObserver(ElementByteSizeObserver outerObserver, + boolean countable) { + this.outerObserver = outerObserver; + this.countable = countable; + + if (countable) { + // Additional 4 bytes are due to size. + outerObserver.update(4L); + } else { + // Additional 5 bytes are due to size = -1 (4 bytes) and + // hasNext = false (1 byte). + outerObserver.update(5L); + } + } + + @Override + public void update(Observable obs, Object obj) { + if (!(obj instanceof Long)) { + throw new AssertionError("unexpected parameter object"); + } + + if (countable) { + outerObserver.update(obs, obj); + } else { + // Additional 1 byte is due to hasNext = true flag. + outerObserver.update(obs, 1 + (long) obj); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java new file mode 100644 index 000000000000..2b0190b5f3d7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.Structs; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.FilterInputStream; +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import javax.xml.bind.JAXBContext; +import javax.xml.bind.JAXBException; +import javax.xml.bind.Marshaller; +import javax.xml.bind.Unmarshaller; + +/** + * A coder for JAXB annotated objects. This coder uses JAXB marshalling/unmarshalling mechanisms + * to encode/decode the objects. Users must provide the {@code Class} of the JAXB annotated object. + * + * @param type of JAXB annotated objects that will be serialized. + */ +public class JAXBCoder extends AtomicCoder { + + private final Class jaxbClass; + private transient Marshaller jaxbMarshaller = null; + private transient Unmarshaller jaxbUnmarshaller = null; + + public Class getJAXBClass() { + return jaxbClass; + } + + private JAXBCoder(Class jaxbClass) { + this.jaxbClass = jaxbClass; + } + + /** + * Create a coder for a given type of JAXB annotated objects. + * + * @param jaxbClass the {@code Class} of the JAXB annotated objects. + */ + public static JAXBCoder of(Class jaxbClass) { + return new JAXBCoder<>(jaxbClass); + } + + @Override + public void encode(T value, OutputStream outStream, Context context) + throws CoderException, IOException { + try { + if (jaxbMarshaller == null) { + JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); + jaxbMarshaller = jaxbContext.createMarshaller(); + } + + jaxbMarshaller.marshal(value, new FilterOutputStream(outStream) { + // JAXB closes the underyling stream so we must filter out those calls. + @Override + public void close() throws IOException { + } + }); + } catch (JAXBException e) { + throw new CoderException(e); + } + } + + @Override + public T decode(InputStream inStream, Context context) throws CoderException, IOException { + try { + if (jaxbUnmarshaller == null) { + JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); + jaxbUnmarshaller = jaxbContext.createUnmarshaller(); + } + + @SuppressWarnings("unchecked") + T obj = (T) jaxbUnmarshaller.unmarshal(new FilterInputStream(inStream) { + // JAXB closes the underyling stream so we must filter out those calls. + @Override + public void close() throws IOException { + } + }); + return obj; + } catch (JAXBException e) { + throw new CoderException(e); + } + } + + @Override + public String getEncodingId() { + return getJAXBClass().getName(); + } + + //////////////////////////////////////////////////////////////////////////////////// + // JSON Serialization details below + + private static final String JAXB_CLASS = "jaxb_class"; + + /** + * Constructor for JSON deserialization only. + */ + @JsonCreator + public static JAXBCoder of( + @JsonProperty(JAXB_CLASS) String jaxbClassName) { + try { + @SuppressWarnings("unchecked") + Class jaxbClass = (Class) Class.forName(jaxbClassName); + return of(jaxbClass); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + Structs.addString(result, JAXB_CLASS, jaxbClass.getName()); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoder.java new file mode 100644 index 000000000000..33085cf2af1b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoder.java @@ -0,0 +1,162 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; + +/** + * A {@code KvCoder} encodes {@link KV}s. + * + * @param the type of the keys of the KVs being transcoded + * @param the type of the values of the KVs being transcoded + */ +public class KvCoder extends KvCoderBase> { + public static KvCoder of(Coder keyCoder, + Coder valueCoder) { + return new KvCoder<>(keyCoder, valueCoder); + } + + @JsonCreator + public static KvCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + return of(components.get(0), components.get(1)); + } + + public static List getInstanceComponents( + KV exampleValue) { + return Arrays.asList( + exampleValue.getKey(), + exampleValue.getValue()); + } + + public Coder getKeyCoder() { + return keyCoder; + } + + public Coder getValueCoder() { + return valueCoder; + } + + ///////////////////////////////////////////////////////////////////////////// + + private final Coder keyCoder; + private final Coder valueCoder; + + private KvCoder(Coder keyCoder, Coder valueCoder) { + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + + @Override + public void encode(KV kv, OutputStream outStream, Context context) + throws IOException, CoderException { + if (kv == null) { + throw new CoderException("cannot encode a null KV"); + } + Context nestedContext = context.nested(); + keyCoder.encode(kv.getKey(), outStream, nestedContext); + valueCoder.encode(kv.getValue(), outStream, nestedContext); + } + + @Override + public KV decode(InputStream inStream, Context context) + throws IOException, CoderException { + Context nestedContext = context.nested(); + K key = keyCoder.decode(inStream, nestedContext); + V value = valueCoder.decode(inStream, nestedContext); + return KV.of(key, value); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(keyCoder, valueCoder); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic("Key coder must be deterministic", getKeyCoder()); + verifyDeterministic("Value coder must be deterministic", getValueCoder()); + } + + @Override + public boolean consistentWithEquals() { + return keyCoder.consistentWithEquals() && valueCoder.consistentWithEquals(); + } + + @Override + public Object structuralValue(KV kv) throws Exception { + if (consistentWithEquals()) { + return kv; + } else { + return KV.of(getKeyCoder().structuralValue(kv.getKey()), + getValueCoder().structuralValue(kv.getValue())); + } + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_PAIR_LIKE, true); + return result; + } + + /** + * Returns whether both keyCoder and valueCoder are considered not expensive. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(KV kv, Context context) { + return keyCoder.isRegisterByteSizeObserverCheap(kv.getKey(), + context.nested()) + && valueCoder.isRegisterByteSizeObserverCheap(kv.getValue(), + context.nested()); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + KV kv, ElementByteSizeObserver observer, Context context) + throws Exception { + if (kv == null) { + throw new CoderException("cannot encode a null KV"); + } + keyCoder.registerByteSizeObserver( + kv.getKey(), observer, context.nested()); + valueCoder.registerByteSizeObserver( + kv.getValue(), observer, context.nested()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoderBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoderBase.java new file mode 100644 index 000000000000..4a12ee0d9663 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/KvCoderBase.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * A abstract base class for KvCoder. Works around a Jackson2 bug tickled when building + * {@link KvCoder} directly (as of this writing, Jackson2 walks off the end of + * an array when it tries to deserialize a class with multiple generic type + * parameters). This class should be removed when possible. + * + * @param the type of values being transcoded + */ +@Deprecated +public abstract class KvCoderBase extends StandardCoder { + /** + * A constructor used only for decoding from JSON. + * + * @param typeId present in the JSON encoding, but unused + * @param isPairLike present in the JSON encoding, but unused + */ + @Deprecated + @JsonCreator + public static KvCoderBase of( + // N.B. typeId is a required parameter here, since a field named "@type" + // is presented to the deserializer as an input. + // + // If this method did not consume the field, Jackson2 would observe an + // unconsumed field and a returned value of a derived type. So Jackson2 + // would attempt to update the returned value with the unconsumed field + // data. The standard JsonDeserializer does not implement a mechanism for + // updating constructed values, so it would throw an exception, causing + // deserialization to fail. + @JsonProperty(value = "@type", required = false) String typeId, + @JsonProperty(value = PropertyNames.IS_PAIR_LIKE, required = false) boolean isPairLike, + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) List> components) { + return KvCoder.of(components); + } + + protected KvCoderBase() {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ListCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ListCoder.java new file mode 100644 index 000000000000..bc74404c5053 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/ListCoder.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * A {@link Coder} for {@link List}, using the format of {@link IterableLikeCoder}. + * + * @param the type of the elements of the Lists being transcoded + */ +public class ListCoder extends IterableLikeCoder> { + + public static ListCoder of(Coder elemCoder) { + return new ListCoder<>(elemCoder); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + @Override + protected final List decodeToIterable(List decodedElements) { + return decodedElements; + } + + @JsonCreator + public static ListCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + /** + * Returns the first element in this list if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents(List exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + protected ListCoder(Coder elemCoder) { + super(elemCoder, "List"); + } + + /** + * List sizes are always known, so ListIterable may be deterministic while + * the general IterableLikeCoder is not. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "ListCoder.elemCoder must be deterministic", getElemCoder()); + } + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoder.java new file mode 100644 index 000000000000..b6f31030e41f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoder.java @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * A {@link Coder} for {@link Map Maps} that encodes them according to provided + * coders for keys and values. + * + * @param the type of the keys of the KVs being transcoded + * @param the type of the values of the KVs being transcoded + */ +public class MapCoder extends MapCoderBase> { + /** + * Produces a MapCoder with the given keyCoder and valueCoder. + */ + public static MapCoder of( + Coder keyCoder, + Coder valueCoder) { + return new MapCoder<>(keyCoder, valueCoder); + } + + @JsonCreator + public static MapCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + return of((Coder) components.get(0), (Coder) components.get(1)); + } + + /** + * Returns the key and value for an arbitrary element of this map, + * if it is non-empty, otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Map exampleValue) { + for (Map.Entry entry : exampleValue.entrySet()) { + return Arrays.asList(entry.getKey(), entry.getValue()); + } + return null; + } + + public Coder getKeyCoder() { + return keyCoder; + } + + public Coder getValueCoder() { + return valueCoder; + } + + ///////////////////////////////////////////////////////////////////////////// + + Coder keyCoder; + Coder valueCoder; + + MapCoder(Coder keyCoder, Coder valueCoder) { + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + + @Override + public void encode( + Map map, + OutputStream outStream, + Context context) + throws IOException, CoderException { + if (map == null) { + throw new CoderException("cannot encode a null Map"); + } + DataOutputStream dataOutStream = new DataOutputStream(outStream); + dataOutStream.writeInt(map.size()); + for (Entry entry : map.entrySet()) { + keyCoder.encode(entry.getKey(), outStream, context.nested()); + valueCoder.encode(entry.getValue(), outStream, context.nested()); + } + dataOutStream.flush(); + } + + @Override + public Map decode(InputStream inStream, Context context) + throws IOException, CoderException { + DataInputStream dataInStream = new DataInputStream(inStream); + int size = dataInStream.readInt(); + Map retval = Maps.newHashMapWithExpectedSize(size); + for (int i = 0; i < size; ++i) { + K key = keyCoder.decode(inStream, context.nested()); + V value = valueCoder.decode(inStream, context.nested()); + retval.put(key, value); + } + return retval; + } + + /** + * {@inheritDoc} + * + * @return a {@link List} containing the key coder at index 0 at the and value coder at index 1. + */ + @Override + public List> getCoderArguments() { + return Arrays.asList(keyCoder, valueCoder); + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. Not all maps have a deterministic encoding. + * For example, {@code HashMap} comparison does not depend on element order, so + * two {@code HashMap} instances may be equal but produce different encodings. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "Ordering of entries in a Map may be non-deterministic."); + } + + @Override + public void registerByteSizeObserver( + Map map, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(4L); + for (Entry entry : map.entrySet()) { + keyCoder.registerByteSizeObserver( + entry.getKey(), observer, context.nested()); + valueCoder.registerByteSizeObserver( + entry.getValue(), observer, context.nested()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoderBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoderBase.java new file mode 100644 index 000000000000..d32406c50d3a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/MapCoderBase.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * A abstract base class for MapCoder. Works around a Jackson2 bug tickled when building + * {@link MapCoder} directly (as of this writing, Jackson2 walks off the end of + * an array when it tries to deserialize a class with multiple generic type + * parameters). This should be removed in favor of a better workaround. + * @param the type of values being transcoded + */ +@Deprecated +public abstract class MapCoderBase extends StandardCoder { + @Deprecated + @JsonCreator + public static MapCoderBase of( + // N.B. typeId is a required parameter here, since a field named "@type" + // is presented to the deserializer as an input. + // + // If this method did not consume the field, Jackson2 would observe an + // unconsumed field and a returned value of a derived type. So Jackson2 + // would attempt to update the returned value with the unconsumed field + // data, The standard JsonDeserializer does not implement a mechanism for + // updating constructed values, so it would throw an exception, causing + // deserialization to fail. + @JsonProperty(value = "@type", required = false) String typeId, + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + return MapCoder.of(components); + } + + protected MapCoderBase() {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/NullableCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/NullableCoder.java new file mode 100644 index 000000000000..5598a71b0501 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/NullableCoder.java @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A {@link NullableCoder} encodes nullable values of type {@code T} using a nested + * {@code Coder} that does not tolerate {@code null} values. {@link NullableCoder} uses + * exactly 1 byte per entry to indicate whether the value is {@code null}, then adds the encoding + * of the inner coder for non-null values. + * + * @param the type of the values being transcoded + */ +public class NullableCoder extends StandardCoder { + public static NullableCoder of(Coder valueCoder) { + return new NullableCoder<>(valueCoder); + } + + @JsonCreator + public static NullableCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 components, got " + components.size()); + return of(components.get(0)); + } + + ///////////////////////////////////////////////////////////////////////////// + + private final Coder valueCoder; + private static final int ENCODE_NULL = 0; + private static final int ENCODE_PRESENT = 1; + + private NullableCoder(Coder valueCoder) { + this.valueCoder = valueCoder; + } + + @Override + public void encode(@Nullable T value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + outStream.write(ENCODE_NULL); + } else { + outStream.write(ENCODE_PRESENT); + valueCoder.encode(value, outStream, context.nested()); + } + } + + @Override + @Nullable + public T decode(InputStream inStream, Context context) throws IOException, CoderException { + int b = inStream.read(); + if (b == ENCODE_NULL) { + return null; + } else if (b != ENCODE_PRESENT) { + throw new CoderException(String.format( + "NullableCoder expects either a byte valued %s (null) or %s (present), got %s", + ENCODE_NULL, ENCODE_PRESENT, b)); + } + return valueCoder.decode(inStream, context.nested()); + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(valueCoder); + } + + /** + * {@code NullableCoder} is deterministic if the nested {@code Coder} is. + * + * {@inheritDoc} + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic("Value coder must be deterministic", valueCoder); + } + + /** + * {@code NullableCoder} is consistent with equals if the nested {@code Coder} is. + * + * {@inheritDoc} + */ + @Override + public boolean consistentWithEquals() { + return valueCoder.consistentWithEquals(); + } + + @Override + public Object structuralValue(@Nullable T value) throws Exception { + if (value == null) { + return Optional.absent(); + } + return Optional.of(valueCoder.structuralValue(value)); + } + + /** + * Overridden to short-circuit the default {@code StandardCoder} behavior of encoding and + * counting the bytes. The size is known (1 byte) when {@code value} is {@code null}, otherwise + * the size is 1 byte plus the size of nested {@code Coder}'s encoding of {@code value}. + * + * {@inheritDoc} + */ + @Override + public void registerByteSizeObserver( + @Nullable T value, ElementByteSizeObserver observer, Context context) throws Exception { + observer.update(1); + if (value != null) { + valueCoder.registerByteSizeObserver(value, observer, context.nested()); + } + } + + /** + * Overridden to short-circuit the default {@code StandardCoder} behavior of encoding and + * counting the bytes. The size is known (1 byte) when {@code value} is {@code null}, otherwise + * the size is 1 byte plus the size of nested {@code Coder}'s encoding of {@code value}. + * + * {@inheritDoc} + */ + @Override + protected long getEncodedElementByteSize(@Nullable T value, Context context) throws Exception { + if (value == null) { + return 1; + } + + if (valueCoder instanceof StandardCoder) { + // If valueCoder is a StandardCoder then we can ask it directly for the encoded size of + // the value, adding 1 byte to count the null indicator. + return 1 + ((StandardCoder) valueCoder) + .getEncodedElementByteSize(value, context.nested()); + } + + // If value is not a StandardCoder then fall back to the default StandardCoder behavior + // of encoding and counting the bytes. The encoding will include the null indicator byte. + return super.getEncodedElementByteSize(value, context); + } + + /** + * {@code NullableCoder} is cheap if {@code valueCoder} is cheap. + * + * {@inheritDoc} + */ + @Override + public boolean isRegisterByteSizeObserverCheap(@Nullable T value, Context context) { + return valueCoder.isRegisterByteSizeObserverCheap(value, context.nested()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Proto2Coder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Proto2Coder.java new file mode 100644 index 000000000000..ef91ba96e8c9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/Proto2Coder.java @@ -0,0 +1,361 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.protobuf.ProtoCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.Structs; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; +import com.google.protobuf.Parser; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * A {@link Coder} using Google Protocol Buffers 2 binary format. + * + *

    To learn more about Protocol Buffers, visit: + * https://developers.google.com/protocol-buffers + * + *

    To use, specify the {@link Coder} type on a PCollection containing Protocol Buffers messages. + * + *

    + * {@code
    + * PCollection records =
    + *     input.apply(...)
    + *          .setCoder(Proto2Coder.of(MyProto.Message.class));
    + * }
    + * 
    + * + *

    Custom message extensions are also supported, but the coder must be made + * aware of them explicitly: + * + *

    + * {@code
    + * PCollection records =
    + *     input.apply(...)
    + *          .setCoder(Proto2Coder.of(MyProto.Message.class)
    + *              .addExtensionsFrom(MyProto.class));
    + * }
    + * 
    + * + * @param the type of elements handled by this coder, must extend {@code Message} + * @deprecated Use {@link ProtoCoder}. + */ +@Deprecated +public class Proto2Coder extends AtomicCoder { + + /** The class of Protobuf message to be encoded. */ + private final Class protoMessageClass; + + /** + * All extension host classes included in this Proto2Coder. The extensions from + * these classes will be included in the {@link ExtensionRegistry} used during + * encoding and decoding. + */ + private final List> extensionHostClasses; + + private Proto2Coder(Class protoMessageClass, List> extensionHostClasses) { + this.protoMessageClass = protoMessageClass; + this.extensionHostClasses = extensionHostClasses; + } + + private static final CoderProvider PROVIDER = + new CoderProvider() { + @Override + public Coder getCoder(TypeDescriptor type) throws CannotProvideCoderException { + if (type.isSubtypeOf(new TypeDescriptor() {})) { + @SuppressWarnings("unchecked") + TypeDescriptor messageType = + (TypeDescriptor) type; + @SuppressWarnings("unchecked") + Coder coder = (Coder) Proto2Coder.of(messageType); + return coder; + } else { + throw new CannotProvideCoderException( + String.format( + "Cannot provide Proto2Coder because %s " + + "is not a subclass of protocol buffer Messsage", + type)); + } + } + }; + + public static CoderProvider coderProvider() { + return PROVIDER; + } + + /** + * Returns a {@code Proto2Coder} for the given Protobuf message class. + */ + public static Proto2Coder of(Class protoMessageClass) { + return new Proto2Coder(protoMessageClass, Collections.>emptyList()); + } + + /** + * Returns a {@code Proto2Coder} for the given Protobuf message class. + */ + public static Proto2Coder of(TypeDescriptor protoMessageType) { + @SuppressWarnings("unchecked") + Class protoMessageClass = (Class) protoMessageType.getRawType(); + return of(protoMessageClass); + } + + /** + * Produces a {@code Proto2Coder} like this one, but with the extensions from + * the given classes registered. + * + * @param moreExtensionHosts an iterable of classes that define a static + * method {@code registerAllExtensions(ExtensionRegistry)} + */ + public Proto2Coder withExtensionsFrom(Iterable> moreExtensionHosts) { + for (Class extensionHost : moreExtensionHosts) { + // Attempt to access the required method, to make sure it's present. + try { + Method registerAllExtensions = + extensionHost.getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class); + checkArgument( + Modifier.isStatic(registerAllExtensions.getModifiers()), + "Method registerAllExtensions() must be static for use with Proto2Coder"); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException(e); + } + } + + return new Proto2Coder( + protoMessageClass, + new ImmutableList.Builder>() + .addAll(extensionHostClasses) + .addAll(moreExtensionHosts) + .build()); + } + + /** + * See {@link #withExtensionsFrom(Iterable)}. + */ + public Proto2Coder withExtensionsFrom(Class... extensionHosts) { + return withExtensionsFrom(ImmutableList.copyOf(extensionHosts)); + } + + /** + * Adds custom Protobuf extensions to the coder. Returns {@code this} + * for method chaining. + * + * @param extensionHosts must be a class that defines a static + * method name {@code registerAllExtensions} + * @deprecated use {@link #withExtensionsFrom} + */ + @Deprecated + public Proto2Coder addExtensionsFrom(Class... extensionHosts) { + return addExtensionsFrom(ImmutableList.copyOf(extensionHosts)); + } + + /** + * Adds custom Protobuf extensions to the coder. Returns {@code this} + * for method chaining. + * + * @param extensionHosts must be a class that defines a static + * method name {@code registerAllExtensions} + * @deprecated use {@link #withExtensionsFrom} + */ + @Deprecated + public Proto2Coder addExtensionsFrom(Iterable> extensionHosts) { + for (Class extensionHost : extensionHosts) { + try { + // Attempt to access the declared method, to make sure it's present. + extensionHost.getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException(e); + } + extensionHostClasses.add(extensionHost); + } + // The memoized extension registry needs to be recomputed because we have mutated this object. + synchronized (this) { + memoizedExtensionRegistry = null; + getExtensionRegistry(); + } + return this; + } + + @Override + public void encode(T value, OutputStream outStream, Context context) throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null " + protoMessageClass.getSimpleName()); + } + if (context.isWholeStream) { + value.writeTo(outStream); + } else { + value.writeDelimitedTo(outStream); + } + } + + @Override + public T decode(InputStream inStream, Context context) throws IOException { + if (context.isWholeStream) { + return getParser().parseFrom(inStream, getExtensionRegistry()); + } else { + return getParser().parseDelimitedFrom(inStream, getExtensionRegistry()); + } + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof Proto2Coder)) { + return false; + } + Proto2Coder otherCoder = (Proto2Coder) other; + return protoMessageClass.equals(otherCoder.protoMessageClass) + && Sets.newHashSet(extensionHostClasses) + .equals(Sets.newHashSet(otherCoder.extensionHostClasses)); + } + + @Override + public int hashCode() { + return Objects.hash(protoMessageClass, extensionHostClasses); + } + + /** + * The encoding identifier is designed to support evolution as per the design of Protocol + * Buffers. In order to use this class effectively, carefully follow the advice in the Protocol + * Buffers documentation at + * Updating + * A Message Type. + * + *

    In particular, the encoding identifier is guaranteed to be the same for {@code Proto2Coder} + * instances of the same principal message class, and otherwise distinct. Loaded extensions do not + * affect the id, nor does it encode the full schema. + * + *

    When modifying a message class, here are the broadest guidelines; see the above link + * for greater detail. + * + *

      + *
    • Do not change the numeric tags for any fields. + *
    • Never remove a required field. + *
    • Only add optional or repeated fields, with sensible defaults. + *
    • When changing the type of a field, consult the Protocol Buffers documentation to ensure + * the new and old types are interchangeable. + *
    + * + *

    Code consuming this message class should be prepared to support all versions of + * the class until it is certain that no remaining serialized instances exist. + * + *

    If backwards incompatible changes must be made, the best recourse is to change the name + * of your Protocol Buffers message class. + */ + @Override + public String getEncodingId() { + return protoMessageClass.getName(); + } + + private transient Parser memoizedParser; + + private Parser getParser() { + if (memoizedParser == null) { + try { + @SuppressWarnings("unchecked") + T protoMessageInstance = (T) protoMessageClass.getMethod("getDefaultInstance").invoke(null); + @SuppressWarnings("unchecked") + Parser tParser = (Parser) protoMessageInstance.getParserForType(); + memoizedParser = tParser; + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalArgumentException(e); + } + } + return memoizedParser; + } + + private transient ExtensionRegistry memoizedExtensionRegistry; + + private synchronized ExtensionRegistry getExtensionRegistry() { + if (memoizedExtensionRegistry == null) { + ExtensionRegistry registry = ExtensionRegistry.newInstance(); + for (Class extensionHost : extensionHostClasses) { + try { + extensionHost + .getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class) + .invoke(null, registry); + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalStateException(e); + } + } + memoizedExtensionRegistry = registry.getUnmodifiable(); + } + return memoizedExtensionRegistry; + } + + //////////////////////////////////////////////////////////////////////////////////// + // JSON Serialization details below + + private static final String PROTO_MESSAGE_CLASS = "proto_message_class"; + private static final String PROTO_EXTENSION_HOSTS = "proto_extension_hosts"; + + /** + * Constructor for JSON deserialization only. + */ + @JsonCreator + public static Proto2Coder of( + @JsonProperty(PROTO_MESSAGE_CLASS) String protoMessageClassName, + @Nullable @JsonProperty(PROTO_EXTENSION_HOSTS) List extensionHostClassNames) { + + try { + @SuppressWarnings("unchecked") + Class protoMessageClass = (Class) Class.forName(protoMessageClassName); + List> extensionHostClasses = Lists.newArrayList(); + if (extensionHostClassNames != null) { + for (String extensionHostClassName : extensionHostClassNames) { + extensionHostClasses.add(Class.forName(extensionHostClassName)); + } + } + return of(protoMessageClass).withExtensionsFrom(extensionHostClasses); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + Structs.addString(result, PROTO_MESSAGE_CLASS, protoMessageClass.getName()); + List extensionHostClassNames = Lists.newArrayList(); + for (Class clazz : extensionHostClasses) { + extensionHostClassNames.add(CloudObject.forString(clazz.getName())); + } + Structs.addList(result, PROTO_EXTENSION_HOSTS, extensionHostClassNames); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SerializableCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SerializableCoder.java new file mode 100644 index 000000000000..593c9f0f809b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SerializableCoder.java @@ -0,0 +1,183 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.io.OutputStream; +import java.io.Serializable; + +/** + * A {@link Coder} for Java classes that implement {@link Serializable}. + * + *

    To use, specify the coder type on a PCollection: + *

    + * {@code
    + *   PCollection records =
    + *       foo.apply(...).setCoder(SerializableCoder.of(MyRecord.class));
    + * }
    + * 
    + * + *

    {@link SerializableCoder} does not guarantee a deterministic encoding, as Java + * serialization may produce different binary encodings for two equivalent + * objects. + * + * @param the type of elements handled by this coder + */ +public class SerializableCoder extends AtomicCoder { + + /** + * Returns a {@link SerializableCoder} instance for the provided element type. + * @param the element type + */ + public static SerializableCoder of(TypeDescriptor type) { + @SuppressWarnings("unchecked") + Class clazz = (Class) type.getRawType(); + return of(clazz); + } + + /** + * Returns a {@link SerializableCoder} instance for the provided element class. + * @param the element type + */ + public static SerializableCoder of(Class clazz) { + return new SerializableCoder<>(clazz); + } + + @JsonCreator + @SuppressWarnings("unchecked") + public static SerializableCoder of(@JsonProperty("type") String classType) + throws ClassNotFoundException { + Class clazz = Class.forName(classType); + if (!Serializable.class.isAssignableFrom(clazz)) { + throw new ClassNotFoundException( + "Class " + classType + " does not implement Serializable"); + } + return of((Class) clazz); + } + + /** + * A {@link CoderProvider} that constructs a {@link SerializableCoder} + * for any class that implements serializable. + */ + public static final CoderProvider PROVIDER = new CoderProvider() { + @Override + public Coder getCoder(TypeDescriptor typeDescriptor) + throws CannotProvideCoderException { + Class clazz = typeDescriptor.getRawType(); + if (Serializable.class.isAssignableFrom(clazz)) { + @SuppressWarnings("unchecked") + Class serializableClazz = + (Class) clazz; + @SuppressWarnings("unchecked") + Coder coder = (Coder) SerializableCoder.of(serializableClazz); + return coder; + } else { + throw new CannotProvideCoderException( + "Cannot provide SerializableCoder because " + typeDescriptor + + " does not implement Serializable"); + } + } + }; + + + private final Class type; + + protected SerializableCoder(Class type) { + this.type = type; + } + + public Class getRecordType() { + return type; + } + + @Override + public void encode(T value, OutputStream outStream, Context context) + throws IOException, CoderException { + try { + ObjectOutputStream oos = new ObjectOutputStream(outStream); + oos.writeObject(value); + oos.flush(); + } catch (IOException exn) { + throw new CoderException("unable to serialize record " + value, exn); + } + } + + @Override + public T decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + ObjectInputStream ois = new ObjectInputStream(inStream); + return type.cast(ois.readObject()); + } catch (ClassNotFoundException e) { + throw new CoderException("unable to deserialize record", e); + } + } + + @Override + public String getEncodingId() { + return String.format("%s:%s", + type.getName(), + ObjectStreamClass.lookup(type).getSerialVersionUID()); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + result.put("type", type.getName()); + return result; + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. Java serialization is not + * deterministic with respect to {@link Object#equals} for all types. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "Java Serialization may be non-deterministic."); + } + + @Override + public boolean equals(Object other) { + if (getClass() != other.getClass()) { + return false; + } + return type == ((SerializableCoder) other).type; + } + + @Override + public int hashCode() { + return type.hashCode(); + } + + // This coder inherits isRegisterByteSizeObserverCheap, + // getEncodedElementByteSize and registerByteSizeObserver + // from StandardCoder. Looks like we cannot do much better + // in this case. +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SetCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SetCoder.java new file mode 100644 index 000000000000..36b3606f84ea --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/SetCoder.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A {@link SetCoder} encodes any {@link Set} using the format of {@link IterableLikeCoder}. The + * elements may not be in a deterministic order, depending on the {@code Set} implementation. + * + * @param the type of the elements of the set + */ +public class SetCoder extends IterableLikeCoder> { + + /** + * Produces a {@link SetCoder} with the given {@code elementCoder}. + */ + public static SetCoder of(Coder elementCoder) { + return new SetCoder<>(elementCoder); + } + + /** + * Dynamically typed constructor for JSON deserialization. + */ + @JsonCreator + public static SetCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. Sets are not ordered, but + * they are encoded in the order of an arbitrary iteration. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "Ordering of elements in a set may be non-deterministic."); + } + + /** + * Returns the first element in this set if it is non-empty, + * otherwise returns {@code null}. + */ + public static List getInstanceComponents( + Set exampleValue) { + return getInstanceComponentsHelper(exampleValue); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal operations below here. + + /** + * {@inheritDoc} + * + * @return A new {@link Set} built from the elements in the {@link List} decoded by + * {@link IterableLikeCoder}. + */ + @Override + protected final Set decodeToIterable(List decodedElements) { + return new HashSet<>(decodedElements); + } + + protected SetCoder(Coder elemCoder) { + super(elemCoder, "Set"); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StandardCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StandardCoder.java new file mode 100644 index 000000000000..faa98619ecdf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StandardCoder.java @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.cloud.dataflow.sdk.util.Structs.addList; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.addStringList; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.collect.Lists; +import com.google.common.io.ByteStreams; +import com.google.common.io.CountingOutputStream; + +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * An abstract base class to implement a {@link Coder} that defines equality, hashing, and printing + * via the class name and recursively using {@link #getComponents}. + * + *

    To extend {@link StandardCoder}, override the following methods as appropriate: + * + *

      + *
    • {@link #getComponents}: the default implementation returns {@link #getCoderArguments}.
    • + *
    • {@link #getEncodedElementByteSize} and + * {@link #isRegisterByteSizeObserverCheap}: the + * default implementation encodes values to bytes and counts the bytes, which is considered + * expensive.
    • + *
    • {@link #getEncodingId} and {@link #getAllowedEncodings}: by default, the encoding id + * is the empty string, so only the canonical name of the subclass will be used for + * compatibility checks, and no other encoding ids are allowed.
    • + *
    + */ +public abstract class StandardCoder implements Coder { + protected StandardCoder() {} + + @Override + public String getEncodingId() { + return ""; + } + + @Override + public Collection getAllowedEncodings() { + return Collections.emptyList(); + } + + /** + * Returns the list of {@link Coder Coders} that are components of this {@link Coder}. + */ + public List> getComponents() { + List> coderArguments = getCoderArguments(); + if (coderArguments == null) { + return Collections.emptyList(); + } else { + return coderArguments; + } + } + + /** + * {@inheritDoc} + * + * @return {@code true} if the two {@link StandardCoder} instances have the + * same class and equal components. + */ + @Override + public boolean equals(Object o) { + if (o == null || this.getClass() != o.getClass()) { + return false; + } + StandardCoder that = (StandardCoder) o; + return this.getComponents().equals(that.getComponents()); + } + + @Override + public int hashCode() { + return getClass().hashCode() * 31 + getComponents().hashCode(); + } + + @Override + public String toString() { + String s = getClass().getName(); + s = s.substring(s.lastIndexOf('.') + 1); + List> componentCoders = getComponents(); + if (!componentCoders.isEmpty()) { + s += "("; + boolean first = true; + for (Coder componentCoder : componentCoders) { + if (first) { + first = false; + } else { + s += ", "; + } + s += componentCoder.toString(); + } + s += ")"; + } + return s; + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = CloudObject.forClass(getClass()); + + List> components = getComponents(); + if (!components.isEmpty()) { + List cloudComponents = new ArrayList<>(components.size()); + for (Coder coder : components) { + cloudComponents.add(coder.asCloudObject()); + } + addList(result, PropertyNames.COMPONENT_ENCODINGS, cloudComponents); + } + + String encodingId = getEncodingId(); + checkNotNull(encodingId, "Coder.getEncodingId() must not return null."); + if (!encodingId.isEmpty()) { + addString(result, PropertyNames.ENCODING_ID, encodingId); + } + + Collection allowedEncodings = getAllowedEncodings(); + if (!allowedEncodings.isEmpty()) { + addStringList(result, PropertyNames.ALLOWED_ENCODINGS, Lists.newArrayList(allowedEncodings)); + } + + return result; + } + + /** + * {@inheritDoc} + * + * @return {@code false} unless it is overridden. {@link StandardCoder#registerByteSizeObserver} + * invokes {@link #getEncodedElementByteSize} which requires re-encoding an element + * unless it is overridden. This is considered expensive. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(T value, Context context) { + return false; + } + + /** + * Returns the size in bytes of the encoded value using this coder. + */ + protected long getEncodedElementByteSize(T value, Context context) + throws Exception { + try { + CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream()); + encode(value, os, context); + return os.getCount(); + } catch (Exception exn) { + throw new IllegalArgumentException( + "Unable to encode element '" + value + "' with coder '" + this + "'.", exn); + } + } + + /** + * {@inheritDoc} + * + *

    For {@link StandardCoder} subclasses, this notifies {@code observer} about the byte size + * of the encoded value using this coder as returned by {@link #getEncodedElementByteSize}. + */ + @Override + public void registerByteSizeObserver( + T value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(getEncodedElementByteSize(value, context)); + } + + protected void verifyDeterministic(String message, Iterable> coders) + throws NonDeterministicException { + for (Coder coder : coders) { + try { + coder.verifyDeterministic(); + } catch (NonDeterministicException e) { + throw new NonDeterministicException(this, message, e); + } + } + } + + protected void verifyDeterministic(String message, Coder... coders) + throws NonDeterministicException { + verifyDeterministic(message, Arrays.asList(coders)); + } + + /** + * {@inheritDoc} + * + * @return {@code false} for {@link StandardCoder} unless overridden. + */ + @Override + public boolean consistentWithEquals() { + return false; + } + + @Override + public Object structuralValue(T value) throws Exception { + if (value != null && consistentWithEquals()) { + return value; + } else { + try { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + encode(value, os, Context.OUTER); + return new StructuralByteArray(os.toByteArray()); + } catch (Exception exn) { + throw new IllegalArgumentException( + "Unable to encode element '" + value + "' with coder '" + this + "'.", exn); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringDelegateCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringDelegateCoder.java new file mode 100644 index 000000000000..1fc1247226a8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringDelegateCoder.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.coders.protobuf.ProtoCoder; + +import java.lang.reflect.InvocationTargetException; + +/** + * A {@link Coder} that wraps a {@code Coder} + * and encodes/decodes values via string representations. + * + *

    To decode, the input byte stream is decoded to + * a {@link String}, and this is passed to the single-argument + * constructor for {@code T}. + * + *

    To encode, the input value is converted via {@code toString()}, + * and this string is encoded. + * + *

    In order for this to operate correctly for a class {@code Clazz}, + * it must be the case for any instance {@code x} that + * {@code x.equals(new Clazz(x.toString()))}. + * + *

    This method of encoding is not designed for ease of evolution of {@code Clazz}; + * it should only be used in cases where the class is stable or the encoding is not + * important. If evolution of the class is important, see {@link ProtoCoder}, {@link AvroCoder}, + * or {@link JAXBCoder}. + * + * @param The type of objects coded. + */ +public class StringDelegateCoder extends DelegateCoder { + public static StringDelegateCoder of(Class clazz) { + return new StringDelegateCoder(clazz); + } + + @Override + public String toString() { + return "StringDelegateCoder(" + clazz + ")"; + } + + private final Class clazz; + + protected StringDelegateCoder(final Class clazz) { + super(StringUtf8Coder.of(), + new CodingFunction() { + @Override + public String apply(T input) { + return input.toString(); + } + }, + new CodingFunction() { + @Override + public T apply(String input) throws + NoSuchMethodException, + InstantiationException, + IllegalAccessException, + InvocationTargetException { + return clazz.getConstructor(String.class).newInstance(input); + } + }); + + this.clazz = clazz; + } + + /** + * The encoding id is the fully qualified name of the encoded/decoded class. + */ + @Override + public String getEncodingId() { + return clazz.getName(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringUtf8Coder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringUtf8Coder.java new file mode 100644 index 000000000000..179840c3c43e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StringUtf8Coder.java @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.ExposedByteArrayOutputStream; +import com.google.cloud.dataflow.sdk.util.StreamUtils; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.base.Utf8; +import com.google.common.io.ByteStreams; +import com.google.common.io.CountingOutputStream; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; +import java.nio.charset.StandardCharsets; + +/** + * A {@link Coder} that encodes {@link String Strings} in UTF-8 encoding. + * If in a nested context, prefixes the string with an integer length field, + * encoded via a {@link VarIntCoder}. + */ +public class StringUtf8Coder extends AtomicCoder { + + @JsonCreator + public static StringUtf8Coder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final StringUtf8Coder INSTANCE = new StringUtf8Coder(); + + private static void writeString(String value, DataOutputStream dos) + throws IOException { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + VarInt.encode(bytes.length, dos); + dos.write(bytes); + } + + private static String readString(DataInputStream dis) throws IOException { + int len = VarInt.decodeInt(dis); + if (len < 0) { + throw new CoderException("Invalid encoded string length: " + len); + } + byte[] bytes = new byte[len]; + dis.readFully(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + private StringUtf8Coder() {} + + @Override + public void encode(String value, OutputStream outStream, Context context) + throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null String"); + } + if (context.isWholeStream) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + if (outStream instanceof ExposedByteArrayOutputStream) { + ((ExposedByteArrayOutputStream) outStream).writeAndOwn(bytes); + } else { + outStream.write(bytes); + } + } else { + writeString(value, new DataOutputStream(outStream)); + } + } + + @Override + public String decode(InputStream inStream, Context context) + throws IOException { + if (context.isWholeStream) { + byte[] bytes = StreamUtils.getBytes(inStream); + return new String(bytes, StandardCharsets.UTF_8); + } else { + try { + return readString(new DataInputStream(inStream)); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + } + + /** + * {@inheritDoc} + * + * @return {@code true}. This coder is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return the byte size of the UTF-8 encoding of the a string or, in a nested context, + * the byte size of the encoding plus the encoded length prefix. + */ + @Override + protected long getEncodedElementByteSize(String value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null String"); + } + if (context.isWholeStream) { + return Utf8.encodedLength(value); + } else { + CountingOutputStream countingStream = + new CountingOutputStream(ByteStreams.nullOutputStream()); + DataOutputStream stream = new DataOutputStream(countingStream); + writeString(value, stream); + return countingStream.getCount(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StructuralByteArray.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StructuralByteArray.java new file mode 100644 index 000000000000..ea18eb971a9e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/StructuralByteArray.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.api.client.util.Base64.encodeBase64String; + +import java.util.Arrays; + +/** + * A wrapper around a byte[] that uses structural, value-based + * equality rather than byte[]'s normal object identity. + */ +public class StructuralByteArray { + byte[] value; + + public StructuralByteArray(byte[] value) { + this.value = value; + } + + public byte[] getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (o instanceof StructuralByteArray) { + StructuralByteArray that = (StructuralByteArray) o; + return Arrays.equals(this.value, that.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "base64:" + encodeBase64String(value); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoder.java new file mode 100644 index 000000000000..bed88b080fe0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoder.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.bigquery.model.TableRow; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} that encodes BigQuery {@link TableRow} objects in their native JSON format. + */ +public class TableRowJsonCoder extends AtomicCoder { + + @JsonCreator + public static TableRowJsonCoder of() { + return INSTANCE; + } + + @Override + public void encode(TableRow value, OutputStream outStream, Context context) + throws IOException { + String strValue = MAPPER.writeValueAsString(value); + StringUtf8Coder.of().encode(strValue, outStream, context); + } + + @Override + public TableRow decode(InputStream inStream, Context context) + throws IOException { + String strValue = StringUtf8Coder.of().decode(inStream, context); + return MAPPER.readValue(strValue, TableRow.class); + } + + @Override + protected long getEncodedElementByteSize(TableRow value, Context context) + throws Exception { + String strValue = MAPPER.writeValueAsString(value); + return StringUtf8Coder.of().getEncodedElementByteSize(strValue, context); + } + + ///////////////////////////////////////////////////////////////////////////// + + // FAIL_ON_EMPTY_BEANS is disabled in order to handle null values in + // TableRow. + private static final ObjectMapper MAPPER = + new ObjectMapper().disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); + + private static final TableRowJsonCoder INSTANCE = new TableRowJsonCoder(); + + private TableRowJsonCoder() { } + + /** + * {@inheritDoc} + * + * @throws NonDeterministicException always. A {@link TableRow} can hold arbitrary + * {@link Object} instances, which makes the encoding non-deterministic. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "TableCell can hold arbitrary instances, which may be non-deterministic."); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoder.java new file mode 100644 index 000000000000..9250c683f85b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoder.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} that encodes {@code Integer Integers} as the ASCII bytes of + * their textual, decimal, representation. + */ +public class TextualIntegerCoder extends AtomicCoder { + + @JsonCreator + public static TextualIntegerCoder of() { + return new TextualIntegerCoder(); + } + + ///////////////////////////////////////////////////////////////////////////// + + protected TextualIntegerCoder() {} + + @Override + public void encode(Integer value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + String textualValue = value.toString(); + StringUtf8Coder.of().encode(textualValue, outStream, context); + } + + @Override + public Integer decode(InputStream inStream, Context context) + throws IOException, CoderException { + String textualValue = StringUtf8Coder.of().decode(inStream, context); + try { + return Integer.valueOf(textualValue); + } catch (NumberFormatException exn) { + throw new CoderException("error when decoding a textual integer", exn); + } + } + + @Override + protected long getEncodedElementByteSize(Integer value, Context context) throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + String textualValue = value.toString(); + return StringUtf8Coder.of().getEncodedElementByteSize(textualValue, context); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarIntCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarIntCoder.java new file mode 100644 index 000000000000..18ec250381b0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarIntCoder.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A {@link Coder} that encodes {@link Integer Integers} using between 1 and 5 bytes. Negative + * numbers always take 5 bytes, so {@link BigEndianIntegerCoder} may be preferable for + * integers that are known to often be large or negative. + */ +public class VarIntCoder extends AtomicCoder { + + @JsonCreator + public static VarIntCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final VarIntCoder INSTANCE = + new VarIntCoder(); + + private VarIntCoder() {} + + @Override + public void encode(Integer value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + VarInt.encode(value.intValue(), outStream); + } + + @Override + public Integer decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return VarInt.decodeInt(inStream); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link VarIntCoder} is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link #getEncodedElementByteSize} is cheap. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Integer value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Integer value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Integer"); + } + return VarInt.getLength(value.longValue()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarLongCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarLongCoder.java new file mode 100644 index 000000000000..520245e49749 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VarLongCoder.java @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.VarInt; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; + +/** + * A {@link Coder} that encodes {@link Long Longs} using between 1 and 10 bytes. Negative + * numbers always take 10 bytes, so {@link BigEndianLongCoder} may be preferable for + * longs that are known to often be large or negative. + */ +public class VarLongCoder extends AtomicCoder { + + @JsonCreator + public static VarLongCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final VarLongCoder INSTANCE = new VarLongCoder(); + + private VarLongCoder() {} + + @Override + public void encode(Long value, OutputStream outStream, Context context) + throws IOException, CoderException { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + VarInt.encode(value.longValue(), outStream); + } + + @Override + public Long decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return VarInt.decodeLong(inStream); + } catch (EOFException | UTFDataFormatException exn) { + // These exceptions correspond to decoding problems, so change + // what kind of exception they're branded as. + throw new CoderException(exn); + } + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link VarLongCoder} is injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link #getEncodedElementByteSize} is cheap. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Long value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Long value, Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null Long"); + } + return VarInt.getLength(value.longValue()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VoidCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VoidCoder.java new file mode 100644 index 000000000000..0de606b789bc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/VoidCoder.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link Coder} for {@link Void}. Uses zero bytes per {@link Void}. + */ +public class VoidCoder extends AtomicCoder { + + @JsonCreator + public static VoidCoder of() { + return INSTANCE; + } + + ///////////////////////////////////////////////////////////////////////////// + + private static final VoidCoder INSTANCE = new VoidCoder(); + + private VoidCoder() {} + + @Override + public void encode(Void value, OutputStream outStream, Context context) { + // Nothing to write! + } + + @Override + public Void decode(InputStream inStream, Context context) { + // Nothing to read! + return null; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link VoidCoder} is (vacuously) injective. + */ + @Override + public boolean consistentWithEquals() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. {@link VoidCoder#getEncodedElementByteSize} runs in constant time. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(Void value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Void value, Context context) + throws Exception { + return 0; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/package-info.java new file mode 100644 index 000000000000..fdf931f4a8eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/package-info.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.coders.Coder Coders} + * to specify how data is encoded to and decoded from byte strings. + * + *

    During execution of a Pipeline, elements in a + * {@link com.google.cloud.dataflow.sdk.values.PCollection} + * may need to be encoded into byte strings. + * This happens both at the beginning and end of a pipeline when data is read from and written to + * persistent storage and also during execution of a pipeline when elements are communicated between + * machines. + * + *

    Exactly when PCollection elements are encoded during execution depends on which + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} is being used and how that runner + * chooses to execute the pipeline. As such, Dataflow requires that all PCollections have an + * appropriate Coder in case it becomes necessary. In many cases, the Coder can be inferred from + * the available Java type + * information and the Pipeline's {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry}. It + * can be specified per PCollection via + * {@link com.google.cloud.dataflow.sdk.values.PCollection#setCoder(Coder)} or per type using the + * {@link com.google.cloud.dataflow.sdk.coders.DefaultCoder} annotation. + * + *

    This package provides a number of coders for common types like {@code Integer}, + * {@code String}, and {@code List}, as well as coders like + * {@link com.google.cloud.dataflow.sdk.coders.AvroCoder} that can be used to encode many custom + * types. + * + */ +package com.google.cloud.dataflow.sdk.coders; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtoCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtoCoder.java new file mode 100644 index 000000000000..d8c8e9e2438e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtoCoder.java @@ -0,0 +1,411 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.coders.protobuf; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.api.services.datastore.DatastoreV1; +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderProvider; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageA; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.Structs; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; +import com.google.protobuf.Parser; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +import javax.annotation.Nullable; + +/** + * A {@link Coder} using Google Protocol Buffers binary format. {@link ProtoCoder} supports both + * Protocol Buffers syntax versions 2 and 3. + * + *

    To learn more about Protocol Buffers, visit: + * https://developers.google.com/protocol-buffers + * + *

    {@link ProtoCoder} is registered in the global {@link CoderRegistry} as the default + * {@link Coder} for any {@link Message} object. Custom message extensions are also supported, but + * these extensions must be registered for a particular {@link ProtoCoder} instance and that + * instance must be registered on the {@link PCollection} that needs the extensions: + * + *

    {@code
    + * import MyProtoFile;
    + * import MyProtoFile.MyMessage;
    + *
    + * Coder coder = ProtoCoder.of(MyMessage.class).withExtensionsFrom(MyProtoFile.class);
    + * PCollection records =  input.apply(...).setCoder(coder);
    + * }
    + * + *

    Versioning

    + * + *

    {@link ProtoCoder} supports both versions 2 and 3 of the Protocol Buffers syntax. However, + * the Java runtime version of the google.com.protobuf library must match exactly the + * version of protoc that was used to produce the JAR files containing the compiled + * .proto messages. + * + *

    For more information, see the + * Protocol Buffers documentation. + * + *

    {@link ProtoCoder} and Determinism

    + * + *

    In general, Protocol Buffers messages can be encoded deterministically within a single + * pipeline as long as: + * + *

      + *
    • The encoded messages (and any transitively linked messages) do not use map + * fields.
    • + *
    • Every Java VM that encodes or decodes the messages use the same runtime version of the + * Protocol Buffers library and the same compiled .proto file JAR.
    • + *
    + * + *

    {@link ProtoCoder} and Encoding Stability

    + * + *

    When changing Protocol Buffers messages, follow the rules in the Protocol Buffers language + * guides for + * {@code proto2} + * and + * {@code proto3} + * syntaxes, depending on your message type. Following these guidelines will ensure that the + * old encoded data can be read by new versions of the code. + * + *

    Generally, any change to the message type, registered extensions, runtime library, or + * compiled proto JARs may change the encoding. Thus even if both the original and updated messages + * can be encoded deterministically within a single job, these deterministic encodings may not be + * the same across jobs. + * + * @param the Protocol Buffers {@link Message} handled by this {@link Coder}. + */ +public class ProtoCoder extends AtomicCoder { + + /** + * A {@link CoderProvider} that returns a {@link ProtoCoder} with an empty + * {@link ExtensionRegistry}. + */ + public static CoderProvider coderProvider() { + return PROVIDER; + } + + /** + * Returns a {@link ProtoCoder} for the given Protocol Buffers {@link Message}. + */ + public static ProtoCoder of(Class protoMessageClass) { + return new ProtoCoder(protoMessageClass, ImmutableSet.>of()); + } + + /** + * Returns a {@link ProtoCoder} for the Protocol Buffers {@link Message} indicated by the given + * {@link TypeDescriptor}. + */ + public static ProtoCoder of(TypeDescriptor protoMessageType) { + @SuppressWarnings("unchecked") + Class protoMessageClass = (Class) protoMessageType.getRawType(); + return of(protoMessageClass); + } + + /** + * Returns a {@link ProtoCoder} like this one, but with the extensions from the given classes + * registered. + * + *

    Each of the extension host classes must be an class automatically generated by the + * Protocol Buffers compiler, {@code protoc}, that contains messages. For example, the class + * {@link Proto2CoderTestMessages} is the extension host for the {@link Message} classes + * {@link MessageA Proto2CoderTestMessages.MessageA} and the class {@link DatastoreV1} is the + * extension host for the Google Cloud Datastore {@link Entity} entity type. + * + *

    Does not modify this object. + */ + public ProtoCoder withExtensionsFrom(Iterable> moreExtensionHosts) { + for (Class extensionHost : moreExtensionHosts) { + // Attempt to access the required method, to make sure it's present. + try { + Method registerAllExtensions = + extensionHost.getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class); + checkArgument( + Modifier.isStatic(registerAllExtensions.getModifiers()), + "Method registerAllExtensions() must be static"); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + String.format("Unable to register extensions for %s", extensionHost.getCanonicalName()), + e); + } + } + + return new ProtoCoder( + protoMessageClass, + new ImmutableSet.Builder>() + .addAll(extensionHostClasses) + .addAll(moreExtensionHosts) + .build()); + } + + /** + * See {@link #withExtensionsFrom(Iterable)}. + * + *

    Does not modify this object. + */ + public ProtoCoder withExtensionsFrom(Class... moreExtensionHosts) { + return withExtensionsFrom(Arrays.asList(moreExtensionHosts)); + } + + @Override + public void encode(T value, OutputStream outStream, Context context) throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null " + protoMessageClass.getSimpleName()); + } + if (context.isWholeStream) { + value.writeTo(outStream); + } else { + value.writeDelimitedTo(outStream); + } + } + + @Override + public T decode(InputStream inStream, Context context) throws IOException { + if (context.isWholeStream) { + return getParser().parseFrom(inStream, getExtensionRegistry()); + } else { + return getParser().parseDelimitedFrom(inStream, getExtensionRegistry()); + } + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ProtoCoder)) { + return false; + } + ProtoCoder otherCoder = (ProtoCoder) other; + return protoMessageClass.equals(otherCoder.protoMessageClass) + && Sets.newHashSet(extensionHostClasses) + .equals(Sets.newHashSet(otherCoder.extensionHostClasses)); + } + + @Override + public int hashCode() { + return Objects.hash(protoMessageClass, extensionHostClasses); + } + + /** + * The encoding identifier is designed to support evolution as per the design of Protocol + * Buffers. In order to use this class effectively, carefully follow the advice in the Protocol + * Buffers documentation at + * Updating + * A Message Type. + * + *

    In particular, the encoding identifier is guaranteed to be the same for {@link ProtoCoder} + * instances of the same principal message class, with the same registered extension host classes, + * and otherwise distinct. Note that the encoding ID does not encode any version of the message + * or extensions, nor does it include the message schema. + * + *

    When modifying a message class, here are the broadest guidelines; see the above link + * for greater detail. + * + *

      + *
    • Do not change the numeric tags for any fields. + *
    • Never remove a required field. + *
    • Only add optional or repeated fields, with sensible defaults. + *
    • When changing the type of a field, consult the Protocol Buffers documentation to ensure + * the new and old types are interchangeable. + *
    + * + *

    Code consuming this message class should be prepared to support all versions of + * the class until it is certain that no remaining serialized instances exist. + * + *

    If backwards incompatible changes must be made, the best recourse is to change the name + * of your Protocol Buffers message class. + */ + @Override + public String getEncodingId() { + return protoMessageClass.getName() + getSortedExtensionClasses().toString(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + ProtobufUtil.verifyDeterministic(this); + } + + /** + * Returns the Protocol Buffers {@link Message} type this {@link ProtoCoder} supports. + */ + public Class getMessageType() { + return protoMessageClass; + } + + /** + * Returns the {@link ExtensionRegistry} listing all known Protocol Buffers extension messages + * to {@code T} registered with this {@link ProtoCoder}. + */ + public ExtensionRegistry getExtensionRegistry() { + if (memoizedExtensionRegistry == null) { + ExtensionRegistry registry = ExtensionRegistry.newInstance(); + for (Class extensionHost : extensionHostClasses) { + try { + extensionHost + .getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class) + .invoke(null, registry); + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalStateException(e); + } + } + memoizedExtensionRegistry = registry.getUnmodifiable(); + } + return memoizedExtensionRegistry; + } + + //////////////////////////////////////////////////////////////////////////////////// + // Private implementation details below. + + /** The {@link Message} type to be coded. */ + private final Class protoMessageClass; + + /** + * All extension host classes included in this {@link ProtoCoder}. The extensions from these + * classes will be included in the {@link ExtensionRegistry} used during encoding and decoding. + */ + private final Set> extensionHostClasses; + + // Constants used to serialize and deserialize + private static final String PROTO_MESSAGE_CLASS = "proto_message_class"; + private static final String PROTO_EXTENSION_HOSTS = "proto_extension_hosts"; + + // Transient fields that are lazy initialized and then memoized. + private transient ExtensionRegistry memoizedExtensionRegistry; + private transient Parser memoizedParser; + + /** Private constructor. */ + private ProtoCoder(Class protoMessageClass, Set> extensionHostClasses) { + this.protoMessageClass = protoMessageClass; + this.extensionHostClasses = extensionHostClasses; + } + + /** + * @deprecated For JSON deserialization only. + */ + @JsonCreator + @Deprecated + public static ProtoCoder of( + @JsonProperty(PROTO_MESSAGE_CLASS) String protoMessageClassName, + @Nullable @JsonProperty(PROTO_EXTENSION_HOSTS) List extensionHostClassNames) { + + try { + @SuppressWarnings("unchecked") + Class protoMessageClass = (Class) Class.forName(protoMessageClassName); + List> extensionHostClasses = Lists.newArrayList(); + if (extensionHostClassNames != null) { + for (String extensionHostClassName : extensionHostClassNames) { + extensionHostClasses.add(Class.forName(extensionHostClassName)); + } + } + return of(protoMessageClass).withExtensionsFrom(extensionHostClasses); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + Structs.addString(result, PROTO_MESSAGE_CLASS, protoMessageClass.getName()); + List extensionHostClassNames = Lists.newArrayList(); + for (String className : getSortedExtensionClasses()) { + extensionHostClassNames.add(CloudObject.forString(className)); + } + Structs.addList(result, PROTO_EXTENSION_HOSTS, extensionHostClassNames); + return result; + } + + /** Get the memoized {@link Parser}, possibly initializing it lazily. */ + private Parser getParser() { + if (memoizedParser == null) { + try { + @SuppressWarnings("unchecked") + T protoMessageInstance = (T) protoMessageClass.getMethod("getDefaultInstance").invoke(null); + @SuppressWarnings("unchecked") + Parser tParser = (Parser) protoMessageInstance.getParserForType(); + memoizedParser = tParser; + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalArgumentException(e); + } + } + return memoizedParser; + } + + /** + * The implementation of the {@link CoderProvider} for this {@link ProtoCoder} returned by + * {@link #coderProvider()}. + */ + private static final CoderProvider PROVIDER = + new CoderProvider() { + @Override + public Coder getCoder(TypeDescriptor type) throws CannotProvideCoderException { + if (!type.isSubtypeOf(new TypeDescriptor() {})) { + throw new CannotProvideCoderException( + String.format( + "Cannot provide %s because %s is not a subclass of %s", + ProtoCoder.class.getSimpleName(), + type, + Message.class.getName())); + } + + @SuppressWarnings("unchecked") + TypeDescriptor messageType = (TypeDescriptor) type; + try { + @SuppressWarnings("unchecked") + Coder coder = (Coder) ProtoCoder.of(messageType); + return coder; + } catch (IllegalArgumentException e) { + throw new CannotProvideCoderException(e); + } + } + }; + + private SortedSet getSortedExtensionClasses() { + SortedSet ret = new TreeSet<>(); + for (Class clazz : extensionHostClasses) { + ret.add(clazz.getName()); + } + return ret; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtobufUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtobufUtil.java new file mode 100644 index 000000000000..597b1de8430e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtobufUtil.java @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders.protobuf; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor.Syntax; +import com.google.protobuf.Descriptors.GenericDescriptor; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.ExtensionRegistry.ExtensionInfo; +import com.google.protobuf.Message; + +import java.lang.reflect.InvocationTargetException; +import java.util.HashSet; +import java.util.Set; + +/** + * Utility functions for reflecting and analyzing Protocol Buffers classes. + * + *

    Used by {@link ProtoCoder}, but in a separate file for testing and isolation. + */ +class ProtobufUtil { + /** + * Returns the {@link Descriptor} for the given Protocol Buffers {@link Message}. + * + * @throws IllegalArgumentException if there is an error in Java reflection. + */ + static Descriptor getDescriptorForClass(Class clazz) { + try { + return (Descriptor) clazz.getMethod("getDescriptor").invoke(null); + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * Returns the {@link Descriptor} for the given Protocol Buffers {@link Message} as well as + * every class it can include transitively. + * + * @throws IllegalArgumentException if there is an error in Java reflection. + */ + static Set getRecursiveDescriptorsForClass( + Class clazz, ExtensionRegistry registry) { + Descriptor root = getDescriptorForClass(clazz); + Set descriptors = new HashSet<>(); + recursivelyAddDescriptors(root, descriptors, registry); + return descriptors; + } + + /** + * Recursively walks the given {@link Message} class and verifies that every field or message + * linked in uses the Protocol Buffers proto2 syntax. + */ + static void checkProto2Syntax(Class clazz, ExtensionRegistry registry) { + for (GenericDescriptor d : getRecursiveDescriptorsForClass(clazz, registry)) { + Syntax s = d.getFile().getSyntax(); + checkArgument( + s == Syntax.PROTO2, + "Message %s or one of its dependencies does not use proto2 syntax: %s in file %s", + clazz.getName(), + d.getFullName(), + d.getFile().getName()); + } + } + + /** + * Recursively checks whether the specified class uses any Protocol Buffers fields that cannot + * be deterministically encoded. + * + * @throws NonDeterministicException if the object cannot be encoded deterministically. + */ + static void verifyDeterministic(ProtoCoder coder) throws NonDeterministicException { + Class message = coder.getMessageType(); + ExtensionRegistry registry = coder.getExtensionRegistry(); + Set descriptors = getRecursiveDescriptorsForClass(message, registry); + for (Descriptor d : descriptors) { + for (FieldDescriptor fd : d.getFields()) { + // If there is a transitively reachable Protocol Buffers map field, then this object cannot + // be encoded deterministically. + if (fd.isMapField()) { + String reason = + String.format( + "Protocol Buffers message %s transitively includes Map field %s (from file %s)." + + " Maps cannot be deterministically encoded.", + message.getName(), + fd.getFullName(), + fd.getFile().getFullName()); + throw new NonDeterministicException(coder, reason); + } + } + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + // Disable construction of utility class + private ProtobufUtil() {} + + private static void recursivelyAddDescriptors( + Descriptor message, Set descriptors, ExtensionRegistry registry) { + if (descriptors.contains(message)) { + return; + } + descriptors.add(message); + + for (FieldDescriptor f : message.getFields()) { + recursivelyAddDescriptors(f, descriptors, registry); + } + for (FieldDescriptor f : message.getExtensions()) { + recursivelyAddDescriptors(f, descriptors, registry); + } + for (ExtensionInfo info : + registry.getAllImmutableExtensionsByExtendedType(message.getFullName())) { + recursivelyAddDescriptors(info.descriptor, descriptors, registry); + } + for (ExtensionInfo info : + registry.getAllMutableExtensionsByExtendedType(message.getFullName())) { + recursivelyAddDescriptors(info.descriptor, descriptors, registry); + } + } + + private static void recursivelyAddDescriptors( + FieldDescriptor field, Set descriptors, ExtensionRegistry registry) { + switch (field.getType()) { + case BOOL: + case BYTES: + case DOUBLE: + case ENUM: + case FIXED32: + case FIXED64: + case FLOAT: + case INT32: + case INT64: + case SFIXED32: + case SFIXED64: + case SINT32: + case SINT64: + case STRING: + case UINT32: + case UINT64: + // Primitive types do not transitively access anything else. + break; + + case GROUP: + case MESSAGE: + // Recursively adds all the fields from this nested Message. + recursivelyAddDescriptors(field.getMessageType(), descriptors, registry); + break; + + default: + throw new UnsupportedOperationException( + "Unexpected Protocol Buffers field type: " + field.getType()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroIO.java new file mode 100644 index 000000000000..f016b5b47bad --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroIO.java @@ -0,0 +1,810 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.Read.Bounded; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.reflect.ReflectData; + +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * {@link PTransform}s for reading and writing Avro files. + * + *

    To read a {@link PCollection} from one or more Avro files, use + * {@link AvroIO.Read}, specifying {@link AvroIO.Read#from} to specify + * the path of the file(s) to read from (e.g., a local filename or + * filename pattern if running locally, or a Google Cloud Storage + * filename or filename pattern of the form + * {@code "gs:///"}), and optionally + * {@link AvroIO.Read#named} to specify the name of the pipeline step. + * + *

    It is required to specify {@link AvroIO.Read#withSchema}. To + * read specific records, such as Avro-generated classes, provide an + * Avro-generated class type. To read {@link GenericRecord GenericRecords}, provide either + * a {@link Schema} object or an Avro schema in a JSON-encoded string form. + * An exception will be thrown if a record doesn't match the specified + * schema. + * + *

    For example: + *

     {@code
    + * Pipeline p = ...;
    + *
    + * // A simple Read of a local file (only runs locally):
    + * PCollection records =
    + *     p.apply(AvroIO.Read.from("/path/to/file.avro")
    + *                        .withSchema(AvroAutoGenClass.class));
    + *
    + * // A Read from a GCS file (runs locally and via the Google Cloud
    + * // Dataflow service):
    + * Schema schema = new Schema.Parser().parse(new File("schema.avsc"));
    + * PCollection records =
    + *     p.apply(AvroIO.Read.named("ReadFromAvro")
    + *                        .from("gs://my_bucket/path/to/records-*.avro")
    + *                        .withSchema(schema));
    + * } 
    + * + *

    To write a {@link PCollection} to one or more Avro files, use + * {@link AvroIO.Write}, specifying {@link AvroIO.Write#to} to specify + * the path of the file to write to (e.g., a local filename or sharded + * filename pattern if running locally, or a Google Cloud Storage + * filename or sharded filename pattern of the form + * {@code "gs:///"}), and optionally + * {@link AvroIO.Write#named} to specify the name of the pipeline step. + * + *

    It is required to specify {@link AvroIO.Write#withSchema}. To + * write specific records, such as Avro-generated classes, provide an + * Avro-generated class type. To write {@link GenericRecord GenericRecords}, provide either + * a {@link Schema} object or a schema in a JSON-encoded string form. + * An exception will be thrown if a record doesn't match the specified + * schema. + * + *

    For example: + *

     {@code
    + * // A simple Write to a local file (only runs locally):
    + * PCollection records = ...;
    + * records.apply(AvroIO.Write.to("/path/to/file.avro")
    + *                           .withSchema(AvroAutoGenClass.class));
    + *
    + * // A Write to a sharded GCS file (runs locally and via the Google Cloud
    + * // Dataflow service):
    + * Schema schema = new Schema.Parser().parse(new File("schema.avsc"));
    + * PCollection records = ...;
    + * records.apply(AvroIO.Write.named("WriteToAvro")
    + *                           .to("gs://my_bucket/path/to/numbers")
    + *                           .withSchema(schema)
    + *                           .withSuffix(".avro"));
    + * } 
    + * + *

    Permissions

    + * Permission requirements depend on the {@link PipelineRunner} that is used to execute the + * Dataflow job. Please refer to the documentation of corresponding {@link PipelineRunner}s for + * more details. + */ +public class AvroIO { + /** + * A root {@link PTransform} that reads from an Avro file (or multiple Avro + * files matching a pattern) and returns a {@link PCollection} containing + * the decoding of each record. + */ + public static class Read { + /** + * Returns a {@link PTransform} with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(GenericRecord.class).named(name); + } + + /** + * Returns a {@link PTransform} that reads from the file(s) + * with the given name or pattern. This can be a local filename + * or filename pattern (if running locally), or a Google Cloud + * Storage filename or filename pattern of the form + * {@code "gs:///"} (if running locally or via + * the Google Cloud Dataflow service). Standard + * Java + * Filesystem glob patterns ("*", "?", "[..]") are supported. + */ + public static Bound from(String filepattern) { + return new Bound<>(GenericRecord.class).from(filepattern); + } + + /** + * Returns a {@link PTransform} that reads Avro file(s) + * containing records whose type is the specified Avro-generated class. + * + * @param the type of the decoded elements, and the elements + * of the resulting {@link PCollection} + */ + public static Bound withSchema(Class type) { + return new Bound<>(type).withSchema(type); + } + + /** + * Returns a {@link PTransform} that reads Avro file(s) + * containing records of the specified schema. + */ + public static Bound withSchema(Schema schema) { + return new Bound<>(GenericRecord.class).withSchema(schema); + } + + /** + * Returns a {@link PTransform} that reads Avro file(s) + * containing records of the specified schema in a JSON-encoded + * string form. + */ + public static Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + /** + * Returns a {@link PTransform} that reads Avro file(s) + * that has GCS path validation on pipeline creation disabled. + * + *

    This can be useful in the case where the GCS input location does + * not exist at the pipeline creation time, but is expected to be available + * at execution time. + */ + public static Bound withoutValidation() { + return new Bound<>(GenericRecord.class).withoutValidation(); + } + + /** + * A {@link PTransform} that reads from an Avro file (or multiple Avro + * files matching a pattern) and returns a bounded {@link PCollection} containing + * the decoding of each record. + * + * @param the type of each of the elements of the resulting + * PCollection + */ + public static class Bound extends PTransform> { + /** The filepattern to read from. */ + @Nullable + final String filepattern; + /** The class type of the records. */ + final Class type; + /** The schema of the input file. */ + @Nullable + final Schema schema; + /** An option to indicate if input validation is desired. Default is true. */ + final boolean validate; + + Bound(Class type) { + this(null, null, type, null, true); + } + + Bound(String name, String filepattern, Class type, Schema schema, boolean validate) { + super(name); + this.filepattern = filepattern; + this.type = type; + this.schema = schema; + this.validate = validate; + } + + /** + * Returns a new {@link PTransform} that's like this one but + * with the given step name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filepattern, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that reads from the file(s) with the given name or pattern. + * (See {@link AvroIO.Read#from} for a description of + * filepatterns.) + * + *

    Does not modify this object. + */ + public Bound from(String filepattern) { + return new Bound<>(name, filepattern, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that reads Avro file(s) containing records whose type is the + * specified Avro-generated class. + * + *

    Does not modify this object. + * + * @param the type of the decoded elements and the elements of + * the resulting PCollection + */ + public Bound withSchema(Class type) { + return new Bound<>(name, filepattern, type, ReflectData.get().getSchema(type), validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that reads Avro file(s) containing records of the specified schema. + * + *

    Does not modify this object. + */ + public Bound withSchema(Schema schema) { + return new Bound<>(name, filepattern, GenericRecord.class, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that reads Avro file(s) containing records of the specified schema + * in a JSON-encoded string form. + * + *

    Does not modify this object. + */ + public Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that has GCS input path validation on pipeline creation disabled. + * + *

    Does not modify this object. + * + *

    This can be useful in the case where the GCS input location does + * not exist at the pipeline creation time, but is expected to be + * available at execution time. + */ + public Bound withoutValidation() { + return new Bound<>(name, filepattern, type, schema, false); + } + + @Override + public PCollection apply(PInput input) { + if (filepattern == null) { + throw new IllegalStateException( + "need to set the filepattern of an AvroIO.Read transform"); + } + if (schema == null) { + throw new IllegalStateException("need to set the schema of an AvroIO.Read transform"); + } + if (validate) { + try { + checkState( + !IOChannelUtils.getFactory(filepattern).match(filepattern).isEmpty(), + "Unable to find any files matching %s", + filepattern); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to validate %s", filepattern), e); + } + } + + @SuppressWarnings("unchecked") + Bounded read = + type == GenericRecord.class + ? (Bounded) com.google.cloud.dataflow.sdk.io.Read.from( + AvroSource.from(filepattern).withSchema(schema)) + : com.google.cloud.dataflow.sdk.io.Read.from( + AvroSource.from(filepattern).withSchema(type)); + + PCollection pcol = input.getPipeline().apply("Read", read); + // Honor the default output coder that would have been used by this PTransform. + pcol.setCoder(getDefaultOutputCoder()); + return pcol; + } + + @Override + protected Coder getDefaultOutputCoder() { + return AvroCoder.of(type, schema); + } + + public String getFilepattern() { + return filepattern; + } + + public Schema getSchema() { + return schema; + } + + public boolean needsValidation() { + return validate; + } + } + + /** Disallow construction of utility class. */ + private Read() {} + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A root {@link PTransform} that writes a {@link PCollection} to an Avro file (or + * multiple Avro files matching a sharding pattern). + */ + public static class Write { + /** + * Returns a {@link PTransform} with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(GenericRecord.class).named(name); + } + + /** + * Returns a {@link PTransform} that writes to the file(s) + * with the given prefix. This can be a local filename + * (if running locally), or a Google Cloud Storage filename of + * the form {@code "gs:///"} + * (if running locally or via the Google Cloud Dataflow service). + * + *

    The files written will begin with this prefix, followed by + * a shard identifier (see {@link Bound#withNumShards}, and end + * in a common extension, if given by {@link Bound#withSuffix}. + */ + public static Bound to(String prefix) { + return new Bound<>(GenericRecord.class).to(prefix); + } + + /** + * Returns a {@link PTransform} that writes to the file(s) with the + * given filename suffix. + */ + public static Bound withSuffix(String filenameSuffix) { + return new Bound<>(GenericRecord.class).withSuffix(filenameSuffix); + } + + /** + * Returns a {@link PTransform} that uses the provided shard count. + * + *

    Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + */ + public static Bound withNumShards(int numShards) { + return new Bound<>(GenericRecord.class).withNumShards(numShards); + } + + /** + * Returns a {@link PTransform} that uses the given shard name + * template. + * + *

    See {@link ShardNameTemplate} for a description of shard templates. + */ + public static Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(GenericRecord.class).withShardNameTemplate(shardTemplate); + } + + /** + * Returns a {@link PTransform} that forces a single file as + * output. + * + *

    Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + */ + public static Bound withoutSharding() { + return new Bound<>(GenericRecord.class).withoutSharding(); + } + + /** + * Returns a {@link PTransform} that writes Avro file(s) + * containing records whose type is the specified Avro-generated class. + * + * @param the type of the elements of the input PCollection + */ + public static Bound withSchema(Class type) { + return new Bound<>(type).withSchema(type); + } + + /** + * Returns a {@link PTransform} that writes Avro file(s) + * containing records of the specified schema. + */ + public static Bound withSchema(Schema schema) { + return new Bound<>(GenericRecord.class).withSchema(schema); + } + + /** + * Returns a {@link PTransform} that writes Avro file(s) + * containing records of the specified schema in a JSON-encoded + * string form. + */ + public static Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + /** + * Returns a {@link PTransform} that writes Avro file(s) that has GCS path validation on + * pipeline creation disabled. + * + *

    This can be useful in the case where the GCS output location does + * not exist at the pipeline creation time, but is expected to be available + * at execution time. + */ + public static Bound withoutValidation() { + return new Bound<>(GenericRecord.class).withoutValidation(); + } + + /** + * A {@link PTransform} that writes a bounded {@link PCollection} to an Avro file (or + * multiple Avro files matching a sharding pattern). + * + * @param the type of each of the elements of the input PCollection + */ + public static class Bound extends PTransform, PDone> { + /** The filename to write to. */ + @Nullable + final String filenamePrefix; + /** Suffix to use for each filename. */ + final String filenameSuffix; + /** Requested number of shards. 0 for automatic. */ + final int numShards; + /** Shard template string. */ + final String shardTemplate; + /** The class type of the records. */ + final Class type; + /** The schema of the output file. */ + @Nullable + final Schema schema; + /** An option to indicate if output validation is desired. Default is true. */ + final boolean validate; + + Bound(Class type) { + this(null, null, "", 0, ShardNameTemplate.INDEX_OF_MAX, type, null, true); + } + + Bound( + String name, + String filenamePrefix, + String filenameSuffix, + int numShards, + String shardTemplate, + Class type, + Schema schema, + boolean validate) { + super(name); + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.type = type; + this.schema = schema; + this.validate = validate; + } + + /** + * Returns a new {@link PTransform} that's like this one but + * with the given step name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>( + name, filenamePrefix, filenameSuffix, numShards, shardTemplate, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that writes to the file(s) with the given filename prefix. + * + *

    See {@link AvroIO.Write#to(String)} for more information + * about filenames. + * + *

    Does not modify this object. + */ + public Bound to(String filenamePrefix) { + validateOutputComponent(filenamePrefix); + return new Bound<>( + name, filenamePrefix, filenameSuffix, numShards, shardTemplate, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that writes to the file(s) with the given filename suffix. + * + *

    See {@link ShardNameTemplate} for a description of shard templates. + * + *

    Does not modify this object. + */ + public Bound withSuffix(String filenameSuffix) { + validateOutputComponent(filenameSuffix); + return new Bound<>( + name, filenamePrefix, filenameSuffix, numShards, shardTemplate, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that uses the provided shard count. + * + *

    Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + *

    Does not modify this object. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + * @see ShardNameTemplate + */ + public Bound withNumShards(int numShards) { + Preconditions.checkArgument(numShards >= 0); + return new Bound<>( + name, filenamePrefix, filenameSuffix, numShards, shardTemplate, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that uses the given shard name template. + * + *

    Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>( + name, filenamePrefix, filenameSuffix, numShards, shardTemplate, type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that forces a single file as output. + * + *

    This is a shortcut for + * {@code .withNumShards(1).withShardNameTemplate("")} + * + *

    Does not modify this object. + */ + public Bound withoutSharding() { + return new Bound<>(name, filenamePrefix, filenameSuffix, 1, "", type, schema, validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that writes to Avro file(s) containing records whose type is the + * specified Avro-generated class. + * + *

    Does not modify this object. + * + * @param the type of the elements of the input PCollection + */ + public Bound withSchema(Class type) { + return new Bound<>( + name, + filenamePrefix, + filenameSuffix, + numShards, + shardTemplate, + type, + ReflectData.get().getSchema(type), + validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that writes to Avro file(s) containing records of the specified + * schema. + * + *

    Does not modify this object. + */ + public Bound withSchema(Schema schema) { + return new Bound<>( + name, + filenamePrefix, + filenameSuffix, + numShards, + shardTemplate, + GenericRecord.class, + schema, + validate); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that writes to Avro file(s) containing records of the specified + * schema in a JSON-encoded string form. + * + *

    Does not modify this object. + */ + public Bound withSchema(String schema) { + return withSchema((new Schema.Parser()).parse(schema)); + } + + /** + * Returns a new {@link PTransform} that's like this one but + * that has GCS output path validation on pipeline creation disabled. + * + *

    Does not modify this object. + * + *

    This can be useful in the case where the GCS output location does + * not exist at the pipeline creation time, but is expected to be + * available at execution time. + */ + public Bound withoutValidation() { + return new Bound<>( + name, filenamePrefix, filenameSuffix, numShards, shardTemplate, type, schema, false); + } + + @Override + public PDone apply(PCollection input) { + if (filenamePrefix == null) { + throw new IllegalStateException( + "need to set the filename prefix of an AvroIO.Write transform"); + } + if (schema == null) { + throw new IllegalStateException("need to set the schema of an AvroIO.Write transform"); + } + + // Note that custom sinks currently do not expose sharding controls. + // Thus pipeline runner writers need to individually add support internally to + // apply user requested sharding limits. + return input.apply( + "Write", + com.google.cloud.dataflow.sdk.io.Write.to( + new AvroSink<>( + filenamePrefix, filenameSuffix, shardTemplate, AvroCoder.of(type, schema)))); + } + + /** + * Returns the current shard name template string. + */ + public String getShardNameTemplate() { + return shardTemplate; + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + public String getFilenamePrefix() { + return filenamePrefix; + } + + public String getShardTemplate() { + return shardTemplate; + } + + public int getNumShards() { + return numShards; + } + + public String getFilenameSuffix() { + return filenameSuffix; + } + + public Class getType() { + return type; + } + + public Schema getSchema() { + return schema; + } + + public boolean needsValidation() { + return validate; + } + } + + /** Disallow construction of utility class. */ + private Write() {} + } + + // Pattern which matches old-style shard output patterns, which are now + // disallowed. + private static final Pattern SHARD_OUTPUT_PATTERN = Pattern.compile("@([0-9]+|\\*)"); + + private static void validateOutputComponent(String partialFilePattern) { + Preconditions.checkArgument( + !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(), + "Output name components are not allowed to contain @* or @N patterns: " + + partialFilePattern); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Disallow construction of utility class. */ + private AvroIO() {} + + /** + * A {@link FileBasedSink} for Avro files. + */ + @VisibleForTesting + static class AvroSink extends FileBasedSink { + private final AvroCoder coder; + + @VisibleForTesting + AvroSink( + String baseOutputFilename, String extension, String fileNameTemplate, AvroCoder coder) { + super(baseOutputFilename, extension, fileNameTemplate); + this.coder = coder; + } + + @Override + public FileBasedSink.FileBasedWriteOperation createWriteOperation(PipelineOptions options) { + return new AvroWriteOperation<>(this, coder); + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriteOperation + * FileBasedWriteOperation} for Avro files. + */ + private static class AvroWriteOperation extends FileBasedWriteOperation { + private final AvroCoder coder; + + private AvroWriteOperation(AvroSink sink, AvroCoder coder) { + super(sink); + this.coder = coder; + } + + @Override + public FileBasedWriter createWriter(PipelineOptions options) throws Exception { + return new AvroWriter<>(this, coder); + } + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriter FileBasedWriter} + * for Avro files. + */ + private static class AvroWriter extends FileBasedWriter { + private final AvroCoder coder; + private DataFileWriter dataFileWriter; + + public AvroWriter(FileBasedWriteOperation writeOperation, AvroCoder coder) { + super(writeOperation); + this.mimeType = MimeTypes.BINARY; + this.coder = coder; + } + + @SuppressWarnings("deprecation") // uses internal test functionality. + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + dataFileWriter = new DataFileWriter<>(coder.createDatumWriter()); + dataFileWriter.create(coder.getSchema(), Channels.newOutputStream(channel)); + } + + @Override + public void write(T value) throws Exception { + dataFileWriter.append(value); + } + + @Override + protected void writeFooter() throws Exception { + dataFileWriter.flush(); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroSource.java new file mode 100644 index 000000000000..297663e96fdd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/AvroSource.java @@ -0,0 +1,647 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.util.AvroUtils; +import com.google.cloud.dataflow.sdk.util.AvroUtils.AvroMetadata; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Preconditions; + +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; +import org.apache.commons.compress.compressors.snappy.SnappyCompressorInputStream; +import org.apache.commons.compress.compressors.xz.XZCompressorInputStream; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.PushbackInputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.Collection; +import java.util.zip.Inflater; +import java.util.zip.InflaterInputStream; + +// CHECKSTYLE.OFF: JavadocStyle +/** + * A {@link FileBasedSource} for reading Avro files. + * + *

    To read a {@link PCollection} of objects from one or more Avro files, use + * {@link AvroSource#from} to specify the path(s) of the files to read. The {@link AvroSource} that + * is returned will read objects of type {@link GenericRecord} with the schema(s) that were written + * at file creation. To further configure the {@link AvroSource} to read with a user-defined schema, + * or to return records of a type other than {@link GenericRecord}, use + * {@link AvroSource#withSchema(Schema)} (using an Avro {@link Schema}), + * {@link AvroSource#withSchema(String)} (using a JSON schema), or + * {@link AvroSource#withSchema(Class)} (to return objects of the Avro-generated class specified). + * + *

    An {@link AvroSource} can be read from using the {@link Read} transform. For example: + * + *

    + * {@code
    + * AvroSource source = AvroSource.from(file.toPath()).withSchema(MyType.class);
    + * PCollection records = Read.from(mySource);
    + * }
    + * 
    + * + *

    The {@link AvroSource#readFromFileWithClass(String, Class)} method is a convenience method + * that returns a read transform. For example: + * + *

    + * {@code
    + * PCollection records = AvroSource.readFromFileWithClass(file.toPath(), MyType.class));
    + * }
    + * 
    + * + *

    This class's implementation is based on the Avro 1.7.7 specification and implements + * parsing of some parts of Avro Object Container Files. The rationale for doing so is that the Avro + * API does not provide efficient ways of computing the precise offsets of blocks within a file, + * which is necessary to support dynamic work rebalancing. However, whenever it is possible to use + * the Avro API in a way that supports maintaining precise offsets, this class uses the Avro API. + * + *

    Avro Object Container files store records in blocks. Each block contains a collection of + * records. Blocks may be encoded (e.g., with bzip2, deflate, snappy, etc.). Blocks are delineated + * from one another by a 16-byte sync marker. + * + *

    An {@link AvroSource} for a subrange of a single file contains records in the blocks such that + * the start offset of the block is greater than or equal to the start offset of the source and less + * than the end offset of the source. + * + *

    To use XZ-encoded Avro files, please include an explicit dependency on {@code xz-1.5.jar}, + * which has been marked as optional in the Maven {@code sdk/pom.xml} for Google Cloud Dataflow: + * + *

    {@code
    + * 
    + *   org.tukaani
    + *   xz
    + *   1.5
    + * 
    + * }
    + * + *

    Permissions

    + *

    Permission requirements depend on the {@link PipelineRunner} that is used to execute the + * Dataflow job. Please refer to the documentation of corresponding {@link PipelineRunner}s for + * more details. + * + * @param The type of records to be read from the source. + */ +// CHECKSTYLE.ON: JavadocStyle +@Experimental(Experimental.Kind.SOURCE_SINK) +public class AvroSource extends BlockBasedSource { + // Default minimum bundle size (chosen as two default-size Avro blocks to attempt to + // ensure that every source has at least one block of records). + // The default sync interval is 64k. + static final long DEFAULT_MIN_BUNDLE_SIZE = 2 * DataFileConstants.DEFAULT_SYNC_INTERVAL; + + // The JSON schema used to encode records. + private final String readSchemaString; + + // The JSON schema that was used to write the source Avro file (may differ from the schema we will + // use to read from it). + private final String fileSchemaString; + + // The type of the records contained in the file. + private final Class type; + + // The following metadata fields are not user-configurable. They are extracted from the object + // container file header upon subsource creation. + + // The codec used to encode the blocks in the Avro file. String value drawn from those in + // https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html + private final String codec; + + // The object container file's 16-byte sync marker. + private final byte[] syncMarker; + + // Default output coder, lazily initialized. + private transient AvroCoder coder = null; + + // Schema of the file, lazily initialized. + private transient Schema fileSchema; + + // Schema used to encode records, lazily initialized. + private transient Schema readSchema; + + /** + * Creates a {@link Read} transform that will read from an {@link AvroSource} that is configured + * to read records of the given type from a file pattern. + */ + public static Read.Bounded readFromFileWithClass(String filePattern, Class clazz) { + return Read.from(new AvroSource(filePattern, DEFAULT_MIN_BUNDLE_SIZE, + ReflectData.get().getSchema(clazz).toString(), clazz, null, null)); + } + + /** + * Creates an {@link AvroSource} that reads from the given file name or pattern ("glob"). The + * returned source can be further configured by calling {@link #withSchema} to return a type other + * than {@link GenericRecord}. + */ + public static AvroSource from(String fileNameOrPattern) { + return new AvroSource<>( + fileNameOrPattern, DEFAULT_MIN_BUNDLE_SIZE, null, GenericRecord.class, null, null); + } + + /** + * Returns an {@link AvroSource} that's like this one but reads files containing records that + * conform to the given schema. + * + *

    Does not modify this object. + */ + public AvroSource withSchema(String schema) { + return new AvroSource<>( + getFileOrPatternSpec(), getMinBundleSize(), schema, GenericRecord.class, codec, syncMarker); + } + + /** + * Returns an {@link AvroSource} that's like this one but reads files containing records that + * conform to the given schema. + * + *

    Does not modify this object. + */ + public AvroSource withSchema(Schema schema) { + return new AvroSource<>(getFileOrPatternSpec(), getMinBundleSize(), schema.toString(), + GenericRecord.class, codec, syncMarker); + } + + /** + * Returns an {@link AvroSource} that's like this one but reads files containing records of the + * type of the given class. + * + *

    Does not modify this object. + */ + public AvroSource withSchema(Class clazz) { + return new AvroSource(getFileOrPatternSpec(), getMinBundleSize(), + ReflectData.get().getSchema(clazz).toString(), clazz, codec, syncMarker); + } + + /** + * Returns an {@link AvroSource} that's like this one but uses the supplied minimum bundle size. + * Refer to {@link OffsetBasedSource} for a description of {@code minBundleSize} and its use. + * + *

    Does not modify this object. + */ + public AvroSource withMinBundleSize(long minBundleSize) { + return new AvroSource( + getFileOrPatternSpec(), minBundleSize, readSchemaString, type, codec, syncMarker); + } + + private AvroSource(String fileNameOrPattern, long minBundleSize, String schema, Class type, + String codec, byte[] syncMarker) { + super(fileNameOrPattern, minBundleSize); + this.readSchemaString = schema; + this.codec = codec; + this.syncMarker = syncMarker; + this.type = type; + this.fileSchemaString = null; + } + + private AvroSource(String fileName, long minBundleSize, long startOffset, long endOffset, + String schema, Class type, String codec, byte[] syncMarker, String fileSchema) { + super(fileName, minBundleSize, startOffset, endOffset); + this.readSchemaString = schema; + this.codec = codec; + this.syncMarker = syncMarker; + this.type = type; + this.fileSchemaString = fileSchema; + } + + @Override + public void validate() { + // AvroSource objects do not need to be configured with more than a file pattern. Overridden to + // make this explicit. + super.validate(); + } + + @Override + public BlockBasedSource createForSubrangeOfFile(String fileName, long start, long end) { + byte[] syncMarker = this.syncMarker; + String codec = this.codec; + String readSchemaString = this.readSchemaString; + String fileSchemaString = this.fileSchemaString; + // codec and syncMarker are initially null when the source is created, as they differ + // across input files and must be read from the file. Here, when we are creating a source + // for a subrange of a file, we can initialize these values. When the resulting AvroSource + // is further split, they do not need to be read again. + if (codec == null || syncMarker == null || fileSchemaString == null) { + AvroMetadata metadata; + try { + Collection files = FileBasedSource.expandFilePattern(fileName); + Preconditions.checkArgument(files.size() <= 1, "More than 1 file matched %s"); + metadata = AvroUtils.readMetadataFromFile(fileName); + } catch (IOException e) { + throw new RuntimeException("Error reading metadata from file " + fileName, e); + } + codec = metadata.getCodec(); + syncMarker = metadata.getSyncMarker(); + fileSchemaString = metadata.getSchemaString(); + // If the source was created with a null schema, use the schema that we read from the file's + // metadata. + if (readSchemaString == null) { + readSchemaString = metadata.getSchemaString(); + } + } + return new AvroSource(fileName, getMinBundleSize(), start, end, readSchemaString, type, + codec, syncMarker, fileSchemaString); + } + + @Override + protected BlockBasedReader createSingleFileReader(PipelineOptions options) { + return new AvroReader(this); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public AvroCoder getDefaultOutputCoder() { + if (coder == null) { + Schema.Parser parser = new Schema.Parser(); + coder = AvroCoder.of(type, parser.parse(readSchemaString)); + } + return coder; + } + + public String getSchema() { + return readSchemaString; + } + + private Schema getReadSchema() { + if (readSchemaString == null) { + return null; + } + + // If the schema has not been parsed, parse it. + if (readSchema == null) { + Schema.Parser parser = new Schema.Parser(); + readSchema = parser.parse(readSchemaString); + } + return readSchema; + } + + private Schema getFileSchema() { + if (fileSchemaString == null) { + return null; + } + + // If the schema has not been parsed, parse it. + if (fileSchema == null) { + Schema.Parser parser = new Schema.Parser(); + fileSchema = parser.parse(fileSchemaString); + } + return fileSchema; + } + + private byte[] getSyncMarker() { + return syncMarker; + } + + private String getCodec() { + return codec; + } + + private DatumReader createDatumReader() { + Schema readSchema = getReadSchema(); + Schema fileSchema = getFileSchema(); + Preconditions.checkNotNull( + readSchema, "No read schema has been initialized for source %s", this); + Preconditions.checkNotNull( + fileSchema, "No file schema has been initialized for source %s", this); + if (type == GenericRecord.class) { + return new GenericDatumReader<>(fileSchema, readSchema); + } else { + return new ReflectDatumReader<>(fileSchema, readSchema); + } + } + + /** + * A {@link BlockBasedSource.Block} of Avro records. + * + * @param The type of records stored in the block. + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + static class AvroBlock extends Block { + // The number of records in the block. + private final long numRecords; + + // The current record in the block. + private T currentRecord; + + // The index of the current record in the block. + private long currentRecordIndex = 0; + + // A DatumReader to read records from the block. + private final DatumReader reader; + + // A BinaryDecoder used by the reader to decode records. + private final BinaryDecoder decoder; + + /** + * Decodes a byte array as an InputStream. The byte array may be compressed using some + * codec. Reads from the returned stream will result in decompressed bytes. + * + *

    This supports the same codecs as Avro's {@link CodecFactory}, namely those defined in + * {@link DataFileConstants}. + * + *

      + *
    • "snappy" : Google's Snappy compression + *
    • "deflate" : deflate compression + *
    • "bzip2" : Bzip2 compression + *
    • "xz" : xz compression + *
    • "null" (the string, not the value): Uncompressed data + *
    + */ + private static InputStream decodeAsInputStream(byte[] data, String codec) throws IOException { + ByteArrayInputStream byteStream = new ByteArrayInputStream(data); + switch (codec) { + case DataFileConstants.SNAPPY_CODEC: + return new SnappyCompressorInputStream(byteStream); + case DataFileConstants.DEFLATE_CODEC: + // nowrap == true: Do not expect ZLIB header or checksum, as Avro does not write them. + Inflater inflater = new Inflater(true); + return new InflaterInputStream(byteStream, inflater); + case DataFileConstants.XZ_CODEC: + return new XZCompressorInputStream(byteStream); + case DataFileConstants.BZIP2_CODEC: + return new BZip2CompressorInputStream(byteStream); + case DataFileConstants.NULL_CODEC: + return byteStream; + default: + throw new IllegalArgumentException("Unsupported codec: " + codec); + } + } + + AvroBlock(byte[] data, long numRecords, AvroSource source) throws IOException { + this.numRecords = numRecords; + this.reader = source.createDatumReader(); + this.decoder = + DecoderFactory.get().binaryDecoder(decodeAsInputStream(data, source.getCodec()), null); + } + + @Override + public T getCurrentRecord() { + return currentRecord; + } + + @Override + public boolean readNextRecord() throws IOException { + if (currentRecordIndex >= numRecords) { + return false; + } + currentRecord = reader.read(null, decoder); + currentRecordIndex++; + return true; + } + + @Override + public double getFractionOfBlockConsumed() { + return ((double) currentRecordIndex) / numRecords; + } + } + + /** + * A {@link BlockBasedSource.BlockBasedReader} for reading blocks from Avro files. + * + *

    An Avro Object Container File consists of a header followed by a 16-bit sync marker + * and then a sequence of blocks, where each block begins with two encoded longs representing + * the total number of records in the block and the block's size in bytes, followed by the + * block's (optionally-encoded) records. Each block is terminated by a 16-bit sync marker. + * + *

    Here, we consider the sync marker that precedes a block to be its offset, as this allows + * a reader that begins reading at that offset to detect the sync marker and the beginning of + * the block. + * + * @param The type of records contained in the block. + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + public static class AvroReader extends BlockBasedReader { + // The current block. + private AvroBlock currentBlock; + + // Offset of the block. + private long currentBlockOffset = 0; + + // Size of the current block. + private long currentBlockSizeBytes = 0; + + // Current offset within the stream. + private long currentOffset = 0; + + // Stream used to read from the underlying file. + // A pushback stream is used to restore bytes buffered during seeking/decoding. + private PushbackInputStream stream; + + // Small buffer for reading encoded values from the stream. + // The maximum size of an encoded long is 10 bytes, and this buffer will be used to read two. + private final byte[] readBuffer = new byte[20]; + + // Decoder to decode binary-encoded values from the buffer. + private BinaryDecoder decoder; + + /** + * Reads Avro records of type {@code T} from the specified source. + */ + public AvroReader(AvroSource source) { + super(source); + } + + @Override + public synchronized AvroSource getCurrentSource() { + return (AvroSource) super.getCurrentSource(); + } + + @Override + public boolean readNextBlock() throws IOException { + // The next block in the file is after the first sync marker that can be read starting from + // the current offset. First, we seek past the next sync marker, if it exists. After a sync + // marker is the start of a block. A block begins with the number of records contained in + // the block, encoded as a long, followed by the size of the block in bytes, encoded as a + // long. The currentOffset after this method should be last byte after this block, and the + // currentBlockOffset should be the start of the sync marker before this block. + + // Seek to the next sync marker, if one exists. + currentOffset += advancePastNextSyncMarker(stream, getCurrentSource().getSyncMarker()); + + // The offset of the current block includes its preceding sync marker. + currentBlockOffset = currentOffset - getCurrentSource().getSyncMarker().length; + + // Read a small buffer to parse the block header. + // We cannot use a BinaryDecoder to do this directly from the stream because a BinaryDecoder + // internally buffers data and we only want to read as many bytes from the stream as the size + // of the header. Though BinaryDecoder#InputStream returns an input stream that is aware of + // its internal buffering, we would have to re-wrap this input stream to seek for the next + // block in the file. + int read = stream.read(readBuffer); + // We reached the last sync marker in the file. + if (read <= 0) { + return false; + } + decoder = DecoderFactory.get().binaryDecoder(readBuffer, decoder); + long numRecords = decoder.readLong(); + long blockSize = decoder.readLong(); + + // The decoder buffers data internally, but since we know the size of the stream the + // decoder has constructed from the readBuffer, the number of bytes available in the + // input stream is equal to the number of unconsumed bytes. + int headerSize = readBuffer.length - decoder.inputStream().available(); + stream.unread(readBuffer, headerSize, read - headerSize); + + // Create the current block by reading blockSize bytes. Block sizes permitted by the Avro + // specification are [32, 2^30], so this narrowing is ok. + byte[] data = new byte[(int) blockSize]; + stream.read(data); + currentBlock = new AvroBlock<>(data, numRecords, getCurrentSource()); + currentBlockSizeBytes = blockSize; + + // Update current offset with the number of bytes we read to get the next block. + currentOffset += headerSize + blockSize; + return true; + } + + @Override + public AvroBlock getCurrentBlock() { + return currentBlock; + } + + @Override + public long getCurrentBlockOffset() { + return currentBlockOffset; + } + + @Override + public long getCurrentBlockSize() { + return currentBlockSizeBytes; + } + + /** + * Creates a {@link PushbackInputStream} that has a large enough pushback buffer to be able + * to push back the syncBuffer and the readBuffer. + */ + private PushbackInputStream createStream(ReadableByteChannel channel) { + return new PushbackInputStream( + Channels.newInputStream(channel), + getCurrentSource().getSyncMarker().length + readBuffer.length); + } + + /** + * Starts reading from the provided channel. Assumes that the channel is already seeked to + * the source's start offset. + */ + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + stream = createStream(channel); + currentOffset = getCurrentSource().getStartOffset(); + } + + /** + * Advances to the first byte after the next occurrence of the sync marker in the + * stream when reading from the current offset. Returns the number of bytes consumed + * from the stream. Note that this method requires a PushbackInputStream with a buffer + * at least as big as the marker it is seeking for. + */ + static long advancePastNextSyncMarker(PushbackInputStream stream, byte[] syncMarker) + throws IOException { + Seeker seeker = new Seeker(syncMarker); + byte[] syncBuffer = new byte[syncMarker.length]; + long totalBytesConsumed = 0; + // Seek until either a sync marker is found or we reach the end of the file. + int mark = -1; // Position of the last byte in the sync marker. + int read; // Number of bytes read. + do { + read = stream.read(syncBuffer); + if (read >= 0) { + mark = seeker.find(syncBuffer, read); + // Update the currentOffset with the number of bytes read. + totalBytesConsumed += read; + } + } while (mark < 0 && read > 0); + + // If the sync marker was found, unread block data and update the current offsets. + if (mark >= 0) { + // The current offset after this call should be just past the sync marker, so we should + // unread the remaining buffer contents and update the currentOffset accordingly. + stream.unread(syncBuffer, mark + 1, read - (mark + 1)); + totalBytesConsumed = totalBytesConsumed - (read - (mark + 1)); + } + return totalBytesConsumed; + } + + /** + * A {@link Seeker} looks for a given marker within a byte buffer. Uses naive string matching + * with a sliding window, as sync markers are small and random. + */ + static class Seeker { + // The marker to search for. + private byte[] marker; + + // Buffer used for the sliding window. + private byte[] searchBuffer; + + // Number of bytes available to be matched in the buffer. + private int available = 0; + + /** + * Create a {@link Seeker} that looks for the given marker. + */ + public Seeker(byte[] marker) { + this.marker = marker; + this.searchBuffer = new byte[marker.length]; + } + + /** + * Find the marker in the byte buffer. Returns the index of the end of the marker in the + * buffer. If the marker is not found, returns -1. + * + *

    State is maintained between calls. If the marker was partially matched, a subsequent + * call to find will resume matching the marker. + * + * @param buffer + * @return the index of the end of the marker within the buffer, or -1 if the buffer was not + * found. + */ + public int find(byte[] buffer, int length) { + for (int i = 0; i < length; i++) { + System.arraycopy(searchBuffer, 1, searchBuffer, 0, searchBuffer.length - 1); + searchBuffer[searchBuffer.length - 1] = buffer[i]; + available = Math.min(available + 1, searchBuffer.length); + if (ByteBuffer.wrap(searchBuffer, searchBuffer.length - available, available) + .equals(ByteBuffer.wrap(marker))) { + available = 0; + return i; + } + } + return -1; + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BigQueryIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BigQueryIO.java new file mode 100644 index 000000000000..ab7df6f08147 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BigQueryIO.java @@ -0,0 +1,1499 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.api.client.json.JsonFactory; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.QueryRequest; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.BigQueryTableInserter; +import com.google.cloud.dataflow.sdk.util.BigQueryTableRowIterator; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Reshuffle; +import com.google.cloud.dataflow.sdk.util.SystemDoFnInternal; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.hadoop.util.ApiErrorExtractor; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * {@link PTransform}s for reading and writing + * BigQuery tables. + * + *

    Table References

    + *

    A fully-qualified BigQuery table name consists of three components: + *

      + *
    • {@code projectId}: the Cloud project id (defaults to + * {@link GcpOptions#getProject()}). + *
    • {@code datasetId}: the BigQuery dataset id, unique within a project. + *
    • {@code tableId}: a table id, unique within a dataset. + *
    + * + *

    BigQuery table references are stored as a {@link TableReference}, which comes + * from the + * BigQuery Java Client API. + * Tables can be referred to as Strings, with or without the {@code projectId}. + * A helper function is provided ({@link BigQueryIO#parseTableSpec(String)}) + * that parses the following string forms into a {@link TableReference}: + * + *

      + *
    • [{@code project_id}]:[{@code dataset_id}].[{@code table_id}] + *
    • [{@code dataset_id}].[{@code table_id}] + *
    + * + *

    Reading

    + *

    To read from a BigQuery table, apply a {@link BigQueryIO.Read} transformation. + * This produces a {@link PCollection} of {@link TableRow TableRows} as output: + *

    {@code
    + * PCollection shakespeare = pipeline.apply(
    + *     BigQueryIO.Read.named("Read")
    + *                    .from("clouddataflow-readonly:samples.weather_stations"));
    + * }
    + * + *

    See {@link TableRow} for more information on the {@link TableRow} object. + * + *

    Users may provide a query to read from rather than reading all of a BigQuery table. If + * specified, the result obtained by executing the specified query will be used as the data of the + * input transform. + * + *

    {@code
    + * PCollection shakespeare = pipeline.apply(
    + *     BigQueryIO.Read.named("Read")
    + *                    .fromQuery("SELECT year, mean_temp FROM samples.weather_stations"));
    + * }
    + * + *

    When creating a BigQuery input transform, users should provide either a query or a table. + * Pipeline construction will fail with a validation error if neither or both are specified. + * + *

    Writing

    + *

    To write to a BigQuery table, apply a {@link BigQueryIO.Write} transformation. + * This consumes a {@link PCollection} of {@link TableRow TableRows} as input. + *

    {@code
    + * PCollection quotes = ...
    + *
    + * List fields = new ArrayList<>();
    + * fields.add(new TableFieldSchema().setName("source").setType("STRING"));
    + * fields.add(new TableFieldSchema().setName("quote").setType("STRING"));
    + * TableSchema schema = new TableSchema().setFields(fields);
    + *
    + * quotes.apply(BigQueryIO.Write
    + *     .named("Write")
    + *     .to("my-project:output.output_table")
    + *     .withSchema(schema)
    + *     .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
    + * }
    + * + *

    See {@link BigQueryIO.Write} for details on how to specify if a write should + * append to an existing table, replace the table, or verify that the table is + * empty. Note that the dataset being written to must already exist. Write + * dispositions are not supported in streaming mode. + * + *

    Sharding BigQuery output tables

    + *

    A common use case is to dynamically generate BigQuery table names based on + * the current window. To support this, + * {@link BigQueryIO.Write#to(SerializableFunction)} + * accepts a function mapping the current window to a tablespec. For example, + * here's code that outputs daily tables to BigQuery: + *

    {@code
    + * PCollection quotes = ...
    + * quotes.apply(Window.into(CalendarWindows.days(1)))
    + *       .apply(BigQueryIO.Write
    + *         .named("Write")
    + *         .withSchema(schema)
    + *         .to(new SerializableFunction() {
    + *           public String apply(BoundedWindow window) {
    + *             // The cast below is safe because CalendarWindows.days(1) produces IntervalWindows.
    + *             String dayString = DateTimeFormat.forPattern("yyyy_MM_dd")
    + *                  .withZone(DateTimeZone.UTC)
    + *                  .print(((IntervalWindow) window).start());
    + *             return "my-project:output.output_table_" + dayString;
    + *           }
    + *         }));
    + * }
    + * + *

    Per-window tables are not yet supported in batch mode. + * + *

    Permissions

    + *

    Permission requirements depend on the {@link PipelineRunner} that is used to execute the + * Dataflow job. Please refer to the documentation of corresponding {@link PipelineRunner}s for + * more details. + * + *

    Please see BigQuery Access Control + * for security and permission related information specific to BigQuery. + */ +public class BigQueryIO { + private static final Logger LOG = LoggerFactory.getLogger(BigQueryIO.class); + + /** + * Singleton instance of the JSON factory used to read and write JSON + * formatted rows. + */ + private static final JsonFactory JSON_FACTORY = Transport.getJsonFactory(); + + /** + * Project IDs must contain 6-63 lowercase letters, digits, or dashes. + * IDs must start with a letter and may not end with a dash. + * This regex isn't exact - this allows for patterns that would be rejected by + * the service, but this is sufficient for basic parsing of table references. + */ + private static final String PROJECT_ID_REGEXP = "[a-z][-a-z0-9:.]{4,61}[a-z0-9]"; + + /** + * Regular expression that matches Dataset IDs. + */ + private static final String DATASET_REGEXP = "[-\\w.]{1,1024}"; + + /** + * Regular expression that matches Table IDs. + */ + private static final String TABLE_REGEXP = "[-\\w$@]{1,1024}"; + + /** + * Matches table specifications in the form {@code "[project_id]:[dataset_id].[table_id]"} or + * {@code "[dataset_id].[table_id]"}. + */ + private static final String DATASET_TABLE_REGEXP = + String.format("((?%s):)?(?%s)\\.(?%s)", PROJECT_ID_REGEXP, + DATASET_REGEXP, TABLE_REGEXP); + + private static final Pattern TABLE_SPEC = Pattern.compile(DATASET_TABLE_REGEXP); + + // TODO: make this private and remove improper access from BigQueryIOTranslator. + public static final String SET_PROJECT_FROM_OPTIONS_WARNING = + "No project specified for BigQuery table \"%1$s.%2$s\". Assuming it is in \"%3$s\". If the" + + " table is in a different project please specify it as a part of the BigQuery table" + + " definition."; + + private static final String RESOURCE_NOT_FOUND_ERROR = + "BigQuery %1$s not found for table \"%2$s\" . Please create the %1$s before pipeline" + + " execution. If the %1$s is created by an earlier stage of the pipeline, this" + + " validation can be disabled using #withoutValidation."; + + private static final String UNABLE_TO_CONFIRM_PRESENCE_OF_RESOURCE_ERROR = + "Unable to confirm BigQuery %1$s presence for table \"%2$s\". If the %1$s is created by" + + " an earlier stage of the pipeline, this validation can be disabled using" + + " #withoutValidation."; + + /** + * Parse a table specification in the form + * {@code "[project_id]:[dataset_id].[table_id]"} or {@code "[dataset_id].[table_id]"}. + * + *

    If the project id is omitted, the default project id is used. + */ + public static TableReference parseTableSpec(String tableSpec) { + Matcher match = TABLE_SPEC.matcher(tableSpec); + if (!match.matches()) { + throw new IllegalArgumentException( + "Table reference is not in [project_id]:[dataset_id].[table_id] " + + "format: " + tableSpec); + } + + TableReference ref = new TableReference(); + ref.setProjectId(match.group("PROJECT")); + + return ref.setDatasetId(match.group("DATASET")).setTableId(match.group("TABLE")); + } + + /** + * Returns a canonical string representation of the {@link TableReference}. + */ + public static String toTableSpec(TableReference ref) { + StringBuilder sb = new StringBuilder(); + if (ref.getProjectId() != null) { + sb.append(ref.getProjectId()); + sb.append(":"); + } + + sb.append(ref.getDatasetId()).append('.').append(ref.getTableId()); + return sb.toString(); + } + + /** + * A {@link PTransform} that reads from a BigQuery table and returns a + * {@link PCollection} of {@link TableRow TableRows} containing each of the rows of the table. + * + *

    Each {@link TableRow} contains values indexed by column name. Here is a + * sample processing function that processes a "line" column from rows: + *

    {@code
    +   * static class ExtractWordsFn extends DoFn {
    +   *   public void processElement(ProcessContext c) {
    +   *     // Get the "line" field of the TableRow object, split it into words, and emit them.
    +   *     TableRow row = c.element();
    +   *     String[] words = row.get("line").toString().split("[^a-zA-Z']+");
    +   *     for (String word : words) {
    +   *       if (!word.isEmpty()) {
    +   *         c.output(word);
    +   *       }
    +   *     }
    +   *   }
    +   * }}
    + */ + public static class Read { + /** + * Returns a {@link Read.Bound} with the given name. The BigQuery table or query to be read + * from has not yet been configured. + */ + public static Bound named(String name) { + return new Bound().named(name); + } + + /** + * Reads a BigQuery table specified as {@code "[project_id]:[dataset_id].[table_id]"} or + * {@code "[dataset_id].[table_id]"} for tables within the current project. + */ + public static Bound from(String tableSpec) { + return new Bound().from(tableSpec); + } + + /** + * Reads results received after executing the given query. + */ + public static Bound fromQuery(String query) { + return new Bound().fromQuery(query); + } + + /** + * Reads a BigQuery table specified as a {@link TableReference} object. + */ + public static Bound from(TableReference table) { + return new Bound().from(table); + } + + /** + * Disables BigQuery table validation, which is enabled by default. + */ + public static Bound withoutValidation() { + return new Bound().withoutValidation(); + } + + /** + * A {@link PTransform} that reads from a BigQuery table and returns a bounded + * {@link PCollection} of {@link TableRow TableRows}. + */ + public static class Bound extends PTransform> { + TableReference table; + final String query; + final boolean validate; + @Nullable + Boolean flattenResults; + + private static final String QUERY_VALIDATION_FAILURE_ERROR = + "Validation of query \"%1$s\" failed. If the query depends on an earlier stage of the" + + " pipeline, This validation can be disabled using #withoutValidation."; + + private Bound() { + this(null, null, null, true, null); + } + + private Bound(String name, String query, TableReference reference, boolean validate, + Boolean flattenResults) { + super(name); + this.table = reference; + this.query = query; + this.validate = validate; + this.flattenResults = flattenResults; + } + + /** + * Returns a copy of this transform using the name associated with this transformation. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound(name, query, table, validate, flattenResults); + } + + /** + * Returns a copy of this transform that reads from the specified table. Refer to + * {@link #parseTableSpec(String)} for the specification format. + * + *

    Does not modify this object. + */ + public Bound from(String tableSpec) { + return from(parseTableSpec(tableSpec)); + } + + /** + * Returns a copy of this transform that reads from the specified table. + * + *

    Does not modify this object. + */ + public Bound from(TableReference table) { + return new Bound(name, query, table, validate, flattenResults); + } + + /** + * Returns a copy of this transform that reads the results of the specified query. + * + *

    Does not modify this object. + * + *

    By default, the query results will be flattened -- see + * "flattenResults" in the + * Jobs documentation for more information. To disable flattening, use + * {@link BigQueryIO.Read.Bound#withoutResultFlattening}. + */ + public Bound fromQuery(String query) { + return new Bound(name, query, table, validate, + MoreObjects.firstNonNull(flattenResults, Boolean.TRUE)); + } + + /** + * Disable table validation. + */ + public Bound withoutValidation() { + return new Bound(name, query, table, false, flattenResults); + } + + /** + * Disable + * flattening of query results. + * + *

    Only valid when a query is used ({@link #fromQuery}). Setting this option when reading + * from a table will cause an error during validation. + */ + public Bound withoutResultFlattening() { + return new Bound(name, query, table, validate, false); + } + + /** + * Validates the current {@link PTransform}. + */ + @Override + public void validate(PInput input) { + if (table == null && query == null) { + throw new IllegalStateException( + "Invalid BigQuery read operation, either table reference or query has to be set"); + } else if (table != null && query != null) { + throw new IllegalStateException("Invalid BigQuery read operation. Specifies both a" + + " query and a table, only one of these should be provided"); + } else if (table != null && flattenResults != null) { + throw new IllegalStateException("Invalid BigQuery read operation. Specifies a" + + " table with a result flattening preference, which is not configurable"); + } else if (query != null && flattenResults == null) { + throw new IllegalStateException("Invalid BigQuery read operation. Specifies a" + + " query without a result flattening preference"); + } + + BigQueryOptions bqOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); + if (table != null && table.getProjectId() == null) { + // If user does not specify a project we assume the table to be located in the project + // that owns the Dataflow job. + LOG.warn(String.format(SET_PROJECT_FROM_OPTIONS_WARNING, table.getDatasetId(), + table.getTableId(), bqOptions.getProject())); + table.setProjectId(bqOptions.getProject()); + } + + if (validate) { + // Check for source table/query presence for early failure notification. + // Note that a presence check can fail if the table or dataset are created by earlier + // stages of the pipeline or if a query depends on earlier stages of a pipeline. For these + // cases the withoutValidation method can be used to disable the check. + if (table != null) { + verifyDatasetPresence(bqOptions, table); + verifyTablePresence(bqOptions, table); + } + if (query != null) { + dryRunQuery(bqOptions, query); + } + } + } + + private static void dryRunQuery(BigQueryOptions options, String query) { + Bigquery client = Transport.newBigQueryClient(options).build(); + QueryRequest request = new QueryRequest(); + request.setQuery(query); + request.setDryRun(true); + + try { + BigQueryTableRowIterator.executeWithBackOff( + client.jobs().query(options.getProject(), request), QUERY_VALIDATION_FAILURE_ERROR, + query); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format(QUERY_VALIDATION_FAILURE_ERROR, query), e); + } + } + + @Override + public PCollection apply(PInput input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + IsBounded.BOUNDED) + // Force the output's Coder to be what the read is using, and + // unchangeable later, to ensure that we read the input in the + // format specified by the Read transform. + .setCoder(TableRowJsonCoder.of()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return TableRowJsonCoder.of(); + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, DirectPipelineRunner.EvaluationContext context) { + evaluateReadHelper(transform, context); + } + }); + } + + /** + * Returns the table to write, or {@code null} if reading from a query instead. + */ + public TableReference getTable() { + return table; + } + + /** + * Returns the query to be read, or {@code null} if reading from a table instead. + */ + public String getQuery() { + return query; + } + + /** + * Returns true if table validation is enabled. + */ + public boolean getValidate() { + return validate; + } + + /** + * Returns true/false if result flattening is enabled/disabled, or null if not applicable. + */ + public Boolean getFlattenResults() { + return flattenResults; + } + } + + /** Disallow construction of utility class. */ + private Read() {} + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@link PTransform} that writes a {@link PCollection} containing {@link TableRow TableRows} + * to a BigQuery table. + * + *

    In BigQuery, each table has an encosing dataset. The dataset being written must already + * exist. + * + *

    By default, tables will be created if they do not exist, which corresponds to a + * {@link CreateDisposition#CREATE_IF_NEEDED} disposition that matches the default of BigQuery's + * Jobs API. A schema must be provided (via {@link BigQueryIO.Write#withSchema(TableSchema)}), + * or else the transform may fail at runtime with an {@link IllegalArgumentException}. + * + *

    By default, writes require an empty table, which corresponds to + * a {@link WriteDisposition#WRITE_EMPTY} disposition that matches the + * default of BigQuery's Jobs API. + * + *

    Here is a sample transform that produces TableRow values containing + * "word" and "count" columns: + *

    {@code
    +   * static class FormatCountsFn extends DoFn, TableRow> {
    +   *   public void processElement(ProcessContext c) {
    +   *     TableRow row = new TableRow()
    +   *         .set("word", c.element().getKey())
    +   *         .set("count", c.element().getValue().intValue());
    +   *     c.output(row);
    +   *   }
    +   * }}
    + */ + public static class Write { + /** + * An enumeration type for the BigQuery create disposition strings. + * + * @see + * configuration.query.createDisposition in the BigQuery Jobs API + */ + public enum CreateDisposition { + /** + * Specifics that tables should not be created. + * + *

    If the output table does not exist, the write fails. + */ + CREATE_NEVER, + + /** + * Specifies that tables should be created if needed. This is the default + * behavior. + * + *

    Requires that a table schema is provided via {@link BigQueryIO.Write#withSchema}. + * This precondition is checked before starting a job. The schema is + * not required to match an existing table's schema. + * + *

    When this transformation is executed, if the output table does not + * exist, the table is created from the provided schema. Note that even if + * the table exists, it may be recreated if necessary when paired with a + * {@link WriteDisposition#WRITE_TRUNCATE}. + */ + CREATE_IF_NEEDED + } + + /** + * An enumeration type for the BigQuery write disposition strings. + * + * @see + * configuration.query.writeDisposition in the BigQuery Jobs API + */ + public enum WriteDisposition { + /** + * Specifies that write should replace a table. + * + *

    The replacement may occur in multiple steps - for instance by first + * removing the existing table, then creating a replacement, then filling + * it in. This is not an atomic operation, and external programs may + * see the table in any of these intermediate steps. + */ + WRITE_TRUNCATE, + + /** + * Specifies that rows may be appended to an existing table. + */ + WRITE_APPEND, + + /** + * Specifies that the output table must be empty. This is the default + * behavior. + * + *

    If the output table is not empty, the write fails at runtime. + * + *

    This check may occur long before data is written, and does not + * guarantee exclusive access to the table. If two programs are run + * concurrently, each specifying the same output table and + * a {@link WriteDisposition} of {@link WriteDisposition#WRITE_EMPTY}, it is possible + * for both to succeed. + */ + WRITE_EMPTY + } + + /** + * Creates a write transformation with the given transform name. The BigQuery table to be + * written has not yet been configured. + */ + public static Bound named(String name) { + return new Bound().named(name); + } + + /** + * Creates a write transformation for the given table specification. + * + *

    Refer to {@link #parseTableSpec(String)} for the specification format. + */ + public static Bound to(String tableSpec) { + return new Bound().to(tableSpec); + } + + /** Creates a write transformation for the given table. */ + public static Bound to(TableReference table) { + return new Bound().to(table); + } + + /** + * Creates a write transformation from a function that maps windows to table specifications. + * Each time a new window is encountered, this function will be called and the resulting table + * will be created. Records within that window will be written to the associated table. + * + *

    See {@link #parseTableSpec(String)} for the format that {@code tableSpecFunction} should + * return. + * + *

    {@code tableSpecFunction} should be deterministic. When given the same window, it should + * always return the same table specification. + */ + public static Bound to(SerializableFunction tableSpecFunction) { + return new Bound().to(tableSpecFunction); + } + + /** + * Creates a write transformation from a function that maps windows to {@link TableReference} + * objects. + * + *

    {@code tableRefFunction} should be deterministic. When given the same window, it should + * always return the same table reference. + */ + public static Bound toTableReference( + SerializableFunction tableRefFunction) { + return new Bound().toTableReference(tableRefFunction); + } + + /** + * Creates a write transformation with the specified schema to use in table creation. + * + *

    The schema is required only if writing to a table that does not already + * exist, and {@link CreateDisposition} is set to + * {@link CreateDisposition#CREATE_IF_NEEDED}. + */ + public static Bound withSchema(TableSchema schema) { + return new Bound().withSchema(schema); + } + + /** Creates a write transformation with the specified options for creating the table. */ + public static Bound withCreateDisposition(CreateDisposition disposition) { + return new Bound().withCreateDisposition(disposition); + } + + /** Creates a write transformation with the specified options for writing to the table. */ + public static Bound withWriteDisposition(WriteDisposition disposition) { + return new Bound().withWriteDisposition(disposition); + } + + /** + * Creates a write transformation with BigQuery table validation disabled. + */ + public static Bound withoutValidation() { + return new Bound().withoutValidation(); + } + + /** + * A {@link PTransform} that can write either a bounded or unbounded + * {@link PCollection} of {@link TableRow TableRows} to a BigQuery table. + */ + public static class Bound extends PTransform, PDone> { + final TableReference table; + + final SerializableFunction tableRefFunction; + + // Table schema. The schema is required only if the table does not exist. + final TableSchema schema; + + // Options for creating the table. Valid values are CREATE_IF_NEEDED and + // CREATE_NEVER. + final CreateDisposition createDisposition; + + // Options for writing to the table. Valid values are WRITE_TRUNCATE, + // WRITE_APPEND and WRITE_EMPTY. + final WriteDisposition writeDisposition; + + // An option to indicate if table validation is desired. Default is true. + final boolean validate; + + private static class TranslateTableSpecFunction implements + SerializableFunction { + private SerializableFunction tableSpecFunction; + + TranslateTableSpecFunction(SerializableFunction tableSpecFunction) { + this.tableSpecFunction = tableSpecFunction; + } + + @Override + public TableReference apply(BoundedWindow value) { + return parseTableSpec(tableSpecFunction.apply(value)); + } + } + + /** + * @deprecated Should be private. Instead, use one of the factory methods in + * {@link BigQueryIO.Write}, such as {@link BigQueryIO.Write#to(String)}, to create an + * instance of this class. + */ + @Deprecated + public Bound() { + this(null, null, null, null, CreateDisposition.CREATE_IF_NEEDED, + WriteDisposition.WRITE_EMPTY, true); + } + + private Bound(String name, TableReference ref, + SerializableFunction tableRefFunction, TableSchema schema, + CreateDisposition createDisposition, WriteDisposition writeDisposition, + boolean validate) { + super(name); + this.table = ref; + this.tableRefFunction = tableRefFunction; + this.schema = schema; + this.createDisposition = createDisposition; + this.writeDisposition = writeDisposition; + this.validate = validate; + } + + /** + * Returns a copy of this write transformation, but with the specified transform name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Returns a copy of this write transformation, but writing to the specified table. Refer to + * {@link #parseTableSpec(String)} for the specification format. + * + *

    Does not modify this object. + */ + public Bound to(String tableSpec) { + return to(parseTableSpec(tableSpec)); + } + + /** + * Returns a copy of this write transformation, but writing to the specified table. + * + *

    Does not modify this object. + */ + public Bound to(TableReference table) { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Returns a copy of this write transformation, but using the specified function to determine + * which table to write to for each window. + * + *

    Does not modify this object. + * + *

    {@code tableSpecFunction} should be deterministic. When given the same window, it + * should always return the same table specification. + */ + public Bound to( + SerializableFunction tableSpecFunction) { + return toTableReference(new TranslateTableSpecFunction(tableSpecFunction)); + } + + /** + * Returns a copy of this write transformation, but using the specified function to determine + * which table to write to for each window. + * + *

    Does not modify this object. + * + *

    {@code tableRefFunction} should be deterministic. When given the same window, it should + * always return the same table reference. + */ + public Bound toTableReference( + SerializableFunction tableRefFunction) { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Returns a copy of this write transformation, but using the specified schema for rows + * to be written. + * + *

    Does not modify this object. + */ + public Bound withSchema(TableSchema schema) { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Returns a copy of this write transformation, but using the specified create disposition. + * + *

    Does not modify this object. + */ + public Bound withCreateDisposition(CreateDisposition createDisposition) { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Returns a copy of this write transformation, but using the specified write disposition. + * + *

    Does not modify this object. + */ + public Bound withWriteDisposition(WriteDisposition writeDisposition) { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, validate); + } + + /** + * Returns a copy of this write transformation, but without BigQuery table validation. + * + *

    Does not modify this object. + */ + public Bound withoutValidation() { + return new Bound(name, table, tableRefFunction, schema, createDisposition, + writeDisposition, false); + } + + private static void verifyTableEmpty( + BigQueryOptions options, + TableReference table) { + try { + Bigquery client = Transport.newBigQueryClient(options).build(); + BigQueryTableInserter inserter = new BigQueryTableInserter(client); + if (!inserter.isEmpty(table)) { + throw new IllegalArgumentException( + "BigQuery table is not empty: " + BigQueryIO.toTableSpec(table)); + } + } catch (IOException e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if (errorExtractor.itemNotFound(e)) { + // Nothing to do. If the table does not exist, it is considered empty. + } else { + throw new RuntimeException( + "unable to confirm BigQuery table emptiness for table " + + BigQueryIO.toTableSpec(table), e); + } + } + } + + @Override + public PDone apply(PCollection input) { + BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class); + + if (table == null && tableRefFunction == null) { + throw new IllegalStateException( + "must set the table reference of a BigQueryIO.Write transform"); + } + if (table != null && tableRefFunction != null) { + throw new IllegalStateException( + "Cannot set both a table reference and a table function for a BigQueryIO.Write " + + "transform"); + } + + if (createDisposition == CreateDisposition.CREATE_IF_NEEDED && schema == null) { + throw new IllegalArgumentException("CreateDisposition is CREATE_IF_NEEDED, " + + "however no schema was provided."); + } + + if (table != null && table.getProjectId() == null) { + // If user does not specify a project we assume the table to be located in the project + // that owns the Dataflow job. + String projectIdFromOptions = options.getProject(); + LOG.warn(String.format(BigQueryIO.SET_PROJECT_FROM_OPTIONS_WARNING, table.getDatasetId(), + table.getTableId(), projectIdFromOptions)); + table.setProjectId(projectIdFromOptions); + } + + // Check for destination table presence and emptiness for early failure notification. + // Note that a presence check can fail if the table or dataset are created by earlier stages + // of the pipeline. For these cases the withoutValidation method can be used to disable + // the check. + // Unfortunately we can't validate anything early in case tableRefFunction is specified. + if (table != null && validate) { + verifyDatasetPresence(options, table); + if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) { + verifyTablePresence(options, table); + } + if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) { + verifyTableEmpty(options, table); + } + } + + // In streaming, BigQuery write is taken care of by StreamWithDeDup transform. + // We also currently do this if a tablespec function is specified. + if (options.isStreaming() || tableRefFunction != null) { + if (createDisposition == CreateDisposition.CREATE_NEVER) { + throw new IllegalArgumentException("CreateDispostion.CREATE_NEVER is not " + + "supported for unbounded PCollections or when using tablespec functions."); + } + + if (writeDisposition == WriteDisposition.WRITE_TRUNCATE) { + throw new IllegalArgumentException("WriteDisposition.WRITE_TRUNCATE is not " + + "supported for unbounded PCollections or when using tablespec functions."); + } + + return input.apply(new StreamWithDeDup(table, tableRefFunction, schema)); + } + + return PDone.in(input.getPipeline()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, DirectPipelineRunner.EvaluationContext context) { + evaluateWriteHelper(transform, context); + } + }); + } + + /** Returns the create disposition. */ + public CreateDisposition getCreateDisposition() { + return createDisposition; + } + + /** Returns the write disposition. */ + public WriteDisposition getWriteDisposition() { + return writeDisposition; + } + + /** Returns the table schema. */ + public TableSchema getSchema() { + return schema; + } + + /** Returns the table reference, or {@code null} if a . */ + public TableReference getTable() { + return table; + } + + /** Returns {@code true} if table validation is enabled. */ + public boolean getValidate() { + return validate; + } + } + + /** Disallow construction of utility class. */ + private Write() {} + } + + private static void verifyDatasetPresence(BigQueryOptions options, TableReference table) { + try { + Bigquery client = Transport.newBigQueryClient(options).build(); + BigQueryTableRowIterator.executeWithBackOff( + client.datasets().get(table.getProjectId(), table.getDatasetId()), + RESOURCE_NOT_FOUND_ERROR, "dataset", BigQueryIO.toTableSpec(table)); + } catch (Exception e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if ((e instanceof IOException) && errorExtractor.itemNotFound((IOException) e)) { + throw new IllegalArgumentException( + String.format(RESOURCE_NOT_FOUND_ERROR, "dataset", BigQueryIO.toTableSpec(table)), + e); + } else { + throw new RuntimeException( + String.format(UNABLE_TO_CONFIRM_PRESENCE_OF_RESOURCE_ERROR, "dataset", + BigQueryIO.toTableSpec(table)), + e); + } + } + } + + private static void verifyTablePresence(BigQueryOptions options, TableReference table) { + try { + Bigquery client = Transport.newBigQueryClient(options).build(); + BigQueryTableRowIterator.executeWithBackOff( + client.tables().get(table.getProjectId(), table.getDatasetId(), table.getTableId()), + RESOURCE_NOT_FOUND_ERROR, "table", BigQueryIO.toTableSpec(table)); + } catch (Exception e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if ((e instanceof IOException) && errorExtractor.itemNotFound((IOException) e)) { + throw new IllegalArgumentException( + String.format(RESOURCE_NOT_FOUND_ERROR, "table", BigQueryIO.toTableSpec(table)), e); + } else { + throw new RuntimeException( + String.format(UNABLE_TO_CONFIRM_PRESENCE_OF_RESOURCE_ERROR, "table", + BigQueryIO.toTableSpec(table)), + e); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Implementation of DoFn to perform streaming BigQuery write. + */ + @SystemDoFnInternal + private static class StreamingWriteFn + extends DoFn, TableRowInfo>, Void> { + /** TableSchema in JSON. Use String to make the class Serializable. */ + private final String jsonTableSchema; + + /** JsonTableRows to accumulate BigQuery rows in order to batch writes. */ + private transient Map> tableRows; + + /** The list of unique ids for each BigQuery table row. */ + private transient Map> uniqueIdsForTableRows; + + /** The list of tables created so far, so we don't try the creation + each time. */ + private static Set createdTables = + Collections.newSetFromMap(new ConcurrentHashMap()); + + /** Tracks bytes written, exposed as "ByteCount" Counter. */ + private Aggregator byteCountAggregator = + createAggregator("ByteCount", new Sum.SumLongFn()); + + /** Constructor. */ + StreamingWriteFn(TableSchema schema) { + try { + jsonTableSchema = JSON_FACTORY.toString(schema); + } catch (IOException e) { + throw new RuntimeException("Cannot initialize BigQuery streaming writer.", e); + } + } + + /** Prepares a target BigQuery table. */ + @Override + public void startBundle(Context context) { + tableRows = new HashMap<>(); + uniqueIdsForTableRows = new HashMap<>(); + } + + /** Accumulates the input into JsonTableRows and uniqueIdsForTableRows. */ + @Override + public void processElement(ProcessContext context) { + String tableSpec = context.element().getKey().getKey(); + List rows = getOrCreateMapListValue(tableRows, tableSpec); + List uniqueIds = getOrCreateMapListValue(uniqueIdsForTableRows, tableSpec); + + rows.add(context.element().getValue().tableRow); + uniqueIds.add(context.element().getValue().uniqueId); + } + + /** Writes the accumulated rows into BigQuery with streaming API. */ + @Override + public void finishBundle(Context context) throws Exception { + BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class); + Bigquery client = Transport.newBigQueryClient(options).build(); + + for (String tableSpec : tableRows.keySet()) { + TableReference tableReference = getOrCreateTable(options, tableSpec); + flushRows(client, tableReference, tableRows.get(tableSpec), + uniqueIdsForTableRows.get(tableSpec)); + } + tableRows.clear(); + uniqueIdsForTableRows.clear(); + } + + public TableReference getOrCreateTable(BigQueryOptions options, String tableSpec) + throws IOException { + TableReference tableReference = parseTableSpec(tableSpec); + if (!createdTables.contains(tableSpec)) { + synchronized (createdTables) { + // Another thread may have succeeded in creating the table in the meanwhile, so + // check again. This check isn't needed for correctness, but we add it to prevent + // every thread from attempting a create and overwhelming our BigQuery quota. + if (!createdTables.contains(tableSpec)) { + TableSchema tableSchema = JSON_FACTORY.fromString(jsonTableSchema, TableSchema.class); + Bigquery client = Transport.newBigQueryClient(options).build(); + BigQueryTableInserter inserter = new BigQueryTableInserter(client); + inserter.getOrCreateTable(tableReference, WriteDisposition.WRITE_APPEND, + CreateDisposition.CREATE_IF_NEEDED, tableSchema); + createdTables.add(tableSpec); + } + } + } + return tableReference; + } + + /** Writes the accumulated rows into BigQuery with streaming API. */ + private void flushRows(Bigquery client, TableReference tableReference, + List tableRows, List uniqueIds) { + if (!tableRows.isEmpty()) { + try { + BigQueryTableInserter inserter = new BigQueryTableInserter(client); + inserter.insertAll(tableReference, tableRows, uniqueIds, byteCountAggregator); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + } + + private static class ShardedKey { + private final K key; + private final int shardNumber; + + public static ShardedKey of(K key, int shardNumber) { + return new ShardedKey(key, shardNumber); + } + + private ShardedKey(K key, int shardNumber) { + this.key = key; + this.shardNumber = shardNumber; + } + + public K getKey() { + return key; + } + + public int getShardNumber() { + return shardNumber; + } + } + + /** + * A {@link Coder} for {@link ShardedKey}, using a wrapped key {@link Coder}. + */ + private static class ShardedKeyCoder + extends StandardCoder> { + public static ShardedKeyCoder of(Coder keyCoder) { + return new ShardedKeyCoder<>(keyCoder); + } + + @JsonCreator + public static ShardedKeyCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of(components.get(0)); + } + + protected ShardedKeyCoder(Coder keyCoder) { + this.keyCoder = keyCoder; + this.shardNumberCoder = VarIntCoder.of(); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(keyCoder); + } + + @Override + public void encode(ShardedKey key, OutputStream outStream, Context context) + throws IOException { + keyCoder.encode(key.getKey(), outStream, context.nested()); + shardNumberCoder.encode(key.getShardNumber(), outStream, context); + } + + @Override + public ShardedKey decode(InputStream inStream, Context context) + throws IOException { + return new ShardedKey( + keyCoder.decode(inStream, context.nested()), + shardNumberCoder.decode(inStream, context)); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + keyCoder.verifyDeterministic(); + } + + Coder keyCoder; + VarIntCoder shardNumberCoder; + } + + private static class TableRowInfoCoder extends AtomicCoder { + private static final TableRowInfoCoder INSTANCE = new TableRowInfoCoder(); + + @JsonCreator + public static TableRowInfoCoder of() { + return INSTANCE; + } + + @Override + public void encode(TableRowInfo value, OutputStream outStream, Context context) + throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null value"); + } + tableRowCoder.encode(value.tableRow, outStream, context.nested()); + idCoder.encode(value.uniqueId, outStream, context.nested()); + } + + @Override + public TableRowInfo decode(InputStream inStream, Context context) + throws IOException { + return new TableRowInfo( + tableRowCoder.decode(inStream, context.nested()), + idCoder.decode(inStream, context.nested())); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, "TableRows are not deterministic."); + } + + TableRowJsonCoder tableRowCoder = TableRowJsonCoder.of(); + StringUtf8Coder idCoder = StringUtf8Coder.of(); + } + + private static class TableRowInfo { + TableRowInfo(TableRow tableRow, String uniqueId) { + this.tableRow = tableRow; + this.uniqueId = uniqueId; + } + + final TableRow tableRow; + final String uniqueId; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Fn that tags each table row with a unique id and destination table. + * To avoid calling UUID.randomUUID() for each element, which can be costly, + * a randomUUID is generated only once per bucket of data. The actual unique + * id is created by concatenating this randomUUID with a sequential number. + */ + private static class TagWithUniqueIdsAndTable + extends DoFn, TableRowInfo>> + implements DoFn.RequiresWindowAccess { + /** TableSpec to write to. */ + private final String tableSpec; + + /** User function mapping windows to {@link TableReference} in JSON. */ + private final SerializableFunction tableRefFunction; + + private transient String randomUUID; + private transient long sequenceNo = 0L; + + TagWithUniqueIdsAndTable(BigQueryOptions options, TableReference table, + SerializableFunction tableRefFunction) { + Preconditions.checkArgument(table == null ^ tableRefFunction == null, + "Exactly one of table or tableRefFunction should be set"); + if (table != null) { + if (table.getProjectId() == null) { + table.setProjectId(options.as(BigQueryOptions.class).getProject()); + } + this.tableSpec = toTableSpec(table); + } else { + tableSpec = null; + } + this.tableRefFunction = tableRefFunction; + } + + + @Override + public void startBundle(Context context) { + randomUUID = UUID.randomUUID().toString(); + } + + /** Tag the input with a unique id. */ + @Override + public void processElement(ProcessContext context) throws IOException { + String uniqueId = randomUUID + sequenceNo++; + ThreadLocalRandom randomGenerator = ThreadLocalRandom.current(); + String tableSpec = tableSpecFromWindow( + context.getPipelineOptions().as(BigQueryOptions.class), context.window()); + // We output on keys 0-50 to ensure that there's enough batching for + // BigQuery. + context.output(KV.of(ShardedKey.of(tableSpec, randomGenerator.nextInt(0, 50)), + new TableRowInfo(context.element(), uniqueId))); + } + + private String tableSpecFromWindow(BigQueryOptions options, BoundedWindow window) { + if (tableSpec != null) { + return tableSpec; + } else { + TableReference table = tableRefFunction.apply(window); + if (table.getProjectId() == null) { + table.setProjectId(options.getProject()); + } + return toTableSpec(table); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * PTransform that performs streaming BigQuery write. To increase consistency, + * it leverages BigQuery best effort de-dup mechanism. + */ + private static class StreamWithDeDup extends PTransform, PDone> { + private final transient TableReference tableReference; + private final SerializableFunction tableRefFunction; + private final transient TableSchema tableSchema; + + /** Constructor. */ + StreamWithDeDup(TableReference tableReference, + SerializableFunction tableRefFunction, + TableSchema tableSchema) { + this.tableReference = tableReference; + this.tableRefFunction = tableRefFunction; + this.tableSchema = tableSchema; + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + @Override + public PDone apply(PCollection input) { + // A naive implementation would be to simply stream data directly to BigQuery. + // However, this could occasionally lead to duplicated data, e.g., when + // a VM that runs this code is restarted and the code is re-run. + + // The above risk is mitigated in this implementation by relying on + // BigQuery built-in best effort de-dup mechanism. + + // To use this mechanism, each input TableRow is tagged with a generated + // unique id, which is then passed to BigQuery and used to ignore duplicates. + + PCollection, TableRowInfo>> tagged = input.apply(ParDo.of( + new TagWithUniqueIdsAndTable(input.getPipeline().getOptions().as(BigQueryOptions.class), + tableReference, tableRefFunction))); + + // To prevent having the same TableRow processed more than once with regenerated + // different unique ids, this implementation relies on "checkpointing", which is + // achieved as a side effect of having StreamingWriteFn immediately follow a GBK, + // performed by Reshuffle. + tagged + .setCoder(KvCoder.of(ShardedKeyCoder.of(StringUtf8Coder.of()), TableRowInfoCoder.of())) + .apply(Reshuffle., TableRowInfo>of()) + .apply(ParDo.of(new StreamingWriteFn(tableSchema))); + + // Note that the implementation to return PDone here breaks the + // implicit assumption about the job execution order. If a user + // implements a PTransform that takes PDone returned here as its + // input, the transform may not necessarily be executed after + // the BigQueryIO.Write. + + return PDone.in(input.getPipeline()); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Disallow construction of utility class. */ + private BigQueryIO() {} + + /** + * Direct mode read evaluator. + * + *

    This loads the entire table into an in-memory PCollection. + */ + private static void evaluateReadHelper( + Read.Bound transform, DirectPipelineRunner.EvaluationContext context) { + BigQueryOptions options = context.getPipelineOptions(); + Bigquery client = Transport.newBigQueryClient(options).build(); + if (transform.table != null && transform.table.getProjectId() == null) { + transform.table.setProjectId(options.getProject()); + } + + BigQueryTableRowIterator iterator; + if (transform.query != null) { + LOG.info("Reading from BigQuery query {}", transform.query); + iterator = + BigQueryTableRowIterator.fromQuery( + transform.query, options.getProject(), client, transform.getFlattenResults()); + } else { + LOG.info("Reading from BigQuery table {}", toTableSpec(transform.table)); + iterator = BigQueryTableRowIterator.fromTable(transform.table, client); + } + + try (BigQueryTableRowIterator ignored = iterator) { + List elems = new ArrayList<>(); + iterator.open(); + while (iterator.advance()) { + elems.add(iterator.getCurrent()); + } + LOG.info("Number of records read from BigQuery: {}", elems.size()); + context.setPCollection(context.getOutput(transform), elems); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + private static List getOrCreateMapListValue(Map> map, K key) { + List value = map.get(key); + if (value == null) { + value = new ArrayList<>(); + map.put(key, value); + } + return value; + } + + /** + * Direct mode write evaluator. + * + *

    This writes the entire table in a single BigQuery request. + * The table will be created if necessary. + */ + private static void evaluateWriteHelper( + Write.Bound transform, DirectPipelineRunner.EvaluationContext context) { + BigQueryOptions options = context.getPipelineOptions(); + Bigquery client = Transport.newBigQueryClient(options).build(); + BigQueryTableInserter inserter = new BigQueryTableInserter(client); + + try { + Map> tableRows = new HashMap<>(); + for (WindowedValue windowedValue : context.getPCollectionWindowedValues( + context.getInput(transform))) { + for (BoundedWindow window : windowedValue.getWindows()) { + TableReference ref; + if (transform.tableRefFunction != null) { + ref = transform.tableRefFunction.apply(window); + } else { + ref = transform.table; + } + if (ref.getProjectId() == null) { + ref.setProjectId(options.getProject()); + } + + List rows = getOrCreateMapListValue(tableRows, ref); + rows.add(windowedValue.getValue()); + } + } + + for (TableReference ref : tableRows.keySet()) { + LOG.info("Writing to BigQuery table {}", toTableSpec(ref)); + // {@link BigQueryTableInserter#getOrCreateTable} validates {@link CreateDisposition} + // and {@link WriteDisposition}. + // For each {@link TableReference}, it can only be called before rows are written. + inserter.getOrCreateTable( + ref, transform.writeDisposition, transform.createDisposition, transform.schema); + inserter.insertAll(ref, tableRows.get(ref)); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BlockBasedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BlockBasedSource.java new file mode 100644 index 000000000000..f4a9c7db0710 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BlockBasedSource.java @@ -0,0 +1,237 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.io.IOException; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * A {@code BlockBasedSource} is a {@link FileBasedSource} where a file consists of blocks of + * records. + * + *

    {@code BlockBasedSource} should be derived from when a file format does not support efficient + * seeking to a record in the file, but can support efficient seeking to a block. Alternatively, + * records in the file cannot be offset-addressed, but blocks can (it is not possible to say + * that record {code i} starts at offset {@code m}, but it is possible to say that block {@code j} + * starts at offset {@code n}). + * + *

    The records that will be read from a {@code BlockBasedSource} that corresponds to a subrange + * of a file {@code [startOffset, endOffset)} are those records such that the record is contained in + * a block that starts at offset {@code i}, where {@code i >= startOffset} and + * {@code i < endOffset}. In other words, a record will be read from the source if its first byte is + * contained in a block that begins within the range described by the source. + * + *

    This entails that it is possible to determine the start offsets of all blocks in a file. + * + *

    Progress reporting for reading from a {@code BlockBasedSource} is inaccurate. A {@link + * BlockBasedReader} reports its current offset as {@code (offset of current block) + (current block + * size) * (fraction of block consumed)}. However, only the offset of the current block is required + * to be accurately reported by subclass implementations. As such, in the worst case, the current + * offset is only updated at block boundaries. + * + *

    {@code BlockBasedSource} supports dynamic splitting. However, because records in a {@code + * BlockBasedSource} are not required to have offsets and progress reporting is inaccurate, {@code + * BlockBasedReader} only supports splitting at block boundaries. + * In other words, {@link BlockBasedReader#atSplitPoint} returns true iff the current record is the + * first record in a block. See {@link FileBasedSource.FileBasedReader} for discussion about split + * points. + * + * @param The type of records to be read from the source. + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public abstract class BlockBasedSource extends FileBasedSource { + /** + * Creates a {@code BlockBasedSource} based on a file name or pattern. Subclasses must call this + * constructor when creating a {@code BlockBasedSource} for a file pattern. See + * {@link FileBasedSource} for more information. + */ + public BlockBasedSource(String fileOrPatternSpec, long minBundleSize) { + super(fileOrPatternSpec, minBundleSize); + } + + /** + * Creates a {@code BlockBasedSource} for a single file. Subclasses must call this constructor + * when implementing {@link BlockBasedSource#createForSubrangeOfFile}. See documentation in + * {@link FileBasedSource}. + */ + public BlockBasedSource(String fileName, long minBundleSize, long startOffset, long endOffset) { + super(fileName, minBundleSize, startOffset, endOffset); + } + + /** + * Creates a {@code BlockBasedSource} for the specified range in a single file. + */ + @Override + protected abstract BlockBasedSource createForSubrangeOfFile( + String fileName, long start, long end); + + /** + * Creates a {@code BlockBasedReader}. + */ + @Override + protected abstract BlockBasedReader createSingleFileReader(PipelineOptions options); + + /** + * A {@code Block} represents a block of records that can be read. + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + protected abstract static class Block { + /** + * Returns the current record. + */ + public abstract T getCurrentRecord(); + + /** + * Reads the next record from the block and returns true iff one exists. + */ + public abstract boolean readNextRecord() throws IOException; + + /** + * Returns the fraction of the block already consumed, if possible, as a value in + * {@code [0, 1]}. It should not include the current record. Successive results from this method + * must be monotonically increasing. + * + *

    If it is not possible to compute the fraction of the block consumed this method may + * return zero. For example, when the total number of records in the block is unknown. + */ + public abstract double getFractionOfBlockConsumed(); + } + + /** + * A {@code Reader} that reads records from a {@link BlockBasedSource}. If the source is a + * subrange of a file, the blocks that will be read by this reader are those such that the first + * byte of the block is within the range {@code [start, end)}. + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + protected abstract static class BlockBasedReader extends FileBasedReader { + private boolean atSplitPoint; + + protected BlockBasedReader(BlockBasedSource source) { + super(source); + } + + /** + * Read the next block from the input. + */ + public abstract boolean readNextBlock() throws IOException; + + /** + * Returns the current block (the block that was read by the last successful call to + * {@link BlockBasedReader#readNextBlock}). May return null initially, or if no block has been + * successfully read. + */ + @Nullable + public abstract Block getCurrentBlock(); + + /** + * Returns the size of the current block in bytes as it is represented in the underlying file, + * if possible. This method may return {@code 0} if the size of the current block is unknown. + * + *

    The size returned by this method must be such that for two successive blocks A and B, + * {@code offset(A) + size(A) <= offset(B)}. If this is not satisfied, the progress reported + * by the {@code BlockBasedReader} will be non-monotonic and will interfere with the quality + * (but not correctness) of dynamic work rebalancing. + * + *

    This method and {@link Block#getFractionOfBlockConsumed} are used to provide an estimate + * of progress within a block ({@code getCurrentBlock().getFractionOfBlockConsumed() * + * getCurrentBlockSize()}). It is acceptable for the result of this computation to be {@code 0}, + * but progress estimation will be inaccurate. + */ + public abstract long getCurrentBlockSize(); + + /** + * Returns the largest offset such that starting to read from that offset includes the current + * block. + */ + public abstract long getCurrentBlockOffset(); + + @Override + public final T getCurrent() throws NoSuchElementException { + Block currentBlock = getCurrentBlock(); + if (currentBlock == null) { + throw new NoSuchElementException( + "No block has been successfully read from " + getCurrentSource()); + } + return currentBlock.getCurrentRecord(); + } + + /** + * Returns true if the reader is at a split point. A {@code BlockBasedReader} is at a split + * point if the current record is the first record in a block. In other words, split points + * are block boundaries. + */ + @Override + protected boolean isAtSplitPoint() { + return atSplitPoint; + } + + /** + * Reads the next record from the {@link #getCurrentBlock() current block} if + * possible. Will call {@link #readNextBlock()} to advance to the next block if not. + * + *

    The first record read from a block is treated as a split point. + */ + @Override + protected final boolean readNextRecord() throws IOException { + atSplitPoint = false; + + while (getCurrentBlock() == null || !getCurrentBlock().readNextRecord()) { + if (!readNextBlock()) { + return false; + } + // The first record in a block is a split point. + atSplitPoint = true; + } + return true; + } + + @Override + public Double getFractionConsumed() { + if (getCurrentSource().getEndOffset() == Long.MAX_VALUE) { + return null; + } + Block currentBlock = getCurrentBlock(); + if (currentBlock == null) { + // There is no current block (i.e., the read has not yet begun). + return 0.0; + } + long currentBlockOffset = getCurrentBlockOffset(); + long startOffset = getCurrentSource().getStartOffset(); + long endOffset = getCurrentSource().getEndOffset(); + double fractionAtBlockStart = + ((double) (currentBlockOffset - startOffset)) / (endOffset - startOffset); + double fractionAtBlockEnd = + ((double) (currentBlockOffset + getCurrentBlockSize() - startOffset) + / (endOffset - startOffset)); + return Math.min( + 1.0, + fractionAtBlockStart + + currentBlock.getFractionOfBlockConsumed() + * (fractionAtBlockEnd - fractionAtBlockStart)); + } + + @Override + protected long getCurrentOffset() { + return getCurrentBlockOffset(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSource.java new file mode 100644 index 000000000000..52c730cc3928 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSource.java @@ -0,0 +1,271 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.util.StringUtils.approximateSimpleName; + +import com.google.api.client.util.BackOff; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.IntervalBoundedExponentialBackOff; +import com.google.cloud.dataflow.sdk.util.ValueWithRecordId; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + + +/** + * {@link PTransform} that reads a bounded amount of data from an {@link UnboundedSource}, + * specified as one or both of a maximum number of elements or a maximum period of time to read. + * + *

    Created by {@link Read}. + */ +class BoundedReadFromUnboundedSource extends PTransform> { + private final UnboundedSource source; + private final long maxNumRecords; + private final Duration maxReadTime; + + /** + * Returns a new {@link BoundedReadFromUnboundedSource} that reads a bounded amount + * of data from the given {@link UnboundedSource}. The bound is specified as a number + * of records to read. + * + *

    This may take a long time to execute if the splits of this source are slow to read + * records. + */ + public BoundedReadFromUnboundedSource withMaxNumRecords(long maxNumRecords) { + return new BoundedReadFromUnboundedSource(source, maxNumRecords, maxReadTime); + } + + /** + * Returns a new {@link BoundedReadFromUnboundedSource} that reads a bounded amount + * of data from the given {@link UnboundedSource}. The bound is specified as an amount + * of time to read for. Each split of the source will read for this much time. + */ + public BoundedReadFromUnboundedSource withMaxReadTime(Duration maxReadTime) { + return new BoundedReadFromUnboundedSource(source, maxNumRecords, maxReadTime); + } + + BoundedReadFromUnboundedSource( + UnboundedSource source, long maxNumRecords, Duration maxReadTime) { + this.source = source; + this.maxNumRecords = maxNumRecords; + this.maxReadTime = maxReadTime; + } + + @Override + public PCollection apply(PInput input) { + PCollection> read = Pipeline.applyTransform(input, + Read.from(new UnboundedToBoundedSourceAdapter<>(source, maxNumRecords, maxReadTime))); + if (source.requiresDeduping()) { + read = read.apply(RemoveDuplicates.withRepresentativeValueFn( + new SerializableFunction, byte[]>() { + @Override + public byte[] apply(ValueWithRecordId input) { + return input.getId(); + } + })); + } + return read.apply(ValueWithRecordId.stripIds()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return source.getDefaultOutputCoder(); + } + + @Override + public String getKindString() { + return "Read(" + approximateSimpleName(source.getClass()) + ")"; + } + + private static class UnboundedToBoundedSourceAdapter + extends BoundedSource> { + private final UnboundedSource source; + private final long maxNumRecords; + private final Duration maxReadTime; + + private UnboundedToBoundedSourceAdapter( + UnboundedSource source, long maxNumRecords, Duration maxReadTime) { + this.source = source; + this.maxNumRecords = maxNumRecords; + this.maxReadTime = maxReadTime; + } + + /** + * Divide the given number of records into {@code numSplits} approximately + * equal parts that sum to {@code numRecords}. + */ + private static long[] splitNumRecords(long numRecords, int numSplits) { + long[] splitNumRecords = new long[numSplits]; + for (int i = 0; i < numSplits; i++) { + splitNumRecords[i] = numRecords / numSplits; + } + for (int i = 0; i < numRecords % numSplits; i++) { + splitNumRecords[i] = splitNumRecords[i] + 1; + } + return splitNumRecords; + } + + /** + * Pick a number of initial splits based on the number of records expected to be processed. + */ + private static int numInitialSplits(long numRecords) { + final int maxSplits = 100; + final long recordsPerSplit = 10000; + return (int) Math.min(maxSplits, numRecords / recordsPerSplit + 1); + } + + @Override + public List>> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + List> result = new ArrayList<>(); + int numInitialSplits = numInitialSplits(maxNumRecords); + List> splits = + source.generateInitialSplits(numInitialSplits, options); + int numSplits = splits.size(); + long[] numRecords = splitNumRecords(maxNumRecords, numSplits); + for (int i = 0; i < numSplits; i++) { + result.add( + new UnboundedToBoundedSourceAdapter(splits.get(i), numRecords[i], maxReadTime)); + } + return result; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) { + // No way to estimate bytes, so returning 0. + return 0L; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) { + return false; + } + + @Override + public Coder> getDefaultOutputCoder() { + return ValueWithRecordId.ValueWithRecordIdCoder.of(source.getDefaultOutputCoder()); + } + + @Override + public void validate() { + source.validate(); + } + + @Override + public BoundedReader> createReader(PipelineOptions options) { + return new Reader(source.createReader(options, null)); + } + + private class Reader extends BoundedReader> { + private long recordsRead = 0L; + private Instant endTime = Instant.now().plus(maxReadTime); + private UnboundedSource.UnboundedReader reader; + + private Reader(UnboundedSource.UnboundedReader reader) { + this.recordsRead = 0L; + if (maxReadTime != null) { + this.endTime = Instant.now().plus(maxReadTime); + } else { + this.endTime = null; + } + this.reader = reader; + } + + @Override + public boolean start() throws IOException { + if (maxNumRecords <= 0 || (maxReadTime != null && maxReadTime.getMillis() == 0)) { + return false; + } + + recordsRead++; + if (reader.start()) { + return true; + } else { + return advanceWithBackoff(); + } + } + + @Override + public boolean advance() throws IOException { + if (recordsRead >= maxNumRecords) { + finalizeCheckpoint(); + return false; + } + recordsRead++; + return advanceWithBackoff(); + } + + private boolean advanceWithBackoff() throws IOException { + // Try reading from the source with exponential backoff + BackOff backoff = new IntervalBoundedExponentialBackOff(10000, 10); + long nextSleep = backoff.nextBackOffMillis(); + while (nextSleep != BackOff.STOP) { + if (endTime != null && Instant.now().isAfter(endTime)) { + finalizeCheckpoint(); + return false; + } + if (reader.advance()) { + return true; + } + try { + Thread.sleep(nextSleep); + } catch (InterruptedException e) {} + nextSleep = backoff.nextBackOffMillis(); + } + finalizeCheckpoint(); + return false; + } + + private void finalizeCheckpoint() throws IOException { + reader.getCheckpointMark().finalizeCheckpoint(); + } + + @Override + public ValueWithRecordId getCurrent() throws NoSuchElementException { + return new ValueWithRecordId<>(reader.getCurrent(), reader.getCurrentRecordId()); + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return reader.getCurrentTimestamp(); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + @Override + public BoundedSource> getCurrentSource() { + return UnboundedToBoundedSourceAdapter.this; + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BoundedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BoundedSource.java new file mode 100644 index 000000000000..be3a415cff93 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/BoundedSource.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * A {@link Source} that reads a finite amount of input and, because of that, supports + * some additional operations. + * + *

    The operations are: + *

      + *
    • Splitting into bundles of given size: {@link #splitIntoBundles}; + *
    • Size estimation: {@link #getEstimatedSizeBytes}; + *
    • Telling whether or not this source produces key/value pairs in sorted order: + * {@link #producesSortedKeys}; + *
    • The reader ({@link BoundedReader}) supports progress estimation + * ({@link BoundedReader#getFractionConsumed}) and dynamic splitting + * ({@link BoundedReader#splitAtFraction}). + *
    + * + *

    To use this class for supporting your custom input type, derive your class + * class from it, and override the abstract methods. For an example, see {@link DatastoreIO}. + * + * @param Type of records read by the source. + */ +public abstract class BoundedSource extends Source { + /** + * Splits the source into bundles of approximately {@code desiredBundleSizeBytes}. + */ + public abstract List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception; + + /** + * An estimate of the total size (in bytes) of the data that would be read from this source. + * This estimate is in terms of external storage size, before any decompression or other + * processing done by the reader. + */ + public abstract long getEstimatedSizeBytes(PipelineOptions options) throws Exception; + + /** + * Whether this source is known to produce key/value pairs sorted by lexicographic order on + * the bytes of the encoded key. + */ + public abstract boolean producesSortedKeys(PipelineOptions options) throws Exception; + + /** + * Returns a new {@link BoundedReader} that reads from this source. + */ + public abstract BoundedReader createReader(PipelineOptions options) throws IOException; + + /** + * A {@code Reader} that reads a bounded amount of input and supports some additional + * operations, such as progress estimation and dynamic work rebalancing. + * + *

    Boundedness

    + *

    Once {@link #start} or {@link #advance} has returned false, neither will be called + * again on this object. + * + *

    Thread safety

    + * All methods will be run from the same thread except {@link #splitAtFraction}, + * {@link #getFractionConsumed} and {@link #getCurrentSource}, which can be called concurrently + * from a different thread. There will not be multiple concurrent calls to + * {@link #splitAtFraction} but there can be for {@link #getFractionConsumed} if + * {@link #splitAtFraction} is implemented. + * + *

    If the source does not implement {@link #splitAtFraction}, you do not need to worry about + * thread safety. If implemented, it must be safe to call {@link #splitAtFraction} and + * {@link #getFractionConsumed} concurrently with other methods. + * + *

    Additionally, a successful {@link #splitAtFraction} call must, by definition, cause + * {@link #getCurrentSource} to start returning a different value. + * Callers of {@link #getCurrentSource} need to be aware of the possibility that the returned + * value can change at any time, and must only access the properties of the source returned by + * {@link #getCurrentSource} which do not change between {@link #splitAtFraction} calls. + * + *

    Implementing {@link #splitAtFraction}

    + * In the course of dynamic work rebalancing, the method {@link #splitAtFraction} + * may be called concurrently with {@link #advance} or {@link #start}. It is critical that + * their interaction is implemented in a thread-safe way, otherwise data loss is possible. + * + *

    Sources which support dynamic work rebalancing should use + * {@link com.google.cloud.dataflow.sdk.io.range.RangeTracker} to manage the (source-specific) + * range of positions that is being split. If your source supports dynamic work rebalancing, + * please use that class to implement it if possible; if not possible, please contact the team + * at dataflow-feedback@google.com. + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + public abstract static class BoundedReader extends Source.Reader { + /** + * Returns a value in [0, 1] representing approximately what fraction of the + * {@link #getCurrentSource current source} this reader has read so far, or {@code null} if such + * an estimate is not available. + * + *

    It is recommended that this method should satisfy the following properties: + *

      + *
    • Should return 0 before the {@link #start} call. + *
    • Should return 1 after a {@link #start} or {@link #advance} call that returns false. + *
    • The returned values should be non-decreasing (though they don't have to be unique). + *
    + * + *

    By default, returns null to indicate that this cannot be estimated. + * + *

    Thread safety
    + * If {@link #splitAtFraction} is implemented, this method can be called concurrently to other + * methods (including itself), and it is therefore critical for it to be implemented + * in a thread-safe way. + */ + public Double getFractionConsumed() { + return null; + } + + /** + * Returns a {@code Source} describing the same input that this {@code Reader} currently reads + * (including items already read). + * + *

    Usage

    + *

    Reader subclasses can use this method for convenience to access unchanging properties of + * the source being read. Alternatively, they can cache these properties in the constructor. + *

    The framework will call this method in the course of dynamic work rebalancing, e.g. after + * a successful {@link BoundedSource.BoundedReader#splitAtFraction} call. + * + *

    Mutability and thread safety

    + * Remember that {@link Source} objects must always be immutable. However, the return value of + * this function may be affected by dynamic work rebalancing, happening asynchronously via + * {@link BoundedSource.BoundedReader#splitAtFraction}, meaning it can return a different + * {@link Source} object. However, the returned object itself will still itself be immutable. + * Callers must take care not to rely on properties of the returned source that may be + * asynchronously changed as a result of this process (e.g. do not cache an end offset when + * reading a file). + * + *

    Implementation

    + * For convenience, subclasses should usually return the most concrete subclass of + * {@link Source} possible. + * In practice, the implementation of this method should nearly always be one of the following: + *
      + *
    • Source that inherits from a base class that already implements + * {@link #getCurrentSource}: delegate to base class. In this case, it is almost always + * an error for the subclass to maintain its own copy of the source. + *
      {@code
      +     *   public FooReader(FooSource source) {
      +     *     super(source);
      +     *   }
      +     *
      +     *   public FooSource getCurrentSource() {
      +     *     return (FooSource)super.getCurrentSource();
      +     *   }
      +     * }
      + *
    • Source that does not support dynamic work rebalancing: return a private final variable. + *
      {@code
      +     *   private final FooSource source;
      +     *
      +     *   public FooReader(FooSource source) {
      +     *     this.source = source;
      +     *   }
      +     *
      +     *   public FooSource getCurrentSource() {
      +     *     return source;
      +     *   }
      +     * }
      + *
    • {@link BoundedSource.BoundedReader} that explicitly supports dynamic work rebalancing: + * maintain a variable pointing to an immutable source object, and protect it with + * synchronization. + *
      {@code
      +     *   private FooSource source;
      +     *
      +     *   public FooReader(FooSource source) {
      +     *     this.source = source;
      +     *   }
      +     *
      +     *   public synchronized FooSource getCurrentSource() {
      +     *     return source;
      +     *   }
      +     *
      +     *   public synchronized FooSource splitAtFraction(double fraction) {
      +     *     ...
      +     *     FooSource primary = ...;
      +     *     FooSource residual = ...;
      +     *     this.source = primary;
      +     *     return residual;
      +     *   }
      +     * }
      + *
    + */ + @Override + public abstract BoundedSource getCurrentSource(); + + /** + * Tells the reader to narrow the range of the input it's going to read and give up + * the remainder, so that the new range would contain approximately the given + * fraction of the amount of data in the current range. + * + *

    Returns a {@code BoundedSource} representing the remainder. + * + *

    Detailed description
    + * Assuming the following sequence of calls: + *
    {@code
    +     *   BoundedSource initial = reader.getCurrentSource();
    +     *   BoundedSource residual = reader.splitAtFraction(fraction);
    +     *   BoundedSource primary = reader.getCurrentSource();
    +     * }
    + *
      + *
    • The "primary" and "residual" sources, when read, should together cover the same + * set of records as "initial". + *
    • The current reader should continue to be in a valid state, and continuing to read + * from it should, together with the records it already read, yield the same records + * as would have been read by "primary". + *
    • The amount of data read by "primary" should ideally represent approximately + * the given fraction of the amount of data read by "initial". + *
    + * For example, a reader that reads a range of offsets [A, B) in a file might implement + * this method by truncating the current range to [A, A + fraction*(B-A)) and returning + * a Source representing the range [A + fraction*(B-A), B). + * + *

    This method should return {@code null} if the split cannot be performed for this fraction + * while satisfying the semantics above. E.g., a reader that reads a range of offsets + * in a file should return {@code null} if it is already past the position in its range + * corresponding to the given fraction. In this case, the method MUST have no effect + * (the reader must behave as if the method hadn't been called at all). + * + *

    Statefulness
    + * Since this method (if successful) affects the reader's source, in subsequent invocations + * "fraction" should be interpreted relative to the new current source. + * + *
    Thread safety and blocking
    + * This method will be called concurrently to other methods (however there will not be multiple + * concurrent invocations of this method itself), and it is critical for it to be implemented + * in a thread-safe way (otherwise data loss is possible). + * + *

    It is also very important that this method always completes quickly. In particular, + * it should not perform or wait on any blocking operations such as I/O, RPCs etc. Violating + * this requirement may stall completion of the work item or even cause it to fail. + * + *

    It is incorrect to make both this method and {@link #start}/{@link #advance} + * {@code synchronized}, because those methods can perform blocking operations, and then + * this method would have to wait for those calls to complete. + * + *

    {@link com.google.cloud.dataflow.sdk.io.range.RangeTracker} makes it easy to implement + * this method safely and correctly. + * + *

    By default, returns null to indicate that splitting is not possible. + */ + public BoundedSource splitAtFraction(double fraction) { + return null; + } + + /** + * By default, returns the minimum possible timestamp. + */ + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return BoundedWindow.TIMESTAMP_MIN_VALUE; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CompressedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CompressedSource.java new file mode 100644 index 000000000000..e3dca9168043 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CompressedSource.java @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; +import com.google.common.io.ByteStreams; +import com.google.common.primitives.Ints; + +import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; + +import java.io.IOException; +import java.io.PushbackInputStream; +import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.NoSuchElementException; +import java.util.zip.GZIPInputStream; + +/** + * A Source that reads from compressed files. A {@code CompressedSources} wraps a delegate + * {@link FileBasedSource} that is able to read the decompressed file format. + * + *

    For example, use the following to read from a gzip-compressed XML file: + * + *

     {@code
    + * XmlSource mySource = XmlSource.from(...);
    + * PCollection collection = p.apply(Read.from(CompressedSource
    + *     .from(mySource)
    + *     .withDecompression(CompressedSource.CompressionMode.GZIP)));
    + * } 
    + * + *

    Supported compression algorithms are {@link CompressionMode#GZIP} and + * {@link CompressionMode#BZIP2}. User-defined compression types are supported by implementing + * {@link DecompressingChannelFactory}. + * + *

    By default, the compression algorithm is selected from those supported in + * {@link CompressionMode} based on the file name provided to the source, namely + * {@code ".bz2"} indicates {@link CompressionMode#BZIP2} and {@code ".gz"} indicates + * {@link CompressionMode#GZIP}. If the file name does not match any of the supported + * algorithms, it is assumed to be uncompressed data. + * + * @param The type to read from the compressed file. + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public class CompressedSource extends FileBasedSource { + /** + * Factory interface for creating channels that decompress the content of an underlying channel. + */ + public static interface DecompressingChannelFactory extends Serializable { + /** + * Given a channel, create a channel that decompresses the content read from the channel. + * @throws IOException + */ + public ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel) + throws IOException; + } + + /** + * Factory interface for creating channels that decompress the content of an underlying channel, + * based on both the channel and the file name. + */ + private static interface FileNameBasedDecompressingChannelFactory + extends DecompressingChannelFactory { + /** + * Given a channel, create a channel that decompresses the content read from the channel. + * @throws IOException + */ + ReadableByteChannel createDecompressingChannel(String fileName, ReadableByteChannel channel) + throws IOException; + + /** + * Given a file name, returns true if the file name matches any supported compression + * scheme. + */ + boolean isCompressed(String fileName); + } + + /** + * Default compression types supported by the {@code CompressedSource}. + */ + public enum CompressionMode implements DecompressingChannelFactory { + /** + * Reads a byte channel assuming it is compressed with gzip. + */ + GZIP { + @Override + public boolean matches(String fileName) { + return fileName.toLowerCase().endsWith(".gz"); + } + + @Override + public ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel) + throws IOException { + // Determine if the input stream is gzipped. The input stream returned from the + // GCS connector may already be decompressed; GCS does this based on the + // content-encoding property. + PushbackInputStream stream = new PushbackInputStream(Channels.newInputStream(channel), 2); + byte[] headerBytes = new byte[2]; + int bytesRead = ByteStreams.read( + stream /* source */, headerBytes /* dest */, 0 /* offset */, 2 /* len */); + stream.unread(headerBytes, 0, bytesRead); + if (bytesRead >= 2) { + byte zero = 0x00; + int header = Ints.fromBytes(zero, zero, headerBytes[1], headerBytes[0]); + if (header == GZIPInputStream.GZIP_MAGIC) { + return Channels.newChannel(new GzipCompressorInputStream(stream)); + } + } + return Channels.newChannel(stream); + } + }, + + /** + * Reads a byte channel assuming it is compressed with bzip2. + */ + BZIP2 { + @Override + public boolean matches(String fileName) { + return fileName.toLowerCase().endsWith(".bz2"); + } + + @Override + public ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel) + throws IOException { + return Channels.newChannel( + new BZip2CompressorInputStream(Channels.newInputStream(channel))); + } + }; + + /** + * Returns {@code true} if the given file name implies that the contents are compressed + * according to the compression embodied by this factory. + */ + public abstract boolean matches(String fileName); + + @Override + public abstract ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel) + throws IOException; + } + + /** + * Reads a byte channel detecting compression according to the file name. If the filename + * is not any other known {@link CompressionMode}, it is presumed to be uncompressed. + */ + private static class DecompressAccordingToFilename + implements FileNameBasedDecompressingChannelFactory { + + @Override + public ReadableByteChannel createDecompressingChannel( + String fileName, ReadableByteChannel channel) throws IOException { + for (CompressionMode type : CompressionMode.values()) { + if (type.matches(fileName)) { + return type.createDecompressingChannel(channel); + } + } + // Uncompressed + return channel; + } + + @Override + public ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel) { + throw new UnsupportedOperationException( + String.format("%s does not support createDecompressingChannel(%s) but only" + + " createDecompressingChannel(%s,%s)", + getClass().getSimpleName(), + String.class.getSimpleName(), + ReadableByteChannel.class.getSimpleName(), + ReadableByteChannel.class.getSimpleName())); + } + + @Override + public boolean isCompressed(String fileName) { + for (CompressionMode type : CompressionMode.values()) { + if (type.matches(fileName)) { + return true; + } + } + return false; + } + } + + private final FileBasedSource sourceDelegate; + private final DecompressingChannelFactory channelFactory; + + /** + * Creates a {@link Read} transform that reads from that reads from the underlying + * {@link FileBasedSource} {@code sourceDelegate} after decompressing it with a {@link + * DecompressingChannelFactory}. + */ + public static Read.Bounded readFromSource( + FileBasedSource sourceDelegate, DecompressingChannelFactory channelFactory) { + return Read.from(new CompressedSource<>(sourceDelegate, channelFactory)); + } + + /** + * Creates a {@code CompressedSource} from an underlying {@code FileBasedSource}. The type + * of compression used will be based on the file name extension unless explicitly + * configured via {@link CompressedSource#withDecompression}. + */ + public static CompressedSource from(FileBasedSource sourceDelegate) { + return new CompressedSource<>(sourceDelegate, new DecompressAccordingToFilename()); + } + + /** + * Return a {@code CompressedSource} that is like this one but will decompress its underlying file + * with the given {@link DecompressingChannelFactory}. + */ + public CompressedSource withDecompression(DecompressingChannelFactory channelFactory) { + return new CompressedSource<>(this.sourceDelegate, channelFactory); + } + + /** + * Creates a {@code CompressedSource} from a delegate file based source and a decompressing + * channel factory. + */ + private CompressedSource( + FileBasedSource sourceDelegate, DecompressingChannelFactory channelFactory) { + super(sourceDelegate.getFileOrPatternSpec(), Long.MAX_VALUE); + this.sourceDelegate = sourceDelegate; + this.channelFactory = channelFactory; + } + + /** + * Creates a {@code CompressedSource} for an individual file. Used by {@link + * CompressedSource#createForSubrangeOfFile}. + */ + private CompressedSource(FileBasedSource sourceDelegate, + DecompressingChannelFactory channelFactory, String filePatternOrSpec, long minBundleSize, + long startOffset, long endOffset) { + super(filePatternOrSpec, minBundleSize, startOffset, endOffset); + Preconditions.checkArgument( + startOffset == 0, + "CompressedSources must start reading at offset 0. Requested offset: " + startOffset); + this.sourceDelegate = sourceDelegate; + this.channelFactory = channelFactory; + } + + /** + * Validates that the delegate source is a valid source and that the channel factory is not null. + */ + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(sourceDelegate); + sourceDelegate.validate(); + Preconditions.checkNotNull(channelFactory); + } + + /** + * Creates a {@code CompressedSource} for a subrange of a file. Called by superclass to create a + * source for a single file. + */ + @Override + protected FileBasedSource createForSubrangeOfFile(String fileName, long start, long end) { + return new CompressedSource<>(sourceDelegate.createForSubrangeOfFile(fileName, start, end), + channelFactory, fileName, Long.MAX_VALUE, start, end); + } + + /** + * Determines whether a single file represented by this source is splittable. Returns true + * if we are using the default decompression factory and and it determines + * from the requested file name that the file is not compressed. + */ + @Override + protected final boolean isSplittable() throws Exception { + if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) { + FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory = + (FileNameBasedDecompressingChannelFactory) channelFactory; + return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec()); + } + return true; + } + + /** + * Creates a {@code FileBasedReader} to read a single file. + * + *

    Uses the delegate source to create a single file reader for the delegate source. + * Utilizes the default decompression channel factory to not wrap the source reader + * if the file name does not represent a compressed file allowing for splitting of + * the source. + */ + @Override + protected final FileBasedReader createSingleFileReader(PipelineOptions options) { + if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) { + FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory = + (FileNameBasedDecompressingChannelFactory) channelFactory; + if (!fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec())) { + return sourceDelegate.createSingleFileReader(options); + } + } + return new CompressedReader( + this, sourceDelegate.createSingleFileReader(options)); + } + + /** + * Returns whether the delegate source produces sorted keys. + */ + @Override + public final boolean producesSortedKeys(PipelineOptions options) throws Exception { + return sourceDelegate.producesSortedKeys(options); + } + + /** + * Returns the delegate source's default output coder. + */ + @Override + public final Coder getDefaultOutputCoder() { + return sourceDelegate.getDefaultOutputCoder(); + } + + public final DecompressingChannelFactory getChannelFactory() { + return channelFactory; + } + + /** + * Reader for a {@link CompressedSource}. Decompresses its input and uses a delegate + * reader to read elements from the decompressed input. + * @param The type of records read from the source. + */ + public static class CompressedReader extends FileBasedReader { + + private final FileBasedReader readerDelegate; + private final CompressedSource source; + private int numRecordsRead; + + /** + * Create a {@code CompressedReader} from a {@code CompressedSource} and delegate reader. + */ + public CompressedReader(CompressedSource source, FileBasedReader readerDelegate) { + super(source); + this.source = source; + this.readerDelegate = readerDelegate; + } + + /** + * Gets the current record from the delegate reader. + */ + @Override + public T getCurrent() throws NoSuchElementException { + return readerDelegate.getCurrent(); + } + + /** + * Returns true only for the first record; compressed sources cannot be split. + */ + @Override + protected final boolean isAtSplitPoint() { + // We have to return true for the first record, but not for the state before reading it, + // and not for the state after reading any other record. Hence == rather than >= or <=. + // This is required because FileBasedReader is intended for readers that can read a range + // of offsets in a file and where the range can be split in parts. CompressedReader, + // however, is a degenerate case because it cannot be split, but it has to satisfy the + // semantics of offsets and split points anyway. + return numRecordsRead == 1; + } + + /** + * Creates a decompressing channel from the input channel and passes it to its delegate reader's + * {@link FileBasedReader#startReading(ReadableByteChannel)}. + */ + @Override + protected final void startReading(ReadableByteChannel channel) throws IOException { + if (source.getChannelFactory() instanceof FileNameBasedDecompressingChannelFactory) { + FileNameBasedDecompressingChannelFactory channelFactory = + (FileNameBasedDecompressingChannelFactory) source.getChannelFactory(); + readerDelegate.startReading(channelFactory.createDecompressingChannel( + getCurrentSource().getFileOrPatternSpec(), + channel)); + } else { + readerDelegate.startReading(source.getChannelFactory().createDecompressingChannel( + channel)); + } + } + + /** + * Reads the next record via the delegate reader. + */ + @Override + protected final boolean readNextRecord() throws IOException { + if (!readerDelegate.readNextRecord()) { + return false; + } + ++numRecordsRead; + return true; + } + + /** + * Returns the delegate reader's current offset in the decompressed input. + */ + @Override + protected final long getCurrentOffset() { + return readerDelegate.getCurrentOffset(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java new file mode 100644 index 000000000000..2938534168ab --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java @@ -0,0 +1,386 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * A source that produces longs. When used as a {@link BoundedSource}, {@link CountingSource} + * starts at {@code 0} and counts up to a specified maximum. When used as an + * {@link UnboundedSource}, it counts up to {@link Long#MAX_VALUE} and then never produces more + * output. (In practice, this limit should never be reached.) + * + *

    The bounded {@link CountingSource} is implemented based on {@link OffsetBasedSource} and + * {@link OffsetBasedSource.OffsetBasedReader}, so it performs efficient initial splitting and it + * supports dynamic work rebalancing. + * + *

    To produce a bounded {@code PCollection}, use {@link CountingSource#upTo(long)}: + * + *

    {@code
    + * Pipeline p = ...
    + * BoundedSource source = CountingSource.upTo(1000);
    + * PCollection bounded = p.apply(Read.from(source));
    + * }
    + * + *

    To produce an unbounded {@code PCollection}, use {@link CountingSource#unbounded} or + * {@link CountingSource#unboundedWithTimestampFn}: + * + *

    {@code
    + * Pipeline p = ...
    + *
    + * // To create an unbounded source that uses processing time as the element timestamp.
    + * UnboundedSource source = CountingSource.unbounded();
    + * // Or, to create an unbounded source that uses a provided function to set the element timestamp.
    + * UnboundedSource source = CountingSource.unboundedWithTimestampFn(someFn);
    + *
    + * PCollection unbounded = p.apply(Read.from(source));
    + * }
    + */ +public class CountingSource { + /** + * Creates a {@link BoundedSource} that will produce the specified number of elements, + * from {@code 0} to {@code numElements - 1}. + */ + public static BoundedSource upTo(long numElements) { + checkArgument(numElements > 0, "numElements (%s) must be greater than 0", numElements); + return new BoundedCountingSource(0, numElements); + } + + /** + * Creates an {@link UnboundedSource} that will produce numbers starting from {@code 0} up to + * {@link Long#MAX_VALUE}. + * + *

    After {@link Long#MAX_VALUE}, the source never produces more output. (In practice, this + * limit should never be reached.) + * + *

    Elements in the resulting {@link PCollection PCollection<Long>} will have timestamps + * corresponding to processing time at element generation, provided by {@link Instant#now}. + */ + public static UnboundedSource unbounded() { + return unboundedWithTimestampFn(new NowTimestampFn()); + } + + /** + * Creates an {@link UnboundedSource} that will produce numbers starting from {@code 0} up to + * {@link Long#MAX_VALUE}, with element timestamps supplied by the specified function. + * + *

    After {@link Long#MAX_VALUE}, the source never produces more output. (In practice, this + * limit should never be reached.) + * + *

    Note that the timestamps produced by {@code timestampFn} may not decrease. + */ + public static UnboundedSource unboundedWithTimestampFn( + SerializableFunction timestampFn) { + return new UnboundedCountingSource(0, 1, timestampFn); + } + + ///////////////////////////////////////////////////////////////////////////////////////////// + + /** Prevent instantiation. */ + private CountingSource() {} + + + /** + * A function that returns {@link Instant#now} as the timestamp for each generated element. + */ + private static class NowTimestampFn implements SerializableFunction { + @Override + public Instant apply(Long input) { + return Instant.now(); + } + } + + /** + * An implementation of {@link CountingSource} that produces a bounded {@link PCollection}. + * It is implemented on top of {@link OffsetBasedSource} (with associated reader + * {@link BoundedCountingReader}) and performs efficient initial splitting and supports dynamic + * work rebalancing. + */ + private static class BoundedCountingSource extends OffsetBasedSource { + /** + * Creates a {@link BoundedCountingSource} that generates the numbers in the specified + * {@code [start, end)} range. + */ + public BoundedCountingSource(long start, long end) { + super(start, end, 1 /* can be split every 1 offset */); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + + @Override + public long getBytesPerOffset() { + return 8; + } + + @Override + public long getMaxEndOffset(PipelineOptions options) throws Exception { + return getEndOffset(); + } + + @Override + public OffsetBasedSource createSourceForSubrange(long start, long end) { + return new BoundedCountingSource(start, end); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return true; + } + + @Override + public com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader createReader( + PipelineOptions options) throws IOException { + return new BoundedCountingReader(this); + } + + @Override + public Coder getDefaultOutputCoder() { + return VarLongCoder.of(); + } + } + + /** + * The reader associated with {@link BoundedCountingSource}. + * + * @see BoundedCountingSource + */ + private static class BoundedCountingReader extends OffsetBasedSource.OffsetBasedReader { + private long current; + + public BoundedCountingReader(OffsetBasedSource source) { + super(source); + } + + @Override + protected long getCurrentOffset() throws NoSuchElementException { + return current; + } + + @Override + public synchronized BoundedCountingSource getCurrentSource() { + return (BoundedCountingSource) super.getCurrentSource(); + } + + @Override + public Long getCurrent() throws NoSuchElementException { + return current; + } + + @Override + protected boolean startImpl() throws IOException { + current = getCurrentSource().getStartOffset(); + return true; + } + + @Override + protected boolean advanceImpl() throws IOException { + current++; + return true; + } + + @Override + public void close() throws IOException {} + } + + /** + * An implementation of {@link CountingSource} that produces an unbounded {@link PCollection}. + */ + private static class UnboundedCountingSource extends UnboundedSource { + /** The first number (>= 0) generated by this {@link UnboundedCountingSource}. */ + private final long start; + /** The interval between numbers generated by this {@link UnboundedCountingSource}. */ + private final long stride; + /** The function used to produce timestamps for the generated elements. */ + private final SerializableFunction timestampFn; + + /** + * Creates an {@link UnboundedSource} that will produce numbers starting from {@code 0} up to + * {@link Long#MAX_VALUE}, with element timestamps supplied by the specified function. + * + *

    After {@link Long#MAX_VALUE}, the source never produces more output. (In practice, this + * limit should never be reached.) + * + *

    Note that the timestamps produced by {@code timestampFn} may not decrease. + */ + public UnboundedCountingSource( + long start, long stride, SerializableFunction timestampFn) { + this.start = start; + this.stride = stride; + this.timestampFn = timestampFn; + } + + /** + * Splits an unbounded source {@code desiredNumSplits} ways by giving each split every + * {@code desiredNumSplits}th element that this {@link UnboundedCountingSource} + * produces. + * + *

    E.g., if a source produces all even numbers {@code [0, 2, 4, 6, 8, ...)} and we want to + * split into 3 new sources, then the new sources will produce numbers that are 6 apart and + * are offset at the start by the original stride: {@code [0, 6, 12, ...)}, + * {@code [2, 8, 14, ...)}, and {@code [4, 10, 16, ...)}. + */ + @Override + public List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception { + // Using Javadoc example, stride 2 with 3 splits becomes stride 6. + long newStride = stride * desiredNumSplits; + + ImmutableList.Builder splits = ImmutableList.builder(); + for (int i = 0; i < desiredNumSplits; ++i) { + // Starts offset by the original stride. Using Javadoc example, this generates starts of + // 0, 2, and 4. + splits.add(new UnboundedCountingSource(start + i * stride, newStride, timestampFn)); + } + return splits.build(); + } + + @Override + public UnboundedReader createReader( + PipelineOptions options, CounterMark checkpointMark) { + return new UnboundedCountingReader(this, checkpointMark); + } + + @Override + public Coder getCheckpointMarkCoder() { + return AvroCoder.of(CountingSource.CounterMark.class); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return VarLongCoder.of(); + } + } + + /** + * The reader associated with {@link UnboundedCountingSource}. + * + * @see UnboundedCountingSource + */ + private static class UnboundedCountingReader extends UnboundedReader { + private UnboundedCountingSource source; + private long current; + private Instant currentTimestamp; + + public UnboundedCountingReader(UnboundedCountingSource source, CounterMark mark) { + this.source = source; + if (mark == null) { + // Because we have not emitted an element yet, and start() calls advance, we need to + // "un-advance" so that start() produces the correct output. + this.current = source.start - source.stride; + } else { + this.current = mark.getLastEmitted(); + } + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + // Overflow-safe check that (current + source.stride) <= LONG.MAX_VALUE. Else, stop producing. + if (Long.MAX_VALUE - source.stride < current) { + return false; + } + current += source.stride; + currentTimestamp = source.timestampFn.apply(current); + return true; + } + + @Override + public Instant getWatermark() { + return source.timestampFn.apply(current); + } + + @Override + public CounterMark getCheckpointMark() { + return new CounterMark(current); + } + + @Override + public UnboundedSource getCurrentSource() { + return source; + } + + @Override + public Long getCurrent() throws NoSuchElementException { + return current; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return currentTimestamp; + } + + @Override + public void close() throws IOException {} + } + + /** + * The checkpoint for an unbounded {@link CountingSource} is simply the last value produced. The + * associated source object encapsulates the information needed to produce the next value. + */ + @DefaultCoder(AvroCoder.class) + public static class CounterMark implements UnboundedSource.CheckpointMark { + /** The last value emitted. */ + private final long lastEmitted; + + /** + * Creates a checkpoint mark reflecting the last emitted value. + */ + public CounterMark(long lastEmitted) { + this.lastEmitted = lastEmitted; + } + + /** + * Returns the last value emitted by the reader. + */ + public long getLastEmitted() { + return lastEmitted; + } + + ///////////////////////////////////////////////////////////////////////////////////// + + @SuppressWarnings("unused") // For AvroCoder + private CounterMark() { + this.lastEmitted = 0L; + } + + @Override + public void finalizeCheckpoint() throws IOException {} + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIO.java new file mode 100644 index 000000000000..f618bc9d63bc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/DatastoreIO.java @@ -0,0 +1,957 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.api.services.datastore.DatastoreV1.PropertyFilter.Operator.EQUAL; +import static com.google.api.services.datastore.DatastoreV1.PropertyOrder.Direction.DESCENDING; +import static com.google.api.services.datastore.DatastoreV1.QueryResultBatch.MoreResultsType.NOT_FINISHED; +import static com.google.api.services.datastore.client.DatastoreHelper.getPropertyMap; +import static com.google.api.services.datastore.client.DatastoreHelper.makeFilter; +import static com.google.api.services.datastore.client.DatastoreHelper.makeOrder; +import static com.google.api.services.datastore.client.DatastoreHelper.makeValue; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.Sleeper; +import com.google.api.services.datastore.DatastoreV1.CommitRequest; +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.EntityResult; +import com.google.api.services.datastore.DatastoreV1.Key; +import com.google.api.services.datastore.DatastoreV1.Key.PathElement; +import com.google.api.services.datastore.DatastoreV1.PartitionId; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.api.services.datastore.DatastoreV1.QueryResultBatch; +import com.google.api.services.datastore.DatastoreV1.RunQueryRequest; +import com.google.api.services.datastore.DatastoreV1.RunQueryResponse; +import com.google.api.services.datastore.client.Datastore; +import com.google.api.services.datastore.client.DatastoreException; +import com.google.api.services.datastore.client.DatastoreFactory; +import com.google.api.services.datastore.client.DatastoreHelper; +import com.google.api.services.datastore.client.DatastoreOptions; +import com.google.api.services.datastore.client.QuerySplitter; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.EntityCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.io.Sink.WriteOperation; +import com.google.cloud.dataflow.sdk.io.Sink.Writer; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineWorkerPoolOptions; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.AttemptBoundedExponentialBackOff; +import com.google.cloud.dataflow.sdk.util.RetryHttpRequestInitializer; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + *

    {@link DatastoreIO} provides an API to Read and Write {@link PCollection PCollections} of + * Google Cloud Datastore + * {@link Entity} objects. + * + *

    Google Cloud Datastore is a fully managed NoSQL data storage service. + * An {@code Entity} is an object in Datastore, analogous to a row in traditional + * database table. + * + *

    This API currently requires an authentication workaround. To use {@link DatastoreIO}, users + * must use the {@code gcloud} command line tool to get credentials for Datastore: + *

    + * $ gcloud auth login
    + * 
    + * + *

    To read a {@link PCollection} from a query to Datastore, use {@link DatastoreIO#source} and + * its methods {@link DatastoreIO.Source#withDataset} and {@link DatastoreIO.Source#withQuery} to + * specify the dataset to query and the query to read from. You can optionally provide a namespace + * to query within using {@link DatastoreIO.Source#withNamespace} or a Datastore host using + * {@link DatastoreIO.Source#withHost}. + * + *

    For example: + * + *

     {@code
    + * // Read a query from Datastore
    + * PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create();
    + * Query query = ...;
    + * String dataset = "...";
    + *
    + * Pipeline p = Pipeline.create(options);
    + * PCollection entities = p.apply(
    + *     Read.from(DatastoreIO.source()
    + *         .withDataset(datasetId)
    + *         .withQuery(query)
    + *         .withHost(host)));
    + * } 
    + * + *

    or: + * + *

     {@code
    + * // Read a query from Datastore using the default namespace and host
    + * PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create();
    + * Query query = ...;
    + * String dataset = "...";
    + *
    + * Pipeline p = Pipeline.create(options);
    + * PCollection entities = p.apply(DatastoreIO.readFrom(datasetId, query));
    + * p.run();
    + * } 
    + * + *

    Note: Normally, a Cloud Dataflow job will read from Cloud Datastore in parallel across + * many workers. However, when the {@link Query} is configured with a limit using + * {@link com.google.api.services.datastore.DatastoreV1.Query.Builder#setLimit(int)}, then + * all returned results will be read by a single Dataflow worker in order to ensure correct data. + * + *

    To write a {@link PCollection} to a Datastore, use {@link DatastoreIO#writeTo}, + * specifying the datastore to write to: + * + *

     {@code
    + * PCollection entities = ...;
    + * entities.apply(DatastoreIO.writeTo(dataset));
    + * p.run();
    + * } 
    + * + *

    To optionally change the host that is used to write to the Datastore, use {@link + * DatastoreIO#sink} to build a {@link DatastoreIO.Sink} and write to it using the {@link Write} + * transform: + * + *

     {@code
    + * PCollection entities = ...;
    + * entities.apply(Write.to(DatastoreIO.sink().withDataset(dataset).withHost(host)));
    + * } 
    + * + *

    {@link Entity Entities} in the {@code PCollection} to be written must have complete + * {@link Key Keys}. Complete {@code Keys} specify the {@code name} and {@code id} of the + * {@code Entity}, where incomplete {@code Keys} do not. A {@code namespace} other than the + * project default may be written to by specifying it in the {@code Entity} {@code Keys}. + * + *

    {@code
    + * Key.Builder keyBuilder = DatastoreHelper.makeKey(...);
    + * keyBuilder.getPartitionIdBuilder().setNamespace(namespace);
    + * }
    + * + *

    {@code Entities} will be committed as upsert (update or insert) mutations. Please read + * Entities, Properties, and + * Keys for more information about {@code Entity} keys. + * + *

    Permissions

    + * Permission requirements depend on the {@code PipelineRunner} that is used to execute the + * Dataflow job. Please refer to the documentation of corresponding {@code PipelineRunner}s for + * more details. + * + *

    Please see Cloud Datastore Sign Up + * for security and permission related information specific to Datastore. + * + * @see com.google.cloud.dataflow.sdk.runners.PipelineRunner + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public class DatastoreIO { + public static final String DEFAULT_HOST = "https://www.googleapis.com"; + + /** + * Datastore has a limit of 500 mutations per batch operation, so we flush + * changes to Datastore every 500 entities. + */ + public static final int DATASTORE_BATCH_UPDATE_LIMIT = 500; + + /** + * Returns an empty {@link DatastoreIO.Source} builder with the default {@code host}. + * Configure the {@code dataset}, {@code query}, and {@code namespace} using + * {@link DatastoreIO.Source#withDataset}, {@link DatastoreIO.Source#withQuery}, + * and {@link DatastoreIO.Source#withNamespace}. + * + * @deprecated the name and return type do not match. Use {@link #source()}. + */ + @Deprecated + public static Source read() { + return source(); + } + + /** + * Returns an empty {@link DatastoreIO.Source} builder with the default {@code host}. + * Configure the {@code dataset}, {@code query}, and {@code namespace} using + * {@link DatastoreIO.Source#withDataset}, {@link DatastoreIO.Source#withQuery}, + * and {@link DatastoreIO.Source#withNamespace}. + * + *

    The resulting {@link Source} object can be passed to {@link Read} to create a + * {@code PTransform} that will read from Datastore. + */ + public static Source source() { + return new Source(DEFAULT_HOST, null, null, null); + } + + /** + * Returns a {@code PTransform} that reads Datastore entities from the query + * against the given dataset. + */ + public static Read.Bounded readFrom(String datasetId, Query query) { + return Read.from(new Source(DEFAULT_HOST, datasetId, query, null)); + } + + /** + * Returns a {@code PTransform} that reads Datastore entities from the query + * against the given dataset and host. + * + * @deprecated prefer {@link #source()} with {@link Source#withHost}, {@link Source#withDataset}, + * {@link Source#withQuery}s. + */ + @Deprecated + public static Read.Bounded readFrom(String host, String datasetId, Query query) { + return Read.from(new Source(host, datasetId, query, null)); + } + + /** + * A {@link Source} that reads the result rows of a Datastore query as {@code Entity} objects. + */ + public static class Source extends BoundedSource { + public String getHost() { + return host; + } + + public String getDataset() { + return datasetId; + } + + public Query getQuery() { + return query; + } + + @Nullable + public String getNamespace() { + return namespace; + } + + public Source withDataset(String datasetId) { + checkNotNull(datasetId, "datasetId"); + return new Source(host, datasetId, query, namespace); + } + + /** + * Returns a new {@link Source} that reads the results of the specified query. + * + *

    Does not modify this object. + * + *

    Note: Normally, a Cloud Dataflow job will read from Cloud Datastore in parallel + * across many workers. However, when the {@link Query} is configured with a limit using + * {@link com.google.api.services.datastore.DatastoreV1.Query.Builder#setLimit(int)}, then all + * returned results will be read by a single Dataflow worker in order to ensure correct data. + */ + public Source withQuery(Query query) { + checkNotNull(query, "query"); + checkArgument(!query.hasLimit() || query.getLimit() > 0, + "Invalid query limit %s: must be positive", query.getLimit()); + return new Source(host, datasetId, query, namespace); + } + + public Source withHost(String host) { + checkNotNull(host, "host"); + return new Source(host, datasetId, query, namespace); + } + + public Source withNamespace(@Nullable String namespace) { + return new Source(host, datasetId, query, namespace); + } + + @Override + public Coder getDefaultOutputCoder() { + return EntityCoder.of(); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) { + // TODO: Perhaps this can be implemented by inspecting the query. + return false; + } + + @Override + public List splitIntoBundles(long desiredBundleSizeBytes, PipelineOptions options) + throws Exception { + // Users may request a limit on the number of results. We can currently support this by + // simply disabling parallel reads and using only a single split. + if (query.hasLimit()) { + return ImmutableList.of(this); + } + + long numSplits; + try { + numSplits = Math.round(((double) getEstimatedSizeBytes(options)) / desiredBundleSizeBytes); + } catch (Exception e) { + // Fallback in case estimated size is unavailable. TODO: fix this, it's horrible. + + // 1. Try Dataflow's numWorkers, which will be 0 for other workers. + DataflowPipelineWorkerPoolOptions poolOptions = + options.as(DataflowPipelineWorkerPoolOptions.class); + if (poolOptions.getNumWorkers() > 0) { + LOG.warn("Estimated size of unavailable, using the number of workers {}", + poolOptions.getNumWorkers(), e); + numSplits = poolOptions.getNumWorkers(); + } else { + // 2. Default to 12 in the unknown case. + numSplits = 12; + } + } + + // If the desiredBundleSize or number of workers results in 1 split, simply return + // a source that reads from the original query. + if (numSplits <= 1) { + return ImmutableList.of(this); + } + + List datastoreSplits; + try { + datastoreSplits = getSplitQueries(Ints.checkedCast(numSplits), options); + } catch (IllegalArgumentException | DatastoreException e) { + LOG.warn("Unable to parallelize the given query: {}", query, e); + return ImmutableList.of(this); + } + + ImmutableList.Builder splits = ImmutableList.builder(); + for (Query splitQuery : datastoreSplits) { + splits.add(new Source(host, datasetId, splitQuery, namespace)); + } + return splits.build(); + } + + @Override + public BoundedReader createReader(PipelineOptions pipelineOptions) throws IOException { + return new DatastoreReader(this, getDatastore(pipelineOptions)); + } + + @Override + public void validate() { + Preconditions.checkNotNull(host, "host"); + Preconditions.checkNotNull(query, "query"); + Preconditions.checkNotNull(datasetId, "datasetId"); + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + // Datastore provides no way to get a good estimate of how large the result of a query + // will be. As a rough approximation, we attempt to fetch the statistics of the whole + // entity kind being queried, using the __Stat_Kind__ system table, assuming exactly 1 kind + // is specified in the query. + // + // See https://cloud.google.com/datastore/docs/concepts/stats + if (mockEstimateSizeBytes != null) { + return mockEstimateSizeBytes; + } + + Datastore datastore = getDatastore(options); + if (query.getKindCount() != 1) { + throw new UnsupportedOperationException( + "Can only estimate size for queries specifying exactly 1 kind."); + } + String ourKind = query.getKind(0).getName(); + long latestTimestamp = queryLatestStatisticsTimestamp(datastore); + Query.Builder query = Query.newBuilder(); + if (namespace == null) { + query.addKindBuilder().setName("__Stat_Kind__"); + } else { + query.addKindBuilder().setName("__Ns_Stat_Kind__"); + } + query.setFilter(makeFilter( + makeFilter("kind_name", EQUAL, makeValue(ourKind)).build(), + makeFilter("timestamp", EQUAL, makeValue(latestTimestamp)).build())); + RunQueryRequest request = makeRequest(query.build()); + + long now = System.currentTimeMillis(); + RunQueryResponse response = datastore.runQuery(request); + LOG.info("Query for per-kind statistics took {}ms", System.currentTimeMillis() - now); + + QueryResultBatch batch = response.getBatch(); + if (batch.getEntityResultCount() == 0) { + throw new NoSuchElementException( + "Datastore statistics for kind " + ourKind + " unavailable"); + } + Entity entity = batch.getEntityResult(0).getEntity(); + return getPropertyMap(entity).get("entity_bytes").getIntegerValue(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("host", host) + .add("dataset", datasetId) + .add("query", query) + .add("namespace", namespace) + .toString(); + } + + /////////////////////////////////////////////////////////////////////////////////////////// + + private static final Logger LOG = LoggerFactory.getLogger(Source.class); + private final String host; + /** Not really nullable, but it may be {@code null} for in-progress {@code Source}s. */ + @Nullable + private final String datasetId; + /** Not really nullable, but it may be {@code null} for in-progress {@code Source}s. */ + @Nullable + private final Query query; + @Nullable + private final String namespace; + + /** For testing only. TODO: This could be much cleaner with dependency injection. */ + @Nullable + private QuerySplitter mockSplitter; + @Nullable + private Long mockEstimateSizeBytes; + + /** + * Note that only {@code namespace} is really {@code @Nullable}. The other parameters may be + * {@code null} as a matter of build order, but if they are {@code null} at instantiation time, + * an error will be thrown. + */ + private Source( + String host, @Nullable String datasetId, @Nullable Query query, + @Nullable String namespace) { + this.host = checkNotNull(host, "host"); + this.datasetId = datasetId; + this.query = query; + this.namespace = namespace; + } + + /** + * A helper function to get the split queries, taking into account the optional + * {@code namespace} and whether there is a mock splitter. + */ + private List getSplitQueries(int numSplits, PipelineOptions options) + throws DatastoreException { + // If namespace is set, include it in the split request so splits are calculated accordingly. + PartitionId.Builder partitionBuilder = PartitionId.newBuilder(); + if (namespace != null) { + partitionBuilder.setNamespace(namespace); + } + + if (mockSplitter != null) { + // For testing. + return mockSplitter.getSplits(query, partitionBuilder.build(), numSplits, null); + } + + return DatastoreHelper.getQuerySplitter().getSplits( + query, partitionBuilder.build(), numSplits, getDatastore(options)); + } + + /** + * Builds a {@link RunQueryRequest} from the {@code query}, using the properties set on this + * {@code Source}. For example, sets the {@code namespace} for the request. + */ + private RunQueryRequest makeRequest(Query query) { + RunQueryRequest.Builder requestBuilder = RunQueryRequest.newBuilder().setQuery(query); + if (namespace != null) { + requestBuilder.getPartitionIdBuilder().setNamespace(namespace); + } + return requestBuilder.build(); + } + + /** + * Datastore system tables with statistics are periodically updated. This method fetches + * the latest timestamp of statistics update using the {@code __Stat_Total__} table. + */ + private long queryLatestStatisticsTimestamp(Datastore datastore) throws DatastoreException { + Query.Builder query = Query.newBuilder(); + query.addKindBuilder().setName("__Stat_Total__"); + query.addOrder(makeOrder("timestamp", DESCENDING)); + query.setLimit(1); + RunQueryRequest request = makeRequest(query.build()); + + long now = System.currentTimeMillis(); + RunQueryResponse response = datastore.runQuery(request); + LOG.info("Query for latest stats timestamp of dataset {} took {}ms", datasetId, + System.currentTimeMillis() - now); + QueryResultBatch batch = response.getBatch(); + if (batch.getEntityResultCount() == 0) { + throw new NoSuchElementException( + "Datastore total statistics for dataset " + datasetId + " unavailable"); + } + Entity entity = batch.getEntityResult(0).getEntity(); + return getPropertyMap(entity).get("timestamp").getTimestampMicrosecondsValue(); + } + + private Datastore getDatastore(PipelineOptions pipelineOptions) { + DatastoreOptions.Builder builder = + new DatastoreOptions.Builder().host(host).dataset(datasetId).initializer( + new RetryHttpRequestInitializer()); + + Credential credential = pipelineOptions.as(GcpOptions.class).getGcpCredential(); + if (credential != null) { + builder.credential(credential); + } + return DatastoreFactory.get().create(builder.build()); + } + + /** For testing only. */ + Source withMockSplitter(QuerySplitter splitter) { + Source res = new Source(host, datasetId, query, namespace); + res.mockSplitter = splitter; + res.mockEstimateSizeBytes = mockEstimateSizeBytes; + return res; + } + + /** For testing only. */ + Source withMockEstimateSizeBytes(Long estimateSizeBytes) { + Source res = new Source(host, datasetId, query, namespace); + res.mockSplitter = mockSplitter; + res.mockEstimateSizeBytes = estimateSizeBytes; + return res; + } + } + + ///////////////////// Write Class ///////////////////////////////// + + /** + * Returns a new {@link DatastoreIO.Sink} builder using the default host. + * You need to further configure it using {@link DatastoreIO.Sink#withDataset}, and optionally + * {@link DatastoreIO.Sink#withHost} before using it in a {@link Write} transform. + * + *

    For example: {@code p.apply(Write.to(DatastoreIO.sink().withDataset(dataset)));} + */ + public static Sink sink() { + return new Sink(DEFAULT_HOST, null); + } + + /** + * Returns a new {@link Write} transform that will write to a {@link Sink}. + * + *

    For example: {@code p.apply(DatastoreIO.writeTo(dataset));} + */ + public static Write.Bound writeTo(String datasetId) { + return Write.to(sink().withDataset(datasetId)); + } + + /** + * A {@link Sink} that writes a {@link PCollection} containing + * {@link Entity Entities} to a Datastore kind. + * + */ + public static class Sink extends com.google.cloud.dataflow.sdk.io.Sink { + final String host; + final String datasetId; + + /** + * Returns a {@link Sink} that is like this one, but will write to the specified dataset. + */ + public Sink withDataset(String datasetId) { + checkNotNull(datasetId, "datasetId"); + return new Sink(host, datasetId); + } + + /** + * Returns a {@link Sink} that is like this one, but will use the given host. If not specified, + * the {@link DatastoreIO#DEFAULT_HOST default host} will be used. + */ + public Sink withHost(String host) { + checkNotNull(host, "host"); + return new Sink(host, datasetId); + } + + /** + * Constructs a Sink with given host and dataset. + */ + protected Sink(String host, String datasetId) { + this.host = checkNotNull(host, "host"); + this.datasetId = datasetId; + } + + /** + * Ensures the host and dataset are set. + */ + @Override + public void validate(PipelineOptions options) { + Preconditions.checkNotNull( + host, "Host is a required parameter. Please use withHost to set the host."); + Preconditions.checkNotNull( + datasetId, + "Dataset ID is a required parameter. Please use withDataset to to set the datasetId."); + } + + @Override + public DatastoreWriteOperation createWriteOperation(PipelineOptions options) { + return new DatastoreWriteOperation(this); + } + } + + /** + * A {@link WriteOperation} that will manage a parallel write to a Datastore sink. + */ + private static class DatastoreWriteOperation + extends WriteOperation { + private static final Logger LOG = LoggerFactory.getLogger(DatastoreWriteOperation.class); + + private final DatastoreIO.Sink sink; + + public DatastoreWriteOperation(DatastoreIO.Sink sink) { + this.sink = sink; + } + + @Override + public Coder getWriterResultCoder() { + return SerializableCoder.of(DatastoreWriteResult.class); + } + + @Override + public void initialize(PipelineOptions options) throws Exception {} + + /** + * Finalizes the write. Logs the number of entities written to the Datastore. + */ + @Override + public void finalize(Iterable writerResults, PipelineOptions options) + throws Exception { + long totalEntities = 0; + for (DatastoreWriteResult result : writerResults) { + totalEntities += result.entitiesWritten; + } + LOG.info("Wrote {} elements.", totalEntities); + } + + @Override + public DatastoreWriter createWriter(PipelineOptions options) throws Exception { + DatastoreOptions.Builder builder = + new DatastoreOptions.Builder() + .host(sink.host) + .dataset(sink.datasetId) + .initializer(new RetryHttpRequestInitializer()); + Credential credential = options.as(GcpOptions.class).getGcpCredential(); + if (credential != null) { + builder.credential(credential); + } + Datastore datastore = DatastoreFactory.get().create(builder.build()); + + return new DatastoreWriter(this, datastore); + } + + @Override + public DatastoreIO.Sink getSink() { + return sink; + } + } + + /** + * {@link Writer} that writes entities to a Datastore Sink. Entities are written in batches, + * where the maximum batch size is {@link DatastoreIO#DATASTORE_BATCH_UPDATE_LIMIT}. Entities + * are committed as upsert mutations (either update if the key already exists, or insert if it is + * a new key). If an entity does not have a complete key (i.e., it has no name or id), the bundle + * will fail. + * + *

    See + * Datastore: Entities, Properties, and Keys for information about entity keys and upsert + * mutations. + * + *

    Commits are non-transactional. If a commit fails because of a conflict over an entity + * group, the commit will be retried (up to {@link DatastoreIO#DATASTORE_BATCH_UPDATE_LIMIT} + * times). + * + *

    Visible for testing purposes. + */ + static class DatastoreWriter extends Writer { + private static final Logger LOG = LoggerFactory.getLogger(DatastoreWriter.class); + private final DatastoreWriteOperation writeOp; + private final Datastore datastore; + private long totalWritten = 0; + + // Visible for testing. + final List entities = new ArrayList<>(); + + /** + * Since a bundle is written in batches, we should retry the commit of a batch in order to + * prevent transient errors from causing the bundle to fail. + */ + private static final int MAX_RETRIES = 5; + + /** + * Initial backoff time for exponential backoff for retry attempts. + */ + private static final int INITIAL_BACKOFF_MILLIS = 5000; + + /** + * Returns true if a Datastore key is complete. A key is complete if its last element + * has either an id or a name. + */ + static boolean isValidKey(Key key) { + List elementList = key.getPathElementList(); + if (elementList.isEmpty()) { + return false; + } + PathElement lastElement = elementList.get(elementList.size() - 1); + return (lastElement.hasId() || lastElement.hasName()); + } + + // Visible for testing + DatastoreWriter(DatastoreWriteOperation writeOp, Datastore datastore) { + this.writeOp = writeOp; + this.datastore = datastore; + } + + @Override + public void open(String uId) throws Exception {} + + /** + * Writes an entity to the Datastore. Writes are batched, up to {@link + * DatastoreIO#DATASTORE_BATCH_UPDATE_LIMIT}. If an entity does not have a complete key, an + * {@link IllegalArgumentException} will be thrown. + */ + @Override + public void write(Entity value) throws Exception { + // Verify that the entity to write has a complete key. + if (!isValidKey(value.getKey())) { + throw new IllegalArgumentException( + "Entities to be written to the Datastore must have complete keys"); + } + + entities.add(value); + + if (entities.size() >= DatastoreIO.DATASTORE_BATCH_UPDATE_LIMIT) { + flushBatch(); + } + } + + /** + * Flushes any pending batch writes and returns a DatastoreWriteResult. + */ + @Override + public DatastoreWriteResult close() throws Exception { + if (entities.size() > 0) { + flushBatch(); + } + return new DatastoreWriteResult(totalWritten); + } + + @Override + public DatastoreWriteOperation getWriteOperation() { + return writeOp; + } + + /** + * Writes a batch of entities to the Datastore. + * + *

    If a commit fails, it will be retried (up to {@link DatastoreWriter#MAX_RETRIES} + * times). All entities in the batch will be committed again, even if the commit was partially + * successful. If the retry limit is exceeded, the last exception from the Datastore will be + * thrown. + * + * @throws DatastoreException if the commit fails or IOException or InterruptedException if + * backing off between retries fails. + */ + private void flushBatch() throws DatastoreException, IOException, InterruptedException { + LOG.debug("Writing batch of {} entities", entities.size()); + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backoff = new AttemptBoundedExponentialBackOff(MAX_RETRIES, INITIAL_BACKOFF_MILLIS); + + while (true) { + // Batch upsert entities. + try { + CommitRequest.Builder commitRequest = CommitRequest.newBuilder(); + commitRequest.getMutationBuilder().addAllUpsert(entities); + commitRequest.setMode(CommitRequest.Mode.NON_TRANSACTIONAL); + datastore.commit(commitRequest.build()); + + // Break if the commit threw no exception. + break; + + } catch (DatastoreException exception) { + // Only log the code and message for potentially-transient errors. The entire exception + // will be propagated upon the last retry. + LOG.error("Error writing to the Datastore ({}): {}", exception.getCode(), + exception.getMessage()); + if (!BackOffUtils.next(sleeper, backoff)) { + LOG.error("Aborting after {} retries.", MAX_RETRIES); + throw exception; + } + } + } + totalWritten += entities.size(); + LOG.debug("Successfully wrote {} entities", entities.size()); + entities.clear(); + } + } + + private static class DatastoreWriteResult implements Serializable { + final long entitiesWritten; + + public DatastoreWriteResult(long recordsWritten) { + this.entitiesWritten = recordsWritten; + } + } + + /** + * A {@link Source.Reader} over the records from a query of the datastore. + * + *

    Timestamped records are currently not supported. + * All records implicitly have the timestamp of {@code BoundedWindow.TIMESTAMP_MIN_VALUE}. + */ + public static class DatastoreReader extends BoundedSource.BoundedReader { + private final Source source; + + /** + * Datastore to read from. + */ + private final Datastore datastore; + + /** + * True if more results may be available. + */ + private boolean moreResults; + + /** + * Iterator over records. + */ + private java.util.Iterator entities; + + /** + * Current batch of query results. + */ + private QueryResultBatch currentBatch; + + /** + * Maximum number of results to request per query. + * + *

    Must be set, or it may result in an I/O error when querying + * Cloud Datastore. + */ + private static final int QUERY_BATCH_LIMIT = 500; + + /** + * Remaining user-requested limit on the number of sources to return. If the user did not set a + * limit, then this variable will always have the value {@link Integer#MAX_VALUE} and will never + * be decremented. + */ + private int userLimit; + + private Entity currentEntity; + + /** + * Returns a DatastoreReader with Source and Datastore object set. + * + * @param datastore a datastore connection to use. + */ + public DatastoreReader(Source source, Datastore datastore) { + this.source = source; + this.datastore = datastore; + // If the user set a limit on the query, remember it. Otherwise pin to MAX_VALUE. + userLimit = source.query.hasLimit() ? source.query.getLimit() : Integer.MAX_VALUE; + } + + @Override + public Entity getCurrent() { + return currentEntity; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + if (entities == null || (!entities.hasNext() && moreResults)) { + try { + entities = getIteratorAndMoveCursor(); + } catch (DatastoreException e) { + throw new IOException(e); + } + } + + if (entities == null || !entities.hasNext()) { + currentEntity = null; + return false; + } + + currentEntity = entities.next().getEntity(); + return true; + } + + @Override + public void close() throws IOException { + // Nothing + } + + @Override + public DatastoreIO.Source getCurrentSource() { + return source; + } + + @Override + public DatastoreIO.Source splitAtFraction(double fraction) { + // Not supported. + return null; + } + + @Override + public Double getFractionConsumed() { + // Not supported. + return null; + } + + /** + * Returns an iterator over the next batch of records for the query + * and updates the cursor to get the next batch as needed. + * Query has specified limit and offset from InputSplit. + */ + private Iterator getIteratorAndMoveCursor() throws DatastoreException { + Query.Builder query = source.query.toBuilder().clone(); + query.setLimit(Math.min(userLimit, QUERY_BATCH_LIMIT)); + if (currentBatch != null && currentBatch.hasEndCursor()) { + query.setStartCursor(currentBatch.getEndCursor()); + } + + RunQueryRequest request = source.makeRequest(query.build()); + RunQueryResponse response = datastore.runQuery(request); + + currentBatch = response.getBatch(); + + // MORE_RESULTS_AFTER_LIMIT is not implemented yet: + // https://groups.google.com/forum/#!topic/gcd-discuss/iNs6M1jA2Vw, so + // use result count to determine if more results might exist. + int numFetch = currentBatch.getEntityResultCount(); + if (source.query.hasLimit()) { + verify(userLimit >= numFetch, + "Expected userLimit %s >= numFetch %s, because query limit %s should be <= userLimit", + userLimit, numFetch, query.getLimit()); + userLimit -= numFetch; + } + moreResults = + // User-limit does not exist (so userLimit == MAX_VALUE) and/or has not been satisfied. + (userLimit > 0) + // All indications from the API are that there are/may be more results. + && ((numFetch == QUERY_BATCH_LIMIT) || (currentBatch.getMoreResults() == NOT_FINISHED)); + + // May receive a batch of 0 results if the number of records is a multiple + // of the request limit. + if (numFetch == 0) { + return null; + } + + return currentBatch.getEntityResultList().iterator(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java new file mode 100644 index 000000000000..dda500c36906 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java @@ -0,0 +1,864 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.api.client.googleapis.batch.BatchRequest; +import com.google.api.client.googleapis.batch.json.JsonBatchCallback; +import com.google.api.client.googleapis.json.GoogleJsonError; +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.services.storage.Storage; +import com.google.api.services.storage.StorageRequest; +import com.google.api.services.storage.model.StorageObject; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.FileIOChannelFactory; +import com.google.cloud.dataflow.sdk.util.GcsIOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.hadoop.util.ApiErrorExtractor; +import com.google.common.base.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Serializable; +import java.nio.channels.WritableByteChannel; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Abstract {@link Sink} for file-based output. An implementation of FileBasedSink writes file-based + * output and defines the format of output files (how values are written, headers/footers, MIME + * type, etc.). + * + *

    At pipeline construction time, the methods of FileBasedSink are called to validate the sink + * and to create a {@link Sink.WriteOperation} that manages the process of writing to the sink. + * + *

    The process of writing to file-based sink is as follows: + *

      + *
    1. An optional subclass-defined initialization, + *
    2. a parallel write of bundles to temporary files, and finally, + *
    3. these temporary files are renamed with final output filenames. + *
    + * + *

    Supported file systems are those registered with {@link IOChannelUtils}. + * + * @param the type of values written to the sink. + */ +public abstract class FileBasedSink extends Sink { + /** + * Base filename for final output files. + */ + protected final String baseOutputFilename; + + /** + * The extension to be used for the final output files. + */ + protected final String extension; + + /** + * Naming template for output files. See {@link ShardNameTemplate} for a description of + * possible naming templates. Default is {@link ShardNameTemplate#INDEX_OF_MAX}. + */ + protected final String fileNamingTemplate; + + /** + * Construct a FileBasedSink with the given base output filename and extension. + */ + public FileBasedSink(String baseOutputFilename, String extension) { + this(baseOutputFilename, extension, ShardNameTemplate.INDEX_OF_MAX); + } + + /** + * Construct a FileBasedSink with the given base output filename, extension, and file naming + * template. + * + *

    See {@link ShardNameTemplate} for a description of file naming templates. + */ + public FileBasedSink(String baseOutputFilename, String extension, String fileNamingTemplate) { + this.baseOutputFilename = baseOutputFilename; + this.extension = extension; + this.fileNamingTemplate = fileNamingTemplate; + } + + /** + * Returns the base output filename for this file based sink. + */ + public String getBaseOutputFilename() { + return baseOutputFilename; + } + + /** + * Perform pipeline-construction-time validation. The default implementation is a no-op. + * Subclasses should override to ensure the sink is valid and can be written to. It is recommended + * to use {@link Preconditions} in the implementation of this method. + */ + @Override + public void validate(PipelineOptions options) {} + + /** + * Return a subclass of {@link FileBasedSink.FileBasedWriteOperation} that will manage the write + * to the sink. + */ + @Override + public abstract FileBasedWriteOperation createWriteOperation(PipelineOptions options); + + /** + * Abstract {@link Sink.WriteOperation} that manages the process of writing to a + * {@link FileBasedSink}. + * + *

    The primary responsibilities of the FileBasedWriteOperation is the management of output + * files. During a write, {@link FileBasedSink.FileBasedWriter}s write bundles to temporary file + * locations. After the bundles have been written, + *

      + *
    1. {@link FileBasedSink.FileBasedWriteOperation#finalize} is given a list of the temporary + * files containing the output bundles. + *
    2. During finalize, these temporary files are copied to final output locations and named + * according to a file naming template. + *
    3. Finally, any temporary files that were created during the write are removed. + *
    + * + *

    Subclass implementations of FileBasedWriteOperation must implement + * {@link FileBasedSink.FileBasedWriteOperation#createWriter} to return a concrete + * FileBasedSinkWriter. + * + *

    Temporary and Output File Naming:

    During the write, bundles are written to temporary + * files using the baseTemporaryFilename that can be provided via the constructor of + * FileBasedWriteOperation. These temporary files will be named + * {@code {baseTemporaryFilename}-temp-{bundleId}}, where bundleId is the unique id of the bundle. + * For example, if baseTemporaryFilename is "gs://my-bucket/my_temp_output", the output for a + * bundle with bundle id 15723 will be "gs://my-bucket/my_temp_output-temp-15723". + * + *

    Final output files are written to baseOutputFilename with the format + * {@code {baseOutputFilename}-0000i-of-0000n.{extension}} where n is the total number of bundles + * written and extension is the file extension. Both baseOutputFilename and extension are required + * constructor arguments. + * + *

    Subclass implementations can change the file naming template by supplying a value for + * {@link FileBasedSink#fileNamingTemplate}. + * + *

    Temporary Bundle File Handling:

    + *

    {@link FileBasedSink.FileBasedWriteOperation#temporaryFileRetention} controls the behavior + * for managing temporary files. By default, temporary files will be removed. Subclasses can + * provide a different value to the constructor. + * + *

    Note that in the case of permanent failure of a bundle's write, no clean up of temporary + * files will occur. + * + *

    If there are no elements in the PCollection being written, no output will be generated. + * + * @param the type of values written to the sink. + */ + public abstract static class FileBasedWriteOperation extends WriteOperation { + private static final Logger LOG = LoggerFactory.getLogger(FileBasedWriteOperation.class); + + /** + * Options for handling of temporary output files. + */ + public enum TemporaryFileRetention { + KEEP, + REMOVE; + } + + /** + * The Sink that this WriteOperation will write to. + */ + protected final FileBasedSink sink; + + /** + * Option to keep or remove temporary output files. + */ + protected final TemporaryFileRetention temporaryFileRetention; + + /** + * Base filename used for temporary output files. Default is the baseOutputFilename. + */ + protected final String baseTemporaryFilename; + + /** + * Name separator for temporary files. Temporary files will be named + * {@code {baseTemporaryFilename}-temp-{bundleId}}. + */ + protected static final String TEMPORARY_FILENAME_SEPARATOR = "-temp-"; + + /** + * Build a temporary filename using the temporary filename separator with the given prefix and + * suffix. + */ + protected static final String buildTemporaryFilename(String prefix, String suffix) { + return prefix + FileBasedWriteOperation.TEMPORARY_FILENAME_SEPARATOR + suffix; + } + + /** + * Construct a FileBasedWriteOperation using the same base filename for both temporary and + * output files. + * + * @param sink the FileBasedSink that will be used to configure this write operation. + */ + public FileBasedWriteOperation(FileBasedSink sink) { + this(sink, sink.baseOutputFilename); + } + + /** + * Construct a FileBasedWriteOperation. + * + * @param sink the FileBasedSink that will be used to configure this write operation. + * @param baseTemporaryFilename the base filename to be used for temporary output files. + */ + public FileBasedWriteOperation(FileBasedSink sink, String baseTemporaryFilename) { + this(sink, baseTemporaryFilename, TemporaryFileRetention.REMOVE); + } + + /** + * Create a new FileBasedWriteOperation. + * + * @param sink the FileBasedSink that will be used to configure this write operation. + * @param baseTemporaryFilename the base filename to be used for temporary output files. + * @param temporaryFileRetention defines how temporary files are handled. + */ + public FileBasedWriteOperation(FileBasedSink sink, String baseTemporaryFilename, + TemporaryFileRetention temporaryFileRetention) { + this.sink = sink; + this.baseTemporaryFilename = baseTemporaryFilename; + this.temporaryFileRetention = temporaryFileRetention; + } + + /** + * Clients must implement to return a subclass of {@link FileBasedSink.FileBasedWriter}. This + * method must satisfy the restrictions placed on implementations of + * {@link Sink.WriteOperation#createWriter}. Namely, it must not mutate the state of the object. + */ + @Override + public abstract FileBasedWriter createWriter(PipelineOptions options) throws Exception; + + /** + * Initialization of the sink. Default implementation is a no-op. May be overridden by subclass + * implementations to perform initialization of the sink at pipeline runtime. This method must + * be idempotent and is subject to the same implementation restrictions as + * {@link Sink.WriteOperation#initialize}. + */ + @Override + public void initialize(PipelineOptions options) throws Exception {} + + /** + * Finalizes writing by copying temporary output files to their final location and optionally + * removing temporary files. + * + *

    Finalization may be overridden by subclass implementations to perform customized + * finalization (e.g., initiating some operation on output bundles, merging them, etc.). + * {@code writerResults} contains the filenames of written bundles. + * + *

    If subclasses override this method, they must guarantee that its implementation is + * idempotent, as it may be executed multiple times in the case of failure or for redundancy. It + * is a best practice to attempt to try to make this method atomic. + * + * @param writerResults the results of writes (FileResult). + */ + @Override + public void finalize(Iterable writerResults, PipelineOptions options) + throws Exception { + // Collect names of temporary files and rename them. + List files = new ArrayList<>(); + for (FileResult result : writerResults) { + LOG.debug("Temporary bundle output file {} will be copied.", result.getFilename()); + files.add(result.getFilename()); + } + copyToOutputFiles(files, options); + + // Optionally remove temporary files. + if (temporaryFileRetention == TemporaryFileRetention.REMOVE) { + removeTemporaryFiles(options); + } + } + + /** + * Copy temporary files to final output filenames using the file naming template. + * + *

    Can be called from subclasses that override {@link FileBasedWriteOperation#finalize}. + * + *

    Files will be named according to the file naming template. The order of the output files + * will be the same as the sorted order of the input filenames. In other words, if the input + * filenames are ["C", "A", "B"], baseOutputFilename is "file", the extension is ".txt", and + * the fileNamingTemplate is "-SSS-of-NNN", the contents of A will be copied to + * file-000-of-003.txt, the contents of B will be copied to file-001-of-003.txt, etc. + * + * @param filenames the filenames of temporary files. + * @return a list containing the names of final output files. + */ + protected final List copyToOutputFiles(List filenames, PipelineOptions options) + throws IOException { + int numFiles = filenames.size(); + List srcFilenames = new ArrayList<>(); + List destFilenames = generateDestinationFilenames(numFiles); + + // Sort files for copying. + srcFilenames.addAll(filenames); + Collections.sort(srcFilenames); + + if (numFiles > 0) { + LOG.debug("Copying {} files.", numFiles); + FileOperations fileOperations = + FileOperationsFactory.getFileOperations(destFilenames.get(0), options); + fileOperations.copy(srcFilenames, destFilenames); + } else { + LOG.info("No output files to write."); + } + + return destFilenames; + } + + /** + * Generate output bundle filenames. + */ + protected final List generateDestinationFilenames(int numFiles) { + List destFilenames = new ArrayList<>(); + String extension = getSink().extension; + String baseOutputFilename = getSink().baseOutputFilename; + String fileNamingTemplate = getSink().fileNamingTemplate; + + String suffix = getFileExtension(extension); + for (int i = 0; i < numFiles; i++) { + destFilenames.add(IOChannelUtils.constructName( + baseOutputFilename, fileNamingTemplate, suffix, i, numFiles)); + } + return destFilenames; + } + + /** + * Returns the file extension to be used. If the user did not request a file + * extension then this method returns the empty string. Otherwise this method + * adds a {@code "."} to the beginning of the users extension if one is not present. + */ + private String getFileExtension(String usersExtension) { + if (usersExtension == null || usersExtension.isEmpty()) { + return ""; + } + if (usersExtension.startsWith(".")) { + return usersExtension; + } + return "." + usersExtension; + } + + /** + * Removes temporary output files. Uses the temporary filename to find files to remove. + * + *

    Can be called from subclasses that override {@link FileBasedWriteOperation#finalize}. + * Note:If finalize is overridden and does not rename or otherwise finalize + * temporary files, this method will remove them. + */ + protected final void removeTemporaryFiles(PipelineOptions options) throws IOException { + String pattern = buildTemporaryFilename(baseTemporaryFilename, "*"); + LOG.debug("Finding temporary bundle output files matching {}.", pattern); + FileOperations fileOperations = FileOperationsFactory.getFileOperations(pattern, options); + IOChannelFactory factory = IOChannelUtils.getFactory(pattern); + Collection matches = factory.match(pattern); + LOG.debug("{} temporary files matched {}", matches.size(), pattern); + LOG.debug("Removing {} files.", matches.size()); + fileOperations.remove(matches); + } + + /** + * Provides a coder for {@link FileBasedSink.FileResult}. + */ + @Override + public Coder getWriterResultCoder() { + return SerializableCoder.of(FileResult.class); + } + + /** + * Returns the FileBasedSink for this write operation. + */ + @Override + public FileBasedSink getSink() { + return sink; + } + } + + /** + * Abstract {@link Sink.Writer} that writes a bundle to a {@link FileBasedSink}. Subclass + * implementations provide a method that can write a single value to a {@link WritableByteChannel} + * ({@link Sink.Writer#write}). + * + *

    Subclass implementations may also override methods that write headers and footers before and + * after the values in a bundle, respectively, as well as provide a MIME type for the output + * channel. + * + *

    Multiple FileBasedWriter instances may be created on the same worker, and therefore any + * access to static members or methods should be thread safe. + * + * @param the type of values to write. + */ + public abstract static class FileBasedWriter extends Writer { + private static final Logger LOG = LoggerFactory.getLogger(FileBasedWriter.class); + + final FileBasedWriteOperation writeOperation; + + /** + * Unique id for this output bundle. + */ + private String id; + + /** + * The filename of the output bundle. Equal to the + * {@link FileBasedSink.FileBasedWriteOperation#TEMPORARY_FILENAME_SEPARATOR} and id appended to + * the baseName. + */ + private String filename; + + /** + * The channel to write to. + */ + private WritableByteChannel channel; + + /** + * The MIME type used in the creation of the output channel (if the file system supports it). + * + *

    GCS, for example, supports writing files with Content-Type metadata. + * + *

    May be overridden. Default is {@link MimeTypes#TEXT}. See {@link MimeTypes} for other + * options. + */ + protected String mimeType = MimeTypes.TEXT; + + /** + * Construct a new FileBasedWriter with a base filename. + */ + public FileBasedWriter(FileBasedWriteOperation writeOperation) { + Preconditions.checkNotNull(writeOperation); + this.writeOperation = writeOperation; + } + + /** + * Called with the channel that a subclass will write its header, footer, and values to. + * Subclasses should either keep a reference to the channel provided or create and keep a + * reference to an appropriate object that they will use to write to it. + * + *

    Called before any subsequent calls to writeHeader, writeFooter, and write. + */ + protected abstract void prepareWrite(WritableByteChannel channel) throws Exception; + + /** + * Writes header at the beginning of output files. Nothing by default; subclasses may override. + */ + protected void writeHeader() throws Exception {} + + /** + * Writes footer at the end of output files. Nothing by default; subclasses may override. + */ + protected void writeFooter() throws Exception {} + + /** + * Opens the channel. + */ + @Override + public final void open(String uId) throws Exception { + this.id = uId; + filename = FileBasedWriteOperation.buildTemporaryFilename( + getWriteOperation().baseTemporaryFilename, uId); + LOG.debug("Opening {}.", filename); + channel = IOChannelUtils.create(filename, mimeType); + try { + prepareWrite(channel); + LOG.debug("Writing header to {}.", filename); + writeHeader(); + } catch (Exception e) { + // The caller shouldn't have to close() this Writer if it fails to open(), so close the + // channel if prepareWrite() or writeHeader() fails. + try { + LOG.error("Writing header to {} failed, closing channel.", filename); + channel.close(); + } catch (IOException closeException) { + // Log exception and mask it. + LOG.error("Closing channel for {} failed: {}", filename, closeException.getMessage()); + } + // Throw the exception that caused the write to fail. + throw e; + } + LOG.debug("Starting write of bundle {} to {}.", this.id, filename); + } + + /** + * Closes the channel and return the bundle result. + */ + @Override + public final FileResult close() throws Exception { + try (WritableByteChannel theChannel = channel) { + LOG.debug("Writing footer to {}.", filename); + writeFooter(); + } + FileResult result = new FileResult(filename); + LOG.debug("Result for bundle {}: {}", this.id, filename); + return result; + } + + /** + * Return the FileBasedWriteOperation that this Writer belongs to. + */ + @Override + public FileBasedWriteOperation getWriteOperation() { + return writeOperation; + } + } + + /** + * Result of a single bundle write. Contains the filename of the bundle. + */ + public static final class FileResult implements Serializable { + private final String filename; + + public FileResult(String filename) { + this.filename = filename; + } + + public String getFilename() { + return filename; + } + } + + // File system operations + // Warning: These class are purposefully private and will be replaced by more robust file I/O + // utilities. Not for use outside FileBasedSink. + + /** + * Factory for FileOperations. + */ + private static class FileOperationsFactory { + /** + * Return a FileOperations implementation based on which IOChannel would be used to write to a + * location specification (not necessarily a filename, as it may contain wildcards). + * + *

    Only supports File and GCS locations (currently, the only factories registered with + * IOChannelUtils). For other locations, an exception is thrown. + */ + public static FileOperations getFileOperations(String spec, PipelineOptions options) + throws IOException { + IOChannelFactory factory = IOChannelUtils.getFactory(spec); + if (factory instanceof GcsIOChannelFactory) { + return new GcsOperations(options); + } else if (factory instanceof FileIOChannelFactory) { + return new LocalFileOperations(); + } else { + throw new IOException("Unrecognized file system."); + } + } + } + + /** + * Copy and Remove operations for files. Operations behave like remove-if-existing and + * copy-if-existing and do not throw exceptions on file not found to enable retries of these + * operations in the case of transient error. + */ + private static interface FileOperations { + /** + * Copy a collection of files from one location to another. + * + *

    The number of source filenames must equal the number of destination filenames. + * + * @param srcFilenames the source filenames. + * @param destFilenames the destination filenames. + */ + public void copy(List srcFilenames, List destFilenames) throws IOException; + + /** + * Remove a collection of files. + */ + public void remove(Collection filenames) throws IOException; + } + + /** + * GCS file system operations. + */ + private static class GcsOperations implements FileOperations { + private static final Logger LOG = LoggerFactory.getLogger(GcsOperations.class); + + /** + * Maximum number of requests permitted in a GCS batch request. + */ + private static final int MAX_REQUESTS_PER_BATCH = 1000; + + private ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + private GcsOptions gcsOptions; + private Storage gcs; + private BatchHelper batchHelper; + + public GcsOperations(PipelineOptions options) { + gcsOptions = options.as(GcsOptions.class); + gcs = Transport.newStorageClient(gcsOptions).build(); + batchHelper = + new BatchHelper(gcs.getRequestFactory().getInitializer(), gcs, MAX_REQUESTS_PER_BATCH); + } + + @Override + public void copy(List srcFilenames, List destFilenames) throws IOException { + Preconditions.checkArgument( + srcFilenames.size() == destFilenames.size(), + String.format("Number of source files {} must equal number of destination files {}", + srcFilenames.size(), destFilenames.size())); + for (int i = 0; i < srcFilenames.size(); i++) { + final GcsPath sourcePath = GcsPath.fromUri(srcFilenames.get(i)); + final GcsPath destPath = GcsPath.fromUri(destFilenames.get(i)); + LOG.debug("Copying {} to {}", sourcePath, destPath); + Storage.Objects.Copy copyObject = gcs.objects().copy(sourcePath.getBucket(), + sourcePath.getObject(), destPath.getBucket(), destPath.getObject(), null); + batchHelper.queue(copyObject, new JsonBatchCallback() { + @Override + public void onSuccess(StorageObject obj, HttpHeaders responseHeaders) { + LOG.debug("Successfully copied {} to {}", sourcePath, destPath); + } + + @Override + public void onFailure(GoogleJsonError e, HttpHeaders responseHeaders) throws IOException { + // Do nothing on item not found. + if (!errorExtractor.itemNotFound(e)) { + throw new IOException(e.toString()); + } + LOG.debug("{} does not exist.", sourcePath); + } + }); + } + batchHelper.flush(); + } + + @Override + public void remove(Collection filenames) throws IOException { + for (String filename : filenames) { + final GcsPath path = GcsPath.fromUri(filename); + LOG.debug("Removing: " + path); + Storage.Objects.Delete deleteObject = + gcs.objects().delete(path.getBucket(), path.getObject()); + batchHelper.queue(deleteObject, new JsonBatchCallback() { + @Override + public void onSuccess(Void obj, HttpHeaders responseHeaders) throws IOException { + LOG.debug("Successfully removed {}", path); + } + + @Override + public void onFailure(GoogleJsonError e, HttpHeaders responseHeaders) throws IOException { + // Do nothing on item not found. + if (!errorExtractor.itemNotFound(e)) { + throw new IOException(e.toString()); + } + LOG.debug("{} does not exist.", path); + } + }); + } + batchHelper.flush(); + } + } + + /** + * File systems supported by {@link Files}. + */ + private static class LocalFileOperations implements FileOperations { + private static final Logger LOG = LoggerFactory.getLogger(LocalFileOperations.class); + + @Override + public void copy(List srcFilenames, List destFilenames) throws IOException { + Preconditions.checkArgument( + srcFilenames.size() == destFilenames.size(), + String.format("Number of source files {} must equal number of destination files {}", + srcFilenames.size(), destFilenames.size())); + int numFiles = srcFilenames.size(); + for (int i = 0; i < numFiles; i++) { + String src = srcFilenames.get(i); + String dst = destFilenames.get(i); + LOG.debug("Copying {} to {}", src, dst); + copyOne(src, dst); + } + } + + private void copyOne(String source, String destination) throws IOException { + try { + // Copy the source file, replacing the existing destination. + Files.copy(Paths.get(source), Paths.get(destination), StandardCopyOption.REPLACE_EXISTING); + } catch (NoSuchFileException e) { + LOG.debug("{} does not exist.", source); + // Suppress exception if file does not exist. + } + } + + @Override + public void remove(Collection filenames) throws IOException { + for (String filename : filenames) { + LOG.debug("Removing file {}", filename); + removeOne(filename); + } + } + + private void removeOne(String filename) throws IOException { + // Delete the file if it exists. + boolean exists = Files.deleteIfExists(Paths.get(filename)); + if (!exists) { + LOG.debug("{} does not exist.", filename); + } + } + } + + /** + * BatchHelper abstracts out the logic for the maximum requests per batch for GCS. + * + *

    Copy of + * https://github.com/GoogleCloudPlatform/bigdata-interop/blob/master/gcs/src/main/java/com/google/cloud/hadoop/gcsio/BatchHelper.java + * + *

    Copied to prevent Dataflow from depending on the Hadoop-related dependencies that are not + * used in Dataflow. Hadoop-related dependencies will be removed from the Google Cloud Storage + * Connector (https://cloud.google.com/hadoop/google-cloud-storage-connector) so that this project + * and others may use the connector without introducing unnecessary dependencies. + * + *

    This class is not thread-safe; create a new BatchHelper instance per single-threaded logical + * grouping of requests. + */ + @NotThreadSafe + private static class BatchHelper { + /** + * Callback that causes a single StorageRequest to be added to the BatchRequest. + */ + protected static interface QueueRequestCallback { + void enqueue() throws IOException; + } + + private final List pendingBatchEntries; + private final BatchRequest batch; + + // Number of requests that can be queued into a single actual HTTP request + // before a sub-batch is sent. + private final long maxRequestsPerBatch; + + // Flag that indicates whether there is an in-progress flush. + private boolean flushing = false; + + /** + * Primary constructor, generally accessed only via the inner Factory class. + */ + public BatchHelper( + HttpRequestInitializer requestInitializer, Storage gcs, long maxRequestsPerBatch) { + this.pendingBatchEntries = new LinkedList<>(); + this.batch = gcs.batch(requestInitializer); + this.maxRequestsPerBatch = maxRequestsPerBatch; + } + + /** + * Adds an additional request to the batch, and possibly flushes the current contents of the + * batch if {@code maxRequestsPerBatch} has been reached. + */ + public void queue(final StorageRequest req, final JsonBatchCallback callback) + throws IOException { + QueueRequestCallback queueCallback = new QueueRequestCallback() { + @Override + public void enqueue() throws IOException { + req.queue(batch, callback); + } + }; + pendingBatchEntries.add(queueCallback); + + flushIfPossibleAndRequired(); + } + + // Flush our buffer if we have more pending entries than maxRequestsPerBatch + private void flushIfPossibleAndRequired() throws IOException { + if (pendingBatchEntries.size() > maxRequestsPerBatch) { + flushIfPossible(); + } + } + + // Flush our buffer if we are not already in a flush operation and we have data to flush. + private void flushIfPossible() throws IOException { + if (!flushing && pendingBatchEntries.size() > 0) { + flushing = true; + try { + while (batch.size() < maxRequestsPerBatch && pendingBatchEntries.size() > 0) { + QueueRequestCallback head = pendingBatchEntries.remove(0); + head.enqueue(); + } + + batch.execute(); + } finally { + flushing = false; + } + } + } + + + /** + * Sends any currently remaining requests in the batch; should be called at the end of any + * series of batched requests to ensure everything has been sent. + */ + public void flush() throws IOException { + flushIfPossible(); + } + } + + static class ReshardForWrite extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + return input + // TODO: This would need to be adapted to write per-window shards. + .apply(Window.into(new GlobalWindows()) + .triggering(DefaultTrigger.of()) + .discardingFiredPanes()) + .apply("RandomKey", ParDo.of( + new DoFn>() { + transient long counter, step; + @Override + public void startBundle(Context c) { + counter = (long) (Math.random() * Long.MAX_VALUE); + step = 1 + 2 * (long) (Math.random() * Long.MAX_VALUE); + } + @Override + public void processElement(ProcessContext c) { + counter += step; + c.output(KV.of(counter, c.element())); + } + })) + .apply(GroupByKey.create()) + .apply("Ungroup", ParDo.of( + new DoFn>, T>() { + @Override + public void processElement(ProcessContext c) { + for (T item : c.element().getValue()) { + c.output(item); + } + } + })); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSource.java new file mode 100644 index 000000000000..5d32a9d08fb1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSource.java @@ -0,0 +1,648 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.IOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; +import java.util.NoSuchElementException; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; + +/** + * A common base class for all file-based {@link Source}s. Extend this class to implement your own + * file-based custom source. + * + *

    A file-based {@code Source} is a {@code Source} backed by a file pattern defined as a Java + * glob, a single file, or a offset range for a single file. See {@link OffsetBasedSource} and + * {@link com.google.cloud.dataflow.sdk.io.range.RangeTracker} for semantics of offset ranges. + * + *

    This source stores a {@code String} that is an {@link IOChannelFactory} specification for a + * file or file pattern. There should be an {@code IOChannelFactory} defined for the file + * specification provided. Please refer to {@link IOChannelUtils} and {@link IOChannelFactory} for + * more information on this. + * + *

    In addition to the methods left abstract from {@code BoundedSource}, subclasses must implement + * methods to create a sub-source and a reader for a range of a single file - + * {@link #createForSubrangeOfFile} and {@link #createSingleFileReader}. Please refer to + * {@link XmlSource} for an example implementation of {@code FileBasedSource}. + * + * @param Type of records represented by the source. + */ +public abstract class FileBasedSource extends OffsetBasedSource { + private static final Logger LOG = LoggerFactory.getLogger(FileBasedSource.class); + private static final float FRACTION_OF_FILES_TO_STAT = 0.01f; + + // Package-private for testing + static final int MAX_NUMBER_OF_FILES_FOR_AN_EXACT_STAT = 100; + + // Size of the thread pool to be used for performing file operations in parallel. + // Package-private for testing. + static final int THREAD_POOL_SIZE = 128; + + private final String fileOrPatternSpec; + private final Mode mode; + + /** + * A given {@code FileBasedSource} represents a file resource of one of these types. + */ + public enum Mode { + FILEPATTERN, + SINGLE_FILE_OR_SUBRANGE + } + + /** + * Create a {@code FileBaseSource} based on a file or a file pattern specification. This + * constructor must be used when creating a new {@code FileBasedSource} for a file pattern. + * + *

    See {@link OffsetBasedSource} for a detailed description of {@code minBundleSize}. + * + * @param fileOrPatternSpec {@link IOChannelFactory} specification of file or file pattern + * represented by the {@link FileBasedSource}. + * @param minBundleSize minimum bundle size in bytes. + */ + public FileBasedSource(String fileOrPatternSpec, long minBundleSize) { + super(0, Long.MAX_VALUE, minBundleSize); + mode = Mode.FILEPATTERN; + this.fileOrPatternSpec = fileOrPatternSpec; + } + + /** + * Create a {@code FileBasedSource} based on a single file. This constructor must be used when + * creating a new {@code FileBasedSource} for a subrange of a single file. + * Additionally, this constructor must be used to create new {@code FileBasedSource}s when + * subclasses implement the method {@link #createForSubrangeOfFile}. + * + *

    See {@link OffsetBasedSource} for detailed descriptions of {@code minBundleSize}, + * {@code startOffset}, and {@code endOffset}. + * + * @param fileName {@link IOChannelFactory} specification of the file represented by the + * {@link FileBasedSource}. + * @param minBundleSize minimum bundle size in bytes. + * @param startOffset starting byte offset. + * @param endOffset ending byte offset. If the specified value {@code >= #getMaxEndOffset()} it + * implies {@code #getMaxEndOffSet()}. + */ + public FileBasedSource(String fileName, long minBundleSize, + long startOffset, long endOffset) { + super(startOffset, endOffset, minBundleSize); + mode = Mode.SINGLE_FILE_OR_SUBRANGE; + this.fileOrPatternSpec = fileName; + } + + public final String getFileOrPatternSpec() { + return fileOrPatternSpec; + } + + public final Mode getMode() { + return mode; + } + + @Override + public final FileBasedSource createSourceForSubrange(long start, long end) { + Preconditions.checkArgument(mode != Mode.FILEPATTERN, + "Cannot split a file pattern based source based on positions"); + Preconditions.checkArgument(start >= getStartOffset(), "Start offset value " + start + + " of the subrange cannot be smaller than the start offset value " + getStartOffset() + + " of the parent source"); + Preconditions.checkArgument(end <= getEndOffset(), "End offset value " + end + + " of the subrange cannot be larger than the end offset value " + getEndOffset() + + " of the parent source"); + + FileBasedSource source = createForSubrangeOfFile(fileOrPatternSpec, start, end); + if (start > 0 || end != Long.MAX_VALUE) { + Preconditions.checkArgument(source.getMode() == Mode.SINGLE_FILE_OR_SUBRANGE, + "Source created for the range [" + start + "," + end + ")" + + " must be a subrange source"); + } + return source; + } + + /** + * Creates and returns a new {@code FileBasedSource} of the same type as the current + * {@code FileBasedSource} backed by a given file and an offset range. When current source is + * being split, this method is used to generate new sub-sources. When creating the source + * subclasses must call the constructor {@link #FileBasedSource(String, long, long, long)} of + * {@code FileBasedSource} with corresponding parameter values passed here. + * + * @param fileName file backing the new {@code FileBasedSource}. + * @param start starting byte offset of the new {@code FileBasedSource}. + * @param end ending byte offset of the new {@code FileBasedSource}. May be Long.MAX_VALUE, + * in which case it will be inferred using {@link #getMaxEndOffset}. + */ + protected abstract FileBasedSource createForSubrangeOfFile( + String fileName, long start, long end); + + /** + * Creates and returns an instance of a {@code FileBasedReader} implementation for the current + * source assuming the source represents a single file. File patterns will be handled by + * {@code FileBasedSource} implementation automatically. + */ + protected abstract FileBasedReader createSingleFileReader( + PipelineOptions options); + + @Override + public final long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + // This implementation of method getEstimatedSizeBytes is provided to simplify subclasses. Here + // we perform the size estimation of files and file patterns using the interface provided by + // IOChannelFactory. + + IOChannelFactory factory = IOChannelUtils.getFactory(fileOrPatternSpec); + if (mode == Mode.FILEPATTERN) { + // TODO Implement a more efficient parallel/batch size estimation mechanism for file patterns. + long startTime = System.currentTimeMillis(); + long totalSize = 0; + Collection inputs = factory.match(fileOrPatternSpec); + if (inputs.size() <= MAX_NUMBER_OF_FILES_FOR_AN_EXACT_STAT) { + totalSize = getExactTotalSizeOfFiles(inputs, factory); + LOG.debug("Size estimation of all files of pattern " + fileOrPatternSpec + " took " + + (System.currentTimeMillis() - startTime) + " ms"); + } else { + totalSize = getEstimatedSizeOfFilesBySampling(inputs, factory); + LOG.debug("Size estimation of pattern " + fileOrPatternSpec + " by sampling took " + + (System.currentTimeMillis() - startTime) + " ms"); + } + return totalSize; + } else { + long start = getStartOffset(); + long end = Math.min(getEndOffset(), getMaxEndOffset(options)); + return end - start; + } + } + + // Get the exact total size of the given set of files. + // Invokes multiple requests for size estimation in parallel using a thread pool. + // TODO: replace this with bulk request API when it is available. Will require updates + // to IOChannelFactory interface. + private static long getExactTotalSizeOfFiles( + Collection files, IOChannelFactory ioChannelFactory) throws Exception { + List> futures = new ArrayList<>(); + ListeningExecutorService service = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(THREAD_POOL_SIZE)); + long totalSize = 0; + try { + for (String file : files) { + futures.add(createFutureForSizeEstimation(file, ioChannelFactory, service)); + } + + for (Long val : Futures.allAsList(futures).get()) { + totalSize += val; + } + + return totalSize; + } finally { + service.shutdown(); + } + } + + private static ListenableFuture createFutureForSizeEstimation( + final String file, + final IOChannelFactory ioChannelFactory, + ListeningExecutorService service) { + return service.submit( + new Callable() { + @Override + public Long call() throws Exception { + return ioChannelFactory.getSizeBytes(file); + } + }); + } + + // Estimate the total size of the given set of files through sampling and extrapolation. + // Currently we use uniform sampling which requires a linear sampling size for a reasonable + // estimate. + // TODO: Implement a more efficient sampling mechanism. + private static long getEstimatedSizeOfFilesBySampling( + Collection files, IOChannelFactory ioChannelFactory) throws Exception { + int sampleSize = (int) (FRACTION_OF_FILES_TO_STAT * files.size()); + sampleSize = Math.max(MAX_NUMBER_OF_FILES_FOR_AN_EXACT_STAT, sampleSize); + + List selectedFiles = new ArrayList(files); + Collections.shuffle(selectedFiles); + selectedFiles = selectedFiles.subList(0, sampleSize); + + return files.size() * getExactTotalSizeOfFiles(selectedFiles, ioChannelFactory) + / selectedFiles.size(); + } + + private ListenableFuture>> createFutureForFileSplit( + final String file, + final long desiredBundleSizeBytes, + final PipelineOptions options, + ListeningExecutorService service) { + return service.submit(new Callable>>() { + @Override + public List> call() throws Exception { + return createForSubrangeOfFile(file, 0, Long.MAX_VALUE) + .splitIntoBundles(desiredBundleSizeBytes, options); + } + }); + } + + @Override + public final List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + // This implementation of method splitIntoBundles is provided to simplify subclasses. Here we + // split a FileBasedSource based on a file pattern to FileBasedSources based on full single + // files. For files that can be efficiently seeked, we further split FileBasedSources based on + // those files to FileBasedSources based on sub ranges of single files. + + if (mode == Mode.FILEPATTERN) { + long startTime = System.currentTimeMillis(); + List>>> futures = new ArrayList<>(); + + ListeningExecutorService service = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(THREAD_POOL_SIZE)); + try { + for (final String file : FileBasedSource.expandFilePattern(fileOrPatternSpec)) { + futures.add(createFutureForFileSplit(file, desiredBundleSizeBytes, options, service)); + } + List> splitResults = + ImmutableList.copyOf(Iterables.concat(Futures.allAsList(futures).get())); + LOG.debug( + "Splitting the source based on file pattern " + + fileOrPatternSpec + + " took " + + (System.currentTimeMillis() - startTime) + + " ms"); + return splitResults; + } finally { + service.shutdown(); + } + } else { + if (isSplittable()) { + List> splitResults = new ArrayList<>(); + for (OffsetBasedSource split : + super.splitIntoBundles(desiredBundleSizeBytes, options)) { + splitResults.add((FileBasedSource) split); + } + return splitResults; + } else { + LOG.debug("The source for file " + fileOrPatternSpec + + " is not split into sub-range based sources since the file is not seekable"); + return ImmutableList.of(this); + } + } + } + + /** + * Determines whether a file represented by this source is can be split into bundles. + * + *

    By default, a file is splittable if it is on a file system that supports efficient read + * seeking. Subclasses may override to provide different behavior. + */ + protected boolean isSplittable() throws Exception { + // We split a file-based source into subranges only if the file is efficiently seekable. + // If a file is not efficiently seekable it would be highly inefficient to create and read a + // source based on a subrange of that file. + IOChannelFactory factory = IOChannelUtils.getFactory(fileOrPatternSpec); + return factory.isReadSeekEfficient(fileOrPatternSpec); + } + + @Override + public final BoundedReader createReader(PipelineOptions options) throws IOException { + // Validate the current source prior to creating a reader for it. + this.validate(); + + if (mode == Mode.FILEPATTERN) { + long startTime = System.currentTimeMillis(); + Collection files = FileBasedSource.expandFilePattern(fileOrPatternSpec); + List> fileReaders = new ArrayList<>(); + for (String fileName : files) { + long endOffset; + try { + endOffset = IOChannelUtils.getFactory(fileName).getSizeBytes(fileName); + } catch (IOException e) { + LOG.warn("Failed to get size of " + fileName, e); + endOffset = Long.MAX_VALUE; + } + fileReaders.add( + createForSubrangeOfFile(fileName, 0, endOffset).createSingleFileReader(options)); + } + LOG.debug("Creating a reader for file pattern " + fileOrPatternSpec + " took " + + (System.currentTimeMillis() - startTime) + " ms"); + if (fileReaders.size() == 1) { + return fileReaders.get(0); + } + return new FilePatternReader(this, fileReaders); + } else { + return createSingleFileReader(options); + } + } + + @Override + public String toString() { + switch (mode) { + case FILEPATTERN: + return fileOrPatternSpec; + case SINGLE_FILE_OR_SUBRANGE: + return fileOrPatternSpec + " range " + super.toString(); + default: + throw new IllegalStateException("Unexpected mode: " + mode); + } + } + + @Override + public void validate() { + super.validate(); + switch (mode) { + case FILEPATTERN: + Preconditions.checkArgument(getStartOffset() == 0, + "FileBasedSource is based on a file pattern or a full single file " + + "but the starting offset proposed " + getStartOffset() + " is not zero"); + Preconditions.checkArgument(getEndOffset() == Long.MAX_VALUE, + "FileBasedSource is based on a file pattern or a full single file " + + "but the ending offset proposed " + getEndOffset() + " is not Long.MAX_VALUE"); + break; + case SINGLE_FILE_OR_SUBRANGE: + // Nothing more to validate. + break; + default: + throw new IllegalStateException("Unknown mode: " + mode); + } + } + + @Override + public final long getMaxEndOffset(PipelineOptions options) throws Exception { + if (mode == Mode.FILEPATTERN) { + throw new IllegalArgumentException("Cannot determine the exact end offset of a file pattern"); + } + if (getEndOffset() == Long.MAX_VALUE) { + IOChannelFactory factory = IOChannelUtils.getFactory(fileOrPatternSpec); + return factory.getSizeBytes(fileOrPatternSpec); + } else { + return getEndOffset(); + } + } + + protected static final Collection expandFilePattern(String fileOrPatternSpec) + throws IOException { + IOChannelFactory factory = IOChannelUtils.getFactory(fileOrPatternSpec); + Collection matches = factory.match(fileOrPatternSpec); + LOG.info("Matched {} files for pattern {}", matches.size(), fileOrPatternSpec); + return matches; + } + + /** + * A {@link Source.Reader reader} that implements code common to readers of + * {@code FileBasedSource}s. + * + *

    Seekability

    + * + *

    This reader uses a {@link ReadableByteChannel} created for the file represented by the + * corresponding source to efficiently move to the correct starting position defined in the + * source. Subclasses of this reader should implement {@link #startReading} to get access to this + * channel. If the source corresponding to the reader is for a subrange of a file the + * {@code ReadableByteChannel} provided is guaranteed to be an instance of the type + * {@link SeekableByteChannel}, which may be used by subclass to traverse back in the channel to + * determine the correct starting position. + * + *

    Reading Records

    + * + *

    Sequential reading is implemented using {@link #readNextRecord}. + * + *

    Then {@code FileBasedReader} implements "reading a range [A, B)" in the following way. + *

      + *
    1. {@link #start} opens the file + *
    2. {@link #start} seeks the {@code SeekableByteChannel} to A (reading offset ranges for + * non-seekable files is not supported) and calls {@code startReading()} + *
    3. {@link #start} calls {@link #advance} once, which, via {@link #readNextRecord}, + * locates the first record which is at a split point AND its offset is at or after A. + * If this record is at or after B, {@link #advance} returns false and reading is finished. + *
    4. if the previous advance call returned {@code true} sequential reading starts and + * {@code advance()} will be called repeatedly + *
    + * {@code advance()} calls {@code readNextRecord()} on the subclass, and stops (returns false) if + * the new record is at a split point AND the offset of the new record is at or after B. + * + *

    Thread Safety

    + * + *

    Since this class implements {@link Source.Reader} it guarantees thread safety. Abstract + * methods defined here will not be accessed by more than one thread concurrently. + */ + public abstract static class FileBasedReader extends OffsetBasedReader { + private ReadableByteChannel channel = null; + + /** + * Subclasses should not perform IO operations at the constructor. All IO operations should be + * delayed until the {@link #startReading} method is invoked. + */ + public FileBasedReader(FileBasedSource source) { + super(source); + Preconditions.checkArgument(source.getMode() != Mode.FILEPATTERN, + "FileBasedReader does not support reading file patterns"); + } + + @Override + public FileBasedSource getCurrentSource() { + return (FileBasedSource) super.getCurrentSource(); + } + + @Override + protected final boolean startImpl() throws IOException { + FileBasedSource source = getCurrentSource(); + IOChannelFactory factory = IOChannelUtils.getFactory(source.getFileOrPatternSpec()); + this.channel = factory.open(source.getFileOrPatternSpec()); + + if (channel instanceof SeekableByteChannel) { + SeekableByteChannel seekChannel = (SeekableByteChannel) channel; + seekChannel.position(source.getStartOffset()); + } else { + // Channel is not seekable. Must not be a subrange. + Preconditions.checkArgument(source.mode != Mode.SINGLE_FILE_OR_SUBRANGE, + "Subrange-based sources must only be defined for file types that support seekable " + + " read channels"); + Preconditions.checkArgument(source.getStartOffset() == 0, "Start offset " + + source.getStartOffset() + + " is not zero but channel for reading the file is not seekable."); + } + + startReading(channel); + + // Advance once to load the first record. + return advanceImpl(); + } + + @Override + protected final boolean advanceImpl() throws IOException { + return readNextRecord(); + } + + /** + * Closes any {@link ReadableByteChannel} created for the current reader. This implementation is + * idempotent. Any {@code close()} method introduced by a subclass must be idempotent and must + * call the {@code close()} method in the {@code FileBasedReader}. + */ + @Override + public void close() throws IOException { + if (channel != null) { + channel.close(); + } + } + + /** + * Performs any initialization of the subclass of {@code FileBasedReader} that involves IO + * operations. Will only be invoked once and before that invocation the base class will seek the + * channel to the source's starting offset. + * + *

    Provided {@link ReadableByteChannel} is for the file represented by the source of this + * reader. Subclass may use the {@code channel} to build a higher level IO abstraction, e.g., a + * BufferedReader or an XML parser. + * + *

    If the corresponding source is for a subrange of a file, {@code channel} is guaranteed to + * be an instance of the type {@link SeekableByteChannel}. + * + *

    After this method is invoked the base class will not be reading data from the channel or + * adjusting the position of the channel. But the base class is responsible for properly closing + * the channel. + * + * @param channel a byte channel representing the file backing the reader. + */ + protected abstract void startReading(ReadableByteChannel channel) throws IOException; + + /** + * Reads the next record from the channel provided by {@link #startReading}. Methods + * {@link #getCurrent}, {@link #getCurrentOffset}, and {@link #isAtSplitPoint()} should return + * the corresponding information about the record read by the last invocation of this method. + * + *

    Note that this method will be called the same way for reading the first record in the + * source (file or offset range in the file) and for reading subsequent records. It is up to the + * subclass to do anything special for locating and reading the first record, if necessary. + * + * @return {@code true} if a record was successfully read, {@code false} if the end of the + * channel was reached before successfully reading a new record. + */ + protected abstract boolean readNextRecord() throws IOException; + } + + // An internal Reader implementation that concatenates a sequence of FileBasedReaders. + private class FilePatternReader extends BoundedReader { + private final FileBasedSource source; + private final List> fileReaders; + final ListIterator> fileReadersIterator; + FileBasedReader currentReader = null; + + public FilePatternReader(FileBasedSource source, List> fileReaders) { + this.source = source; + this.fileReaders = fileReaders; + this.fileReadersIterator = fileReaders.listIterator(); + } + + @Override + public boolean start() throws IOException { + return startNextNonemptyReader(); + } + + @Override + public boolean advance() throws IOException { + Preconditions.checkState(currentReader != null, "Call start() before advance()"); + if (currentReader.advance()) { + return true; + } + return startNextNonemptyReader(); + } + + private boolean startNextNonemptyReader() throws IOException { + while (fileReadersIterator.hasNext()) { + currentReader = fileReadersIterator.next(); + if (currentReader.start()) { + return true; + } + currentReader.close(); + } + return false; + } + + @Override + public T getCurrent() throws NoSuchElementException { + // A NoSuchElement will be thrown by the last FileBasedReader if getCurrent() is called after + // advance() returns false. + return currentReader.getCurrent(); + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + // A NoSuchElement will be thrown by the last FileBasedReader if getCurrentTimestamp() + // is called after advance() returns false. + return currentReader.getCurrentTimestamp(); + } + + @Override + public void close() throws IOException { + // Close all readers that may have not yet been closed. + // If this reader has not been started, currentReader is null. + if (currentReader != null) { + currentReader.close(); + } + while (fileReadersIterator.hasNext()) { + fileReadersIterator.next().close(); + } + } + + @Override + public FileBasedSource getCurrentSource() { + return source; + } + + @Override + public FileBasedSource splitAtFraction(double fraction) { + // Unsupported. TODO: implement. + LOG.debug("Dynamic splitting of FilePatternReader is unsupported."); + return null; + } + + @Override + public Double getFractionConsumed() { + if (currentReader == null) { + return 0.0; + } + if (fileReaders.isEmpty()) { + return 1.0; + } + int index = fileReadersIterator.previousIndex(); + int numReaders = fileReaders.size(); + if (index == numReaders) { + return 1.0; + } + double before = 1.0 * index / numReaders; + double after = 1.0 * (index + 1) / numReaders; + Double fractionOfCurrentReader = currentReader.getFractionConsumed(); + if (fractionOfCurrentReader == null) { + return before; + } + return before + fractionOfCurrentReader * (after - before); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/OffsetBasedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/OffsetBasedSource.java new file mode 100644 index 000000000000..d581b80ca270 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/OffsetBasedSource.java @@ -0,0 +1,326 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.io.range.OffsetRangeTracker; +import com.google.cloud.dataflow.sdk.io.range.RangeTracker; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * A {@link BoundedSource} that uses offsets to define starting and ending positions. + * + *

    {@link OffsetBasedSource} is a common base class for all bounded sources where the input can + * be represented as a single range, and an input can be efficiently processed in parallel by + * splitting the range into a set of disjoint ranges whose union is the original range. This class + * should be used for sources that can be cheaply read starting at any given offset. + * {@link OffsetBasedSource} stores the range and implements splitting into bundles. + * + *

    Extend {@link OffsetBasedSource} to implement your own offset-based custom source. + * {@link FileBasedSource}, which is a subclass of this, adds additional functionality useful for + * custom sources that are based on files. If possible implementors should start from + * {@link FileBasedSource} instead of {@link OffsetBasedSource}. + * + *

    Consult {@link RangeTracker} for important semantics common to all sources defined by a range + * of positions of a certain type, including the semantics of split points + * ({@link OffsetBasedReader#isAtSplitPoint}). + * + * @param Type of records represented by the source. + * @see BoundedSource + * @see FileBasedSource + * @see RangeTracker + */ +public abstract class OffsetBasedSource extends BoundedSource { + private final long startOffset; + private final long endOffset; + private final long minBundleSize; + + /** + * @param startOffset starting offset (inclusive) of the source. Must be non-negative. + * + * @param endOffset ending offset (exclusive) of the source. Use {@link Long#MAX_VALUE} to + * indicate that the entire source after {@code startOffset} should be read. Must be + * {@code > startOffset}. + * + * @param minBundleSize minimum bundle size in offset units that should be used when splitting the + * source into sub-sources. This value may not be respected if the total + * range of the source is smaller than the specified {@code minBundleSize}. + * Must be non-negative. + */ + public OffsetBasedSource(long startOffset, long endOffset, long minBundleSize) { + this.startOffset = startOffset; + this.endOffset = endOffset; + this.minBundleSize = minBundleSize; + } + + /** + * Returns the starting offset of the source. + */ + public long getStartOffset() { + return startOffset; + } + + /** + * Returns the specified ending offset of the source. Any returned value greater than or equal to + * {@link #getMaxEndOffset(PipelineOptions)} should be treated as + * {@link #getMaxEndOffset(PipelineOptions)}. + */ + public long getEndOffset() { + return endOffset; + } + + /** + * Returns the minimum bundle size that should be used when splitting the source into sub-sources. + * This value may not be respected if the total range of the source is smaller than the specified + * {@code minBundleSize}. + */ + public long getMinBundleSize() { + return minBundleSize; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + long trueEndOffset = (endOffset == Long.MAX_VALUE) ? getMaxEndOffset(options) : endOffset; + return getBytesPerOffset() * (trueEndOffset - getStartOffset()); + } + + @Override + public List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + // Split the range into bundles based on the desiredBundleSizeBytes. Final bundle is adjusted to + // make sure that we do not end up with a too small bundle at the end. If the desired bundle + // size is smaller than the minBundleSize of the source then minBundleSize will be used instead. + + long desiredBundleSizeOffsetUnits = Math.max( + Math.max(1, desiredBundleSizeBytes / getBytesPerOffset()), + minBundleSize); + + List> subSources = new ArrayList<>(); + long start = startOffset; + long maxEnd = Math.min(endOffset, getMaxEndOffset(options)); + + while (start < maxEnd) { + long end = start + desiredBundleSizeOffsetUnits; + end = Math.min(end, maxEnd); + // Avoid having a too small bundle at the end and ensure that we respect minBundleSize. + long remaining = maxEnd - end; + if ((remaining < desiredBundleSizeOffsetUnits / 4) || (remaining < minBundleSize)) { + end = maxEnd; + } + subSources.add(createSourceForSubrange(start, end)); + + start = end; + } + return subSources; + } + + @Override + public void validate() { + Preconditions.checkArgument( + this.startOffset >= 0, + "Start offset has value %s, must be non-negative", this.startOffset); + Preconditions.checkArgument( + this.endOffset >= 0, + "End offset has value %s, must be non-negative", this.endOffset); + Preconditions.checkArgument( + this.startOffset < this.endOffset, + "Start offset %s must be before end offset %s", + this.startOffset, this.endOffset); + Preconditions.checkArgument( + this.minBundleSize >= 0, + "minBundleSize has value %s, must be non-negative", + this.minBundleSize); + } + + @Override + public String toString() { + return "[" + startOffset + ", " + endOffset + ")"; + } + + /** + * Returns approximately how many bytes of data correspond to a single offset in this source. + * Used for translation between this source's range and methods defined in terms of bytes, such + * as {@link #getEstimatedSizeBytes} and {@link #splitIntoBundles}. + * + *

    Defaults to {@code 1} byte, which is the common case for, e.g., file sources. + */ + public long getBytesPerOffset() { + return 1L; + } + + /** + * Returns the actual ending offset of the current source. The value returned by this function + * will be used to clip the end of the range {@code [startOffset, endOffset)} such that the + * range used is {@code [startOffset, min(endOffset, maxEndOffset))}. + * + *

    As an example in which {@link OffsetBasedSource} is used to implement a file source, suppose + * that this source was constructed with an {@code endOffset} of {@link Long#MAX_VALUE} to + * indicate that a file should be read to the end. Then {@link #getMaxEndOffset} should determine + * the actual, exact size of the file in bytes and return it. + */ + public abstract long getMaxEndOffset(PipelineOptions options) throws Exception; + + /** + * Returns an {@link OffsetBasedSource} for a subrange of the current source. The + * subrange {@code [start, end)} must be within the range {@code [startOffset, endOffset)} of + * the current source, i.e. {@code startOffset <= start < end <= endOffset}. + */ + public abstract OffsetBasedSource createSourceForSubrange(long start, long end); + + /** + * Whether this source should allow dynamic splitting of the offset ranges. + * + *

    True by default. Override this to return false if the source cannot + * support dynamic splitting correctly. If this returns false, + * {@link OffsetBasedSource.OffsetBasedReader#splitAtFraction} will refuse all split requests. + */ + public boolean allowsDynamicSplitting() { + return true; + } + + /** + * A {@link Source.Reader} that implements code common to readers of all + * {@link OffsetBasedSource}s. + * + *

    Subclasses have to implement: + *

      + *
    • The methods {@link #startImpl} and {@link #advanceImpl} for reading the + * first or subsequent records. + *
    • The methods {@link #getCurrent}, {@link #getCurrentOffset}, and optionally + * {@link #isAtSplitPoint} and {@link #getCurrentTimestamp} to access properties of + * the last record successfully read by {@link #startImpl} or {@link #advanceImpl}. + *
    + */ + public abstract static class OffsetBasedReader extends BoundedReader { + private static final Logger LOG = LoggerFactory.getLogger(OffsetBasedReader.class); + + private OffsetBasedSource source; + + /** The {@link OffsetRangeTracker} managing the range and current position of the source. */ + private final OffsetRangeTracker rangeTracker; + + /** + * @param source the {@link OffsetBasedSource} to be read by the current reader. + */ + public OffsetBasedReader(OffsetBasedSource source) { + this.source = source; + this.rangeTracker = new OffsetRangeTracker(source.getStartOffset(), source.getEndOffset()); + } + + /** + * Returns the starting offset of the {@link Source.Reader#getCurrent current record}, + * which has been read by the last successful {@link Source.Reader#start} or + * {@link Source.Reader#advance} call. + *

    If no such call has been made yet, the return value is unspecified. + *

    See {@link RangeTracker} for description of offset semantics. + */ + protected abstract long getCurrentOffset() throws NoSuchElementException; + + /** + * Returns whether the current record is at a split point (i.e., whether the current record + * would be the first record to be read by a source with a specified start offset of + * {@link #getCurrentOffset}). + * + *

    See detailed documentation about split points in {@link RangeTracker}. + */ + protected boolean isAtSplitPoint() throws NoSuchElementException { + return true; + } + + @Override + public final boolean start() throws IOException { + return startImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint(), getCurrentOffset()); + } + + @Override + public final boolean advance() throws IOException { + return advanceImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint(), getCurrentOffset()); + } + + /** + * Initializes the {@link OffsetBasedSource.OffsetBasedReader} and advances to the first record, + * returning {@code true} if there is a record available to be read. This method will be + * invoked exactly once and may perform expensive setup operations that are needed to + * initialize the reader. + * + *

    This function is the {@code OffsetBasedReader} implementation of + * {@link BoundedReader#start}. The key difference is that the implementor can ignore the + * possibility that it should no longer produce the first record, either because it has exceeded + * the original {@code endOffset} assigned to the reader, or because a concurrent call to + * {@link #splitAtFraction} has changed the source to shrink the offset range being read. + * + * @see BoundedReader#start + */ + protected abstract boolean startImpl() throws IOException; + + /** + * Advances to the next record and returns {@code true}, or returns false if there is no next + * record. + * + *

    This function is the {@code OffsetBasedReader} implementation of + * {@link BoundedReader#advance}. The key difference is that the implementor can ignore the + * possibility that it should no longer produce the next record, either because it has exceeded + * the original {@code endOffset} assigned to the reader, or because a concurrent call to + * {@link #splitAtFraction} has changed the source to shrink the offset range being read. + * + * @see BoundedReader#advance + */ + protected abstract boolean advanceImpl() throws IOException; + + @Override + public synchronized OffsetBasedSource getCurrentSource() { + return source; + } + + @Override + public Double getFractionConsumed() { + return rangeTracker.getFractionConsumed(); + } + + @Override + public final synchronized OffsetBasedSource splitAtFraction(double fraction) { + if (!getCurrentSource().allowsDynamicSplitting()) { + return null; + } + if (rangeTracker.getStopPosition() == Long.MAX_VALUE) { + LOG.debug( + "Refusing to split unbounded OffsetBasedReader {} at fraction {}", + rangeTracker, fraction); + return null; + } + long splitOffset = rangeTracker.getPositionForFractionConsumed(fraction); + LOG.debug( + "Proposing to split OffsetBasedReader {} at fraction {} (offset {})", + rangeTracker, fraction, splitOffset); + if (!rangeTracker.trySplitAtPosition(splitOffset)) { + return null; + } + long start = source.getStartOffset(); + long end = source.getEndOffset(); + OffsetBasedSource primary = source.createSourceForSubrange(start, splitOffset); + OffsetBasedSource residual = source.createSourceForSubrange(splitOffset, end); + this.source = primary; + return residual; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubIO.java new file mode 100644 index 000000000000..653b31f059e4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubIO.java @@ -0,0 +1,1044 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.api.client.util.Clock; +import com.google.api.client.util.DateTime; +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.pubsub.model.AcknowledgeRequest; +import com.google.api.services.pubsub.model.PublishRequest; +import com.google.api.services.pubsub.model.PubsubMessage; +import com.google.api.services.pubsub.model.PullRequest; +import com.google.api.services.pubsub.model.PullResponse; +import com.google.api.services.pubsub.model.ReceivedMessage; +import com.google.api.services.pubsub.model.Subscription; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * Read and Write {@link PTransform}s for Cloud Pub/Sub streams. These transforms create + * and consume unbounded {@link PCollection PCollections}. + * + *

    Permissions

    + *

    Permission requirements depend on the {@link PipelineRunner} that is used to execute the + * Dataflow job. Please refer to the documentation of corresponding + * {@link PipelineRunner PipelineRunners} for more details. + */ +public class PubsubIO { + private static final Logger LOG = LoggerFactory.getLogger(PubsubIO.class); + + /** The default {@link Coder} used to translate to/from Cloud Pub/Sub messages. */ + public static final Coder DEFAULT_PUBSUB_CODER = StringUtf8Coder.of(); + + /** + * Project IDs must contain 6-63 lowercase letters, digits, or dashes. + * IDs must start with a letter and may not end with a dash. + * This regex isn't exact - this allows for patterns that would be rejected by + * the service, but this is sufficient for basic parsing of table references. + */ + private static final Pattern PROJECT_ID_REGEXP = + Pattern.compile("[a-z][-a-z0-9:.]{4,61}[a-z0-9]"); + + private static final Pattern SUBSCRIPTION_REGEXP = + Pattern.compile("projects/([^/]+)/subscriptions/(.+)"); + + private static final Pattern TOPIC_REGEXP = Pattern.compile("projects/([^/]+)/topics/(.+)"); + + private static final Pattern V1BETA1_SUBSCRIPTION_REGEXP = + Pattern.compile("/subscriptions/([^/]+)/(.+)"); + + private static final Pattern V1BETA1_TOPIC_REGEXP = Pattern.compile("/topics/([^/]+)/(.+)"); + + private static final Pattern PUBSUB_NAME_REGEXP = Pattern.compile("[a-zA-Z][-._~%+a-zA-Z0-9]+"); + + private static final int PUBSUB_NAME_MAX_LENGTH = 255; + + private static final String SUBSCRIPTION_RANDOM_TEST_PREFIX = "_random/"; + private static final String SUBSCRIPTION_STARTING_SIGNAL = "_starting_signal/"; + private static final String TOPIC_DEV_NULL_TEST_NAME = "/topics/dev/null"; + + private static void validateProjectName(String project) { + Matcher match = PROJECT_ID_REGEXP.matcher(project); + if (!match.matches()) { + throw new IllegalArgumentException( + "Illegal project name specified in Pubsub subscription: " + project); + } + } + + private static void validatePubsubName(String name) { + if (name.length() > PUBSUB_NAME_MAX_LENGTH) { + throw new IllegalArgumentException( + "Pubsub object name is longer than 255 characters: " + name); + } + + if (name.startsWith("goog")) { + throw new IllegalArgumentException("Pubsub object name cannot start with goog: " + name); + } + + Matcher match = PUBSUB_NAME_REGEXP.matcher(name); + if (!match.matches()) { + throw new IllegalArgumentException("Illegal Pubsub object name specified: " + name + + " Please see Javadoc for naming rules."); + } + } + + /** + * Returns the {@link Instant} that corresponds to the timestamp in the supplied + * {@link PubsubMessage} under the specified {@code ink label}. See + * {@link PubsubIO.Read#timestampLabel(String)} for details about how these messages are + * parsed. + * + *

    The {@link Clock} parameter is used to virtualize time for testing. + * + * @throws IllegalArgumentException if the timestamp label is provided, but there is no + * corresponding attribute in the message or the value provided is not a valid timestamp + * string. + * @see PubsubIO.Read#timestampLabel(String) + */ + @VisibleForTesting + protected static Instant assignMessageTimestamp( + PubsubMessage message, @Nullable String label, Clock clock) { + if (label == null) { + return new Instant(clock.currentTimeMillis()); + } + + // Extract message attributes, defaulting to empty map if null. + Map attributes = firstNonNull( + message.getAttributes(), ImmutableMap.of()); + + String timestampStr = attributes.get(label); + checkArgument(timestampStr != null && !timestampStr.isEmpty(), + "PubSub message is missing a timestamp in label: %s", label); + + long millisSinceEpoch; + try { + // Try parsing as milliseconds since epoch. Note there is no way to parse a string in + // RFC 3339 format here. + // Expected IllegalArgumentException if parsing fails; we use that to fall back to RFC 3339. + millisSinceEpoch = Long.parseLong(timestampStr); + } catch (IllegalArgumentException e) { + // Try parsing as RFC3339 string. DateTime.parseRfc3339 will throw an IllegalArgumentException + // if parsing fails, and the caller should handle. + millisSinceEpoch = DateTime.parseRfc3339(timestampStr).getValue(); + } + return new Instant(millisSinceEpoch); + } + + /** + * Class representing a Cloud Pub/Sub Subscription. + */ + public static class PubsubSubscription implements Serializable { + private enum Type { NORMAL, FAKE } + + private final Type type; + private final String project; + private final String subscription; + + private PubsubSubscription(Type type, String project, String subscription) { + this.type = type; + this.project = project; + this.subscription = subscription; + } + + /** + * Creates a class representing a Pub/Sub subscription from the specified subscription path. + * + *

    Cloud Pub/Sub subscription names should be of the form + * {@code projects//subscriptions/}, where {@code } is the name + * of the project the subscription belongs to. The {@code } component must comply + * with the following requirements: + * + *

      + *
    • Can only contain lowercase letters, numbers, dashes ('-'), underscores ('_') and periods + * ('.').
    • + *
    • Must be between 3 and 255 characters.
    • + *
    • Must begin with a letter.
    • + *
    • Must end with a letter or a number.
    • + *
    • Cannot begin with {@code 'goog'} prefix.
    • + *
    + */ + public static PubsubSubscription fromPath(String path) { + if (path.startsWith(SUBSCRIPTION_RANDOM_TEST_PREFIX) + || path.startsWith(SUBSCRIPTION_STARTING_SIGNAL)) { + return new PubsubSubscription(Type.FAKE, "", path); + } + + String projectName, subscriptionName; + + Matcher v1beta1Match = V1BETA1_SUBSCRIPTION_REGEXP.matcher(path); + if (v1beta1Match.matches()) { + LOG.warn("Saw subscription in v1beta1 format. Subscriptions should be in the format " + + "projects//subscriptions/"); + projectName = v1beta1Match.group(1); + subscriptionName = v1beta1Match.group(2); + } else { + Matcher match = SUBSCRIPTION_REGEXP.matcher(path); + if (!match.matches()) { + throw new IllegalArgumentException("Pubsub subscription is not in " + + "projects//subscriptions/ format: " + path); + } + projectName = match.group(1); + subscriptionName = match.group(2); + } + + validateProjectName(projectName); + validatePubsubName(subscriptionName); + return new PubsubSubscription(Type.NORMAL, projectName, subscriptionName); + } + + /** + * Returns the string representation of this subscription as a path used in the Cloud Pub/Sub + * v1beta1 API. + * + * @deprecated the v1beta1 API for Cloud Pub/Sub is deprecated. + */ + @Deprecated + public String asV1Beta1Path() { + if (type == Type.NORMAL) { + return "/subscriptions/" + project + "/" + subscription; + } else { + return subscription; + } + } + + /** + * Returns the string representation of this subscription as a path used in the Cloud Pub/Sub + * v1beta2 API. + * + * @deprecated the v1beta2 API for Cloud Pub/Sub is deprecated. + */ + @Deprecated + public String asV1Beta2Path() { + if (type == Type.NORMAL) { + return "projects/" + project + "/subscriptions/" + subscription; + } else { + return subscription; + } + } + + /** + * Returns the string representation of this subscription as a path used in the Cloud Pub/Sub + * API. + */ + public String asPath() { + if (type == Type.NORMAL) { + return "projects/" + project + "/subscriptions/" + subscription; + } else { + return subscription; + } + } + } + + /** + * Class representing a Cloud Pub/Sub Topic. + */ + public static class PubsubTopic implements Serializable { + private enum Type { NORMAL, FAKE } + + private final Type type; + private final String project; + private final String topic; + + private PubsubTopic(Type type, String project, String topic) { + this.type = type; + this.project = project; + this.topic = topic; + } + + /** + * Creates a class representing a Cloud Pub/Sub topic from the specified topic path. + * + *

    Cloud Pub/Sub topic names should be of the form + * {@code /topics//}, where {@code } is the name of + * the publishing project. The {@code } component must comply with + * the following requirements: + * + *

      + *
    • Can only contain lowercase letters, numbers, dashes ('-'), underscores ('_') and periods + * ('.').
    • + *
    • Must be between 3 and 255 characters.
    • + *
    • Must begin with a letter.
    • + *
    • Must end with a letter or a number.
    • + *
    • Cannot begin with 'goog' prefix.
    • + *
    + */ + public static PubsubTopic fromPath(String path) { + if (path.equals(TOPIC_DEV_NULL_TEST_NAME)) { + return new PubsubTopic(Type.FAKE, "", path); + } + + String projectName, topicName; + + Matcher v1beta1Match = V1BETA1_TOPIC_REGEXP.matcher(path); + if (v1beta1Match.matches()) { + LOG.warn("Saw topic in v1beta1 format. Topics should be in the format " + + "projects//topics/"); + projectName = v1beta1Match.group(1); + topicName = v1beta1Match.group(2); + } else { + Matcher match = TOPIC_REGEXP.matcher(path); + if (!match.matches()) { + throw new IllegalArgumentException( + "Pubsub topic is not in projects//topics/ format: " + path); + } + projectName = match.group(1); + topicName = match.group(2); + } + + validateProjectName(projectName); + validatePubsubName(topicName); + return new PubsubTopic(Type.NORMAL, projectName, topicName); + } + + /** + * Returns the string representation of this topic as a path used in the Cloud Pub/Sub + * v1beta1 API. + * + * @deprecated the v1beta1 API for Cloud Pub/Sub is deprecated. + */ + @Deprecated + public String asV1Beta1Path() { + if (type == Type.NORMAL) { + return "/topics/" + project + "/" + topic; + } else { + return topic; + } + } + + /** + * Returns the string representation of this topic as a path used in the Cloud Pub/Sub + * v1beta2 API. + * + * @deprecated the v1beta2 API for Cloud Pub/Sub is deprecated. + */ + @Deprecated + public String asV1Beta2Path() { + if (type == Type.NORMAL) { + return "projects/" + project + "/topics/" + topic; + } else { + return topic; + } + } + + /** + * Returns the string representation of this topic as a path used in the Cloud Pub/Sub + * API. + */ + public String asPath() { + if (type == Type.NORMAL) { + return "projects/" + project + "/topics/" + topic; + } else { + return topic; + } + } + } + + /** + * A {@link PTransform} that continuously reads from a Cloud Pub/Sub stream and + * returns a {@link PCollection} of {@link String Strings} containing the items from + * the stream. + * + *

    When running with a {@link PipelineRunner} that only supports bounded + * {@link PCollection PCollections} (such as {@link DirectPipelineRunner} or + * {@link DataflowPipelineRunner} without {@code --streaming}), only a bounded portion of the + * input Pub/Sub stream can be processed. As such, either {@link Bound#maxNumRecords(int)} or + * {@link Bound#maxReadTime(Duration)} must be set. + */ + public static class Read { + /** + * Creates and returns a transform for reading from Cloud Pub/Sub with the specified transform + * name. + */ + public static Bound named(String name) { + return new Bound<>(DEFAULT_PUBSUB_CODER).named(name); + } + + /** + * Creates and returns a transform for reading from a Cloud Pub/Sub topic. Mutually exclusive + * with {@link #subscription(String)}. + * + *

    See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format + * of the {@code topic} string. + * + *

    Dataflow will start reading data published on this topic from the time the pipeline is + * started. Any data published on the topic before the pipeline is started will not be read by + * Dataflow. + */ + public static Bound topic(String topic) { + return new Bound<>(DEFAULT_PUBSUB_CODER).topic(topic); + } + + /** + * Creates and returns a transform for reading from a specific Cloud Pub/Sub subscription. + * Mutually exclusive with {@link #topic(String)}. + * + *

    See {@link PubsubIO.PubsubSubscription#fromPath(String)} for more details on the format + * of the {@code subscription} string. + */ + public static Bound subscription(String subscription) { + return new Bound<>(DEFAULT_PUBSUB_CODER).subscription(subscription); + } + + /** + * Creates and returns a transform reading from Cloud Pub/Sub where record timestamps are + * expected to be provided as Pub/Sub message attributes. The {@code timestampLabel} + * parameter specifies the name of the attribute that contains the timestamp. + * + *

    The timestamp value is expected to be represented in the attribute as either: + * + *

      + *
    • a numerical value representing the number of milliseconds since the Unix epoch. For + * example, if using the Joda time classes, {@link Instant#getMillis()} returns the correct + * value for this attribute. + *
    • a String in RFC 3339 format. For example, {@code 2015-10-29T23:41:41.123Z}. The + * sub-second component of the timestamp is optional, and digits beyond the first three + * (i.e., time units smaller than milliseconds) will be ignored. + *
    + * + *

    If {@code timestampLabel} is not provided, the system will generate record timestamps + * the first time it sees each record. All windowing will be done relative to these timestamps. + * + *

    By default, windows are emitted based on an estimate of when this source is likely + * done producing data for a given timestamp (referred to as the Watermark; see + * {@link AfterWatermark} for more details). Any late data will be handled by the trigger + * specified with the windowing strategy – by default it will be output immediately. + * + *

    Note that the system can guarantee that no late data will ever be seen when it assigns + * timestamps by arrival time (i.e. {@code timestampLabel} is not provided). + * + * @see RFC 3339 + */ + public static Bound timestampLabel(String timestampLabel) { + return new Bound<>(DEFAULT_PUBSUB_CODER).timestampLabel(timestampLabel); + } + + /** + * Creates and returns a transform for reading from Cloud Pub/Sub where unique record + * identifiers are expected to be provided as Pub/Sub message attributes. The {@code idLabel} + * parameter specifies the attribute name. The value of the attribute can be any string + * that uniquely identifies this record. + * + *

    If {@code idLabel} is not provided, Dataflow cannot guarantee that no duplicate data will + * be delivered on the Pub/Sub stream. In this case, deduplication of the stream will be + * strictly best effort. + */ + public static Bound idLabel(String idLabel) { + return new Bound<>(DEFAULT_PUBSUB_CODER).idLabel(idLabel); + } + + /** + * Creates and returns a transform for reading from Cloud Pub/Sub that uses the given + * {@link Coder} to decode Pub/Sub messages into a value of type {@code T}. + * + *

    By default, uses {@link StringUtf8Coder}, which just + * returns the text lines as Java strings. + * + * @param the type of the decoded elements, and the elements + * of the resulting PCollection. + */ + public static Bound withCoder(Coder coder) { + return new Bound<>(coder); + } + + /** + * Creates and returns a transform for reading from Cloud Pub/Sub with a maximum number of + * records that will be read. The transform produces a bounded {@link PCollection}. + * + *

    Either this option or {@link #maxReadTime(Duration)} must be set in order to create a + * bounded source. + */ + public static Bound maxNumRecords(int maxNumRecords) { + return new Bound<>(DEFAULT_PUBSUB_CODER).maxNumRecords(maxNumRecords); + } + + /** + * Creates and returns a transform for reading from Cloud Pub/Sub with a maximum number of + * duration during which records will be read. The transform produces a bounded + * {@link PCollection}. + * + *

    Either this option or {@link #maxNumRecords(int)} must be set in order to create a bounded + * source. + */ + public static Bound maxReadTime(Duration maxReadTime) { + return new Bound<>(DEFAULT_PUBSUB_CODER).maxReadTime(maxReadTime); + } + + /** + * A {@link PTransform} that reads from a Cloud Pub/Sub source and returns + * a unbounded {@link PCollection} containing the items from the stream. + */ + public static class Bound extends PTransform> { + /** The Cloud Pub/Sub topic to read from. */ + @Nullable private final PubsubTopic topic; + + /** The Cloud Pub/Sub subscription to read from. */ + @Nullable private final PubsubSubscription subscription; + + /** The name of the message attribute to read timestamps from. */ + @Nullable private final String timestampLabel; + + /** The name of the message attribute to read unique message IDs from. */ + @Nullable private final String idLabel; + + /** The coder used to decode each record. */ + @Nullable private final Coder coder; + + /** Stop after reading this many records. */ + private final int maxNumRecords; + + /** Stop after reading for this much time. */ + @Nullable private final Duration maxReadTime; + + private Bound(Coder coder) { + this(null, null, null, null, coder, null, 0, null); + } + + private Bound(String name, PubsubSubscription subscription, PubsubTopic topic, + String timestampLabel, Coder coder, String idLabel, int maxNumRecords, + Duration maxReadTime) { + super(name); + this.subscription = subscription; + this.topic = topic; + this.timestampLabel = timestampLabel; + this.coder = coder; + this.idLabel = idLabel; + this.maxNumRecords = maxNumRecords; + this.maxReadTime = maxReadTime; + } + + /** + * Returns a transform that's like this one but with the given step name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>( + name, subscription, topic, timestampLabel, coder, idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but reading from the + * given subscription. + * + *

    See {@link PubsubIO.PubsubSubscription#fromPath(String)} for more details on the format + * of the {@code subscription} string. + * + *

    Multiple readers reading from the same subscription will each receive + * some arbitrary portion of the data. Most likely, separate readers should + * use their own subscriptions. + * + *

    Does not modify this object. + */ + public Bound subscription(String subscription) { + return new Bound<>(name, PubsubSubscription.fromPath(subscription), topic, timestampLabel, + coder, idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but that reads from the specified topic. + * + *

    See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the + * format of the {@code topic} string. + * + *

    Does not modify this object. + */ + public Bound topic(String topic) { + return new Bound<>(name, subscription, PubsubTopic.fromPath(topic), timestampLabel, coder, + idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but that reads message timestamps + * from the given message attribute. See {@link PubsubIO.Read#timestampLabel(String)} for + * more details on the format of the timestamp attribute. + * + *

    Does not modify this object. + */ + public Bound timestampLabel(String timestampLabel) { + return new Bound<>( + name, subscription, topic, timestampLabel, coder, idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but that reads unique message IDs + * from the given message attribute. See {@link PubsubIO.Read#idLabel(String)} for more + * details on the format of the ID attribute. + * + *

    Does not modify this object. + */ + public Bound idLabel(String idLabel) { + return new Bound<>( + name, subscription, topic, timestampLabel, coder, idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but that uses the given + * {@link Coder} to decode each record into a value of type {@code X}. + * + *

    Does not modify this object. + * + * @param the type of the decoded elements, and the + * elements of the resulting PCollection. + */ + public Bound withCoder(Coder coder) { + return new Bound<>( + name, subscription, topic, timestampLabel, coder, idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but will only read up to the specified + * maximum number of records from Cloud Pub/Sub. The transform produces a bounded + * {@link PCollection}. See {@link PubsubIO.Read#maxNumRecords(int)} for more details. + */ + public Bound maxNumRecords(int maxNumRecords) { + return new Bound<>( + name, subscription, topic, timestampLabel, coder, idLabel, maxNumRecords, maxReadTime); + } + + /** + * Returns a transform that's like this one but will only read during the specified + * duration from Cloud Pub/Sub. The transform produces a bounded {@link PCollection}. + * See {@link PubsubIO.Read#maxReadTime(Duration)} for more details. + */ + public Bound maxReadTime(Duration maxReadTime) { + return new Bound<>( + name, subscription, topic, timestampLabel, coder, idLabel, maxNumRecords, maxReadTime); + } + + @Override + public PCollection apply(PInput input) { + if (topic == null && subscription == null) { + throw new IllegalStateException("need to set either the topic or the subscription for " + + "a PubsubIO.Read transform"); + } + if (topic != null && subscription != null) { + throw new IllegalStateException("Can't set both the topic and the subscription for a " + + "PubsubIO.Read transform"); + } + + boolean boundedOutput = getMaxNumRecords() > 0 || getMaxReadTime() != null; + + if (boundedOutput) { + return input.getPipeline().begin() + .apply(Create.of((Void) null)).setCoder(VoidCoder.of()) + .apply(ParDo.of(new PubsubReader())).setCoder(coder); + } else { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED) + .setCoder(coder); + } + } + + @Override + protected Coder getDefaultOutputCoder() { + return coder; + } + + public PubsubTopic getTopic() { + return topic; + } + + public PubsubSubscription getSubscription() { + return subscription; + } + + public String getTimestampLabel() { + return timestampLabel; + } + + public Coder getCoder() { + return coder; + } + + public String getIdLabel() { + return idLabel; + } + + public int getMaxNumRecords() { + return maxNumRecords; + } + + public Duration getMaxReadTime() { + return maxReadTime; + } + + private class PubsubReader extends DoFn { + private static final int DEFAULT_PULL_SIZE = 100; + + @Override + public void processElement(ProcessContext c) throws IOException { + Pubsub pubsubClient = + Transport.newPubsubClient(c.getPipelineOptions().as(DataflowPipelineOptions.class)) + .build(); + + String subscription; + if (getSubscription() == null) { + String topic = getTopic().asPath(); + String[] split = topic.split("/"); + subscription = + "projects/" + split[1] + "/subscriptions/" + split[3] + "_dataflow_" + + new Random().nextLong(); + Subscription subInfo = new Subscription().setAckDeadlineSeconds(60).setTopic(topic); + try { + pubsubClient.projects().subscriptions().create(subscription, subInfo).execute(); + } catch (Exception e) { + throw new RuntimeException("Failed to create subscription: ", e); + } + } else { + subscription = getSubscription().asPath(); + } + + Instant endTime = (getMaxReadTime() == null) + ? new Instant(Long.MAX_VALUE) : Instant.now().plus(getMaxReadTime()); + + List messages = new ArrayList<>(); + + Throwable finallyBlockException = null; + try { + while ((getMaxNumRecords() == 0 || messages.size() < getMaxNumRecords()) + && Instant.now().isBefore(endTime)) { + PullRequest pullRequest = new PullRequest().setReturnImmediately(false); + if (getMaxNumRecords() > 0) { + pullRequest.setMaxMessages(getMaxNumRecords() - messages.size()); + } else { + pullRequest.setMaxMessages(DEFAULT_PULL_SIZE); + } + + PullResponse pullResponse = + pubsubClient.projects().subscriptions().pull(subscription, pullRequest).execute(); + List ackIds = new ArrayList<>(); + if (pullResponse.getReceivedMessages() != null) { + for (ReceivedMessage received : pullResponse.getReceivedMessages()) { + messages.add(received.getMessage()); + ackIds.add(received.getAckId()); + } + } + + if (ackIds.size() != 0) { + AcknowledgeRequest ackRequest = new AcknowledgeRequest().setAckIds(ackIds); + pubsubClient.projects() + .subscriptions() + .acknowledge(subscription, ackRequest) + .execute(); + } + } + } catch (IOException e) { + throw new RuntimeException("Unexpected exception while reading from Pubsub: ", e); + } finally { + if (getTopic() != null) { + try { + pubsubClient.projects().subscriptions().delete(subscription).execute(); + } catch (IOException e) { + finallyBlockException = new RuntimeException("Failed to delete subscription: ", e); + LOG.error("Failed to delete subscription: ", e); + } + } + } + if (finallyBlockException != null) { + Throwables.propagate(finallyBlockException); + } + + for (PubsubMessage message : messages) { + c.outputWithTimestamp( + CoderUtils.decodeFromByteArray(getCoder(), message.decodeData()), + assignMessageTimestamp(message, getTimestampLabel(), Clock.SYSTEM)); + } + } + } + } + + /** Disallow construction of utility class. */ + private Read() {} + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** Disallow construction of utility class. */ + private PubsubIO() {} + + /** + * A {@link PTransform} that continuously writes a + * {@link PCollection} of {@link String Strings} to a Cloud Pub/Sub stream. + */ + // TODO: Support non-String encodings. + public static class Write { + /** + * Creates a transform that writes to Pub/Sub with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(DEFAULT_PUBSUB_CODER).named(name); + } + + /** + * Creates a transform that publishes to the specified topic. + * + *

    See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the + * {@code topic} string. + */ + public static Bound topic(String topic) { + return new Bound<>(DEFAULT_PUBSUB_CODER).topic(topic); + } + + /** + * Creates a transform that writes to Pub/Sub, adds each record's timestamp to the published + * messages in an attribute with the specified name. The value of the attribute will be a number + * representing the number of milliseconds since the Unix epoch. For example, if using the Joda + * time classes, {@link Instant#Instant(long)} can be used to parse this value. + * + *

    If the output from this sink is being read by another Dataflow source, then + * {@link PubsubIO.Read#timestampLabel(String)} can be used to ensure the other source reads + * these timestamps from the appropriate attribute. + */ + public static Bound timestampLabel(String timestampLabel) { + return new Bound<>(DEFAULT_PUBSUB_CODER).timestampLabel(timestampLabel); + } + + /** + * Creates a transform that writes to Pub/Sub, adding each record's unique identifier to the + * published messages in an attribute with the specified name. The value of the attribute is an + * opaque string. + * + *

    If the the output from this sink is being read by another Dataflow source, then + * {@link PubsubIO.Read#idLabel(String)} can be used to ensure that* the other source reads + * these unique identifiers from the appropriate attribute. + */ + public static Bound idLabel(String idLabel) { + return new Bound<>(DEFAULT_PUBSUB_CODER).idLabel(idLabel); + } + + /** + * Creates a transform that uses the given {@link Coder} to encode each of the + * elements of the input collection into an output message. + * + *

    By default, uses {@link StringUtf8Coder}, which writes input Java strings directly as + * records. + * + * @param the type of the elements of the input PCollection + */ + public static Bound withCoder(Coder coder) { + return new Bound<>(coder); + } + + /** + * A {@link PTransform} that writes an unbounded {@link PCollection} of {@link String Strings} + * to a Cloud Pub/Sub stream. + */ + public static class Bound extends PTransform, PDone> { + /** The Cloud Pub/Sub topic to publish to. */ + @Nullable private final PubsubTopic topic; + /** The name of the message attribute to publish message timestamps in. */ + @Nullable private final String timestampLabel; + /** The name of the message attribute to publish unique message IDs in. */ + @Nullable private final String idLabel; + private final Coder coder; + + private Bound(Coder coder) { + this(null, null, null, null, coder); + } + + private Bound( + String name, PubsubTopic topic, String timestampLabel, String idLabel, Coder coder) { + super(name); + this.topic = topic; + this.timestampLabel = timestampLabel; + this.idLabel = idLabel; + this.coder = coder; + } + + /** + * Returns a new transform that's like this one but with the specified step + * name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, topic, timestampLabel, idLabel, coder); + } + + /** + * Returns a new transform that's like this one but that writes to the specified + * topic. + * + *

    See {@link PubsubIO.PubsubTopic#fromPath(String)} for more details on the format of the + * {@code topic} string. + * + *

    Does not modify this object. + */ + public Bound topic(String topic) { + return new Bound<>(name, PubsubTopic.fromPath(topic), timestampLabel, idLabel, coder); + } + + /** + * Returns a new transform that's like this one but that publishes record timestamps + * to a message attribute with the specified name. See + * {@link PubsubIO.Write#timestampLabel(String)} for more details. + * + *

    Does not modify this object. + */ + public Bound timestampLabel(String timestampLabel) { + return new Bound<>(name, topic, timestampLabel, idLabel, coder); + } + + /** + * Returns a new transform that's like this one but that publishes unique record IDs + * to a message attribute with the specified name. See {@link PubsubIO.Write#idLabel(String)} + * for more details. + * + *

    Does not modify this object. + */ + public Bound idLabel(String idLabel) { + return new Bound<>(name, topic, timestampLabel, idLabel, coder); + } + + /** + * Returns a new transform that's like this one + * but that uses the given {@link Coder} to encode each of + * the elements of the input {@link PCollection} into an + * output record. + * + *

    Does not modify this object. + * + * @param the type of the elements of the input {@link PCollection} + */ + public Bound withCoder(Coder coder) { + return new Bound<>(name, topic, timestampLabel, idLabel, coder); + } + + @Override + public PDone apply(PCollection input) { + if (topic == null) { + throw new IllegalStateException("need to set the topic of a PubsubIO.Write transform"); + } + input.apply(ParDo.of(new PubsubWriter())); + return PDone.in(input.getPipeline()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + public PubsubTopic getTopic() { + return topic; + } + + public String getTimestampLabel() { + return timestampLabel; + } + + public String getIdLabel() { + return idLabel; + } + + public Coder getCoder() { + return coder; + } + + private class PubsubWriter extends DoFn { + private static final int MAX_PUBLISH_BATCH_SIZE = 100; + private transient List output; + private transient Pubsub pubsubClient; + + @Override + public void startBundle(Context c) { + this.output = new ArrayList<>(); + this.pubsubClient = + Transport.newPubsubClient(c.getPipelineOptions().as(DataflowPipelineOptions.class)) + .build(); + } + + @Override + public void processElement(ProcessContext c) throws IOException { + PubsubMessage message = + new PubsubMessage().encodeData(CoderUtils.encodeToByteArray(getCoder(), c.element())); + if (getTimestampLabel() != null) { + Map attributes = message.getAttributes(); + if (attributes == null) { + attributes = new HashMap<>(); + message.setAttributes(attributes); + } + attributes.put(getTimestampLabel(), String.valueOf(c.timestamp().getMillis())); + } + output.add(message); + + if (output.size() >= MAX_PUBLISH_BATCH_SIZE) { + publish(); + } + } + + @Override + public void finishBundle(Context c) throws IOException { + if (!output.isEmpty()) { + publish(); + } + } + + private void publish() throws IOException { + PublishRequest publishRequest = new PublishRequest().setMessages(output); + pubsubClient.projects().topics() + .publish(getTopic().asPath(), publishRequest) + .execute(); + output.clear(); + } + } + } + + /** Disallow construction of utility class. */ + private Write() {} + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Read.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Read.java new file mode 100644 index 000000000000..cde87696bbcc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Read.java @@ -0,0 +1,253 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.util.StringUtils.approximateSimpleName; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PInput; + +import org.joda.time.Duration; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A {@link PTransform} for reading from a {@link Source}. + * + *

    Usage example: + *

    + * Pipeline p = Pipeline.create();
    + * p.apply(Read.from(new MySource().withFoo("foo").withBar("bar"))
    + *             .named("foobar"));
    + * 
    + */ +public class Read { + /** + * Returns a new {@code Read} {@code PTransform} builder with the given name. + */ + public static Builder named(String name) { + return new Builder(name); + } + + /** + * Returns a new {@code Read.Bounded} {@code PTransform} reading from the given + * {@code BoundedSource}. + */ + public static Bounded from(BoundedSource source) { + return new Bounded<>(null, source); + } + + /** + * Returns a new {@code Read.Unbounded} {@code PTransform} reading from the given + * {@code UnboundedSource}. + */ + public static Unbounded from(UnboundedSource source) { + return new Unbounded<>(null, source); + } + + /** + * Helper class for building {@code Read} transforms. + */ + public static class Builder { + private final String name; + + private Builder(String name) { + this.name = name; + } + + /** + * Returns a new {@code Read.Bounded} {@code PTransform} reading from the given + * {@code BoundedSource}. + */ + public Bounded from(BoundedSource source) { + return new Bounded<>(name, source); + } + + /** + * Returns a new {@code Read.Unbounded} {@code PTransform} reading from the given + * {@code UnboundedSource}. + */ + public Unbounded from(UnboundedSource source) { + return new Unbounded<>(name, source); + } + } + + /** + * {@link PTransform} that reads from a {@link BoundedSource}. + */ + public static class Bounded extends PTransform> { + private final BoundedSource source; + + private Bounded(@Nullable String name, BoundedSource source) { + super(name); + this.source = SerializableUtils.ensureSerializable(source); + } + + /** + * Returns a new {@code Bounded} {@code PTransform} that's like this one but + * has the given name. + * + *

    Does not modify this object. + */ + public Bounded named(String name) { + return new Bounded(name, source); + } + + @Override + protected Coder getDefaultOutputCoder() { + return source.getDefaultOutputCoder(); + } + + @Override + public final PCollection apply(PInput input) { + source.validate(); + + return PCollection.createPrimitiveOutputInternal(input.getPipeline(), + WindowingStrategy.globalDefault(), IsBounded.BOUNDED) + .setCoder(getDefaultOutputCoder()); + } + + /** + * Returns the {@code BoundedSource} used to create this {@code Read} {@code PTransform}. + */ + public BoundedSource getSource() { + return source; + } + + @Override + public String getKindString() { + return "Read(" + approximateSimpleName(source.getClass()) + ")"; + } + + static { + registerDefaultTransformEvaluator(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static void registerDefaultTransformEvaluator() { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bounded.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bounded transform, DirectPipelineRunner.EvaluationContext context) { + evaluateReadHelper(transform, context); + } + + private void evaluateReadHelper( + Read.Bounded transform, DirectPipelineRunner.EvaluationContext context) { + try { + List> output = new ArrayList<>(); + BoundedSource source = transform.getSource(); + try (BoundedSource.BoundedReader reader = + source.createReader(context.getPipelineOptions())) { + for (boolean available = reader.start(); + available; + available = reader.advance()) { + output.add( + DirectPipelineRunner.ValueWithMetadata.of( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp()))); + } + } + context.setPCollectionValuesWithMetadata(context.getOutput(transform), output); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + } + } + + /** + * {@link PTransform} that reads from a {@link UnboundedSource}. + */ + public static class Unbounded extends PTransform> { + private final UnboundedSource source; + + private Unbounded(@Nullable String name, UnboundedSource source) { + super(name); + this.source = SerializableUtils.ensureSerializable(source); + } + + /** + * Returns a new {@code Unbounded} {@code PTransform} that's like this one but + * has the given name. + * + *

    Does not modify this object. + */ + public Unbounded named(String name) { + return new Unbounded(name, source); + } + + /** + * Returns a new {@link BoundedReadFromUnboundedSource} that reads a bounded amount + * of data from the given {@link UnboundedSource}. The bound is specified as a number + * of records to read. + * + *

    This may take a long time to execute if the splits of this source are slow to read + * records. + */ + public BoundedReadFromUnboundedSource withMaxNumRecords(long maxNumRecords) { + return new BoundedReadFromUnboundedSource(source, maxNumRecords, null); + } + + /** + * Returns a new {@link BoundedReadFromUnboundedSource} that reads a bounded amount + * of data from the given {@link UnboundedSource}. The bound is specified as an amount + * of time to read for. Each split of the source will read for this much time. + */ + public BoundedReadFromUnboundedSource withMaxReadTime(Duration maxReadTime) { + return new BoundedReadFromUnboundedSource(source, Long.MAX_VALUE, maxReadTime); + } + + @Override + protected Coder getDefaultOutputCoder() { + return source.getDefaultOutputCoder(); + } + + @Override + public final PCollection apply(PInput input) { + source.validate(); + + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED); + } + + /** + * Returns the {@code UnboundedSource} used to create this {@code Read} {@code PTransform}. + */ + public UnboundedSource getSource() { + return source; + } + + @Override + public String getKindString() { + return "Read(" + approximateSimpleName(source.getClass()) + ")"; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/ShardNameTemplate.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/ShardNameTemplate.java new file mode 100644 index 000000000000..727001276809 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/ShardNameTemplate.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +/** + * Standard shard naming templates. + * + *

    Shard naming templates are strings that may contain placeholders for + * the shard number and shard count. When constructing a filename for a + * particular shard number, the upper-case letters 'S' and 'N' are replaced + * with the 0-padded shard number and shard count respectively. + * + *

    Left-padding of the numbers enables lexicographical sorting of the + * resulting filenames. If the shard number or count are too large for the + * space provided in the template, then the result may no longer sort + * lexicographically. For example, a shard template of "S-of-N", for 200 + * shards, will result in outputs named "0-of-200", ... '10-of-200', + * '100-of-200", etc. + * + *

    Shard numbers start with 0, so the last shard number is the shard count + * minus one. For example, the template "-SSSSS-of-NNNNN" will be + * instantiated as "-00000-of-01000" for the first shard (shard 0) of a + * 1000-way sharded output. + * + *

    A shard name template is typically provided along with a name prefix + * and suffix, which allows constructing complex paths that have embedded + * shard information. For example, outputs in the form + * "gs://bucket/path-01-of-99.txt" could be constructed by providing the + * individual components: + * + *

    {@code
    + *   pipeline.apply(
    + *       TextIO.Write.to("gs://bucket/path")
    + *                   .withShardNameTemplate("-SS-of-NN")
    + *                   .withSuffix(".txt"))
    + * }
    + * + *

    In the example above, you could make parts of the output configurable + * by users without the user having to specify all components of the output + * name. + * + *

    If a shard name template does not contain any repeating 'S', then + * the output shard count must be 1, as otherwise the same filename would be + * generated for multiple shards. + */ +public class ShardNameTemplate { + /** + * Shard name containing the index and max. + * + *

    Eg: [prefix]-00000-of-00100[suffix] and + * [prefix]-00001-of-00100[suffix] + */ + public static final String INDEX_OF_MAX = "-SSSSS-of-NNNNN"; + + /** + * Shard is a file within a directory. + * + *

    Eg: [prefix]/part-00000[suffix] and [prefix]/part-00001[suffix] + */ + public static final String DIRECTORY_CONTAINER = "/part-SSSSS"; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Sink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Sink.java new file mode 100644 index 000000000000..a5649ceb5e9e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Sink.java @@ -0,0 +1,252 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.io.Serializable; + +/** + * A {@code Sink} represents a resource that can be written to using the {@link Write} transform. + * + *

    A parallel write to a {@code Sink} consists of three phases: + *

      + *
    1. A sequential initialization phase (e.g., creating a temporary output directory, etc.) + *
    2. A parallel write phase where workers write bundles of records + *
    3. A sequential finalization phase (e.g., committing the writes, merging output files, + * etc.) + *
    + * + *

    The {@link Write} transform can be used in a Dataflow pipeline to perform this write. + * Specifically, a Write transform can be applied to a {@link PCollection} {@code p} by: + * + *

    {@code p.apply(Write.to(new MySink()));} + * + *

    Implementing a {@link Sink} and the corresponding write operations requires extending three + * abstract classes: + * + *

      + *
    • {@link Sink}: an immutable logical description of the location/resource to write to. + * Depending on the type of sink, it may contain fields such as the path to an output directory + * on a filesystem, a database table name, etc. Implementors of {@link Sink} must + * implement two methods: {@link Sink#validate} and {@link Sink#createWriteOperation}. + * {@link Sink#validate Validate} is called by the Write transform at pipeline creation, and should + * validate that the Sink can be written to. The createWriteOperation method is also called at + * pipeline creation, and should return a WriteOperation object that defines how to write to the + * Sink. Note that implementations of Sink must be serializable and Sinks must be immutable. + * + *
    • {@link WriteOperation}: The WriteOperation implements the initialization and + * finalization phases of a write. Implementors of {@link WriteOperation} must implement + * corresponding {@link WriteOperation#initialize} and {@link WriteOperation#finalize} methods. A + * WriteOperation must also implement {@link WriteOperation#createWriter} that creates Writers, + * {@link WriteOperation#getWriterResultCoder} that returns a {@link Coder} for the result of a + * parallel write, and a {@link WriteOperation#getSink} that returns the Sink that the write + * operation corresponds to. See below for more information about these methods and restrictions on + * their implementation. + * + *
    • {@link Writer}: A Writer writes a bundle of records. Writer defines four methods: + * {@link Writer#open}, which is called once at the start of writing a bundle; {@link Writer#write}, + * which writes a single record from the bundle; {@link Writer#close}, which is called once at the + * end of writing a bundle; and {@link Writer#getWriteOperation}, which returns the write operation + * that the writer belongs to. + *
    + * + *

    WriteOperation

    + *

    {@link WriteOperation#initialize} and {@link WriteOperation#finalize} are conceptually called + * once: at the beginning and end of a Write transform. However, implementors must ensure that these + * methods are idempotent, as they may be called multiple times on different machines in the case of + * failure/retry or for redundancy. + * + *

    The finalize method of WriteOperation is passed an Iterable of a writer result type. This + * writer result type should encode the result of a write and, in most cases, some encoding of the + * unique bundle id. + * + *

    All implementations of {@link WriteOperation} must be serializable. + * + *

    WriteOperation may have mutable state. For instance, {@link WriteOperation#initialize} may + * mutate the object state. These mutations will be visible in {@link WriteOperation#createWriter} + * and {@link WriteOperation#finalize} because the object will be serialized after initialize and + * deserialized before these calls. However, it is not serialized again after createWriter is + * called, as createWriter will be called within workers to create Writers for the bundles that are + * distributed to these workers. Therefore, newWriter should not mutate the WriteOperation state (as + * these mutations will not be visible in finalize). + * + *

    Bundle Ids:

    + *

    In order to ensure fault-tolerance, a bundle may be executed multiple times (e.g., in the + * event of failure/retry or for redundancy). However, exactly one of these executions will have its + * result passed to the WriteOperation's finalize method. Each call to {@link Writer#open} is passed + * a unique bundle id when it is called by the Write transform, so even redundant or retried + * bundles will have a unique way of identifying their output. + * + *

    The bundle id should be used to guarantee that a bundle's output is unique. This uniqueness + * guarantee is important; if a bundle is to be output to a file, for example, the name of the file + * must be unique to avoid conflicts with other Writers. The bundle id should be encoded in the + * writer result returned by the Writer and subsequently used by the WriteOperation's finalize + * method to identify the results of successful writes. + * + *

    For example, consider the scenario where a Writer writes files containing serialized records + * and the WriteOperation's finalization step is to merge or rename these output files. In this + * case, a Writer may use its unique id to name its output file (to avoid conflicts) and return the + * name of the file it wrote as its writer result. The WriteOperation will then receive an Iterable + * of output file names that it can then merge or rename using some bundle naming scheme. + * + *

    Writer Results:

    + *

    {@link WriteOperation}s and {@link Writer}s must agree on a writer result type that will be + * returned by a Writer after it writes a bundle. This type can be a client-defined object or an + * existing type; {@link WriteOperation#getWriterResultCoder} should return a {@link Coder} for the + * type. + * + *

    A note about thread safety: Any use of static members or methods in Writer should be thread + * safe, as different instances of Writer objects may be created in different threads on the same + * worker. + * + * @param the type that will be written to the Sink. + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public abstract class Sink implements Serializable { + /** + * Ensures that the sink is valid and can be written to before the write operation begins. One + * should use {@link com.google.common.base.Preconditions} to implement this method. + */ + public abstract void validate(PipelineOptions options); + + /** + * Returns an instance of a {@link WriteOperation} that can write to this Sink. + */ + public abstract WriteOperation createWriteOperation(PipelineOptions options); + + /** + * A {@link WriteOperation} defines the process of a parallel write of objects to a Sink. + * + *

    The {@code WriteOperation} defines how to perform initialization and finalization of a + * parallel write to a sink as well as how to create a {@link Sink.Writer} object that can write + * a bundle to the sink. + * + *

    Since operations in Dataflow may be run multiple times for redundancy or fault-tolerance, + * the initialization and finalization defined by a WriteOperation must be idempotent. + * + *

    {@code WriteOperation}s may be mutable; a {@code WriteOperation} is serialized after the + * call to {@code initialize} method and deserialized before calls to + * {@code createWriter} and {@code finalized}. However, it is not + * reserialized after {@code createWriter}, so {@code createWriter} should not mutate the + * state of the {@code WriteOperation}. + * + *

    See {@link Sink} for more detailed documentation about the process of writing to a Sink. + * + * @param The type of objects to write + * @param The result of a per-bundle write + */ + public abstract static class WriteOperation implements Serializable { + /** + * Performs initialization before writing to the sink. Called before writing begins. + */ + public abstract void initialize(PipelineOptions options) throws Exception; + + /** + * Given an Iterable of results from bundle writes, performs finalization after writing and + * closes the sink. Called after all bundle writes are complete. + * + *

    The results that are passed to finalize are those returned by bundles that completed + * successfully. Although bundles may have been run multiple times (for fault-tolerance), only + * one writer result will be passed to finalize for each bundle. An implementation of finalize + * should perform clean up of any failed and successfully retried bundles. Note that these + * failed bundles will not have their writer result passed to finalize, so finalize should be + * capable of locating any temporary/partial output written by failed bundles. + * + *

    A best practice is to make finalize atomic. If this is impossible given the semantics + * of the sink, finalize should be idempotent, as it may be called multiple times in the case of + * failure/retry or for redundancy. + * + *

    Note that the iteration order of the writer results is not guaranteed to be consistent if + * finalize is called multiple times. + * + * @param writerResults an Iterable of results from successful bundle writes. + */ + public abstract void finalize(Iterable writerResults, PipelineOptions options) + throws Exception; + + /** + * Creates a new {@link Sink.Writer} to write a bundle of the input to the sink. + * + *

    The bundle id that the writer will use to uniquely identify its output will be passed to + * {@link Writer#open}. + * + *

    Must not mutate the state of the WriteOperation. + */ + public abstract Writer createWriter(PipelineOptions options) throws Exception; + + /** + * Returns the Sink that this write operation writes to. + */ + public abstract Sink getSink(); + + /** + * Returns a coder for the writer result type. + */ + public Coder getWriterResultCoder() { + return null; + } + } + + /** + * A Writer writes a bundle of elements from a PCollection to a sink. {@link Writer#open} is + * called before writing begins and {@link Writer#close} is called after all elements in the + * bundle have been written. {@link Writer#write} writes an element to the sink. + * + *

    Note that any access to static members or methods of a Writer must be thread-safe, as + * multiple instances of a Writer may be instantiated in different threads on the same worker. + * + *

    See {@link Sink} for more detailed documentation about the process of writing to a Sink. + * + * @param The type of object to write + * @param The writer results type (e.g., the bundle's output filename, as String) + */ + public abstract static class Writer { + /** + * Performs bundle initialization. For example, creates a temporary file for writing or + * initializes any state that will be used across calls to {@link Writer#write}. + * + *

    The unique id that is given to open should be used to ensure that the writer's output does + * not interfere with the output of other Writers, as a bundle may be executed many times for + * fault tolerance. See {@link Sink} for more information about bundle ids. + */ + public abstract void open(String uId) throws Exception; + + /** + * Called for each value in the bundle. + */ + public abstract void write(T value) throws Exception; + + /** + * Finishes writing the bundle. Closes any resources used for writing the bundle. + * + *

    Returns a writer result that will be used in the {@link Sink.WriteOperation}'s + * finalization. The result should contain some way to identify the output of this bundle (using + * the bundle id). {@link WriteOperation#finalize} will use the writer result to identify + * successful writes. See {@link Sink} for more information about bundle ids. + * + * @return the writer result + */ + public abstract WriteT close() throws Exception; + + /** + * Returns the write operation this writer belongs to. + */ + public abstract WriteOperation getWriteOperation(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Source.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Source.java new file mode 100644 index 000000000000..4a020787f5c2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Source.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.Serializable; +import java.util.NoSuchElementException; + +/** + * Base class for defining input formats and creating a {@code Source} for reading the input. + * + *

    This class is not intended to be subclassed directly. Instead, to define + * a bounded source (a source which produces a finite amount of input), subclass + * {@link BoundedSource}; to define an unbounded source, subclass {@link UnboundedSource}. + * + *

    A {@code Source} passed to a {@code Read} transform must be + * {@code Serializable}. This allows the {@code Source} instance + * created in this "main program" to be sent (in serialized form) to + * remote worker machines and reconstituted for each batch of elements + * of the input {@code PCollection} being processed or for each source splitting + * operation. A {@code Source} can have instance variable state, and + * non-transient instance variable state will be serialized in the main program + * and then deserialized on remote worker machines. + * + *

    {@code Source} classes MUST be effectively immutable. The only acceptable use of + * mutable fields is to cache the results of expensive operations, and such fields MUST be + * marked {@code transient}. + * + *

    {@code Source} objects should override {@link Object#toString}, as it will be + * used in important error and debugging messages. + * + * @param Type of elements read by the source. + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public abstract class Source implements Serializable { + /** + * Checks that this source is valid, before it can be used in a pipeline. + * + *

    It is recommended to use {@link com.google.common.base.Preconditions} for implementing + * this method. + */ + public abstract void validate(); + + /** + * Returns the default {@code Coder} to use for the data read from this source. + */ + public abstract Coder getDefaultOutputCoder(); + + /** + * The interface that readers of custom input sources must implement. + * + *

    This interface is deliberately distinct from {@link java.util.Iterator} because + * the current model tends to be easier to program and more efficient in practice + * for iterating over sources such as files, databases etc. (rather than pure collections). + * + *

    Reading data from the {@link Reader} must obey the following access pattern: + *

      + *
    • One call to {@link #start} + *
      • If {@link #start} returned true, any number of calls to {@code getCurrent}* + * methods
      + *
    • Repeatedly, a call to {@link #advance}. This may be called regardless + * of what the previous {@link #start}/{@link #advance} returned. + *
      • If {@link #advance} returned true, any number of calls to {@code getCurrent}* + * methods
      + *
    + * + *

    For example, if the reader is reading a fixed set of data: + *

    +   *   try {
    +   *     for (boolean available = reader.start(); available; available = reader.advance()) {
    +   *       T item = reader.getCurrent();
    +   *       Instant timestamp = reader.getCurrentTimestamp();
    +   *       ...
    +   *     }
    +   *   } finally {
    +   *     reader.close();
    +   *   }
    +   * 
    + * + *

    If the set of data being read is continually growing: + *

    +   *   try {
    +   *     boolean available = reader.start();
    +   *     while (true) {
    +   *       if (available) {
    +   *         T item = reader.getCurrent();
    +   *         Instant timestamp = reader.getCurrentTimestamp();
    +   *         ...
    +   *         resetExponentialBackoff();
    +   *       } else {
    +   *         exponentialBackoff();
    +   *       }
    +   *       available = reader.advance();
    +   *     }
    +   *   } finally {
    +   *     reader.close();
    +   *   }
    +   * 
    + * + *

    Note: this interface is a work-in-progress and may change. + * + *

    All {@code Reader} functions except {@link #getCurrentSource} do not need to be thread-safe; + * they may only be accessed by a single thread at once. However, {@link #getCurrentSource} needs + * to be thread-safe, and other functions should assume that its returned value can change + * asynchronously. + */ + public abstract static class Reader implements AutoCloseable { + /** + * Initializes the reader and advances the reader to the first record. + * + *

    This method should be called exactly once. The invocation should occur prior to calling + * {@link #advance} or {@link #getCurrent}. This method may perform expensive operations that + * are needed to initialize the reader. + * + * @return {@code true} if a record was read, {@code false} if there is no more input available. + */ + public abstract boolean start() throws IOException; + + /** + * Advances the reader to the next valid record. + * + *

    It is an error to call this without having called {@link #start} first. + * + * @return {@code true} if a record was read, {@code false} if there is no more input available. + */ + public abstract boolean advance() throws IOException; + + /** + * Returns the value of the data item that was read by the last {@link #start} or + * {@link #advance} call. The returned value must be effectively immutable and remain valid + * indefinitely. + * + *

    Multiple calls to this method without an intervening call to {@link #advance} should + * return the same result. + * + * @throws java.util.NoSuchElementException if {@link #start} was never called, or if + * the last {@link #start} or {@link #advance} returned {@code false}. + */ + public abstract T getCurrent() throws NoSuchElementException; + + /** + * Returns the timestamp associated with the current data item. + * + *

    If the source does not support timestamps, this should return + * {@code BoundedWindow.TIMESTAMP_MIN_VALUE}. + * + *

    Multiple calls to this method without an intervening call to {@link #advance} should + * return the same result. + * + * @throws NoSuchElementException if the reader is at the beginning of the input and + * {@link #start} or {@link #advance} wasn't called, or if the last {@link #start} or + * {@link #advance} returned {@code false}. + */ + public abstract Instant getCurrentTimestamp() throws NoSuchElementException; + + /** + * Closes the reader. The reader cannot be used after this method is called. + */ + @Override + public abstract void close() throws IOException; + + /** + * Returns a {@code Source} describing the same input that this {@code Reader} currently reads + * (including items already read). + * + *

    Usually, an implementation will simply return the immutable {@link Source} object from + * which the current {@link Reader} was constructed, or delegate to the base class. + * However, when using or implementing this method on a {@link BoundedSource.BoundedReader}, + * special considerations apply, see documentation for + * {@link BoundedSource.BoundedReader#getCurrentSource}. + */ + public abstract Source getCurrentSource(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java new file mode 100644 index 000000000000..d342f250b2e8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java @@ -0,0 +1,992 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.Read.Bounded; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MimeTypes; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; +import java.util.NoSuchElementException; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * {@link PTransform}s for reading and writing text files. + * + *

    To read a {@link PCollection} from one or more text files, use {@link TextIO.Read}. + * You can instantiate a transform using {@link TextIO.Read#from(String)} to specify + * the path of the file(s) to read from (e.g., a local filename or + * filename pattern if running locally, or a Google Cloud Storage + * filename or filename pattern of the form + * {@code "gs:///"}). You may optionally call + * {@link TextIO.Read#named(String)} to specify the name of the pipeline step. + * + *

    By default, {@link TextIO.Read} returns a {@link PCollection} of {@link String Strings}, + * each corresponding to one line of an input UTF-8 text file. To convert directly from the raw + * bytes (split into lines delimited by '\n', '\r', or '\r\n') to another object of type {@code T}, + * supply a {@code Coder} using {@link TextIO.Read#withCoder(Coder)}. + * + *

    See the following examples: + * + *

    {@code
    + * Pipeline p = ...;
    + *
    + * // A simple Read of a local file (only runs locally):
    + * PCollection lines =
    + *     p.apply(TextIO.Read.from("/local/path/to/file.txt"));
    + *
    + * // A fully-specified Read from a GCS file (runs locally and via the
    + * // Google Cloud Dataflow service):
    + * PCollection numbers =
    + *     p.apply(TextIO.Read.named("ReadNumbers")
    + *                        .from("gs://my_bucket/path/to/numbers-*.txt")
    + *                        .withCoder(TextualIntegerCoder.of()));
    + * }
    + * + *

    To write a {@link PCollection} to one or more text files, use + * {@link TextIO.Write}, specifying {@link TextIO.Write#to(String)} to specify + * the path of the file to write to (e.g., a local filename or sharded + * filename pattern if running locally, or a Google Cloud Storage + * filename or sharded filename pattern of the form + * {@code "gs:///"}). You can optionally name the resulting transform using + * {@link TextIO.Write#named(String)}, and you can use {@link TextIO.Write#withCoder(Coder)} + * to specify the Coder to use to encode the Java values into text lines. + * + *

    Any existing files with the same names as generated output files + * will be overwritten. + * + *

    For example: + *

    {@code
    + * // A simple Write to a local file (only runs locally):
    + * PCollection lines = ...;
    + * lines.apply(TextIO.Write.to("/path/to/file.txt"));
    + *
    + * // A fully-specified Write to a sharded GCS file (runs locally and via the
    + * // Google Cloud Dataflow service):
    + * PCollection numbers = ...;
    + * numbers.apply(TextIO.Write.named("WriteNumbers")
    + *                           .to("gs://my_bucket/path/to/numbers")
    + *                           .withSuffix(".txt")
    + *                           .withCoder(TextualIntegerCoder.of()));
    + * }
    + * + *

    Permissions

    + *

    When run using the {@link DirectPipelineRunner}, your pipeline can read and write text files + * on your local drive and remote text files on Google Cloud Storage that you have access to using + * your {@code gcloud} credentials. When running in the Dataflow service using + * {@link DataflowPipelineRunner}, the pipeline can only read and write files from GCS. For more + * information about permissions, see the Cloud Dataflow documentation on + * Security and + * Permissions. + */ +public class TextIO { + /** The default coder, which returns each line of the input file as a string. */ + public static final Coder DEFAULT_TEXT_CODER = StringUtf8Coder.of(); + + /** + * A {@link PTransform} that reads from a text file (or multiple text + * files matching a pattern) and returns a {@link PCollection} containing + * the decoding of each of the lines of the text file(s). The + * default decoding just returns each line as a {@link String}, but you may call + * {@link #withCoder(Coder)} to change the return type. + */ + public static class Read { + /** + * Returns a transform for reading text files that uses the given step name. + */ + public static Bound named(String name) { + return new Bound<>(DEFAULT_TEXT_CODER).named(name); + } + + /** + * Returns a transform for reading text files that reads from the file(s) + * with the given filename or filename pattern. This can be a local path (if running locally), + * or a Google Cloud Storage filename or filename pattern of the form + * {@code "gs:///"} (if running locally or via the Google Cloud Dataflow + * service). Standard Java Filesystem glob patterns ("*", "?", "[..]") are supported. + */ + public static Bound from(String filepattern) { + return new Bound<>(DEFAULT_TEXT_CODER).from(filepattern); + } + + /** + * Returns a transform for reading text files that uses the given + * {@code Coder} to decode each of the lines of the file into a + * value of type {@code T}. + * + *

    By default, uses {@link StringUtf8Coder}, which just + * returns the text lines as Java strings. + * + * @param the type of the decoded elements, and the elements + * of the resulting PCollection + */ + public static Bound withCoder(Coder coder) { + return new Bound<>(coder); + } + + /** + * Returns a transform for reading text files that has GCS path validation on + * pipeline creation disabled. + * + *

    This can be useful in the case where the GCS input does not + * exist at the pipeline creation time, but is expected to be + * available at execution time. + */ + public static Bound withoutValidation() { + return new Bound<>(DEFAULT_TEXT_CODER).withoutValidation(); + } + + /** + * Returns a transform for reading text files that decompresses all input files + * using the specified compression type. + * + *

    If no compression type is specified, the default is {@link TextIO.CompressionType#AUTO}. + * In this mode, the compression type of the file is determined by its extension + * (e.g., {@code *.gz} is gzipped, {@code *.bz2} is bzipped, and all other extensions are + * uncompressed). + */ + public static Bound withCompressionType(TextIO.CompressionType compressionType) { + return new Bound<>(DEFAULT_TEXT_CODER).withCompressionType(compressionType); + } + + // TODO: strippingNewlines, etc. + + /** + * A {@link PTransform} that reads from one or more text files and returns a bounded + * {@link PCollection} containing one element for each line of the input files. + * + * @param the type of each of the elements of the resulting + * {@link PCollection}. By default, each line is returned as a {@link String}, however you + * may use {@link #withCoder(Coder)} to supply a {@code Coder} to produce a + * {@code PCollection} instead. + */ + public static class Bound extends PTransform> { + /** The filepattern to read from. */ + @Nullable private final String filepattern; + + /** The Coder to use to decode each line. */ + private final Coder coder; + + /** An option to indicate if input validation is desired. Default is true. */ + private final boolean validate; + + /** Option to indicate the input source's compression type. Default is AUTO. */ + private final TextIO.CompressionType compressionType; + + Bound(Coder coder) { + this(null, null, coder, true, TextIO.CompressionType.AUTO); + } + + private Bound(String name, String filepattern, Coder coder, boolean validate, + TextIO.CompressionType compressionType) { + super(name); + this.coder = coder; + this.filepattern = filepattern; + this.validate = validate; + this.compressionType = compressionType; + } + + /** + * Returns a new transform for reading from text files that's like this one but + * with the given step name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filepattern, coder, validate, compressionType); + } + + /** + * Returns a new transform for reading from text files that's like this one but + * that reads from the file(s) with the given name or pattern. See {@link TextIO.Read#from} + * for a description of filepatterns. + * + *

    Does not modify this object. + + */ + public Bound from(String filepattern) { + return new Bound<>(name, filepattern, coder, validate, compressionType); + } + + /** + * Returns a new transform for reading from text files that's like this one but + * that uses the given {@link Coder Coder} to decode each of the + * lines of the file into a value of type {@code X}. + * + *

    Does not modify this object. + * + * @param the type of the decoded elements, and the + * elements of the resulting PCollection + */ + public Bound withCoder(Coder coder) { + return new Bound<>(name, filepattern, coder, validate, compressionType); + } + + /** + * Returns a new transform for reading from text files that's like this one but + * that has GCS path validation on pipeline creation disabled. + * + *

    This can be useful in the case where the GCS input does not + * exist at the pipeline creation time, but is expected to be + * available at execution time. + * + *

    Does not modify this object. + */ + public Bound withoutValidation() { + return new Bound<>(name, filepattern, coder, false, compressionType); + } + + /** + * Returns a new transform for reading from text files that's like this one but + * reads from input sources using the specified compression type. + * + *

    If no compression type is specified, the default is {@link TextIO.CompressionType#AUTO}. + * See {@link TextIO.Read#withCompressionType} for more details. + * + *

    Does not modify this object. + */ + public Bound withCompressionType(TextIO.CompressionType compressionType) { + return new Bound<>(name, filepattern, coder, validate, compressionType); + } + + @Override + public PCollection apply(PInput input) { + if (filepattern == null) { + throw new IllegalStateException("need to set the filepattern of a TextIO.Read transform"); + } + + if (validate) { + try { + checkState( + !IOChannelUtils.getFactory(filepattern).match(filepattern).isEmpty(), + "Unable to find any files matching %s", + filepattern); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to validate %s", filepattern), e); + } + } + + // Create a source specific to the requested compression type. + final Bounded read; + switch(compressionType) { + case UNCOMPRESSED: + read = com.google.cloud.dataflow.sdk.io.Read.from( + new TextSource(filepattern, coder)); + break; + case AUTO: + read = com.google.cloud.dataflow.sdk.io.Read.from( + CompressedSource.from(new TextSource(filepattern, coder))); + break; + case BZIP2: + read = com.google.cloud.dataflow.sdk.io.Read.from( + CompressedSource.from(new TextSource(filepattern, coder)) + .withDecompression(CompressedSource.CompressionMode.BZIP2)); + break; + case GZIP: + read = com.google.cloud.dataflow.sdk.io.Read.from( + CompressedSource.from(new TextSource(filepattern, coder)) + .withDecompression(CompressedSource.CompressionMode.GZIP)); + break; + default: + throw new IllegalArgumentException("Unknown compression mode: " + compressionType); + } + + PCollection pcol = input.getPipeline().apply("Read", read); + // Honor the default output coder that would have been used by this PTransform. + pcol.setCoder(getDefaultOutputCoder()); + return pcol; + } + + @Override + protected Coder getDefaultOutputCoder() { + return coder; + } + + public String getFilepattern() { + return filepattern; + } + + public boolean needsValidation() { + return validate; + } + + public TextIO.CompressionType getCompressionType() { + return compressionType; + } + } + + /** Disallow construction of utility classes. */ + private Read() {} + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@link PTransform} that writes a {@link PCollection} to text file (or + * multiple text files matching a sharding pattern), with each + * element of the input collection encoded into its own line. + */ + public static class Write { + /** + * Returns a transform for writing to text files with the given step name. + */ + public static Bound named(String name) { + return new Bound<>(DEFAULT_TEXT_CODER).named(name); + } + + /** + * Returns a transform for writing to text files that writes to the file(s) + * with the given prefix. This can be a local filename + * (if running locally), or a Google Cloud Storage filename of + * the form {@code "gs:///"} + * (if running locally or via the Google Cloud Dataflow service). + * + *

    The files written will begin with this prefix, followed by + * a shard identifier (see {@link Bound#withNumShards(int)}, and end + * in a common extension, if given by {@link Bound#withSuffix(String)}. + */ + public static Bound to(String prefix) { + return new Bound<>(DEFAULT_TEXT_CODER).to(prefix); + } + + /** + * Returns a transform for writing to text files that appends the specified suffix + * to the created files. + */ + public static Bound withSuffix(String nameExtension) { + return new Bound<>(DEFAULT_TEXT_CODER).withSuffix(nameExtension); + } + + /** + * Returns a transform for writing to text files that uses the provided shard count. + * + *

    Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + */ + public static Bound withNumShards(int numShards) { + return new Bound<>(DEFAULT_TEXT_CODER).withNumShards(numShards); + } + + /** + * Returns a transform for writing to text files that uses the given shard name + * template. + * + *

    See {@link ShardNameTemplate} for a description of shard templates. + */ + public static Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(DEFAULT_TEXT_CODER).withShardNameTemplate(shardTemplate); + } + + /** + * Returns a transform for writing to text files that forces a single file as + * output. + */ + public static Bound withoutSharding() { + return new Bound<>(DEFAULT_TEXT_CODER).withoutSharding(); + } + + /** + * Returns a transform for writing to text files that uses the given + * {@link Coder} to encode each of the elements of the input + * {@link PCollection} into an output text line. + * + *

    By default, uses {@link StringUtf8Coder}, which writes input + * Java strings directly as output lines. + * + * @param the type of the elements of the input {@link PCollection} + */ + public static Bound withCoder(Coder coder) { + return new Bound<>(coder); + } + + /** + * Returns a transform for writing to text files that has GCS path validation on + * pipeline creation disabled. + * + *

    This can be useful in the case where the GCS output location does + * not exist at the pipeline creation time, but is expected to be available + * at execution time. + */ + public static Bound withoutValidation() { + return new Bound<>(DEFAULT_TEXT_CODER).withoutValidation(); + } + + // TODO: appendingNewlines, header, footer, etc. + + /** + * A PTransform that writes a bounded PCollection to a text file (or + * multiple text files matching a sharding pattern), with each + * PCollection element being encoded into its own line. + * + * @param the type of the elements of the input PCollection + */ + public static class Bound extends PTransform, PDone> { + /** The prefix of each file written, combined with suffix and shardTemplate. */ + @Nullable private final String filenamePrefix; + /** The suffix of each file written, combined with prefix and shardTemplate. */ + private final String filenameSuffix; + + /** The Coder to use to decode each line. */ + private final Coder coder; + + /** Requested number of shards. 0 for automatic. */ + private final int numShards; + + /** The shard template of each file written, combined with prefix and suffix. */ + private final String shardTemplate; + + /** An option to indicate if output validation is desired. Default is true. */ + private final boolean validate; + + Bound(Coder coder) { + this(null, null, "", coder, 0, ShardNameTemplate.INDEX_OF_MAX, true); + } + + private Bound(String name, String filenamePrefix, String filenameSuffix, Coder coder, + int numShards, String shardTemplate, boolean validate) { + super(name); + this.coder = coder; + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.validate = validate; + } + + /** + * Returns a transform for writing to text files that's like this one but + * with the given step name. + * + *

    Does not modify this object. + */ + public Bound named(String name) { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate, validate); + } + + /** + * Returns a transform for writing to text files that's like this one but + * that writes to the file(s) with the given filename prefix. + * + *

    See {@link TextIO.Write#to(String) Write.to(String)} for more information. + * + *

    Does not modify this object. + */ + public Bound to(String filenamePrefix) { + validateOutputComponent(filenamePrefix); + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate, validate); + } + + /** + * Returns a transform for writing to text files that that's like this one but + * that writes to the file(s) with the given filename suffix. + * + *

    Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withSuffix(String nameExtension) { + validateOutputComponent(nameExtension); + return new Bound<>(name, filenamePrefix, nameExtension, coder, numShards, + shardTemplate, validate); + } + + /** + * Returns a transform for writing to text files that's like this one but + * that uses the provided shard count. + * + *

    Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + *

    Does not modify this object. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + * @see ShardNameTemplate + */ + public Bound withNumShards(int numShards) { + Preconditions.checkArgument(numShards >= 0); + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate, validate); + } + + /** + * Returns a transform for writing to text files that's like this one but + * that uses the given shard name template. + * + *

    Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withShardNameTemplate(String shardTemplate) { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate, validate); + } + + /** + * Returns a transform for writing to text files that's like this one but + * that forces a single file as output. + * + *

    Constraining the number of shards is likely to reduce + * the performance of a pipeline. Using this setting is not recommended + * unless you truly require a single output file. + * + *

    This is a shortcut for + * {@code .withNumShards(1).withShardNameTemplate("")} + * + *

    Does not modify this object. + */ + public Bound withoutSharding() { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, 1, "", validate); + } + + /** + * Returns a transform for writing to text files that's like this one + * but that uses the given {@link Coder Coder} to encode each of + * the elements of the input {@link PCollection PCollection} into an + * output text line. Does not modify this object. + * + * @param the type of the elements of the input {@link PCollection} + */ + public Bound withCoder(Coder coder) { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate, validate); + } + + /** + * Returns a transform for writing to text files that's like this one but + * that has GCS output path validation on pipeline creation disabled. + * + *

    This can be useful in the case where the GCS output location does + * not exist at the pipeline creation time, but is expected to be + * available at execution time. + * + *

    Does not modify this object. + */ + public Bound withoutValidation() { + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, + shardTemplate, false); + } + + @Override + public PDone apply(PCollection input) { + if (filenamePrefix == null) { + throw new IllegalStateException( + "need to set the filename prefix of a TextIO.Write transform"); + } + + // Note that custom sinks currently do not expose sharding controls. + // Thus pipeline runner writers need to individually add support internally to + // apply user requested sharding limits. + return input.apply("Write", com.google.cloud.dataflow.sdk.io.Write.to( + new TextSink<>( + filenamePrefix, filenameSuffix, shardTemplate, coder))); + } + + /** + * Returns the current shard name template string. + */ + public String getShardNameTemplate() { + return shardTemplate; + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + + public String getFilenamePrefix() { + return filenamePrefix; + } + + public String getShardTemplate() { + return shardTemplate; + } + + public int getNumShards() { + return numShards; + } + + public String getFilenameSuffix() { + return filenameSuffix; + } + + public Coder getCoder() { + return coder; + } + + public boolean needsValidation() { + return validate; + } + } + } + + /** + * Possible text file compression types. + */ + public static enum CompressionType { + /** + * Automatically determine the compression type based on filename extension. + */ + AUTO(""), + /** + * Uncompressed (i.e., may be split). + */ + UNCOMPRESSED(""), + /** + * GZipped. + */ + GZIP(".gz"), + /** + * BZipped. + */ + BZIP2(".bz2"); + + private String filenameSuffix; + + private CompressionType(String suffix) { + this.filenameSuffix = suffix; + } + + /** + * Determine if a given filename matches a compression type based on its extension. + * @param filename the filename to match + * @return true iff the filename ends with the compression type's known extension. + */ + public boolean matches(String filename) { + return filename.toLowerCase().endsWith(filenameSuffix.toLowerCase()); + } + } + + // Pattern which matches old-style shard output patterns, which are now + // disallowed. + private static final Pattern SHARD_OUTPUT_PATTERN = Pattern.compile("@([0-9]+|\\*)"); + + private static void validateOutputComponent(String partialFilePattern) { + Preconditions.checkArgument( + !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(), + "Output name components are not allowed to contain @* or @N patterns: " + + partialFilePattern); + } + + ////////////////////////////////////////////////////////////////////////////// + + /** Disable construction of utility class. */ + private TextIO() {} + + /** + * A {@link FileBasedSource} which can decode records delimited by new line characters. + * + *

    This source splits the data into records using {@code UTF-8} {@code \n}, {@code \r}, or + * {@code \r\n} as the delimiter. This source is not strict and supports decoding the last record + * even if it is not delimited. Finally, no records are decoded if the stream is empty. + * + *

    This source supports reading from any arbitrary byte position within the stream. If the + * starting position is not {@code 0}, then bytes are skipped until the first delimiter is found + * representing the beginning of the first record to be decoded. + */ + @VisibleForTesting + static class TextSource extends FileBasedSource { + /** The Coder to use to decode each line. */ + private final Coder coder; + + @VisibleForTesting + TextSource(String fileSpec, Coder coder) { + super(fileSpec, 1L); + this.coder = coder; + } + + private TextSource(String fileName, long start, long end, Coder coder) { + super(fileName, 1L, start, end); + this.coder = coder; + } + + @Override + protected FileBasedSource createForSubrangeOfFile(String fileName, long start, long end) { + return new TextSource<>(fileName, start, end, coder); + } + + @Override + protected FileBasedReader createSingleFileReader(PipelineOptions options) { + return new TextBasedReader<>(this); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSource.FileBasedReader FileBasedReader} + * which can decode records delimited by new line characters. + * + * See {@link TextSource} for further details. + */ + @VisibleForTesting + static class TextBasedReader extends FileBasedReader { + private static final int READ_BUFFER_SIZE = 8192; + private final Coder coder; + private final ByteBuffer readBuffer = ByteBuffer.allocate(READ_BUFFER_SIZE); + private ByteString buffer; + private int startOfSeparatorInBuffer; + private int endOfSeparatorInBuffer; + private long startOfNextRecord; + private boolean eof; + private boolean elementIsPresent; + private T currentValue; + private ReadableByteChannel inChannel; + + private TextBasedReader(TextSource source) { + super(source); + coder = source.coder; + buffer = ByteString.EMPTY; + } + + @Override + protected long getCurrentOffset() throws NoSuchElementException { + if (!elementIsPresent) { + throw new NoSuchElementException(); + } + return startOfNextRecord; + } + + @Override + public T getCurrent() throws NoSuchElementException { + if (!elementIsPresent) { + throw new NoSuchElementException(); + } + return currentValue; + } + + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + this.inChannel = channel; + // If the first offset is greater than zero, we need to skip bytes until we see our + // first separator. + if (getCurrentSource().getStartOffset() > 0) { + checkState(channel instanceof SeekableByteChannel, + "%s only supports reading from a SeekableByteChannel when given a start offset" + + " greater than 0.", TextSource.class.getSimpleName()); + long requiredPosition = getCurrentSource().getStartOffset() - 1; + ((SeekableByteChannel) channel).position(requiredPosition); + findSeparatorBounds(); + buffer = buffer.substring(endOfSeparatorInBuffer); + startOfNextRecord = requiredPosition + endOfSeparatorInBuffer; + endOfSeparatorInBuffer = 0; + startOfSeparatorInBuffer = 0; + } + } + + /** + * Locates the start position and end position of the next delimiter. Will + * consume the channel till either EOF or the delimiter bounds are found. + * + *

    This fills the buffer and updates the positions as follows: + *

    {@code
    +       * ------------------------------------------------------
    +       * | element bytes | delimiter bytes | unconsumed bytes |
    +       * ------------------------------------------------------
    +       * 0            start of          end of              buffer
    +       *              separator         separator           size
    +       *              in buffer         in buffer
    +       * }
    + */ + private void findSeparatorBounds() throws IOException { + int bytePositionInBuffer = 0; + while (true) { + if (!tryToEnsureNumberOfBytesInBuffer(bytePositionInBuffer + 1)) { + startOfSeparatorInBuffer = endOfSeparatorInBuffer = bytePositionInBuffer; + break; + } + + byte currentByte = buffer.byteAt(bytePositionInBuffer); + + if (currentByte == '\n') { + startOfSeparatorInBuffer = bytePositionInBuffer; + endOfSeparatorInBuffer = startOfSeparatorInBuffer + 1; + break; + } else if (currentByte == '\r') { + startOfSeparatorInBuffer = bytePositionInBuffer; + endOfSeparatorInBuffer = startOfSeparatorInBuffer + 1; + + if (tryToEnsureNumberOfBytesInBuffer(bytePositionInBuffer + 2)) { + currentByte = buffer.byteAt(bytePositionInBuffer + 1); + if (currentByte == '\n') { + endOfSeparatorInBuffer += 1; + } + } + break; + } + + // Move to the next byte in buffer. + bytePositionInBuffer += 1; + } + } + + @Override + protected boolean readNextRecord() throws IOException { + startOfNextRecord += endOfSeparatorInBuffer; + findSeparatorBounds(); + + // If we have reached EOF file and consumed all of the buffer then we know + // that there are no more records. + if (eof && buffer.size() == 0) { + elementIsPresent = false; + return false; + } + + decodeCurrentElement(); + return true; + } + + /** + * Decodes the current element updating the buffer to only contain the unconsumed bytes. + * + * This invalidates the currently stored {@code startOfSeparatorInBuffer} and + * {@code endOfSeparatorInBuffer}. + */ + private void decodeCurrentElement() throws IOException { + ByteString dataToDecode = buffer.substring(0, startOfSeparatorInBuffer); + currentValue = coder.decode(dataToDecode.newInput(), Context.OUTER); + elementIsPresent = true; + buffer = buffer.substring(endOfSeparatorInBuffer); + } + + /** + * Returns false if we were unable to ensure the minimum capacity by consuming the channel. + */ + private boolean tryToEnsureNumberOfBytesInBuffer(int minCapacity) throws IOException { + // While we aren't at EOF or haven't fulfilled the minimum buffer capacity, + // attempt to read more bytes. + while (buffer.size() <= minCapacity && !eof) { + eof = inChannel.read(readBuffer) == -1; + readBuffer.flip(); + buffer = buffer.concat(ByteString.copyFrom(readBuffer)); + readBuffer.clear(); + } + // Return true if we were able to honor the minimum buffer capacity request + return buffer.size() >= minCapacity; + } + } + } + + /** + * A {@link FileBasedSink} for text files. Produces text files with the new line separator + * {@code '\n'} represented in {@code UTF-8} format as the record separator. + * Each record (including the last) is terminated. + */ + @VisibleForTesting + static class TextSink extends FileBasedSink { + private final Coder coder; + + @VisibleForTesting + TextSink( + String baseOutputFilename, String extension, String fileNameTemplate, Coder coder) { + super(baseOutputFilename, extension, fileNameTemplate); + this.coder = coder; + } + + @Override + public FileBasedSink.FileBasedWriteOperation createWriteOperation(PipelineOptions options) { + return new TextWriteOperation<>(this, coder); + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriteOperation + * FileBasedWriteOperation} for text files. + */ + private static class TextWriteOperation extends FileBasedWriteOperation { + private final Coder coder; + + private TextWriteOperation(TextSink sink, Coder coder) { + super(sink); + this.coder = coder; + } + + @Override + public FileBasedWriter createWriter(PipelineOptions options) throws Exception { + return new TextWriter<>(this, coder); + } + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriter FileBasedWriter} + * for text files. + */ + private static class TextWriter extends FileBasedWriter { + private static final byte[] NEWLINE = "\n".getBytes(StandardCharsets.UTF_8); + private final Coder coder; + private OutputStream out; + + public TextWriter(FileBasedWriteOperation writeOperation, Coder coder) { + super(writeOperation); + this.mimeType = MimeTypes.TEXT; + this.coder = coder; + } + + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + out = Channels.newOutputStream(channel); + } + + @Override + public void write(T value) throws Exception { + coder.encode(value, out, Context.OUTER); + out.write(NEWLINE); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/UnboundedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/UnboundedSource.java new file mode 100644 index 000000000000..e585151c892a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/UnboundedSource.java @@ -0,0 +1,253 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * A {@link Source} that reads an unbounded amount of input and, because of that, supports + * some additional operations such as checkpointing, watermarks, and record ids. + * + *
      + *
    • Checkpointing allows sources to not re-read the same data again in the case of failures. + *
    • Watermarks allow for downstream parts of the pipeline to know up to what point + * in time the data is complete. + *
    • Record ids allow for efficient deduplication of input records; many streaming sources + * do not guarantee that a given record will only be read a single time. + *
    + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} and + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.Trigger} for more information on + * timestamps and watermarks. + * + * @param Type of records output by this source. + * @param Type of checkpoint marks used by the readers of this source. + */ +public abstract class UnboundedSource< + OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> extends Source { + /** + * Returns a list of {@code UnboundedSource} objects representing the instances of this source + * that should be used when executing the workflow. Each split should return a separate partition + * of the input data. + * + *

    For example, for a source reading from a growing directory of files, each split + * could correspond to a prefix of file names. + * + *

    Some sources are not splittable, such as reading from a single TCP stream. In that + * case, only a single split should be returned. + * + *

    Some data sources automatically partition their data among readers. For these types of + * inputs, {@code n} identical replicas of the top-level source can be returned. + * + *

    The size of the returned list should be as close to {@code desiredNumSplits} + * as possible, but does not have to match exactly. A low number of splits + * will limit the amount of parallelism in the source. + */ + public abstract List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception; + + /** + * Create a new {@link UnboundedReader} to read from this source, resuming from the given + * checkpoint if present. + */ + public abstract UnboundedReader createReader( + PipelineOptions options, @Nullable CheckpointMarkT checkpointMark); + + /** + * Returns a {@link Coder} for encoding and decoding the checkpoints for this source, or + * null if the checkpoints do not need to be durably committed. + */ + @Nullable + public abstract Coder getCheckpointMarkCoder(); + + /** + * Returns whether this source requires explicit deduping. + * + *

    This is needed if the underlying data source can return the same record multiple times, + * such a queuing system with a pull-ack model. Sources where the records read are uniquely + * identified by the persisted state in the CheckpointMark do not need this. + */ + public boolean requiresDeduping() { + return false; + } + + /** + * A marker representing the progress and state of an + * {@link com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader}. + * + *

    For example, this could be offsets in a set of files being read. + */ + public interface CheckpointMark { + /** + * Perform any finalization that needs to happen after a bundle of data read from + * the source has been processed and committed. + * + *

    For example, this could be sending acknowledgement requests to an external + * data source such as Pub/Sub. + * + *

    This may be called from any thread, potentially at the same time as calls to the + * {@code UnboundedReader} that created it. + */ + void finalizeCheckpoint() throws IOException; + } + + /** + * A {@code Reader} that reads an unbounded amount of input. + * + *

    A given {@code UnboundedReader} object will only be accessed by a single thread at once. + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + public abstract static class UnboundedReader extends Source.Reader { + private static final byte[] EMPTY = new byte[0]; + + /** + * Initializes the reader and advances the reader to the first record. + * + *

    This method should be called exactly once. The invocation should occur prior to calling + * {@link #advance} or {@link #getCurrent}. This method may perform expensive operations that + * are needed to initialize the reader. + * + *

    Returns {@code true} if a record was read, {@code false} if there is no more input + * currently available. Future calls to {@link #advance} may return {@code true} once more data + * is available. Regardless of the return value of {@code start}, {@code start} will not be + * called again on the same {@code UnboundedReader} object; it will only be called again when a + * new reader object is constructed for the same source, e.g. on recovery. + */ + @Override + public abstract boolean start() throws IOException; + + /** + * Advances the reader to the next valid record. + * + *

    Returns {@code true} if a record was read, {@code false} if there is no more input + * available. Future calls to {@link #advance} may return {@code true} once more data is + * available. + */ + @Override + public abstract boolean advance() throws IOException; + + /** + * Returns a unique identifier for the current record. This should be the same for each + * instance of the same logical record read from the underlying data source. + * + *

    It is only necessary to override this if {@link #requiresDeduping} has been overridden to + * return true. + * + *

    For example, this could be a hash of the record contents, or a logical ID present in + * the record. If this is generated as a hash of the record contents, it should be at least 16 + * bytes (128 bits) to avoid collisions. + * + *

    This method has the same restrictions on when it can be called as {@link #getCurrent} and + * {@link #getCurrentTimestamp}. + * + * @throws NoSuchElementException if the reader is at the beginning of the input and + * {@link #start} or {@link #advance} wasn't called, or if the last {@link #start} or + * {@link #advance} returned {@code false}. + */ + public byte[] getCurrentRecordId() throws NoSuchElementException { + if (getCurrentSource().requiresDeduping()) { + throw new IllegalStateException( + "getCurrentRecordId() must be overridden if requiresDeduping returns true()"); + } + return EMPTY; + } + + /** + * Returns a timestamp before or at the timestamps of all future elements read by this reader. + * + *

    This can be approximate. If records are read that violate this guarantee, they will be + * considered late, which will affect how they will be processed. See + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} for more information on + * late data and how to handle it. + * + *

    However, this value should be as late as possible. Downstream windows may not be able + * to close until this watermark passes their end. + * + *

    For example, a source may know that the records it reads will be in timestamp order. In + * this case, the watermark can be the timestamp of the last record read. For a + * source that does not have natural timestamps, timestamps can be set to the time of + * reading, in which case the watermark is the current clock time. + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} and + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.Trigger} for more + * information on timestamps and watermarks. + * + *

    May be called after {@link #advance} or {@link #start} has returned false, but not before + * {@link #start} has been called. + */ + public abstract Instant getWatermark(); + + /** + * Returns a {@link CheckpointMark} representing the progress of this {@code UnboundedReader}. + * + *

    The elements read up until this is called will be processed together as a bundle. Once + * the result of this processing has been durably committed, + * {@link CheckpointMark#finalizeCheckpoint} will be called on the {@link CheckpointMark} + * object. + * + *

    The returned object should not be modified. + * + *

    May be called after {@link #advance} or {@link #start} has returned false, but not before + * {@link #start} has been called. + */ + public abstract CheckpointMark getCheckpointMark(); + + /** + * Constant representing an unknown amount of backlog. + */ + public static final long BACKLOG_UNKNOWN = -1L; + + /** + * Returns the size of the backlog of unread data in the underlying data source represented by + * this split of this source. + * + *

    One of this or {@link #getTotalBacklogBytes} should be overridden in order to allow the + * runner to scale the amount of resources allocated to the pipeline. + */ + public long getSplitBacklogBytes() { + return BACKLOG_UNKNOWN; + } + + /** + * Returns the size of the backlog of unread data in the underlying data source represented by + * all splits of this source. + * + *

    One of this or {@link #getSplitBacklogBytes} should be overridden in order to allow the + * runner to scale the amount of resources allocated to the pipeline. + */ + public long getTotalBacklogBytes() { + return BACKLOG_UNKNOWN; + } + + /** + * Returns the {@link UnboundedSource} that created this reader. This will not change over the + * life of the reader. + */ + @Override + public abstract UnboundedSource getCurrentSource(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Write.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Write.java new file mode 100644 index 000000000000..0b78b8384ea6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/Write.java @@ -0,0 +1,213 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.io.Sink.WriteOperation; +import com.google.cloud.dataflow.sdk.io.Sink.Writer; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; + +import org.joda.time.Instant; + +import java.util.UUID; + +/** + * A {@link PTransform} that writes to a {@link Sink}. A write begins with a sequential global + * initialization of a sink, followed by a parallel write, and ends with a sequential finalization + * of the write. The output of a write is {@link PDone}. In the case of an empty PCollection, only + * the global initialization and finalization will be performed. + * + *

    Currently, only batch workflows can contain Write transforms. + * + *

    Example usage: + * + *

    {@code p.apply(Write.to(new MySink(...)));} + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public class Write { + /** + * Creates a Write transform that writes to the given Sink. + */ + public static Bound to(Sink sink) { + return new Bound<>(sink); + } + + /** + * A {@link PTransform} that writes to a {@link Sink}. See {@link Write} and {@link Sink} for + * documentation about writing to Sinks. + */ + public static class Bound extends PTransform, PDone> { + private final Sink sink; + + private Bound(Sink sink) { + this.sink = sink; + } + + @Override + public PDone apply(PCollection input) { + PipelineOptions options = input.getPipeline().getOptions(); + sink.validate(options); + return createWrite(input, sink.createWriteOperation(options)); + } + + /** + * Returns the {@link Sink} associated with this PTransform. + */ + public Sink getSink() { + return sink; + } + + /** + * A write is performed as sequence of three {@link ParDo}'s. + * + *

    In the first, a do-once ParDo is applied to a singleton PCollection containing the Sink's + * {@link WriteOperation}. In this initialization ParDo, {@link WriteOperation#initialize} is + * called. The output of this ParDo is a singleton PCollection + * containing the WriteOperation. + * + *

    This singleton collection containing the WriteOperation is then used as a side input to a + * ParDo over the PCollection of elements to write. In this bundle-writing phase, + * {@link WriteOperation#createWriter} is called to obtain a {@link Writer}. + * {@link Writer#open} and {@link Writer#close} are called in {@link DoFn#startBundle} and + * {@link DoFn#finishBundle}, respectively, and {@link Writer#write} method is called for every + * element in the bundle. The output of this ParDo is a PCollection of writer result + * objects (see {@link Sink} for a description of writer results)-one for each bundle. + * + *

    The final do-once ParDo uses the singleton collection of the WriteOperation as input and + * the collection of writer results as a side-input. In this ParDo, + * {@link WriteOperation#finalize} is called to finalize the write. + * + *

    If the write of any element in the PCollection fails, {@link Writer#close} will be called + * before the exception that caused the write to fail is propagated and the write result will be + * discarded. + * + *

    Since the {@link WriteOperation} is serialized after the initialization ParDo and + * deserialized in the bundle-writing and finalization phases, any state change to the + * WriteOperation object that occurs during initialization is visible in the latter phases. + * However, the WriteOperation is not serialized after the bundle-writing phase. This is why + * implementations should guarantee that {@link WriteOperation#createWriter} does not mutate + * WriteOperation). + */ + private PDone createWrite( + PCollection input, WriteOperation writeOperation) { + Pipeline p = input.getPipeline(); + + // A coder to use for the WriteOperation. + @SuppressWarnings("unchecked") + Coder> operationCoder = + (Coder>) SerializableCoder.of(writeOperation.getClass()); + + // A singleton collection of the WriteOperation, to be used as input to a ParDo to initialize + // the sink. + PCollection> operationCollection = + p.apply(Create.>of(writeOperation).withCoder(operationCoder)); + + // Initialize the resource in a do-once ParDo on the WriteOperation. + operationCollection = operationCollection + .apply("Initialize", ParDo.of( + new DoFn, WriteOperation>() { + @Override + public void processElement(ProcessContext c) throws Exception { + WriteOperation writeOperation = c.element(); + writeOperation.initialize(c.getPipelineOptions()); + // The WriteOperation is also the output of this ParDo, so it can have mutable + // state. + c.output(writeOperation); + } + })) + .setCoder(operationCoder); + + // Create a view of the WriteOperation to be used as a sideInput to the parallel write phase. + final PCollectionView> writeOperationView = + operationCollection.apply(View.>asSingleton()); + + // Perform the per-bundle writes as a ParDo on the input PCollection (with the WriteOperation + // as a side input) and collect the results of the writes in a PCollection. + // There is a dependency between this ParDo and the first (the WriteOperation PCollection + // as a side input), so this will happen after the initial ParDo. + PCollection results = input + .apply("WriteBundles", ParDo.of(new DoFn() { + // Writer that will write the records in this bundle. Lazily + // initialized in processElement. + private Writer writer = null; + + @Override + public void processElement(ProcessContext c) throws Exception { + // Lazily initialize the Writer + if (writer == null) { + WriteOperation writeOperation = c.sideInput(writeOperationView); + writer = writeOperation.createWriter(c.getPipelineOptions()); + writer.open(UUID.randomUUID().toString()); + } + try { + writer.write(c.element()); + } catch (Exception e) { + // Discard write result and close the write. + try { + writer.close(); + } catch (Exception closeException) { + // Do not mask the exception that caused the write to fail. + } + throw e; + } + } + + @Override + public void finishBundle(Context c) throws Exception { + if (writer != null) { + WriteT result = writer.close(); + // Output the result of the write. + c.outputWithTimestamp(result, Instant.now()); + } + } + }).withSideInputs(writeOperationView)) + .setCoder(writeOperation.getWriterResultCoder()) + .apply(Window.into(new GlobalWindows())); + + final PCollectionView> resultsView = + results.apply(View.asIterable()); + + // Finalize the write in another do-once ParDo on the singleton collection containing the + // Writer. The results from the per-bundle writes are given as an Iterable side input. + // The WriteOperation's state is the same as after its initialization in the first do-once + // ParDo. There is a dependency between this ParDo and the parallel write (the writer results + // collection as a side input), so it will happen after the parallel write. + @SuppressWarnings("unused") + final PCollection done = operationCollection + .apply("Finalize", ParDo.of(new DoFn, Integer>() { + @Override + public void processElement(ProcessContext c) throws Exception { + Iterable results = c.sideInput(resultsView); + WriteOperation writeOperation = c.element(); + writeOperation.finalize(results, c.getPipelineOptions()); + } + }).withSideInputs(resultsView)); + return PDone.in(input.getPipeline()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/XmlSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/XmlSink.java new file mode 100644 index 000000000000..b728c0a792f1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/XmlSink.java @@ -0,0 +1,310 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriteOperation; +import com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriter; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Preconditions; + +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import javax.xml.bind.JAXBContext; +import javax.xml.bind.JAXBException; +import javax.xml.bind.Marshaller; + +// CHECKSTYLE.OFF: JavadocStyle +/** + * A {@link Sink} that outputs records as XML-formatted elements. Writes a {@link PCollection} of + * records from JAXB-annotated classes to a single file location. + * + *

    Given a PCollection containing records of type T that can be marshalled to XML elements, this + * Sink will produce a single file consisting of a single root element that contains all of the + * elements in the PCollection. + * + *

    XML Sinks are created with a base filename to write to, a root element name that will be used + * for the root element of the output files, and a class to bind to an XML element. This class + * will be used in the marshalling of records in an input PCollection to their XML representation + * and must be able to be bound using JAXB annotations (checked at pipeline construction time). + * + *

    XML Sinks can be written to using the {@link Write} transform: + * + *

    + * p.apply(Write.to(
    + *      XmlSink.ofRecordClass(Type.class)
    + *          .withRootElementName(root_element)
    + *          .toFilenamePrefix(output_filename)));
    + * 
    + * + *

    For example, consider the following class with JAXB annotations: + * + *

    + *  {@literal @}XmlRootElement(name = "word_count_result")
    + *  {@literal @}XmlType(propOrder = {"word", "frequency"})
    + *  public class WordFrequency {
    + *    private String word;
    + *    private long frequency;
    + *
    + *    public WordFrequency() { }
    + *
    + *    public WordFrequency(String word, long frequency) {
    + *      this.word = word;
    + *      this.frequency = frequency;
    + *    }
    + *
    + *    public void setWord(String word) {
    + *      this.word = word;
    + *    }
    + *
    + *    public void setFrequency(long frequency) {
    + *      this.frequency = frequency;
    + *    }
    + *
    + *    public long getFrequency() {
    + *      return frequency;
    + *    }
    + *
    + *    public String getWord() {
    + *      return word;
    + *    }
    + *  }
    + * 
    + * + *

    The following will produce XML output with a root element named "words" from a PCollection of + * WordFrequency objects: + *

    + * p.apply(Write.to(
    + *  XmlSink.ofRecordClass(WordFrequency.class)
    + *      .withRootElement("words")
    + *      .toFilenamePrefix(output_file)));
    + * 
    + * + *

    The output of which will look like: + *

    + * {@code
    + * 
    + *
    + *  
    + *    decreased
    + *    1
    + *  
    + *
    + *  
    + *    War
    + *    4
    + *  
    + *
    + *  
    + *    empress'
    + *    14
    + *  
    + *
    + *  
    + *    stoops
    + *    6
    + *  
    + *
    + *  ...
    + * 
    + * }
    + */ +// CHECKSTYLE.ON: JavadocStyle +@SuppressWarnings("checkstyle:javadocstyle") +public class XmlSink { + protected static final String XML_EXTENSION = "xml"; + + /** + * Returns a builder for an XmlSink. You'll need to configure the class to bind, the root + * element name, and the output file prefix with {@link Bound#ofRecordClass}, {@link + * Bound#withRootElement}, and {@link Bound#toFilenamePrefix}, respectively. + */ + public static Bound write() { + return new Bound<>(null, null, null); + } + + /** + * Returns an XmlSink that writes objects as XML entities. + * + *

    Output files will have the name {@literal {baseOutputFilename}-0000i-of-0000n.xml} where n + * is the number of output bundles that the Dataflow service divides the output into. + * + * @param klass the class of the elements to write. + * @param rootElementName the enclosing root element. + * @param baseOutputFilename the output filename prefix. + */ + public static Bound writeOf( + Class klass, String rootElementName, String baseOutputFilename) { + return new Bound<>(klass, rootElementName, baseOutputFilename); + } + + /** + * A {@link FileBasedSink} that writes objects as XML elements. + */ + public static class Bound extends FileBasedSink { + final Class classToBind; + final String rootElementName; + + private Bound(Class classToBind, String rootElementName, String baseOutputFilename) { + super(baseOutputFilename, XML_EXTENSION); + this.classToBind = classToBind; + this.rootElementName = rootElementName; + } + + /** + * Returns an XmlSink that writes objects of the class specified as XML elements. + * + *

    The specified class must be able to be used to create a JAXB context. + */ + public Bound ofRecordClass(Class classToBind) { + return new Bound<>(classToBind, rootElementName, baseOutputFilename); + } + + /** + * Returns an XmlSink that writes to files with the given prefix. + * + *

    Output files will have the name {@literal {filenamePrefix}-0000i-of-0000n.xml} where n is + * the number of output bundles that the Dataflow service divides the output into. + */ + public Bound toFilenamePrefix(String baseOutputFilename) { + return new Bound<>(classToBind, rootElementName, baseOutputFilename); + } + + /** + * Returns an XmlSink that writes XML files with an enclosing root element of the + * supplied name. + */ + public Bound withRootElement(String rootElementName) { + return new Bound<>(classToBind, rootElementName, baseOutputFilename); + } + + /** + * Validates that the root element, class to bind to a JAXB context, and filenamePrefix have + * been set and that the class can be bound in a JAXB context. + */ + @Override + public void validate(PipelineOptions options) { + Preconditions.checkNotNull(classToBind, "Missing a class to bind to a JAXB context."); + Preconditions.checkNotNull(rootElementName, "Missing a root element name."); + Preconditions.checkNotNull(baseOutputFilename, "Missing a filename to write to."); + try { + JAXBContext.newInstance(classToBind); + } catch (JAXBException e) { + throw new RuntimeException("Error binding classes to a JAXB Context.", e); + } + } + + /** + * Creates an {@link XmlWriteOperation}. + */ + @Override + public XmlWriteOperation createWriteOperation(PipelineOptions options) { + return new XmlWriteOperation<>(this); + } + } + + /** + * {@link Sink.WriteOperation} for XML {@link Sink}s. + */ + protected static final class XmlWriteOperation extends FileBasedWriteOperation { + public XmlWriteOperation(XmlSink.Bound sink) { + super(sink); + } + + /** + * Creates a {@link XmlWriter} with a marshaller for the type it will write. + */ + @Override + public XmlWriter createWriter(PipelineOptions options) throws Exception { + JAXBContext context; + Marshaller marshaller; + context = JAXBContext.newInstance(getSink().classToBind); + marshaller = context.createMarshaller(); + marshaller.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.TRUE); + marshaller.setProperty(Marshaller.JAXB_FRAGMENT, Boolean.TRUE); + marshaller.setProperty(Marshaller.JAXB_ENCODING, "UTF-8"); + return new XmlWriter<>(this, marshaller); + } + + /** + * Return the XmlSink.Bound for this write operation. + */ + @Override + public XmlSink.Bound getSink() { + return (XmlSink.Bound) super.getSink(); + } + } + + /** + * A {@link Sink.Writer} that can write objects as XML elements. + */ + protected static final class XmlWriter extends FileBasedWriter { + final Marshaller marshaller; + private OutputStream os = null; + + public XmlWriter(XmlWriteOperation writeOperation, Marshaller marshaller) { + super(writeOperation); + this.marshaller = marshaller; + } + + /** + * Creates the output stream that elements will be written to. + */ + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + os = Channels.newOutputStream(channel); + } + + /** + * Writes the root element opening tag. + */ + @Override + protected void writeHeader() throws Exception { + String rootElementName = getWriteOperation().getSink().rootElementName; + os.write(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "<" + rootElementName + ">\n")); + } + + /** + * Writes the root element closing tag. + */ + @Override + protected void writeFooter() throws Exception { + String rootElementName = getWriteOperation().getSink().rootElementName; + os.write(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "\n")); + } + + /** + * Writes a value to the stream. + */ + @Override + public void write(T value) throws Exception { + marshaller.marshal(value, os); + } + + /** + * Return the XmlWriteOperation this write belongs to. + */ + @Override + public XmlWriteOperation getWriteOperation() { + return (XmlWriteOperation) super.getWriteOperation(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/XmlSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/XmlSource.java new file mode 100644 index 000000000000..1ead39187d61 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/XmlSource.java @@ -0,0 +1,541 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.JAXBCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.common.base.Preconditions; + +import org.codehaus.stax2.XMLInputFactory2; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.SequenceInputStream; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.charset.StandardCharsets; +import java.util.NoSuchElementException; + +import javax.xml.bind.JAXBContext; +import javax.xml.bind.JAXBElement; +import javax.xml.bind.JAXBException; +import javax.xml.bind.Unmarshaller; +import javax.xml.bind.ValidationEvent; +import javax.xml.bind.ValidationEventHandler; +import javax.xml.stream.FactoryConfigurationError; +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamConstants; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; + +// CHECKSTYLE.OFF: JavadocStyle +/** + * A source that can be used to read XML files. This source reads one or more + * XML files and creates a {@code PCollection} of a given type. An Dataflow read transform can be + * created by passing an {@code XmlSource} object to {@code Read.from()}. Please note the + * example given below. + * + *

    The XML file must be of the following form, where {@code root} and {@code record} are XML + * element names that are defined by the user: + * + *

    + * {@code
    + * 
    + *  ... 
    + *  ... 
    + *  ... 
    + * ...
    + *  ... 
    + * 
    + * }
    + * 
    + * + *

    Basically, the XML document should contain a single root element with an inner list consisting + * entirely of record elements. The records may contain arbitrary XML content; however, that content + * must not contain the start {@code } or end {@code } tags. This + * restriction enables reading from large XML files in parallel from different offsets in the file. + * + *

    Root and/or record elements may additionally contain an arbitrary number of XML attributes. + * Additionally users must provide a class of a JAXB annotated Java type that can be used convert + * records into Java objects and vice versa using JAXB marshalling/unmarshalling mechanisms. Reading + * the source will generate a {@code PCollection} of the given JAXB annotated Java type. + * Optionally users may provide a minimum size of a bundle that should be created for the source. + * + *

    The following example shows how to read from {@link XmlSource} in a Dataflow pipeline: + * + *

    + * {@code
    + * XmlSource source = XmlSource.from(file.toPath().toString())
    + *     .withRootElement("root")
    + *     .withRecordElement("record")
    + *     .withRecordClass(Record.class);
    + * PCollection output = p.apply(Read.from(source));
    + * }
    + * 
    + * + *

    Currently, only XML files that use single-byte characters are supported. Using a file that + * contains multi-byte characters may result in data loss or duplication. + * + *

    To use {@link XmlSource}: + *

      + *
    1. Explicitly declare a dependency on org.codehaus.woodstox:stax2-api
    2. + *
    3. Include a compatible implementation on the classpath at run-time, + * such as org.codehaus.woodstox:woodstox-core-asl
    4. + *
    + * + *

    These dependencies have been declared as optional in Maven sdk/pom.xml file of + * Google Cloud Dataflow. + * + *

    Permissions

    + * Permission requirements depend on the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner PipelineRunner} that is + * used to execute the Dataflow job. Please refer to the documentation of corresponding + * {@link PipelineRunner PipelineRunners} for more details. + * + * @param Type of the objects that represent the records of the XML file. The + * {@code PCollection} generated by this source will be of this type. + */ +// CHECKSTYLE.ON: JavadocStyle +public class XmlSource extends FileBasedSource { + + private static final String XML_VERSION = "1.1"; + private static final int DEFAULT_MIN_BUNDLE_SIZE = 8 * 1024; + private final String rootElement; + private final String recordElement; + private final Class recordClass; + + /** + * Creates an XmlSource for a single XML file or a set of XML files defined by a Java "glob" file + * pattern. Each XML file should be of the form defined in {@link XmlSource}. + */ + public static XmlSource from(String fileOrPatternSpec) { + return new XmlSource<>(fileOrPatternSpec, DEFAULT_MIN_BUNDLE_SIZE, null, null, null); + } + + /** + * Sets name of the root element of the XML document. This will be used to create a valid starting + * root element when initiating a bundle of records created from an XML document. This is a + * required parameter. + */ + public XmlSource withRootElement(String rootElement) { + return new XmlSource<>( + getFileOrPatternSpec(), getMinBundleSize(), rootElement, recordElement, recordClass); + } + + /** + * Sets name of the record element of the XML document. This will be used to determine offset of + * the first record of a bundle created from the XML document. This is a required parameter. + */ + public XmlSource withRecordElement(String recordElement) { + return new XmlSource<>( + getFileOrPatternSpec(), getMinBundleSize(), rootElement, recordElement, recordClass); + } + + /** + * Sets a JAXB annotated class that can be populated using a record of the provided XML file. This + * will be used when unmarshalling record objects from the XML file. This is a required + * parameter. + */ + public XmlSource withRecordClass(Class recordClass) { + return new XmlSource<>( + getFileOrPatternSpec(), getMinBundleSize(), rootElement, recordElement, recordClass); + } + + /** + * Sets a parameter {@code minBundleSize} for the minimum bundle size of the source. Please refer + * to {@link OffsetBasedSource} for the definition of minBundleSize. This is an optional + * parameter. + */ + public XmlSource withMinBundleSize(long minBundleSize) { + return new XmlSource<>( + getFileOrPatternSpec(), minBundleSize, rootElement, recordElement, recordClass); + } + + private XmlSource(String fileOrPattern, long minBundleSize, String rootElement, + String recordElement, Class recordClass) { + super(fileOrPattern, minBundleSize); + this.rootElement = rootElement; + this.recordElement = recordElement; + this.recordClass = recordClass; + } + + private XmlSource(String fileOrPattern, long minBundleSize, long startOffset, long endOffset, + String rootElement, String recordElement, Class recordClass) { + super(fileOrPattern, minBundleSize, startOffset, endOffset); + this.rootElement = rootElement; + this.recordElement = recordElement; + this.recordClass = recordClass; + } + + @Override + protected FileBasedSource createForSubrangeOfFile(String fileName, long start, long end) { + return new XmlSource( + fileName, getMinBundleSize(), start, end, rootElement, recordElement, recordClass); + } + + @Override + protected FileBasedReader createSingleFileReader(PipelineOptions options) { + return new XMLReader(this); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull( + rootElement, "rootElement is null. Use builder method withRootElement() to set this."); + Preconditions.checkNotNull( + recordElement, + "recordElement is null. Use builder method withRecordElement() to set this."); + Preconditions.checkNotNull( + recordClass, "recordClass is null. Use builder method withRecordClass() to set this."); + } + + @Override + public Coder getDefaultOutputCoder() { + return JAXBCoder.of(recordClass); + } + + public String getRootElement() { + return rootElement; + } + + public String getRecordElement() { + return recordElement; + } + + public Class getRecordClass() { + return recordClass; + } + + /** + * A {@link Source.Reader} for reading JAXB annotated Java objects from an XML file. The XML + * file should be of the form defined at {@link XmlSource}. + * + *

    Timestamped values are currently unsupported - all values implicitly have the timestamp + * of {@code BoundedWindow.TIMESTAMP_MIN_VALUE}. + * + * @param Type of objects that will be read by the reader. + */ + private static class XMLReader extends FileBasedReader { + // The amount of bytes read from the channel to memory when determining the starting offset of + // the first record in a bundle. After matching to starting offset of the first record the + // remaining bytes read to this buffer and the bytes still not read from the channel are used to + // create the XML parser. + private static final int BUF_SIZE = 1024; + + // This should be the maximum number of bytes a character will encode to, for any encoding + // supported by XmlSource. Currently this is set to 4 since UTF-8 characters may be + // four bytes. + private static final int MAX_CHAR_BYTES = 4; + + // In order to support reading starting in the middle of an XML file, we construct an imaginary + // well-formed document (a header and root tag followed by the contents of the input starting at + // the record boundary) and feed it to the parser. Because of this, the offset reported by the + // XML parser is not the same as offset in the original file. They differ by a constant amount: + // offsetInOriginalFile = parser.getLocation().getCharacterOffset() + parserBaseOffset; + // Note that this is true only for files with single-byte characters. + // It appears that, as of writing, there does not exist a Java XML parser capable of correctly + // reporting byte offsets of elements in the presence of multi-byte characters. + private long parserBaseOffset = 0; + private boolean readingStarted = false; + + // If true, the current bundle does not contain any records. + private boolean emptyBundle = false; + + private Unmarshaller jaxbUnmarshaller = null; + private XMLStreamReader parser = null; + + private T currentRecord = null; + + // Byte offset of the current record in the XML file provided when creating the source. + private long currentByteOffset = 0; + + public XMLReader(XmlSource source) { + super(source); + + // Set up a JAXB Unmarshaller that can be used to unmarshall record objects. + try { + JAXBContext jaxbContext = JAXBContext.newInstance(getCurrentSource().recordClass); + jaxbUnmarshaller = jaxbContext.createUnmarshaller(); + + // Throw errors if validation fails. JAXB by default ignores validation errors. + jaxbUnmarshaller.setEventHandler(new ValidationEventHandler() { + @Override + public boolean handleEvent(ValidationEvent event) { + throw new RuntimeException(event.getMessage(), event.getLinkedException()); + } + }); + } catch (JAXBException e) { + throw new RuntimeException(e); + } + } + + @Override + public synchronized XmlSource getCurrentSource() { + return (XmlSource) super.getCurrentSource(); + } + + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + // This method determines the correct starting offset of the first record by reading bytes + // from the ReadableByteChannel. This implementation does not need the channel to be a + // SeekableByteChannel. + // The method tries to determine the first record element in the byte channel. The first + // record must start with the characters "' character + // * '/' character (to support empty records). + // + // After this match this method creates the XML parser for parsing the XML document, + // feeding it a fake document consisting of an XML header and the tag followed + // by the contents of channel starting from tag may be never + // closed. + + // This stores any bytes that should be used prior to the remaining bytes of the channel when + // creating an XML parser object. + ByteArrayOutputStream preambleByteBuffer = new ByteArrayOutputStream(); + // A dummy declaration and root for the document with proper XML version and encoding. Without + // this XML parsing may fail or may produce incorrect results. + + byte[] dummyStartDocumentBytes = + ("" + + "<" + getCurrentSource().rootElement + ">").getBytes(StandardCharsets.UTF_8); + preambleByteBuffer.write(dummyStartDocumentBytes); + // Gets the byte offset (in the input file) of the first record in ReadableByteChannel. This + // method returns the offset and stores any bytes that should be used when creating the XML + // parser in preambleByteBuffer. + long offsetInFileOfRecordElement = + getFirstOccurenceOfRecordElement(channel, preambleByteBuffer); + if (offsetInFileOfRecordElement < 0) { + // Bundle has no records. So marking this bundle as an empty bundle. + emptyBundle = true; + return; + } else { + byte[] preambleBytes = preambleByteBuffer.toByteArray(); + currentByteOffset = offsetInFileOfRecordElement; + setUpXMLParser(channel, preambleBytes); + parserBaseOffset = offsetInFileOfRecordElement - dummyStartDocumentBytes.length; + } + readingStarted = true; + } + + // Gets the first occurrence of the next record within the given ReadableByteChannel. Puts + // any bytes read past the starting offset of the next record back to the preambleByteBuffer. + // If a record is found, returns the starting offset of the record, otherwise + // returns -1. + private long getFirstOccurenceOfRecordElement( + ReadableByteChannel channel, ByteArrayOutputStream preambleByteBuffer) throws IOException { + int byteIndexInRecordElementToMatch = 0; + // Index of the byte in the string " 0) { + buf.flip(); + while (buf.hasRemaining()) { + offsetInFileOfCurrentByte++; + byte b = buf.get(); + boolean reset = false; + if (recordStartBytesMatched) { + // We already matched "..." + // * "..." + // * "' || c == '/') { + fullyMatched = true; + // Add the recordStartBytes and charBytes to preambleByteBuffer since these were + // already read from the channel. + preambleByteBuffer.write(recordStartBytes); + preambleByteBuffer.write(charBytes); + // Also add the rest of the current buffer to preambleByteBuffer. + while (buf.hasRemaining()) { + preambleByteBuffer.write(buf.get()); + } + break outer; + } else { + // Matching was unsuccessful. Reset the buffer to include bytes read for the char. + ByteBuffer newbuf = ByteBuffer.allocate(BUF_SIZE); + newbuf.put(charBytes); + offsetInFileOfCurrentByte -= charBytes.length; + while (buf.hasRemaining()) { + newbuf.put(buf.get()); + } + newbuf.flip(); + buf = newbuf; + + // Ignore everything and try again starting from the current buffer. + reset = true; + } + } else if (b == recordStartBytes[byteIndexInRecordElementToMatch]) { + // Next byte matched. + if (!matchStarted) { + // Match was for the first byte, record the starting offset. + matchStarted = true; + startingOffsetInFileOfCurrentMatch = offsetInFileOfCurrentByte; + } + byteIndexInRecordElementToMatch++; + } else { + // Not a match. Ignore everything and try again starting at current point. + reset = true; + } + if (reset) { + // Clear variables and try to match starting from the next byte. + byteIndexInRecordElementToMatch = 0; + startingOffsetInFileOfCurrentMatch = -1; + matchStarted = false; + recordStartBytesMatched = false; + charBytes = new byte[MAX_CHAR_BYTES]; + charBytesFound = 0; + } + if (byteIndexInRecordElementToMatch == recordStartBytes.length) { + // " jb = jaxbUnmarshaller.unmarshal(parser, getCurrentSource().recordClass); + currentRecord = jb.getValue(); + return true; + } catch (JAXBException | XMLStreamException e) { + throw new IOException(e); + } + } + + @Override + public T getCurrent() throws NoSuchElementException { + if (!readingStarted) { + throw new NoSuchElementException(); + } + return currentRecord; + } + + @Override + protected boolean isAtSplitPoint() { + // Every record is at a split point. + return true; + } + + @Override + protected long getCurrentOffset() { + return currentByteOffset; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java new file mode 100644 index 000000000000..c3f233f24990 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java @@ -0,0 +1,987 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.bigtable; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.bigtable.v1.Mutation; +import com.google.bigtable.v1.Row; +import com.google.bigtable.v1.RowFilter; +import com.google.bigtable.v1.SampleRowKeysResponse; +import com.google.cloud.bigtable.config.BigtableOptions; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Proto2Coder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; +import com.google.cloud.dataflow.sdk.io.Sink.WriteOperation; +import com.google.cloud.dataflow.sdk.io.Sink.Writer; +import com.google.cloud.dataflow.sdk.io.range.ByteKey; +import com.google.cloud.dataflow.sdk.io.range.ByteKeyRange; +import com.google.cloud.dataflow.sdk.io.range.ByteKeyRangeTracker; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.DataflowReleaseInfo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.annotation.Nullable; + +/** + * A bounded source and sink for Google Cloud Bigtable. + * + *

    For more information, see the online documentation at + * Google Cloud Bigtable. + * + *

    Reading from Cloud Bigtable

    + * + *

    The Bigtable source returns a set of rows from a single table, returning a + * {@code PCollection<Row>}. + * + *

    To configure a Cloud Bigtable source, you must supply a table id and a {@link BigtableOptions} + * or builder configured with the project and other information necessary to identify the + * Bigtable cluster. A {@link RowFilter} may also optionally be specified using + * {@link BigtableIO.Read#withRowFilter}. For example: + * + *

    {@code
    + * BigtableOptions.Builder optionsBuilder =
    + *     new BigtableOptions.Builder()
    + *         .setProjectId("project")
    + *         .setClusterId("cluster")
    + *         .setZoneId("zone");
    + *
    + * Pipeline p = ...;
    + *
    + * // Scan the entire table.
    + * p.apply("read",
    + *     BigtableIO.read()
    + *         .withBigtableOptions(optionsBuilder)
    + *         .withTableId("table"));
    + *
    + * // Scan a subset of rows that match the specified row filter.
    + * p.apply("filtered read",
    + *     BigtableIO.read()
    + *         .withBigtableOptions(optionsBuilder)
    + *         .withTableId("table")
    + *         .withRowFilter(filter));
    + * }
    + * + *

    Writing to Cloud Bigtable

    + * + *

    The Bigtable sink executes a set of row mutations on a single table. It takes as input a + * {@link PCollection PCollection<KV<ByteString, Iterable<Mutation>>>}, where the + * {@link ByteString} is the key of the row being mutated, and each {@link Mutation} represents an + * idempotent transformation to that row. + * + *

    To configure a Cloud Bigtable sink, you must supply a table id and a {@link BigtableOptions} + * or builder configured with the project and other information necessary to identify the + * Bigtable cluster, for example: + * + *

    {@code
    + * BigtableOptions.Builder optionsBuilder =
    + *     new BigtableOptions.Builder()
    + *         .setProjectId("project")
    + *         .setClusterId("cluster")
    + *         .setZoneId("zone");
    + *
    + * PCollection>> data = ...;
    + *
    + * data.apply("write",
    + *     BigtableIO.write()
    + *         .withBigtableOptions(optionsBuilder)
    + *         .withTableId("table"));
    + * }
    + * + *

    Experimental

    + * + *

    This connector for Cloud Bigtable is considered experimental and may break or receive + * backwards-incompatible changes in future versions of the Cloud Dataflow SDK. Cloud Bigtable is + * in Beta, and thus it may introduce breaking changes in future revisions of its service or APIs. + * + *

    Permissions

    + * + *

    Permission requirements depend on the {@link PipelineRunner} that is used to execute the + * Dataflow job. Please refer to the documentation of corresponding + * {@link PipelineRunner PipelineRunners} for more details. + */ +@Experimental +public class BigtableIO { + private static final Logger logger = LoggerFactory.getLogger(BigtableIO.class); + + /** + * Creates an uninitialized {@link BigtableIO.Read}. Before use, the {@code Read} must be + * initialized with a + * {@link BigtableIO.Read#withBigtableOptions(BigtableOptions) BigtableOptions} that specifies + * the source Cloud Bigtable cluster, and a {@link BigtableIO.Read#withTableId tableId} that + * specifies which table to read. A {@link RowFilter} may also optionally be specified using + * {@link BigtableIO.Read#withRowFilter}. + */ + @Experimental + public static Read read() { + return new Read(null, "", null, null); + } + + /** + * Creates an uninitialized {@link BigtableIO.Write}. Before use, the {@code Write} must be + * initialized with a + * {@link BigtableIO.Write#withBigtableOptions(BigtableOptions) BigtableOptions} that specifies + * the destination Cloud Bigtable cluster, and a {@link BigtableIO.Write#withTableId tableId} that + * specifies which table to write. + */ + @Experimental + public static Write write() { + return new Write(null, "", null); + } + + /** + * A {@link PTransform} that reads from Google Cloud Bigtable. See the class-level Javadoc on + * {@link BigtableIO} for more information. + * + * @see BigtableIO + */ + @Experimental + public static class Read extends PTransform> { + /** + * Returns a new {@link BigtableIO.Read} that will read from the Cloud Bigtable cluster + * indicated by the given options, and using any other specified customizations. + * + *

    Does not modify this object. + */ + public Read withBigtableOptions(BigtableOptions options) { + checkNotNull(options, "options"); + return withBigtableOptions(options.toBuilder()); + } + + /** + * Returns a new {@link BigtableIO.Read} that will read from the Cloud Bigtable cluster + * indicated by the given options, and using any other specified customizations. + * + *

    Clones the given {@link BigtableOptions} builder so that any further changes + * will have no effect on the returned {@link BigtableIO.Read}. + * + *

    Does not modify this object. + */ + public Read withBigtableOptions(BigtableOptions.Builder optionsBuilder) { + checkNotNull(optionsBuilder, "optionsBuilder"); + // TODO: is there a better way to clone a Builder? Want it to be immune from user changes. + BigtableOptions.Builder clonedBuilder = optionsBuilder.build().toBuilder(); + BigtableOptions optionsWithAgent = clonedBuilder.setUserAgent(getUserAgent()).build(); + return new Read(optionsWithAgent, tableId, filter, bigtableService); + } + + /** + * Returns a new {@link BigtableIO.Read} that will filter the rows read from Cloud Bigtable + * using the given row filter. + * + *

    Does not modify this object. + */ + Read withRowFilter(RowFilter filter) { + checkNotNull(filter, "filter"); + return new Read(options, tableId, filter, bigtableService); + } + + /** + * Returns a new {@link BigtableIO.Read} that will read from the specified table. + * + *

    Does not modify this object. + */ + public Read withTableId(String tableId) { + checkNotNull(tableId, "tableId"); + return new Read(options, tableId, filter, bigtableService); + } + + /** + * Returns the Google Cloud Bigtable cluster being read from, and other parameters. + */ + public BigtableOptions getBigtableOptions() { + return options; + } + + /** + * Returns the table being read from. + */ + public String getTableId() { + return tableId; + } + + @Override + public PCollection apply(PInput input) { + BigtableSource source = + new BigtableSource(getBigtableService(), tableId, filter, ByteKeyRange.ALL_KEYS, null); + return input.getPipeline().apply(com.google.cloud.dataflow.sdk.io.Read.from(source)); + } + + @Override + public void validate(PInput input) { + checkArgument(options != null, "BigtableOptions not specified"); + checkArgument(!tableId.isEmpty(), "Table ID not specified"); + try { + checkArgument( + getBigtableService().tableExists(tableId), "Table %s does not exist", tableId); + } catch (IOException e) { + logger.warn("Error checking whether table {} exists; proceeding.", tableId, e); + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(Read.class) + .add("options", options) + .add("tableId", tableId) + .add("filter", filter) + .toString(); + } + + ///////////////////////////////////////////////////////////////////////////////////////// + /** + * Used to define the Cloud Bigtable cluster and any options for the networking layer. + * Cannot actually be {@code null} at validation time, but may start out {@code null} while + * source is being built. + */ + @Nullable private final BigtableOptions options; + private final String tableId; + @Nullable private final RowFilter filter; + @Nullable private final BigtableService bigtableService; + + private Read( + @Nullable BigtableOptions options, + String tableId, + @Nullable RowFilter filter, + @Nullable BigtableService bigtableService) { + this.options = options; + this.tableId = checkNotNull(tableId, "tableId"); + this.filter = filter; + this.bigtableService = bigtableService; + } + + /** + * Returns a new {@link BigtableIO.Read} that will read using the given Cloud Bigtable + * service implementation. + * + *

    This is used for testing. + * + *

    Does not modify this object. + */ + Read withBigtableService(BigtableService bigtableService) { + checkNotNull(bigtableService, "bigtableService"); + return new Read(options, tableId, filter, bigtableService); + } + + /** + * Helper function that either returns the mock Bigtable service supplied by + * {@link #withBigtableService} or creates and returns an implementation that talks to + * {@code Cloud Bigtable}. + */ + private BigtableService getBigtableService() { + if (bigtableService != null) { + return bigtableService; + } + return new BigtableServiceImpl(options); + } + } + + /** + * A {@link PTransform} that writes to Google Cloud Bigtable. See the class-level Javadoc on + * {@link BigtableIO} for more information. + * + * @see BigtableIO + */ + @Experimental + public static class Write + extends PTransform>>, PDone> { + /** + * Used to define the Cloud Bigtable cluster and any options for the networking layer. + * Cannot actually be {@code null} at validation time, but may start out {@code null} while + * source is being built. + */ + @Nullable private final BigtableOptions options; + private final String tableId; + @Nullable private final BigtableService bigtableService; + + private Write( + @Nullable BigtableOptions options, + String tableId, + @Nullable BigtableService bigtableService) { + this.options = options; + this.tableId = checkNotNull(tableId, "tableId"); + this.bigtableService = bigtableService; + } + + /** + * Returns a new {@link BigtableIO.Write} that will write to the Cloud Bigtable cluster + * indicated by the given options, and using any other specified customizations. + * + *

    Does not modify this object. + */ + public Write withBigtableOptions(BigtableOptions options) { + checkNotNull(options, "options"); + return withBigtableOptions(options.toBuilder()); + } + + /** + * Returns a new {@link BigtableIO.Write} that will write to the Cloud Bigtable cluster + * indicated by the given options, and using any other specified customizations. + * + *

    Clones the given {@link BigtableOptions} builder so that any further changes + * will have no effect on the returned {@link BigtableIO.Write}. + * + *

    Does not modify this object. + */ + public Write withBigtableOptions(BigtableOptions.Builder optionsBuilder) { + checkNotNull(optionsBuilder, "optionsBuilder"); + // TODO: is there a better way to clone a Builder? Want it to be immune from user changes. + BigtableOptions.Builder clonedBuilder = optionsBuilder.build().toBuilder(); + BigtableOptions optionsWithAgent = clonedBuilder.setUserAgent(getUserAgent()).build(); + return new Write(optionsWithAgent, tableId, bigtableService); + } + + /** + * Returns a new {@link BigtableIO.Write} that will write to the specified table. + * + *

    Does not modify this object. + */ + public Write withTableId(String tableId) { + checkNotNull(tableId, "tableId"); + return new Write(options, tableId, bigtableService); + } + + /** + * Returns the Google Cloud Bigtable cluster being written to, and other parameters. + */ + public BigtableOptions getBigtableOptions() { + return options; + } + + /** + * Returns the table being written to. + */ + public String getTableId() { + return tableId; + } + + @Override + public PDone apply(PCollection>> input) { + Sink sink = new Sink(tableId, getBigtableService()); + return input.apply(com.google.cloud.dataflow.sdk.io.Write.to(sink)); + } + + @Override + public void validate(PCollection>> input) { + checkArgument(options != null, "BigtableOptions not specified"); + checkArgument(!tableId.isEmpty(), "Table ID not specified"); + try { + checkArgument( + getBigtableService().tableExists(tableId), "Table %s does not exist", tableId); + } catch (IOException e) { + logger.warn("Error checking whether table {} exists; proceeding.", tableId, e); + } + } + + /** + * Returns a new {@link BigtableIO.Write} that will write using the given Cloud Bigtable + * service implementation. + * + *

    This is used for testing. + * + *

    Does not modify this object. + */ + Write withBigtableService(BigtableService bigtableService) { + checkNotNull(bigtableService, "bigtableService"); + return new Write(options, tableId, bigtableService); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(Write.class) + .add("options", options) + .add("tableId", tableId) + .toString(); + } + + /** + * Helper function that either returns the mock Bigtable service supplied by + * {@link #withBigtableService} or creates and returns an implementation that talks to + * {@code Cloud Bigtable}. + */ + private BigtableService getBigtableService() { + if (bigtableService != null) { + return bigtableService; + } + return new BigtableServiceImpl(options); + } + } + + ////////////////////////////////////////////////////////////////////////////////////////// + /** Disallow construction of utility class. */ + private BigtableIO() {} + + static class BigtableSource extends BoundedSource { + public BigtableSource( + BigtableService service, + String tableId, + @Nullable RowFilter filter, + ByteKeyRange range, + Long estimatedSizeBytes) { + this.service = service; + this.tableId = tableId; + this.filter = filter; + this.range = range; + this.estimatedSizeBytes = estimatedSizeBytes; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(BigtableSource.class) + .add("tableId", tableId) + .add("filter", filter) + .add("range", range) + .add("estimatedSizeBytes", estimatedSizeBytes) + .toString(); + } + + ////// Private state and internal implementation details ////// + private final BigtableService service; + @Nullable private final String tableId; + @Nullable private final RowFilter filter; + private final ByteKeyRange range; + @Nullable private Long estimatedSizeBytes; + @Nullable private transient List sampleRowKeys; + + protected BigtableSource withStartKey(ByteKey startKey) { + checkNotNull(startKey, "startKey"); + return new BigtableSource( + service, tableId, filter, range.withStartKey(startKey), estimatedSizeBytes); + } + + protected BigtableSource withEndKey(ByteKey endKey) { + checkNotNull(endKey, "endKey"); + return new BigtableSource( + service, tableId, filter, range.withEndKey(endKey), estimatedSizeBytes); + } + + protected BigtableSource withEstimatedSizeBytes(Long estimatedSizeBytes) { + checkNotNull(estimatedSizeBytes, "estimatedSizeBytes"); + return new BigtableSource(service, tableId, filter, range, estimatedSizeBytes); + } + + /** + * Makes an API call to the Cloud Bigtable service that gives information about tablet key + * boundaries and estimated sizes. We can use these samples to ensure that splits are on + * different tablets, and possibly generate sub-splits within tablets. + */ + private List getSampleRowKeys() throws IOException { + return service.getSampleRowKeys(this); + } + + @Override + public List splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + // Update the desiredBundleSizeBytes in order to limit the + // number of splits to maximumNumberOfSplits. + long maximumNumberOfSplits = 4000; + long sizeEstimate = getEstimatedSizeBytes(options); + desiredBundleSizeBytes = + Math.max(sizeEstimate / maximumNumberOfSplits, desiredBundleSizeBytes); + + // Delegate to testable helper. + return splitIntoBundlesBasedOnSamples(desiredBundleSizeBytes, getSampleRowKeys()); + } + + /** Helper that splits this source into bundles based on Cloud Bigtable sampled row keys. */ + private List splitIntoBundlesBasedOnSamples( + long desiredBundleSizeBytes, List sampleRowKeys) { + // There are no regions, or no samples available. Just scan the entire range. + if (sampleRowKeys.isEmpty()) { + logger.info("Not splitting source {} because no sample row keys are available.", this); + return Collections.singletonList(this); + } + + logger.info( + "About to split into bundles of size {} with sampleRowKeys length {} first element {}", + desiredBundleSizeBytes, + sampleRowKeys.size(), + sampleRowKeys.get(0)); + + // Loop through all sampled responses and generate splits from the ones that overlap the + // scan range. The main complication is that we must track the end range of the previous + // sample to generate good ranges. + ByteKey lastEndKey = ByteKey.EMPTY; + long lastOffset = 0; + ImmutableList.Builder splits = ImmutableList.builder(); + for (SampleRowKeysResponse response : sampleRowKeys) { + ByteKey responseEndKey = ByteKey.of(response.getRowKey()); + long responseOffset = response.getOffsetBytes(); + checkState( + responseOffset >= lastOffset, + "Expected response byte offset %s to come after the last offset %s", + responseOffset, + lastOffset); + + if (!range.overlaps(ByteKeyRange.of(lastEndKey, responseEndKey))) { + // This region does not overlap the scan, so skip it. + lastOffset = responseOffset; + lastEndKey = responseEndKey; + continue; + } + + // Calculate the beginning of the split as the larger of startKey and the end of the last + // split. Unspecified start is smallest key so is correctly treated as earliest key. + ByteKey splitStartKey = lastEndKey; + if (splitStartKey.compareTo(range.getStartKey()) < 0) { + splitStartKey = range.getStartKey(); + } + + // Calculate the end of the split as the smaller of endKey and the end of this sample. Note + // that range.containsKey handles the case when range.getEndKey() is empty. + ByteKey splitEndKey = responseEndKey; + if (!range.containsKey(splitEndKey)) { + splitEndKey = range.getEndKey(); + } + + // We know this region overlaps the desired key range, and we know a rough estimate of its + // size. Split the key range into bundle-sized chunks and then add them all as splits. + long sampleSizeBytes = responseOffset - lastOffset; + List subSplits = + splitKeyRangeIntoBundleSizedSubranges( + sampleSizeBytes, + desiredBundleSizeBytes, + ByteKeyRange.of(splitStartKey, splitEndKey)); + splits.addAll(subSplits); + + // Move to the next region. + lastEndKey = responseEndKey; + lastOffset = responseOffset; + } + + // We must add one more region after the end of the samples if both these conditions hold: + // 1. we did not scan to the end yet (lastEndKey is concrete, not 0-length). + // 2. we want to scan to the end (endKey is empty) or farther (lastEndKey < endKey). + if (!lastEndKey.isEmpty() + && (range.getEndKey().isEmpty() || lastEndKey.compareTo(range.getEndKey()) < 0)) { + splits.add(this.withStartKey(lastEndKey).withEndKey(range.getEndKey())); + } + + List ret = splits.build(); + logger.info("Generated {} splits. First split: {}", ret.size(), ret.get(0)); + return ret; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws IOException { + // Delegate to testable helper. + if (estimatedSizeBytes == null) { + estimatedSizeBytes = getEstimatedSizeBytesBasedOnSamples(getSampleRowKeys()); + } + return estimatedSizeBytes; + } + + /** + * Computes the estimated size in bytes based on the total size of all samples that overlap + * the key range this source will scan. + */ + private long getEstimatedSizeBytesBasedOnSamples(List samples) { + long estimatedSizeBytes = 0; + long lastOffset = 0; + ByteKey currentStartKey = ByteKey.EMPTY; + // Compute the total estimated size as the size of each sample that overlaps the scan range. + // TODO: In future, Bigtable service may provide finer grained APIs, e.g., to sample given a + // filter or to sample on a given key range. + for (SampleRowKeysResponse response : samples) { + ByteKey currentEndKey = ByteKey.of(response.getRowKey()); + long currentOffset = response.getOffsetBytes(); + if (!currentStartKey.isEmpty() && currentStartKey.equals(currentEndKey)) { + // Skip an empty region. + lastOffset = currentOffset; + continue; + } else if (range.overlaps(ByteKeyRange.of(currentStartKey, currentEndKey))) { + estimatedSizeBytes += currentOffset - lastOffset; + } + currentStartKey = currentEndKey; + lastOffset = currentOffset; + } + return estimatedSizeBytes; + } + + /** + * Cloud Bigtable returns query results ordered by key. + */ + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return true; + } + + @Override + public BoundedReader createReader(PipelineOptions options) throws IOException { + return new BigtableReader(this, service); + } + + @Override + public void validate() { + checkArgument(!tableId.isEmpty(), "tableId cannot be empty"); + } + + @Override + public Coder getDefaultOutputCoder() { + return Proto2Coder.of(Row.class); + } + + /** Helper that splits the specified range in this source into bundles. */ + private List splitKeyRangeIntoBundleSizedSubranges( + long sampleSizeBytes, long desiredBundleSizeBytes, ByteKeyRange range) { + // Catch the trivial cases. Split is small enough already, or this is the last region. + logger.debug( + "Subsplit for sampleSizeBytes {} and desiredBundleSizeBytes {}", + sampleSizeBytes, + desiredBundleSizeBytes); + if (sampleSizeBytes <= desiredBundleSizeBytes) { + return Collections.singletonList( + this.withStartKey(range.getStartKey()).withEndKey(range.getEndKey())); + } + + checkArgument( + sampleSizeBytes > 0, "Sample size %s bytes must be greater than 0.", sampleSizeBytes); + checkArgument( + desiredBundleSizeBytes > 0, + "Desired bundle size %s bytes must be greater than 0.", + desiredBundleSizeBytes); + + int splitCount = (int) Math.ceil(((double) sampleSizeBytes) / (desiredBundleSizeBytes)); + List splitKeys = range.split(splitCount); + ImmutableList.Builder splits = ImmutableList.builder(); + Iterator keys = splitKeys.iterator(); + ByteKey prev = keys.next(); + while (keys.hasNext()) { + ByteKey next = keys.next(); + splits.add( + this + .withStartKey(prev) + .withEndKey(next) + .withEstimatedSizeBytes(sampleSizeBytes / splitCount)); + prev = next; + } + return splits.build(); + } + + public ByteKeyRange getRange() { + return range; + } + + public RowFilter getRowFilter() { + return filter; + } + + public String getTableId() { + return tableId; + } + } + + private static class BigtableReader extends BoundedReader { + // Thread-safety: source is protected via synchronization and is only accessed or modified + // inside a synchronized block (or constructor, which is the same). + private BigtableSource source; + private BigtableService service; + private BigtableService.Reader reader; + private final ByteKeyRangeTracker rangeTracker; + private long recordsReturned; + + public BigtableReader(BigtableSource source, BigtableService service) { + this.source = source; + this.service = service; + rangeTracker = ByteKeyRangeTracker.of(source.getRange()); + } + + @Override + public boolean start() throws IOException { + reader = service.createReader(getCurrentSource()); + boolean hasRecord = + reader.start() + && rangeTracker.tryReturnRecordAt(true, ByteKey.of(reader.getCurrentRow().getKey())); + if (hasRecord) { + ++recordsReturned; + } + return hasRecord; + } + + @Override + public synchronized BigtableSource getCurrentSource() { + return source; + } + + @Override + public boolean advance() throws IOException { + boolean hasRecord = + reader.advance() + && rangeTracker.tryReturnRecordAt(true, ByteKey.of(reader.getCurrentRow().getKey())); + if (hasRecord) { + ++recordsReturned; + } + return hasRecord; + } + + @Override + public Row getCurrent() throws NoSuchElementException { + return reader.getCurrentRow(); + } + + @Override + public void close() throws IOException { + logger.info("Closing reader after reading {} records.", recordsReturned); + if (reader != null) { + reader.close(); + reader = null; + } + } + + @Override + public final Double getFractionConsumed() { + return rangeTracker.getFractionConsumed(); + } + + @Override + public final synchronized BigtableSource splitAtFraction(double fraction) { + ByteKey splitKey; + try { + splitKey = source.getRange().interpolateKey(fraction); + } catch (IllegalArgumentException e) { + logger.info("%s: Failed to interpolate key for fraction %s.", source.getRange(), fraction); + return null; + } + logger.debug( + "Proposing to split {} at fraction {} (key {})", rangeTracker, fraction, splitKey); + if (!rangeTracker.trySplitAtPosition(splitKey)) { + return null; + } + BigtableSource primary = source.withEndKey(splitKey); + BigtableSource residual = source.withStartKey(splitKey); + this.source = primary; + return residual; + } + } + + private static class Sink + extends com.google.cloud.dataflow.sdk.io.Sink>> { + + public Sink(String tableId, BigtableService bigtableService) { + this.tableId = checkNotNull(tableId, "tableId"); + this.bigtableService = checkNotNull(bigtableService, "bigtableService"); + } + + public String getTableId() { + return tableId; + } + + public BigtableService getBigtableService() { + return bigtableService; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(Sink.class) + .add("bigtableService", bigtableService) + .add("tableId", tableId) + .toString(); + } + + /////////////////////////////////////////////////////////////////////////////// + private final String tableId; + private final BigtableService bigtableService; + + @Override + public WriteOperation>, Long> createWriteOperation( + PipelineOptions options) { + return new BigtableWriteOperation(this); + } + + /** Does nothing, as it is redundant with {@link Write#validate}. */ + @Override + public void validate(PipelineOptions options) {} + } + + private static class BigtableWriteOperation + extends WriteOperation>, Long> { + private final Sink sink; + + public BigtableWriteOperation(Sink sink) { + this.sink = sink; + } + + @Override + public Writer>, Long> createWriter(PipelineOptions options) + throws Exception { + return new BigtableWriter(this); + } + + @Override + public void initialize(PipelineOptions options) {} + + @Override + public void finalize(Iterable writerResults, PipelineOptions options) { + long count = 0; + for (Long value : writerResults) { + value += count; + } + logger.debug("Wrote {} elements to BigtableIO.Sink {}", sink); + } + + @Override + public Sink getSink() { + return sink; + } + + @Override + public Coder getWriterResultCoder() { + return VarLongCoder.of(); + } + } + + private static class BigtableWriter extends Writer>, Long> { + private final BigtableWriteOperation writeOperation; + private final Sink sink; + private BigtableService.Writer bigtableWriter; + private long recordsWritten; + private final ConcurrentLinkedQueue failures; + + public BigtableWriter(BigtableWriteOperation writeOperation) { + this.writeOperation = writeOperation; + this.sink = writeOperation.getSink(); + this.failures = new ConcurrentLinkedQueue<>(); + } + + @Override + public void open(String uId) throws Exception { + bigtableWriter = sink.getBigtableService().openForWriting(sink.getTableId()); + recordsWritten = 0; + } + + /** + * If any write has asynchronously failed, fail the bundle with a useful error. + */ + private void checkForFailures() throws IOException { + // Note that this function is never called by multiple threads and is the only place that + // we remove from failures, so this code is safe. + if (failures.isEmpty()) { + return; + } + + StringBuilder logEntry = new StringBuilder(); + int i = 0; + for (; i < 10 && !failures.isEmpty(); ++i) { + BigtableWriteException exc = failures.remove(); + logEntry.append("\n").append(exc.getMessage()); + if (exc.getCause() != null) { + logEntry.append(": ").append(exc.getCause().getMessage()); + } + } + String message = + String.format( + "At least %d errors occurred writing to Bigtable. First %d errors: %s", + i + failures.size(), + i, + logEntry.toString()); + logger.error(message); + throw new IOException(message); + } + + @Override + public void write(KV> rowMutations) throws Exception { + checkForFailures(); + Futures.addCallback( + bigtableWriter.writeRecord(rowMutations), new WriteExceptionCallback(rowMutations)); + ++recordsWritten; + } + + @Override + public Long close() throws Exception { + bigtableWriter.close(); + bigtableWriter = null; + checkForFailures(); + logger.info("Wrote {} records", recordsWritten); + return recordsWritten; + } + + @Override + public WriteOperation>, Long> getWriteOperation() { + return writeOperation; + } + + private class WriteExceptionCallback implements FutureCallback { + private final KV> value; + + public WriteExceptionCallback(KV> value) { + this.value = value; + } + + @Override + public void onFailure(Throwable cause) { + failures.add(new BigtableWriteException(value, cause)); + } + + @Override + public void onSuccess(Empty produced) {} + } + } + + /** + * An exception that puts information about the failed record being written in its message. + */ + static class BigtableWriteException extends IOException { + public BigtableWriteException(KV> record, Throwable cause) { + super( + String.format( + "Error mutating row %s with mutations %s", + record.getKey().toStringUtf8(), + record.getValue()), + cause); + } + } + + /** + * A helper function to produce a Cloud Bigtable user agent string. + */ + private static String getUserAgent() { + String javaVersion = System.getProperty("java.specification.version"); + DataflowReleaseInfo info = DataflowReleaseInfo.getReleaseInfo(); + return String.format( + "%s/%s (%s); %s", + info.getName(), + info.getVersion(), + javaVersion, + "0.2.3" /* TODO get Bigtable client version directly from jar. */); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableService.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableService.java new file mode 100644 index 000000000000..85d706cb0a67 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableService.java @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.io.bigtable; + +import com.google.bigtable.v1.Mutation; +import com.google.bigtable.v1.Row; +import com.google.bigtable.v1.SampleRowKeysResponse; +import com.google.cloud.dataflow.sdk.io.bigtable.BigtableIO.BigtableSource; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * An interface for real or fake implementations of Cloud Bigtable. + */ +interface BigtableService extends Serializable { + + /** + * The interface of a class that can write to Cloud Bigtable. + */ + interface Writer { + /** + * Writes a single row transaction to Cloud Bigtable. The key of the {@code record} is the + * row key to be mutated and the iterable of mutations represent the changes to be made to the + * row. + * + * @throws IOException if there is an error submitting the write. + */ + ListenableFuture writeRecord(KV> record) + throws IOException; + + /** + * Closes the writer. + * + * @throws IOException if any writes did not succeed + */ + void close() throws IOException; + } + + /** + * The interface of a class that reads from Cloud Bigtable. + */ + interface Reader { + /** + * Reads the first element (including initialization, such as opening a network connection) and + * returns true if an element was found. + */ + boolean start() throws IOException; + + /** + * Attempts to read the next element, and returns true if an element has been read. + */ + boolean advance() throws IOException; + + /** + * Closes the reader. + * + * @throws IOException if there is an error. + */ + void close() throws IOException; + + /** + * Returns the last row read by a successful start() or advance(), or throws if there is no + * current row because the last such call was unsuccessful. + */ + Row getCurrentRow() throws NoSuchElementException; + } + + /** + * Returns {@code true} if the table with the give name exists. + */ + boolean tableExists(String tableId) throws IOException; + + /** + * Returns a {@link Reader} that will read from the specified source. + */ + Reader createReader(BigtableSource source) throws IOException; + + /** + * Returns a {@link Writer} that will write to the specified table. + */ + Writer openForWriting(String tableId) throws IOException; + + /** + * Returns a set of row keys sampled from the underlying table. These contain information about + * the distribution of keys within the table. + */ + List getSampleRowKeys(BigtableSource source) throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableServiceImpl.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableServiceImpl.java new file mode 100644 index 000000000000..5ab85827ec05 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableServiceImpl.java @@ -0,0 +1,241 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.io.bigtable; + +import com.google.bigtable.admin.table.v1.GetTableRequest; +import com.google.bigtable.v1.MutateRowRequest; +import com.google.bigtable.v1.Mutation; +import com.google.bigtable.v1.ReadRowsRequest; +import com.google.bigtable.v1.Row; +import com.google.bigtable.v1.RowRange; +import com.google.bigtable.v1.SampleRowKeysRequest; +import com.google.bigtable.v1.SampleRowKeysResponse; +import com.google.cloud.bigtable.config.BigtableOptions; +import com.google.cloud.bigtable.grpc.BigtableSession; +import com.google.cloud.bigtable.grpc.async.AsyncExecutor; +import com.google.cloud.bigtable.grpc.async.HeapSizeManager; +import com.google.cloud.bigtable.grpc.scanner.ResultScanner; +import com.google.cloud.dataflow.sdk.io.bigtable.BigtableIO.BigtableSource; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.base.MoreObjects; +import com.google.common.io.Closer; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * An implementation of {@link BigtableService} that actually communicates with the Cloud Bigtable + * service. + */ +class BigtableServiceImpl implements BigtableService { + private static final Logger logger = LoggerFactory.getLogger(BigtableService.class); + + public BigtableServiceImpl(BigtableOptions options) { + this.options = options; + } + + private final BigtableOptions options; + + @Override + public BigtableWriterImpl openForWriting(String tableId) throws IOException { + BigtableSession session = new BigtableSession(options); + String tableName = options.getClusterName().toTableNameStr(tableId); + return new BigtableWriterImpl(session, tableName); + } + + @Override + public boolean tableExists(String tableId) throws IOException { + if (!BigtableSession.isAlpnProviderEnabled()) { + logger.info( + "Skipping existence check for table {} (BigtableOptions {}) because ALPN is not" + + " configured.", + tableId, + options); + return true; + } + + try (BigtableSession session = new BigtableSession(options)) { + GetTableRequest getTable = + GetTableRequest.newBuilder() + .setName(options.getClusterName().toTableNameStr(tableId)) + .build(); + session.getTableAdminClient().getTable(getTable); + return true; + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() == Code.NOT_FOUND) { + return false; + } + String message = + String.format( + "Error checking whether table %s (BigtableOptions %s) exists", tableId, options); + logger.error(message, e); + throw new IOException(message, e); + } + } + + private class BigtableReaderImpl implements Reader { + private BigtableSession session; + private final BigtableSource source; + private ResultScanner results; + private Row currentRow; + + public BigtableReaderImpl(BigtableSession session, BigtableSource source) { + this.session = session; + this.source = source; + } + + @Override + public boolean start() throws IOException { + RowRange range = + RowRange.newBuilder() + .setStartKey(source.getRange().getStartKey().getValue()) + .setEndKey(source.getRange().getEndKey().getValue()) + .build(); + ReadRowsRequest.Builder requestB = + ReadRowsRequest.newBuilder() + .setRowRange(range) + .setTableName(options.getClusterName().toTableNameStr(source.getTableId())); + if (source.getRowFilter() != null) { + requestB.setFilter(source.getRowFilter()); + } + results = session.getDataClient().readRows(requestB.build()); + return advance(); + } + + @Override + public boolean advance() throws IOException { + currentRow = results.next(); + return (currentRow != null); + } + + @Override + public void close() throws IOException { + // Goal: by the end of this function, both results and session are null and closed, + // independent of what errors they throw or prior state. + + if (session == null) { + // Only possible when previously closed, so we know that results is also null. + return; + } + + // Session does not implement Closeable -- it's AutoCloseable. So we can't register it with + // the Closer, but we can use the Closer to simplify the error handling. + try (Closer closer = Closer.create()) { + if (results != null) { + closer.register(results); + results = null; + } + + session.close(); + } finally { + session = null; + } + } + + @Override + public Row getCurrentRow() throws NoSuchElementException { + if (currentRow == null) { + throw new NoSuchElementException(); + } + return currentRow; + } + } + + private static class BigtableWriterImpl implements Writer { + private BigtableSession session; + private AsyncExecutor executor; + private final MutateRowRequest.Builder partialBuilder; + + public BigtableWriterImpl(BigtableSession session, String tableName) { + this.session = session; + this.executor = + new AsyncExecutor( + session.getDataClient(), + new HeapSizeManager( + AsyncExecutor.ASYNC_MUTATOR_MAX_MEMORY_DEFAULT, + AsyncExecutor.MAX_INFLIGHT_RPCS_DEFAULT)); + + partialBuilder = MutateRowRequest.newBuilder().setTableName(tableName); + } + + @Override + public void close() throws IOException { + try { + if (executor != null) { + executor.flush(); + executor = null; + } + } finally { + if (session != null) { + session.close(); + session = null; + } + } + } + + @Override + public ListenableFuture writeRecord(KV> record) + throws IOException { + MutateRowRequest r = + partialBuilder + .clone() + .setRowKey(record.getKey()) + .addAllMutations(record.getValue()) + .build(); + try { + return executor.mutateRowAsync(r); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Write interrupted", e); + } + } + } + + @Override + public String toString() { + return MoreObjects + .toStringHelper(BigtableServiceImpl.class) + .add("options", options) + .toString(); + } + + @Override + public Reader createReader(BigtableSource source) throws IOException { + BigtableSession session = new BigtableSession(options); + return new BigtableReaderImpl(session, source); + } + + @Override + public List getSampleRowKeys(BigtableSource source) throws IOException { + try (BigtableSession session = new BigtableSession(options)) { + SampleRowKeysRequest request = + SampleRowKeysRequest.newBuilder() + .setTableName(options.getClusterName().toTableNameStr(source.getTableId())) + .build(); + return session.getDataClient().sampleRowKeys(request); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/package-info.java new file mode 100644 index 000000000000..de0bd8609498 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/package-info.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines transforms for reading and writing common storage formats, including + * {@link com.google.cloud.dataflow.sdk.io.AvroIO}, + * {@link com.google.cloud.dataflow.sdk.io.BigQueryIO}, and + * {@link com.google.cloud.dataflow.sdk.io.TextIO}. + * + *

    The classes in this package provide {@code Read} transforms that create PCollections + * from existing storage: + *

    {@code
    + * PCollection inputData = pipeline.apply(
    + *     BigQueryIO.Read.named("Read")
    + *                    .from("clouddataflow-readonly:samples.weather_stations");
    + * }
    + * and {@code Write} transforms that persist PCollections to external storage: + *
     {@code
    + * PCollection numbers = ...;
    + * numbers.apply(TextIO.Write.named("WriteNumbers")
    + *                           .to("gs://my_bucket/path/to/numbers"));
    + * } 
    + */ +package com.google.cloud.dataflow.sdk.io; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKey.java new file mode 100644 index 000000000000..30772da793c3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKey.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.protobuf.ByteString; +import com.google.protobuf.ByteString.ByteIterator; + +import java.io.Serializable; + +/** + * A class representing a key consisting of an array of bytes. Arbitrary-length + * {@code byte[]} keys are typical in key-value stores such as Google Cloud Bigtable. + * + *

    Instances of {@link ByteKey} are immutable. + * + *

    {@link ByteKey} implements {@link Comparable Comparable<ByteKey>} by comparing the + * arrays in lexicographic order. The smallest {@link ByteKey} is a zero-length array; the successor + * to a key is the same key with an additional 0 byte appended; and keys have unbounded size. + * + *

    Note that the empty {@link ByteKey} compares smaller than all other keys, but some systems + * have the semantic that when an empty {@link ByteKey} is used as an upper bound, it represents + * the largest possible key. In these cases, implementors should use {@link #isEmpty} to test + * whether an upper bound key is empty. + */ +public final class ByteKey implements Comparable, Serializable { + /** An empty key. */ + public static final ByteKey EMPTY = ByteKey.of(); + + /** + * Creates a new {@link ByteKey} backed by the specified {@link ByteString}. + */ + public static ByteKey of(ByteString value) { + return new ByteKey(value); + } + + /** + * Creates a new {@link ByteKey} backed by a copy of the specified {@code byte[]}. + * + *

    Makes a copy of the underlying array. + */ + public static ByteKey copyFrom(byte[] bytes) { + return of(ByteString.copyFrom(bytes)); + } + + /** + * Creates a new {@link ByteKey} backed by a copy of the specified {@code int[]}. This method is + * primarily used as a convenience to create a {@link ByteKey} in code without casting down to + * signed Java {@link Byte bytes}: + * + *

    {@code
    +   * ByteKey key = ByteKey.of(0xde, 0xad, 0xbe, 0xef);
    +   * }
    + * + *

    Makes a copy of the input. + */ + public static ByteKey of(int... bytes) { + byte[] ret = new byte[bytes.length]; + for (int i = 0; i < bytes.length; ++i) { + ret[i] = (byte) (bytes[i] & 0xff); + } + return ByteKey.copyFrom(ret); + } + + /** + * Returns an immutable {@link ByteString} representing this {@link ByteKey}. + * + *

    Does not copy. + */ + public ByteString getValue() { + return value; + } + + /** + * Returns a newly-allocated {@code byte[]} representing this {@link ByteKey}. + * + *

    Copies the underlying {@code byte[]}. + */ + public byte[] getBytes() { + return value.toByteArray(); + } + + /** + * Returns {@code true} if the {@code byte[]} backing this {@link ByteKey} is of length 0. + */ + public boolean isEmpty() { + return value.isEmpty(); + } + + /** + * {@link ByteKey} implements {@link Comparable Comparable<ByteKey>} by comparing the + * arrays in lexicographic order. The smallest {@link ByteKey} is a zero-length array; the + * successor to a key is the same key with an additional 0 byte appended; and keys have unbounded + * size. + */ + @Override + public int compareTo(ByteKey other) { + checkNotNull(other, "other"); + ByteIterator thisIt = value.iterator(); + ByteIterator otherIt = other.value.iterator(); + while (thisIt.hasNext() && otherIt.hasNext()) { + // (byte & 0xff) converts [-128,127] bytes to [0,255] ints. + int cmp = (thisIt.nextByte() & 0xff) - (otherIt.nextByte() & 0xff); + if (cmp != 0) { + return cmp; + } + } + // If we get here, the prefix of both arrays is equal up to the shorter array. The array with + // more bytes is larger. + return value.size() - other.value.size(); + } + + //////////////////////////////////////////////////////////////////////////////////// + private final ByteString value; + + private ByteKey(ByteString value) { + this.value = value; + } + + /** Array used as a helper in {@link #toString}. */ + private static final char[] HEX = + new char[] {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; + + // Prints the key as a string "[deadbeef]". + @Override + public String toString() { + char[] encoded = new char[2 * value.size() + 2]; + encoded[0] = '['; + int cnt = 1; + ByteIterator iterator = value.iterator(); + while (iterator.hasNext()) { + byte b = iterator.nextByte(); + encoded[cnt] = HEX[(b & 0xF0) >>> 4]; + ++cnt; + encoded[cnt] = HEX[b & 0xF]; + ++cnt; + } + encoded[cnt] = ']'; + return new String(encoded); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof ByteKey)) { + return false; + } + ByteKey other = (ByteKey) o; + return (other.value.size() == value.size()) && this.compareTo(other) == 0; + } + + @Override + public int hashCode() { + return value.hashCode(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRange.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRange.java new file mode 100644 index 000000000000..6f58d393f905 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRange.java @@ -0,0 +1,376 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * A class representing a range of {@link ByteKey ByteKeys}. + * + *

    Instances of {@link ByteKeyRange} are immutable. + * + *

    A {@link ByteKeyRange} enforces the restriction that its start and end keys must form a valid, + * non-empty range {@code [startKey, endKey)} that is inclusive of the start key and exclusive of + * the end key. + * + *

    When the end key is empty, it is treated as the largest possible key. + * + *

    Interpreting {@link ByteKey} in a {@link ByteKeyRange}

    + * + *

    The primary role of {@link ByteKeyRange} is to provide functionality for + * {@link #estimateFractionForKey(ByteKey)}, {@link #interpolateKey(double)}, and + * {@link #split(int)}, which are used for Google Cloud Dataflow's + * Autoscaling + * and Dynamic Work Rebalancing features. + * + *

    {@link ByteKeyRange} implements these features by treating a {@link ByteKey}'s underlying + * {@code byte[]} as the binary expansion of floating point numbers in the range {@code [0.0, 1.0]}. + * For example, the keys {@code ByteKey.of(0x80)}, {@code ByteKey.of(0xc0)}, and + * {@code ByteKey.of(0xe0)} are interpreted as {@code 0.5}, {@code 0.75}, and {@code 0.875} + * respectively. The empty {@code ByteKey.EMPTY} is interpreted as {@code 0.0} when used as the + * start of a range and {@code 1.0} when used as the end key. + * + *

    Key interpolation, fraction estimation, and range splitting are all interpreted in these + * floating-point semantics. See the respective implementations for further details. Note: + * the underlying implementations of these functions use {@link BigInteger} and {@link BigDecimal}, + * so they can be slow and should not be called in hot loops. Dataflow's dynamic work + * rebalancing will only invoke these functions during periodic control operations, so they are not + * called on the critical path. + * + * @see ByteKey + */ +public final class ByteKeyRange implements Serializable { + private static final Logger logger = LoggerFactory.getLogger(ByteKeyRange.class); + + /** The range of all keys, with empty start and end keys. */ + public static final ByteKeyRange ALL_KEYS = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.EMPTY); + + /** + * Creates a new {@link ByteKeyRange} with the given start and end keys. + * + *

    Note that if {@code endKey} is empty, it is treated as the largest possible key. + * + * @see ByteKeyRange + * + * @throws IllegalArgumentException if {@code endKey} is less than or equal to {@code startKey}, + * unless {@code endKey} is empty indicating the maximum possible {@link ByteKey}. + */ + public static ByteKeyRange of(ByteKey startKey, ByteKey endKey) { + return new ByteKeyRange(startKey, endKey); + } + + /** + * Returns the {@link ByteKey} representing the lower bound of this {@link ByteKeyRange}. + */ + public ByteKey getStartKey() { + return startKey; + } + + /** + * Returns the {@link ByteKey} representing the upper bound of this {@link ByteKeyRange}. + * + *

    Note that if {@code endKey} is empty, it is treated as the largest possible key. + */ + public ByteKey getEndKey() { + return endKey; + } + + /** + * Returns {@code true} if the specified {@link ByteKey} is contained within this range. + */ + public Boolean containsKey(ByteKey key) { + return key.compareTo(startKey) >= 0 && endsAfterKey(key); + } + + /** + * Returns {@code true} if the specified {@link ByteKeyRange} overlaps this range. + */ + public Boolean overlaps(ByteKeyRange other) { + // If each range starts before the other range ends, then they must overlap. + // { [] } -- one range inside the other OR { [ } ] -- partial overlap. + return endsAfterKey(other.startKey) && other.endsAfterKey(startKey); + } + + /** + * Returns a list of up to {@code numSplits + 1} {@link ByteKey ByteKeys} in ascending order, + * where the keys have been interpolated to form roughly equal sub-ranges of this + * {@link ByteKeyRange}, assuming a uniform distribution of keys within this range. + * + *

    The first {@link ByteKey} in the result is guaranteed to be equal to {@link #getStartKey}, + * and the last {@link ByteKey} in the result is guaranteed to be equal to {@link #getEndKey}. + * Thus the resulting list exactly spans the same key range as this {@link ByteKeyRange}. + * + *

    Note that the number of keys returned is not always equal to {@code numSplits + 1}. + * Specifically, if this range is unsplittable (e.g., because the start and end keys are equal + * up to padding by zero bytes), the list returned will only contain the start and end key. + * + * @throws IllegalArgumentException if the specified number of splits is < 1 + * @see ByteKeyRange the ByteKeyRange class Javadoc for more information about split semantics. + */ + public List split(int numSplits) { + checkArgument(numSplits > 0, "numSplits %s must be a positive integer", numSplits); + + try { + ImmutableList.Builder ret = ImmutableList.builder(); + ret.add(startKey); + for (int i = 1; i < numSplits; ++i) { + ret.add(interpolateKey(i / (double) numSplits)); + } + ret.add(endKey); + return ret.build(); + } catch (IllegalStateException e) { + // The range is not splittable -- just return + return ImmutableList.of(startKey, endKey); + } + } + + /** + * Returns the fraction of this range {@code [startKey, endKey)} that is in the interval + * {@code [startKey, key)}. + * + * @throws IllegalArgumentException if {@code key} does not fall within this range + * @see ByteKeyRange the ByteKeyRange class Javadoc for more information about fraction semantics. + */ + public double estimateFractionForKey(ByteKey key) { + checkNotNull(key, "key"); + checkArgument(!key.isEmpty(), "Cannot compute fraction for an empty key"); + checkArgument( + key.compareTo(startKey) >= 0, "Expected key %s >= range start key %s", key, startKey); + + if (key.equals(endKey)) { + return 1.0; + } + checkArgument(containsKey(key), "Cannot compute fraction for %s outside this %s", key, this); + + byte[] startBytes = startKey.getBytes(); + byte[] endBytes = endKey.getBytes(); + byte[] keyBytes = key.getBytes(); + // If the endKey is unspecified, add a leading 1 byte to it and a leading 0 byte to all other + // keys, to get a concrete least upper bound for the desired range. + if (endKey.isEmpty()) { + startBytes = addHeadByte(startBytes, (byte) 0); + endBytes = addHeadByte(endBytes, (byte) 1); + keyBytes = addHeadByte(keyBytes, (byte) 0); + } + + // Pad to the longest of all 3 keys. + int paddedKeyLength = Math.max(Math.max(startBytes.length, endBytes.length), keyBytes.length); + BigInteger rangeStartInt = paddedPositiveInt(startBytes, paddedKeyLength); + BigInteger rangeEndInt = paddedPositiveInt(endBytes, paddedKeyLength); + BigInteger keyInt = paddedPositiveInt(keyBytes, paddedKeyLength); + + // Keys are equal subject to padding by 0. + BigInteger range = rangeEndInt.subtract(rangeStartInt); + if (range.equals(BigInteger.ZERO)) { + logger.warn( + "Using 0.0 as the default fraction for this near-empty range {} where start and end keys" + + " differ only by trailing zeros.", + this); + return 0.0; + } + + // Compute the progress (key-start)/(end-start) scaling by 2^64, dividing (which rounds), + // and then scaling down after the division. This gives ample precision when converted to + // double. + BigInteger progressScaled = keyInt.subtract(rangeStartInt).shiftLeft(64); + return progressScaled.divide(range).doubleValue() / Math.pow(2, 64); + } + + /** + * Returns a {@link ByteKey} {@code key} such that {@code [startKey, key)} represents + * approximately the specified fraction of the range {@code [startKey, endKey)}. The interpolation + * is computed assuming a uniform distribution of keys. + * + *

    For example, given the largest possible range (defined by empty start and end keys), the + * fraction {@code 0.5} will return the {@code ByteKey.of(0x80)}, which will also be returned for + * ranges {@code [0x40, 0xc0)} and {@code [0x6f, 0x91)}. + * + *

    The key returned will never be empty. + * + * @throws IllegalArgumentException if {@code fraction} is outside the range [0, 1) + * @throws IllegalStateException if this range cannot be interpolated + * @see ByteKeyRange the ByteKeyRange class Javadoc for more information about fraction semantics. + */ + public ByteKey interpolateKey(double fraction) { + checkArgument( + fraction >= 0.0 && fraction < 1.0, "Fraction %s must be in the range [0, 1)", fraction); + byte[] startBytes = startKey.getBytes(); + byte[] endBytes = endKey.getBytes(); + // If the endKey is unspecified, add a leading 1 byte to it and a leading 0 byte to all other + // keys, to get a concrete least upper bound for the desired range. + if (endKey.isEmpty()) { + startBytes = addHeadByte(startBytes, (byte) 0); + endBytes = addHeadByte(endBytes, (byte) 1); + } + + // Pad to the longest key. + int paddedKeyLength = Math.max(startBytes.length, endBytes.length); + BigInteger rangeStartInt = paddedPositiveInt(startBytes, paddedKeyLength); + BigInteger rangeEndInt = paddedPositiveInt(endBytes, paddedKeyLength); + + // If the keys are equal subject to padding by 0, we can't interpolate. + BigInteger range = rangeEndInt.subtract(rangeStartInt); + checkState( + !range.equals(BigInteger.ZERO), + "Refusing to interpolate for near-empty %s where start and end keys differ only by trailing" + + " zero bytes.", + this); + + // Add precision so that range is at least 53 (double mantissa length) bits long. This way, we + // can interpolate small ranges finely, e.g., split the range key 3 to key 4 into 1024 parts. + // We add precision to range by adding zero bytes to the end of the keys, aka shifting the + // underlying BigInteger left by a multiple of 8 bits. + int bytesNeeded = ((53 - range.bitLength()) + 7) / 8; + if (bytesNeeded > 0) { + range = range.shiftLeft(bytesNeeded * 8); + rangeStartInt = rangeStartInt.shiftLeft(bytesNeeded * 8); + paddedKeyLength += bytesNeeded; + } + + BigInteger interpolatedOffset = + new BigDecimal(range).multiply(BigDecimal.valueOf(fraction)).toBigInteger(); + + int outputKeyLength = endKey.isEmpty() ? (paddedKeyLength - 1) : paddedKeyLength; + return ByteKey.copyFrom( + fixupHeadZeros(rangeStartInt.add(interpolatedOffset).toByteArray(), outputKeyLength)); + } + + /** + * Returns new {@link ByteKeyRange} like this one, but with the specified start key. + */ + public ByteKeyRange withStartKey(ByteKey startKey) { + return new ByteKeyRange(startKey, endKey); + } + + /** + * Returns new {@link ByteKeyRange} like this one, but with the specified end key. + */ + public ByteKeyRange withEndKey(ByteKey endKey) { + return new ByteKeyRange(startKey, endKey); + } + + //////////////////////////////////////////////////////////////////////////////////// + private final ByteKey startKey; + private final ByteKey endKey; + + private ByteKeyRange(ByteKey startKey, ByteKey endKey) { + this.startKey = checkNotNull(startKey, "startKey"); + this.endKey = checkNotNull(endKey, "endKey"); + checkArgument(endsAfterKey(startKey), "Start %s must be less than end %s", startKey, endKey); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(ByteKeyRange.class) + .add("startKey", startKey) + .add("endKey", endKey) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof ByteKeyRange)) { + return false; + } + ByteKeyRange other = (ByteKeyRange) o; + return Objects.equals(startKey, other.startKey) && Objects.equals(endKey, other.endKey); + } + + @Override + public int hashCode() { + return Objects.hash(startKey, endKey); + } + + /** + * Returns a copy of the specified array with the specified byte added at the front. + */ + private static byte[] addHeadByte(byte[] array, byte b) { + byte[] ret = new byte[array.length + 1]; + ret[0] = b; + System.arraycopy(array, 0, ret, 1, array.length); + return ret; + } + + /** + * Ensures the array is exactly {@code size} bytes long. Returns the input array if the condition + * is met, otherwise either adds or removes zero bytes from the beginning of {@code array}. + */ + private static byte[] fixupHeadZeros(byte[] array, int size) { + int padding = size - array.length; + if (padding == 0) { + return array; + } + + if (padding < 0) { + // There is one zero byte at the beginning, added by BigInteger to make there be a sign + // bit when converting to bytes. + verify( + padding == -1, + "key %s: expected length %d with exactly one byte of padding, found %d", + ByteKey.copyFrom(array), + size, + -padding); + verify( + (array[0] == 0) && ((array[1] & 0x80) == 0x80), + "key %s: is 1 byte longer than expected, indicating BigInteger padding. Expect first byte" + + " to be zero with set MSB in second byte.", + ByteKey.copyFrom(array)); + return Arrays.copyOfRange(array, 1, array.length); + } + + byte[] ret = new byte[size]; + System.arraycopy(array, 0, ret, padding, array.length); + return ret; + } + + /** + * Returns {@code true} when the specified {@code key} is smaller this range's end key. The only + * semantic change from {@code (key.compareTo(getEndKey()) < 0)} is that the empty end key is + * treated as larger than all possible {@link ByteKey keys}. + */ + boolean endsAfterKey(ByteKey key) { + return endKey.isEmpty() || key.compareTo(endKey) < 0; + } + + /** Builds a BigInteger out of the specified array, padded to the desired byte length. */ + private static BigInteger paddedPositiveInt(byte[] bytes, int length) { + int bytePaddingNeeded = length - bytes.length; + checkArgument( + bytePaddingNeeded >= 0, "Required bytes.length {} < length {}", bytes.length, length); + BigInteger ret = new BigInteger(1, bytes); + return (bytePaddingNeeded == 0) ? ret : ret.shiftLeft(8 * bytePaddingNeeded); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTracker.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTracker.java new file mode 100644 index 000000000000..f6796cc5afb9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTracker.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static com.google.common.base.MoreObjects.toStringHelper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +/** + * A {@link RangeTracker} for {@link ByteKey ByteKeys} in {@link ByteKeyRange ByteKeyRanges}. + * + * @see ByteKey + * @see ByteKeyRange + */ +public final class ByteKeyRangeTracker implements RangeTracker { + private static final Logger logger = LoggerFactory.getLogger(ByteKeyRangeTracker.class); + + /** Instantiates a new {@link ByteKeyRangeTracker} with the specified range. */ + public static ByteKeyRangeTracker of(ByteKeyRange range) { + return new ByteKeyRangeTracker(range); + } + + @Override + public synchronized ByteKey getStartPosition() { + return range.getStartKey(); + } + + @Override + public synchronized ByteKey getStopPosition() { + return range.getEndKey(); + } + + @Override + public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, ByteKey recordStart) { + if (isAtSplitPoint && !range.containsKey(recordStart)) { + return false; + } + position = recordStart; + return true; + } + + @Override + public synchronized boolean trySplitAtPosition(ByteKey splitPosition) { + // Unstarted. + if (position == null) { + logger.warn( + "{}: Rejecting split request at {} because no records have been returned.", + this, + splitPosition); + return false; + } + + // Started, but not after current position. + if (splitPosition.compareTo(position) <= 0) { + logger.warn( + "{}: Rejecting split request at {} because it is not after current position {}.", + this, + splitPosition, + position); + return false; + } + + // Sanity check. + if (!range.containsKey(splitPosition)) { + logger.warn( + "{}: Rejecting split request at {} because it is not within the range.", + this, + splitPosition); + return false; + } + + range = range.withEndKey(splitPosition); + return true; + } + + @Override + public synchronized double getFractionConsumed() { + if (position == null) { + return 0; + } + return range.estimateFractionForKey(position); + } + + /////////////////////////////////////////////////////////////////////////////// + private ByteKeyRange range; + @Nullable private ByteKey position; + + private ByteKeyRangeTracker(ByteKeyRange range) { + this.range = range; + this.position = null; + } + + @Override + public String toString() { + return toStringHelper(ByteKeyRangeTracker.class) + .add("range", range) + .add("position", position) + .toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/OffsetRangeTracker.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/OffsetRangeTracker.java new file mode 100644 index 000000000000..b23721749605 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/OffsetRangeTracker.java @@ -0,0 +1,182 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.io.range; + +import com.google.common.annotations.VisibleForTesting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link RangeTracker} for non-negative positions of type {@code long}. + */ +public class OffsetRangeTracker implements RangeTracker { + private static final Logger LOG = LoggerFactory.getLogger(OffsetRangeTracker.class); + + private final long startOffset; + private long stopOffset; + private long lastRecordStart = -1L; + private long offsetOfLastSplitPoint = -1L; + + /** + * Offset corresponding to infinity. This can only be used as the upper-bound of a range, and + * indicates reading all of the records until the end without specifying exactly what the end is. + * + *

    Infinite ranges cannot be split because it is impossible to estimate progress within them. + */ + public static final long OFFSET_INFINITY = Long.MAX_VALUE; + + /** + * Creates an {@code OffsetRangeTracker} for the specified range. + */ + public OffsetRangeTracker(long startOffset, long stopOffset) { + this.startOffset = startOffset; + this.stopOffset = stopOffset; + } + + @Override + public synchronized Long getStartPosition() { + return startOffset; + } + + @Override + public synchronized Long getStopPosition() { + return stopOffset; + } + + @Override + public boolean tryReturnRecordAt(boolean isAtSplitPoint, Long recordStart) { + return tryReturnRecordAt(isAtSplitPoint, recordStart.longValue()); + } + + public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, long recordStart) { + if (lastRecordStart == -1 && !isAtSplitPoint) { + throw new IllegalStateException( + String.format("The first record [starting at %d] must be at a split point", recordStart)); + } + if (recordStart < lastRecordStart) { + throw new IllegalStateException( + String.format( + "Trying to return record [starting at %d] " + + "which is before the last-returned record [starting at %d]", + recordStart, + lastRecordStart)); + } + if (isAtSplitPoint) { + if (offsetOfLastSplitPoint != -1L && recordStart == offsetOfLastSplitPoint) { + throw new IllegalStateException( + String.format( + "Record at a split point has same offset as the previous split point: " + + "previous split point at %d, current record starts at %d", + offsetOfLastSplitPoint, recordStart)); + } + if (recordStart >= stopOffset) { + return false; + } + offsetOfLastSplitPoint = recordStart; + } + + lastRecordStart = recordStart; + return true; + } + + @Override + public boolean trySplitAtPosition(Long splitOffset) { + return trySplitAtPosition(splitOffset.longValue()); + } + + public synchronized boolean trySplitAtPosition(long splitOffset) { + if (stopOffset == OFFSET_INFINITY) { + LOG.debug("Refusing to split {} at {}: stop position unspecified", this, splitOffset); + return false; + } + if (lastRecordStart == -1) { + LOG.debug("Refusing to split {} at {}: unstarted", this, splitOffset); + return false; + } + + // Note: technically it is correct to split at any position after the last returned + // split point, not just the last returned record. + // TODO: Investigate whether in practice this is useful or, rather, confusing. + if (splitOffset <= lastRecordStart) { + LOG.debug( + "Refusing to split {} at {}: already past proposed split position", this, splitOffset); + return false; + } + if (splitOffset < startOffset || splitOffset >= stopOffset) { + LOG.debug( + "Refusing to split {} at {}: proposed split position out of range", this, splitOffset); + return false; + } + LOG.debug("Agreeing to split {} at {}", this, splitOffset); + this.stopOffset = splitOffset; + return true; + } + + /** + * Returns a position {@code P} such that the range {@code [start, P)} represents approximately + * the given fraction of the range {@code [start, end)}. Assumes that the density of records + * in the range is approximately uniform. + */ + public synchronized long getPositionForFractionConsumed(double fraction) { + if (stopOffset == OFFSET_INFINITY) { + throw new IllegalArgumentException( + "getPositionForFractionConsumed is not applicable to an unbounded range: " + this); + } + return (long) Math.ceil(startOffset + fraction * (stopOffset - startOffset)); + } + + @Override + public synchronized double getFractionConsumed() { + if (stopOffset == OFFSET_INFINITY) { + return 0.0; + } + if (lastRecordStart == -1) { + return 0.0; + } + // E.g., when reading [3, 6) and lastRecordStart is 4, that means we consumed 3,4 of 3,4,5 + // which is (4 - 3 + 1) / (6 - 3) = 67%. + // Also, clamp to at most 1.0 because the last consumed position can extend past the + // stop position. + return Math.min(1.0, 1.0 * (lastRecordStart - startOffset + 1) / (stopOffset - startOffset)); + } + + @Override + public synchronized String toString() { + String stopString = (stopOffset == OFFSET_INFINITY) ? "infinity" : String.valueOf(stopOffset); + if (lastRecordStart >= 0) { + return String.format( + "", + lastRecordStart, + startOffset, + stopString); + } else { + return String.format("", startOffset, stopString); + } + } + + /** + * Returns a copy of this tracker for testing purposes (to simplify testing methods with + * side effects). + */ + @VisibleForTesting + OffsetRangeTracker copy() { + OffsetRangeTracker res = new OffsetRangeTracker(startOffset, stopOffset); + res.lastRecordStart = this.lastRecordStart; + return res; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/RangeTracker.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/RangeTracker.java new file mode 100644 index 000000000000..84359f1aa1dc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/RangeTracker.java @@ -0,0 +1,220 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.io.range; + +/** + * A {@code RangeTracker} is a thread-safe helper object for implementing dynamic work rebalancing + * in position-based {@link com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader} + * subclasses. + * + *

    Usage of the RangeTracker class hierarchy

    + * The abstract {@code RangeTracker} interface should not be used per se - all users should use its + * subclasses directly. We declare it here because all subclasses have roughly the same interface + * and the same properties, to centralize the documentation. Currently we provide one + * implementation - {@link OffsetRangeTracker}. + * + *

    Position-based sources

    + * A position-based source is one where the source can be described by a range of positions of + * an ordered type and the records returned by the reader can be described by positions of the + * same type. + * + *

    In case a record occupies a range of positions in the source, the most important thing about + * the record is the position where it starts. + * + *

    Defining the semantics of positions for a source is entirely up to the source class, however + * the chosen definitions have to obey certain properties in order to make it possible to correctly + * split the source into parts, including dynamic splitting. Two main aspects need to be defined: + *

      + *
    • How to assign starting positions to records. + *
    • Which records should be read by a source with a range {@code [A, B)}. + *
    + * Moreover, reading a range must be efficient, i.e., the performance of reading a range + * should not significantly depend on the location of the range. For example, reading the range + * {@code [A, B)} should not require reading all data before {@code A}. + * + *

    The sections below explain exactly what properties these definitions must satisfy, and + * how to use a {@code RangeTracker} with a properly defined source. + * + *

    Properties of position-based sources

    + * The main requirement for position-based sources is associativity: reading records from + * {@code [A, B)} and records from {@code [B, C)} should give the same records as reading from + * {@code [A, C)}, where {@code A <= B <= C}. This property ensures that no matter how a range + * of positions is split into arbitrarily many sub-ranges, the total set of records described by + * them stays the same. + * + *

    The other important property is how the source's range relates to positions of records in + * the source. In many sources each record can be identified by a unique starting position. + * In this case: + *

      + *
    • All records returned by a source {@code [A, B)} must have starting positions + * in this range. + *
    • All but the last record should end within this range. The last record may or may not + * extend past the end of the range. + *
    • Records should not overlap. + *
    + * Such sources should define "read {@code [A, B)}" as "read from the first record starting at or + * after A, up to but not including the first record starting at or after B". + * + *

    Some examples of such sources include reading lines or CSV from a text file, reading keys and + * values from a BigTable, etc. + * + *

    The concept of split points allows to extend the definitions for dealing with sources + * where some records cannot be identified by a unique starting position. + * + *

    In all cases, all records returned by a source {@code [A, B)} must start at or after + * {@code A}. + * + *

    Split points

    + * + *

    Some sources may have records that are not directly addressable. For example, imagine a file + * format consisting of a sequence of compressed blocks. Each block can be assigned an offset, but + * records within the block cannot be directly addressed without decompressing the block. Let us + * refer to this hypothetical format as CBF (Compressed Blocks Format). + * + *

    Many such formats can still satisfy the associativity property. For example, in CBF, reading + * {@code [A, B)} can mean "read all the records in all blocks whose starting offset is in + * {@code [A, B)}". + * + *

    To support such complex formats, we introduce the notion of split points. We say that + * a record is a split point if there exists a position {@code A} such that the record is the first + * one to be returned when reading the range {@code [A, infinity)}. In CBF, the only split points + * would be the first records in each block. + * + *

    Split points allow us to define the meaning of a record's position and a source's range + * in all cases: + *

      + *
    • For a record that is at a split point, its position is defined to be the largest + * {@code A} such that reading a source with the range {@code [A, infinity)} returns this record; + *
    • Positions of other records are only required to be non-decreasing; + *
    • Reading the source {@code [A, B)} must return records starting from the first split point + * at or after {@code A}, up to but not including the first split point at or after {@code B}. + * In particular, this means that the first record returned by a source MUST always be + * a split point. + *
    • Positions of split points must be unique. + *
    + * As a result, for any decomposition of the full range of the source into position ranges, the + * total set of records will be the full set of records in the source, and each record + * will be read exactly once. + * + *

    Consumed positions

    + * As the source is being read, and records read from it are being passed to the downstream + * transforms in the pipeline, we say that positions in the source are being consumed. + * When a reader has read a record (or promised to a caller that a record will be returned), + * positions up to and including the record's start position are considered consumed. + * + *

    Dynamic splitting can happen only at unconsumed positions. If the reader just + * returned a record at offset 42 in a file, dynamic splitting can happen only at offset 43 or + * beyond, as otherwise that record could be read twice (by the current reader and by a reader + * of the task starting at 43). + * + *

    Example

    + * The following example uses an {@link OffsetRangeTracker} to support dynamically splitting + * a source with integer positions (offsets). + *
     {@code
    + *   class MyReader implements BoundedReader {
    + *     private MySource currentSource;
    + *     private final OffsetRangeTracker tracker = new OffsetRangeTracker();
    + *     ...
    + *     MyReader(MySource source) {
    + *       this.currentSource = source;
    + *       this.tracker = new MyRangeTracker<>(source.getStartOffset(), source.getEndOffset())
    + *     }
    + *     ...
    + *     boolean start() {
    + *       ... (general logic for locating the first record) ...
    + *       if (!tracker.tryReturnRecordAt(true, recordStartOffset)) return false;
    + *       ... (any logic that depends on the record being returned, e.g. counting returned records)
    + *       return true;
    + *     }
    + *     boolean advance() {
    + *       ... (general logic for locating the next record) ...
    + *       if (!tracker.tryReturnRecordAt(isAtSplitPoint, recordStartOffset)) return false;
    + *       ... (any logic that depends on the record being returned, e.g. counting returned records)
    + *       return true;
    + *     }
    + *
    + *     double getFractionConsumed() {
    + *       return tracker.getFractionConsumed();
    + *     }
    + *   }
    + * } 
    + * + *

    Usage with different models of iteration

    + * When using this class to protect a + * {@link com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader}, follow the pattern + * described above. + * + *

    When using this class to protect iteration in the {@code hasNext()/next()} + * model, consider the record consumed when {@code hasNext()} is about to return true, rather than + * when {@code next()} is called, because {@code hasNext()} returning true is promising the caller + * that {@code next()} will have an element to return - so {@link #trySplitAtPosition} must not + * split the range in a way that would make the record promised by {@code hasNext()} belong to + * a different range. + * + *

    Also note that implementations of {@code hasNext()} need to ensure + * that they call {@link #tryReturnRecordAt} only once even if {@code hasNext()} is called + * repeatedly, due to the requirement on uniqueness of split point positions. + * + * @param Type of positions used by the source to define ranges and identify records. + */ +public interface RangeTracker { + /** + * Returns the starting position of the current range, inclusive. + */ + PositionT getStartPosition(); + + /** + * Returns the ending position of the current range, exclusive. + */ + PositionT getStopPosition(); + + /** + * Atomically determines whether a record at the given position can be returned and updates + * internal state. In particular: + *

      + *
    • If {@code isAtSplitPoint} is {@code true}, and {@code recordStart} is outside the current + * range, returns {@code false}; + *
    • Otherwise, updates the last-consumed position to {@code recordStart} and returns + * {@code true}. + *
    + *

    This method MUST be called on all split point records. It may be called on every record. + */ + boolean tryReturnRecordAt(boolean isAtSplitPoint, PositionT recordStart); + + /** + * Atomically splits the current range [{@link #getStartPosition}, {@link #getStopPosition}) + * into a "primary" part [{@link #getStartPosition}, {@code splitPosition}) + * and a "residual" part [{@code splitPosition}, {@link #getStopPosition}), assuming the current + * last-consumed position is within [{@link #getStartPosition}, splitPosition) + * (i.e., {@code splitPosition} has not been consumed yet). + * + *

    Updates the current range to be the primary and returns {@code true}. This means that + * all further calls on the current object will interpret their arguments relative to the + * primary range. + * + *

    If the split position has already been consumed, or if no {@link #tryReturnRecordAt} call + * was made yet, returns {@code false}. The second condition is to prevent dynamic splitting + * during reader start-up. + */ + boolean trySplitAtPosition(PositionT splitPosition); + + /** + * Returns the approximate fraction of positions in the source that have been consumed by + * successful {@link #tryReturnRecordAt} calls, or 0.0 if no such calls have happened. + */ + double getFractionConsumed(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/package-info.java new file mode 100644 index 000000000000..beb77bf0add1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/range/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Provides thread-safe helpers for implementing dynamic work rebalancing in position-based + * bounded sources. + * + *

    See {@link com.google.cloud.dataflow.sdk.io.range.RangeTracker} to get started. + */ +package com.google.cloud.dataflow.sdk.io.range; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ApplicationNameOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ApplicationNameOptions.java new file mode 100644 index 000000000000..60d62d375489 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ApplicationNameOptions.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Options that allow setting the application name. + */ +public interface ApplicationNameOptions extends PipelineOptions { + /** + * Name of application, for display purposes. + * + *

    Defaults to the name of the class that constructs the {@link PipelineOptions} + * via the {@link PipelineOptionsFactory}. + */ + @Description("Name of application for display purposes. Defaults to the name of the class that " + + "constructs the PipelineOptions via the PipelineOptionsFactory.") + String getAppName(); + void setAppName(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BigQueryOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BigQueryOptions.java new file mode 100644 index 000000000000..ed4eb24bacb1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BigQueryOptions.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Properties needed when using BigQuery with the Dataflow SDK. + */ +@Description("Options that are used to configure BigQuery. See " + + "https://cloud.google.com/bigquery/what-is-bigquery for details on BigQuery.") +public interface BigQueryOptions extends ApplicationNameOptions, GcpOptions, + PipelineOptions, StreamingOptions { + @Description("Temporary dataset for BigQuery table operations. " + + "Supported values are \"bigquery.googleapis.com/{dataset}\"") + @Default.String("bigquery.googleapis.com/cloud_dataflow") + String getTempDatasetId(); + void setTempDatasetId(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BlockingDataflowPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BlockingDataflowPipelineOptions.java new file mode 100644 index 000000000000..43a46b029ce3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/BlockingDataflowPipelineOptions.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.io.PrintStream; + +/** + * Options that are used to configure the {@link BlockingDataflowPipelineRunner}. + */ +@Description("Configure options on the BlockingDataflowPipelineRunner.") +public interface BlockingDataflowPipelineOptions extends DataflowPipelineOptions { + /** + * Output stream for job status messages. + */ + @Description("Where messages generated during execution of the Dataflow job will be output.") + @JsonIgnore + @Hidden + @Default.InstanceFactory(StandardOutputFactory.class) + PrintStream getJobMessageOutput(); + void setJobMessageOutput(PrintStream value); + + /** + * Returns a default of {@link System#out}. + */ + public static class StandardOutputFactory implements DefaultValueFactory { + @Override + public PrintStream create(PipelineOptions options) { + return System.out; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/CloudDebuggerOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/CloudDebuggerOptions.java new file mode 100644 index 000000000000..62be4c9ec2e2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/CloudDebuggerOptions.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; + +/** + * Options for controlling Cloud Debugger. + */ +@Description("[Experimental] Used to configure the Cloud Debugger") +@Experimental +@Hidden +public interface CloudDebuggerOptions { + + /** + * Whether to enable the Cloud Debugger snapshot agent for the current job. + */ + @Description("Whether to enable the Cloud Debugger snapshot agent for the current job.") + boolean getEnableCloudDebugger(); + void setEnableCloudDebugger(boolean enabled); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptions.java new file mode 100644 index 000000000000..e94b56df8714 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptions.java @@ -0,0 +1,242 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.api.services.dataflow.Dataflow; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.util.DataflowPathValidator; +import com.google.cloud.dataflow.sdk.util.GcsStager; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.PathValidator; +import com.google.cloud.dataflow.sdk.util.Stager; +import com.google.cloud.dataflow.sdk.util.Transport; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.List; +import java.util.Map; + +/** + * Internal. Options used to control execution of the Dataflow SDK for + * debugging and testing purposes. + */ +@Description("[Internal] Options used to control execution of the Dataflow SDK for " + + "debugging and testing purposes.") +@Hidden +public interface DataflowPipelineDebugOptions extends PipelineOptions { + + /** + * The list of backend experiments to enable. + * + *

    Dataflow provides a number of experimental features that can be enabled + * with this flag. + * + *

    Please sync with the Dataflow team before enabling any experiments. + */ + @Description("[Experimental] Dataflow provides a number of experimental features that can " + + "be enabled with this flag. Please sync with the Dataflow team before enabling any " + + "experiments.") + @Experimental + List getExperiments(); + void setExperiments(List value); + + /** + * The root URL for the Dataflow API. {@code dataflowEndpoint} can override this value + * if it contains an absolute URL, otherwise {@code apiRootUrl} will be combined with + * {@code dataflowEndpoint} to generate the full URL to communicate with the Dataflow API. + */ + @Description("The root URL for the Dataflow API. dataflowEndpoint can override this " + + "value if it contains an absolute URL, otherwise apiRootUrl will be combined with " + + "dataflowEndpoint to generate the full URL to communicate with the Dataflow API.") + @Default.String(Dataflow.DEFAULT_ROOT_URL) + String getApiRootUrl(); + void setApiRootUrl(String value); + + /** + * Dataflow endpoint to use. + * + *

    Defaults to the current version of the Google Cloud Dataflow + * API, at the time the current SDK version was released. + * + *

    If the string contains "://", then this is treated as a URL, + * otherwise {@link #getApiRootUrl()} is used as the root + * URL. + */ + @Description("The URL for the Dataflow API. If the string contains \"://\", this" + + " will be treated as the entire URL, otherwise will be treated relative to apiRootUrl.") + @Default.String(Dataflow.DEFAULT_SERVICE_PATH) + String getDataflowEndpoint(); + void setDataflowEndpoint(String value); + + /** + * The path to write the translated Dataflow job specification out to + * at job submission time. The Dataflow job specification will be represented in JSON + * format. + */ + @Description("The path to write the translated Dataflow job specification out to " + + "at job submission time. The Dataflow job specification will be represented in JSON " + + "format.") + String getDataflowJobFile(); + void setDataflowJobFile(String value); + + /** + * The class of the validator that should be created and used to validate paths. + * If pathValidator has not been set explicitly, an instance of this class will be + * constructed and used as the path validator. + */ + @Description("The class of the validator that should be created and used to validate paths. " + + "If pathValidator has not been set explicitly, an instance of this class will be " + + "constructed and used as the path validator.") + @Default.Class(DataflowPathValidator.class) + Class getPathValidatorClass(); + void setPathValidatorClass(Class validatorClass); + + /** + * The path validator instance that should be used to validate paths. + * If no path validator has been set explicitly, the default is to use the instance factory that + * constructs a path validator based upon the currently set pathValidatorClass. + */ + @JsonIgnore + @Description("The path validator instance that should be used to validate paths. " + + "If no path validator has been set explicitly, the default is to use the instance factory " + + "that constructs a path validator based upon the currently set pathValidatorClass.") + @Default.InstanceFactory(PathValidatorFactory.class) + PathValidator getPathValidator(); + void setPathValidator(PathValidator validator); + + /** + * The class responsible for staging resources to be accessible by workers + * during job execution. If stager has not been set explicitly, an instance of this class + * will be created and used as the resource stager. + */ + @Description("The class of the stager that should be created and used to stage resources. " + + "If stager has not been set explicitly, an instance of the this class will be created " + + "and used as the resource stager.") + @Default.Class(GcsStager.class) + Class getStagerClass(); + void setStagerClass(Class stagerClass); + + /** + * The resource stager instance that should be used to stage resources. + * If no stager has been set explicitly, the default is to use the instance factory + * that constructs a resource stager based upon the currently set stagerClass. + */ + @JsonIgnore + @Description("The resource stager instance that should be used to stage resources. " + + "If no stager has been set explicitly, the default is to use the instance factory " + + "that constructs a resource stager based upon the currently set stagerClass.") + @Default.InstanceFactory(StagerFactory.class) + Stager getStager(); + void setStager(Stager stager); + + /** + * An instance of the Dataflow client. Defaults to creating a Dataflow client + * using the current set of options. + */ + @JsonIgnore + @Description("An instance of the Dataflow client. Defaults to creating a Dataflow client " + + "using the current set of options.") + @Default.InstanceFactory(DataflowClientFactory.class) + Dataflow getDataflowClient(); + void setDataflowClient(Dataflow value); + + /** Returns the default Dataflow client built from the passed in PipelineOptions. */ + public static class DataflowClientFactory implements DefaultValueFactory { + @Override + public Dataflow create(PipelineOptions options) { + return Transport.newDataflowClient(options.as(DataflowPipelineOptions.class)).build(); + } + } + + /** + * Root URL for use with the Pubsub API. + */ + @Description("Root URL for use with the Pubsub API") + @Default.String("https://pubsub.googleapis.com") + String getPubsubRootUrl(); + void setPubsubRootUrl(String value); + + /** + * Whether to update the currently running pipeline with the same name as this one. + */ + @JsonIgnore + @Description("If set, replace the existing pipeline with the name specified by --jobName with " + + "this pipeline, preserving state.") + boolean getUpdate(); + void setUpdate(boolean value); + + /** + * Mapping of old PTranform names to new ones, specified as JSON + * {"oldName":"newName",...}. To mark a transform as deleted, make newName the + * empty string. + */ + @JsonIgnore + @Description( + "Mapping of old PTranform names to new ones, specified as JSON " + + "{\"oldName\":\"newName\",...}. To mark a transform as deleted, make newName the empty " + + "string.") + Map getTransformNameMapping(); + void setTransformNameMapping(Map value); + + /** + * Custom windmill_main binary to use with the streaming runner. + */ + @Description("Custom windmill_main binary to use with the streaming runner") + String getOverrideWindmillBinary(); + void setOverrideWindmillBinary(String value); + + /** + * Number of threads to use on the Dataflow worker harness. If left unspecified, + * the Dataflow service will compute an appropriate number of threads to use. + */ + @Description("Number of threads to use on the Dataflow worker harness. If left unspecified, " + + "the Dataflow service will compute an appropriate number of threads to use.") + int getNumberOfWorkerHarnessThreads(); + void setNumberOfWorkerHarnessThreads(int value); + + /** + * Creates a {@link PathValidator} object using the class specified in + * {@link #getPathValidatorClass()}. + */ + public static class PathValidatorFactory implements DefaultValueFactory { + @Override + public PathValidator create(PipelineOptions options) { + DataflowPipelineDebugOptions debugOptions = options.as(DataflowPipelineDebugOptions.class); + return InstanceBuilder.ofType(PathValidator.class) + .fromClass(debugOptions.getPathValidatorClass()) + .fromFactoryMethod("fromOptions") + .withArg(PipelineOptions.class, options) + .build(); + } + } + + /** + * Creates a {@link Stager} object using the class specified in + * {@link #getStagerClass()}. + */ + public static class StagerFactory implements DefaultValueFactory { + @Override + public Stager create(PipelineOptions options) { + DataflowPipelineDebugOptions debugOptions = options.as(DataflowPipelineDebugOptions.class); + return InstanceBuilder.ofType(Stager.class) + .fromClass(debugOptions.getStagerClass()) + .fromFactoryMethod("fromOptions") + .withArg(PipelineOptions.class, options) + .build(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptions.java new file mode 100644 index 000000000000..a0f188af0785 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptions.java @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.runners.DataflowPipeline; +import com.google.common.base.MoreObjects; + +import org.joda.time.DateTimeUtils; +import org.joda.time.DateTimeZone; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +/** + * Options that can be used to configure the {@link DataflowPipeline}. + */ +@Description("Options that configure the Dataflow pipeline.") +public interface DataflowPipelineOptions extends + PipelineOptions, GcpOptions, ApplicationNameOptions, DataflowPipelineDebugOptions, + DataflowPipelineWorkerPoolOptions, BigQueryOptions, + GcsOptions, StreamingOptions, CloudDebuggerOptions, DataflowWorkerLoggingOptions, + DataflowProfilingOptions { + + static final String DATAFLOW_STORAGE_LOCATION = "Dataflow Storage Location"; + + @Description("Project id. Required when running a Dataflow in the cloud. " + + "See https://cloud.google.com/storage/docs/projects for further details.") + @Override + @Validation.Required + @Default.InstanceFactory(DefaultProjectFactory.class) + String getProject(); + @Override + void setProject(String value); + + /** + * GCS path for temporary files, e.g. gs://bucket/object + * + *

    Must be a valid Cloud Storage URL, beginning with the prefix "gs://" + * + *

    At least one of {@link #getTempLocation()} or {@link #getStagingLocation()} must be set. If + * {@link #getTempLocation()} is not set, then the Dataflow pipeline defaults to using + * {@link #getStagingLocation()}. + */ + @Description("GCS path for temporary files, eg \"gs://bucket/object\". " + + "Must be a valid Cloud Storage URL, beginning with the prefix \"gs://\". " + + "At least one of tempLocation or stagingLocation must be set. If tempLocation is unset, " + + "defaults to using stagingLocation.") + @Validation.Required(groups = {DATAFLOW_STORAGE_LOCATION}) + String getTempLocation(); + void setTempLocation(String value); + + /** + * GCS path for staging local files, e.g. gs://bucket/object + * + *

    Must be a valid Cloud Storage URL, beginning with the prefix "gs://" + * + *

    At least one of {@link #getTempLocation()} or {@link #getStagingLocation()} must be set. If + * {@link #getTempLocation()} is not set, then the Dataflow pipeline defaults to using + * {@link #getStagingLocation()}. + */ + @Description("GCS path for staging local files, e.g. \"gs://bucket/object\". " + + "Must be a valid Cloud Storage URL, beginning with the prefix \"gs://\". " + + "At least one of stagingLocation or tempLocation must be set. If stagingLocation is unset, " + + "defaults to using tempLocation.") + @Validation.Required(groups = {DATAFLOW_STORAGE_LOCATION}) + String getStagingLocation(); + void setStagingLocation(String value); + + /** + * The Dataflow job name is used as an idempotence key within the Dataflow service. + * If there is an existing job that is currently active, another active job with the same + * name will not be able to be created. Defaults to using the ApplicationName-UserName-Date. + */ + @Description("The Dataflow job name is used as an idempotence key within the Dataflow service. " + + "If there is an existing job that is currently active, another active job with the same " + + "name will not be able to be created. Defaults to using the ApplicationName-UserName-Date.") + @Default.InstanceFactory(JobNameFactory.class) + String getJobName(); + void setJobName(String value); + + /** + * Returns a normalized job name constructed from {@link ApplicationNameOptions#getAppName()}, the + * local system user name (if available), and the current time. The normalization makes sure that + * the job name matches the required pattern of [a-z]([-a-z0-9]*[a-z0-9])? and length limit of 40 + * characters. + * + *

    This job name factory is only able to generate one unique name per second per application + * and user combination. + */ + public static class JobNameFactory implements DefaultValueFactory { + private static final DateTimeFormatter FORMATTER = + DateTimeFormat.forPattern("MMddHHmmss").withZone(DateTimeZone.UTC); + + @Override + public String create(PipelineOptions options) { + String appName = options.as(ApplicationNameOptions.class).getAppName(); + String normalizedAppName = appName == null || appName.length() == 0 ? "dataflow" + : appName.toLowerCase() + .replaceAll("[^a-z0-9]", "0") + .replaceAll("^[^a-z]", "a"); + String userName = MoreObjects.firstNonNull(System.getProperty("user.name"), ""); + String normalizedUserName = userName.toLowerCase() + .replaceAll("[^a-z0-9]", "0"); + String datePart = FORMATTER.print(DateTimeUtils.currentTimeMillis()); + return normalizedAppName + "-" + normalizedUserName + "-" + datePart; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineWorkerPoolOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineWorkerPoolOptions.java new file mode 100644 index 000000000000..25d15890c7c3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineWorkerPoolOptions.java @@ -0,0 +1,242 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.List; + +/** + * Options that are used to configure the Dataflow pipeline worker pool. + */ +@Description("Options that are used to configure the Dataflow pipeline worker pool.") +public interface DataflowPipelineWorkerPoolOptions extends PipelineOptions { + /** + * Number of workers to use when executing the Dataflow job. Note that selection of an autoscaling + * algorithm other then {@code NONE} will affect the size of the worker pool. If left unspecified, + * the Dataflow service will determine the number of workers. + */ + @Description("Number of workers to use when executing the Dataflow job. Note that " + + "selection of an autoscaling algorithm other then \"NONE\" will affect the " + + "size of the worker pool. If left unspecified, the Dataflow service will " + + "determine the number of workers.") + int getNumWorkers(); + void setNumWorkers(int value); + + /** + * Type of autoscaling algorithm to use. + */ + @Experimental(Experimental.Kind.AUTOSCALING) + public enum AutoscalingAlgorithmType { + /** Use numWorkers machines. Do not autoscale the worker pool. */ + NONE("AUTOSCALING_ALGORITHM_NONE"), + + @Deprecated + BASIC("AUTOSCALING_ALGORITHM_BASIC"), + + /** Autoscale the workerpool based on throughput (up to maxNumWorkers). */ + THROUGHPUT_BASED("AUTOSCALING_ALGORITHM_BASIC"); + + private final String algorithm; + + private AutoscalingAlgorithmType(String algorithm) { + this.algorithm = algorithm; + } + + /** Returns the string representation of this type. */ + public String getAlgorithm() { + return this.algorithm; + } + } + + /** + * [Experimental] The autoscaling algorithm to use for the workerpool. + * + *

      + *
    • NONE: does not change the size of the worker pool.
    • + *
    • BASIC: autoscale the worker pool size up to maxNumWorkers until the job completes.
    • + *
    • THROUGHPUT_BASED: autoscale the workerpool based on throughput (up to maxNumWorkers). + *
    • + *
    + */ + @Description("[Experimental] The autoscaling algorithm to use for the workerpool. " + + "NONE: does not change the size of the worker pool. " + + "BASIC (deprecated): autoscale the worker pool size up to maxNumWorkers until the job " + + "completes. " + + "THROUGHPUT_BASED: autoscale the workerpool based on throughput (up to maxNumWorkers).") + @Experimental(Experimental.Kind.AUTOSCALING) + AutoscalingAlgorithmType getAutoscalingAlgorithm(); + void setAutoscalingAlgorithm(AutoscalingAlgorithmType value); + + /** + * The maximum number of workers to use for the workerpool. This options limits the size of the + * workerpool for the lifetime of the job, including + * pipeline updates. + * If left unspecified, the Dataflow service will compute a ceiling. + */ + @Description("The maximum number of workers to use for the workerpool. This options limits the " + + "size of the workerpool for the lifetime of the job, including pipeline updates. " + + "If left unspecified, the Dataflow service will compute a ceiling.") + int getMaxNumWorkers(); + void setMaxNumWorkers(int value); + + /** + * Remote worker disk size, in gigabytes, or 0 to use the default size. + */ + @Description("Remote worker disk size, in gigabytes, or 0 to use the default size.") + int getDiskSizeGb(); + void setDiskSizeGb(int value); + + /** + * Docker container image that executes Dataflow worker harness, residing in Google Container + * Registry. + */ + @Default.InstanceFactory(WorkerHarnessContainerImageFactory.class) + @Description("Docker container image that executes Dataflow worker harness, residing in Google " + + " Container Registry.") + @Hidden + String getWorkerHarnessContainerImage(); + void setWorkerHarnessContainerImage(String value); + + /** + * Returns the default Docker container image that executes Dataflow worker harness, residing in + * Google Container Registry. + */ + public static class WorkerHarnessContainerImageFactory + implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + if (dataflowOptions.isStreaming()) { + return DataflowPipelineRunner.STREAMING_WORKER_HARNESS_CONTAINER_IMAGE; + } else { + return DataflowPipelineRunner.BATCH_WORKER_HARNESS_CONTAINER_IMAGE; + } + } + } + + /** + * GCE network for launching + * workers. + * + *

    Default is up to the Dataflow service. + */ + @Description("GCE network for launching workers. For more information, see the reference " + + "documentation https://cloud.google.com/compute/docs/networking. " + + "Default is up to the Dataflow service.") + String getNetwork(); + void setNetwork(String value); + + /** + * GCE availability zone for launching workers. + * + *

    Default is up to the Dataflow service. + */ + @Description("GCE availability zone for launching workers. See " + + "https://developers.google.com/compute/docs/zones for a list of valid options. " + + "Default is up to the Dataflow service.") + String getZone(); + void setZone(String value); + + /** + * Machine type to create Dataflow worker VMs as. + * + *

    See GCE machine types + * for a list of valid options. + * + *

    If unset, the Dataflow service will choose a reasonable default. + */ + @Description("Machine type to create Dataflow worker VMs as. See " + + "https://cloud.google.com/compute/docs/machine-types for a list of valid options. " + + "If unset, the Dataflow service will choose a reasonable default.") + String getWorkerMachineType(); + void setWorkerMachineType(String value); + + /** + * The policy for tearing down the workers spun up by the service. + */ + public enum TeardownPolicy { + /** + * All VMs created for a Dataflow job are deleted when the job finishes, regardless of whether + * it fails or succeeds. + */ + TEARDOWN_ALWAYS("TEARDOWN_ALWAYS"), + /** + * All VMs created for a Dataflow job are left running when the job finishes, regardless of + * whether it fails or succeeds. + */ + TEARDOWN_NEVER("TEARDOWN_NEVER"), + /** + * All VMs created for a Dataflow job are deleted when the job succeeds, but are left running + * when it fails. (This is typically used for debugging failing jobs by SSHing into the + * workers.) + */ + TEARDOWN_ON_SUCCESS("TEARDOWN_ON_SUCCESS"); + + private final String teardownPolicy; + + private TeardownPolicy(String teardownPolicy) { + this.teardownPolicy = teardownPolicy; + } + + public String getTeardownPolicyName() { + return this.teardownPolicy; + } + } + + /** + * The teardown policy for the VMs. + * + *

    If unset, the Dataflow service will choose a reasonable default. + */ + @Description("The teardown policy for the VMs. If unset, the Dataflow service will " + + "choose a reasonable default.") + TeardownPolicy getTeardownPolicy(); + void setTeardownPolicy(TeardownPolicy value); + + /** + * List of local files to make available to workers. + * + *

    Files are placed on the worker's classpath. + * + *

    The default value is the list of jars from the main program's classpath. + */ + @Description("Files to stage on GCS and make available to workers. " + + "Files are placed on the worker's classpath. " + + "The default value is all files from the classpath.") + @JsonIgnore + List getFilesToStage(); + void setFilesToStage(List value); + + /** + * Specifies what type of persistent disk should be used. The value should be a full or partial + * URL of a disk type resource, e.g., zones/us-central1-f/disks/pd-standard. For + * more information, see the + * API reference + * documentation for DiskTypes. + */ + @Description("Specifies what type of persistent disk should be used. The value should be a full " + + "or partial URL of a disk type resource, e.g., zones/us-central1-f/disks/pd-standard. For " + + "more information, see the API reference documentation for DiskTypes: " + + "https://cloud.google.com/compute/docs/reference/latest/diskTypes") + String getWorkerDiskType(); + void setWorkerDiskType(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowProfilingOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowProfilingOptions.java new file mode 100644 index 000000000000..8ad2ba2e5e0d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowProfilingOptions.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; + +import java.util.HashMap; + +/** + * Options for controlling profiling of pipeline execution. + */ +@Description("[Experimental] Used to configure profiling of the Dataflow pipeline") +@Experimental +@Hidden +public interface DataflowProfilingOptions { + + @Description("Whether to periodically dump profiling information to local disk.\n" + + "WARNING: Enabling this option may fill local disk with profiling information.") + boolean getEnableProfilingAgent(); + void setEnableProfilingAgent(boolean enabled); + + @Description( + "[INTERNAL] Additional configuration for the profiling agent. Not typically necessary.") + @Hidden + DataflowProfilingAgentConfiguration getProfilingAgentConfiguration(); + void setProfilingAgentConfiguration(DataflowProfilingAgentConfiguration configuration); + + /** + * Configuration the for profiling agent. + */ + public static class DataflowProfilingAgentConfiguration extends HashMap { + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerHarnessOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerHarnessOptions.java new file mode 100644 index 000000000000..e4b1d725701e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerHarnessOptions.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Options that are used exclusively within the Dataflow worker harness. + * These options have no effect at pipeline creation time. + */ +@Description("[Internal] Options that are used exclusively within the Dataflow worker harness. " + + "These options have no effect at pipeline creation time.") +@Hidden +public interface DataflowWorkerHarnessOptions extends DataflowPipelineOptions { + /** + * The identity of the worker running this pipeline. + */ + @Description("The identity of the worker running this pipeline.") + String getWorkerId(); + void setWorkerId(String value); + + /** + * The identity of the Dataflow job. + */ + @Description("The identity of the Dataflow job.") + String getJobId(); + void setJobId(String value); + + /** + * The size of the worker's in-memory cache, in megabytes. + * + *

    Currently, this cache is used for storing read values of side inputs. + */ + @Description("The size of the worker's in-memory cache, in megabytes.") + @Default.Integer(100) + Integer getWorkerCacheMb(); + void setWorkerCacheMb(Integer value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerLoggingOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerLoggingOptions.java new file mode 100644 index 000000000000..232887378489 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerLoggingOptions.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Options that are used to control logging configuration on the Dataflow worker. + */ +@Description("Options that are used to control logging configuration on the Dataflow worker.") +public interface DataflowWorkerLoggingOptions extends PipelineOptions { + /** + * The set of log levels that can be used on the Dataflow worker. + */ + public enum Level { + DEBUG, ERROR, INFO, TRACE, WARN + } + + /** + * This option controls the default log level of all loggers without a log level override. + */ + @Description("Controls the default log level of all loggers without a log level override.") + @Default.Enum("INFO") + Level getDefaultWorkerLogLevel(); + void setDefaultWorkerLogLevel(Level level); + + /** + * This option controls the log levels for specifically named loggers. + * + *

    Later options with equivalent names override earlier options. + * + *

    See {@link WorkerLogLevelOverrides} for more information on how to configure logging + * on a per {@link Class}, {@link Package}, or name basis. If used from the command line, + * the expected format is {"Name":"Level",...}, further details on + * {@link WorkerLogLevelOverrides#from}. + */ + @Description("This option controls the log levels for specifically named loggers. " + + "The expected format is {\"Name\":\"Level\",...}. The Dataflow worker uses " + + "java.util.logging, which supports a logging hierarchy based off of names that are '.' " + + "separated. For example, by specifying the value {\"a.b.c.Foo\":\"DEBUG\"}, the logger " + + "for the class 'a.b.c.Foo' will be configured to output logs at the DEBUG level. " + + "Similarly, by specifying the value {\"a.b.c\":\"WARN\"}, all loggers underneath the " + + "'a.b.c' package will be configured to output logs at the WARN level. Also, note that " + + "when multiple overrides are specified, the exact name followed by the closest parent " + + "takes precedence.") + WorkerLogLevelOverrides getWorkerLogLevelOverrides(); + void setWorkerLogLevelOverrides(WorkerLogLevelOverrides value); + + /** + * Defines a log level override for a specific class, package, or name. + * + *

    {@code java.util.logging} is used on the Dataflow worker harness and supports + * a logging hierarchy based off of names that are "." separated. It is a common + * pattern to have the logger for a given class share the same name as the class itself. + * Given the classes {@code a.b.c.Foo}, {@code a.b.c.Xyz}, and {@code a.b.Bar}, with + * loggers named {@code "a.b.c.Foo"}, {@code "a.b.c.Xyz"}, and {@code "a.b.Bar"} respectively, + * we can override the log levels: + *

      + *
    • for {@code Foo} by specifying the name {@code "a.b.c.Foo"} or the {@link Class} + * representing {@code a.b.c.Foo}. + *
    • for {@code Foo}, {@code Xyz}, and {@code Bar} by specifying the name {@code "a.b"} or + * the {@link Package} representing {@code a.b}. + *
    • for {@code Foo} and {@code Bar} by specifying both of their names or classes. + *
    + * Note that by specifying multiple overrides, the exact name followed by the closest parent + * takes precedence. + */ + public static class WorkerLogLevelOverrides extends HashMap { + /** + * Overrides the default log level for the passed in class. + * + *

    This is equivalent to calling + * {@link #addOverrideForName(String, DataflowWorkerLoggingOptions.Level)} + * and passing in the {@link Class#getName() class name}. + */ + public WorkerLogLevelOverrides addOverrideForClass(Class klass, Level level) { + Preconditions.checkNotNull(klass, "Expected class to be not null."); + addOverrideForName(klass.getName(), level); + return this; + } + + /** + * Overrides the default log level for the passed in package. + * + *

    This is equivalent to calling + * {@link #addOverrideForName(String, DataflowWorkerLoggingOptions.Level)} + * and passing in the {@link Package#getName() package name}. + */ + public WorkerLogLevelOverrides addOverrideForPackage(Package pkg, Level level) { + Preconditions.checkNotNull(pkg, "Expected package to be not null."); + addOverrideForName(pkg.getName(), level); + return this; + } + + /** + * Overrides the default log level for the passed in name. + * + *

    Note that because of the hierarchical nature of logger names, this will + * override the log level of all loggers that have the passed in name or + * a parent logger that has the passed in name. + */ + public WorkerLogLevelOverrides addOverrideForName(String name, Level level) { + Preconditions.checkNotNull(name, "Expected name to be not null."); + Preconditions.checkNotNull(level, + "Expected level to be one of %s.", Arrays.toString(Level.values())); + put(name, level); + return this; + } + + /** + * Expects a map keyed by logger {@code Name}s with values representing {@code Level}s. + * The {@code Name} generally represents the fully qualified Java + * {@link Class#getName() class name}, or fully qualified Java + * {@link Package#getName() package name}, or custom logger name. The {@code Level} + * represents the log level and must be one of {@link Level}. + */ + @JsonCreator + public static WorkerLogLevelOverrides from(Map values) { + Preconditions.checkNotNull(values, "Expected values to be not null."); + WorkerLogLevelOverrides overrides = new WorkerLogLevelOverrides(); + for (Map.Entry entry : values.entrySet()) { + try { + overrides.addOverrideForName(entry.getKey(), Level.valueOf(entry.getValue())); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(String.format( + "Unsupported log level '%s' requested for %s. Must be one of %s.", + entry.getValue(), entry.getKey(), Arrays.toString(Level.values()))); + } + + } + return overrides; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Default.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Default.java new file mode 100644 index 000000000000..46ff682f5b9f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Default.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * {@link Default} represents a set of annotations that can be used to annotate getter properties + * on {@link PipelineOptions} with information representing the default value to be returned + * if no value is specified. + */ +public @interface Default { + /** + * This represents that the default of the option is the specified {@link java.lang.Class} value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Class { + java.lang.Class value(); + } + + /** + * This represents that the default of the option is the specified {@link java.lang.String} + * value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface String { + java.lang.String value(); + } + + /** + * This represents that the default of the option is the specified boolean primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Boolean { + boolean value(); + } + + /** + * This represents that the default of the option is the specified char primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Character { + char value(); + } + + /** + * This represents that the default of the option is the specified byte primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Byte { + byte value(); + } + /** + * This represents that the default of the option is the specified short primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Short { + short value(); + } + /** + * This represents that the default of the option is the specified int primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Integer { + int value(); + } + + /** + * This represents that the default of the option is the specified long primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Long { + long value(); + } + + /** + * This represents that the default of the option is the specified float primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Float { + float value(); + } + + /** + * This represents that the default of the option is the specified double primitive value. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Double { + double value(); + } + + /** + * This represents that the default of the option is the specified enum. + * The value should equal the enum's {@link java.lang.Enum#name() name}. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Enum { + java.lang.String value(); + } + + /** + * Value must be of type {@link DefaultValueFactory} and have a default constructor. + * Value is instantiated and then used as a factory to generate the default. + * + *

    See {@link DefaultValueFactory} for more details. + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface InstanceFactory { + java.lang.Class> value(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DefaultValueFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DefaultValueFactory.java new file mode 100644 index 000000000000..1faedb70d694 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DefaultValueFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * An interface used with the {@link Default.InstanceFactory} annotation to specify the class that + * will be an instance factory to produce default values for a given getter on + * {@link PipelineOptions}. When a property on a {@link PipelineOptions} is fetched, and is + * currently unset, the default value factory will be instantiated and invoked. + * + *

    Care must be taken to not produce an infinite loop when accessing other fields on the + * {@link PipelineOptions} object. + * + * @param The type of object this factory produces. + */ +public interface DefaultValueFactory { + /** + * Creates a default value for a getter marked with {@link Default.InstanceFactory}. + * + * @param options The current pipeline options. + * @return The default value to be used for the annotated getter. + */ + T create(PipelineOptions options); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Description.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Description.java new file mode 100644 index 000000000000..9ceaf586f595 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Description.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Descriptions are used to generate human readable output when the {@code --help} + * command is specified. Description annotations placed on interfaces that extend + * {@link PipelineOptions} will describe groups of related options. Description annotations + * placed on getter methods will be used to provide human readable information + * for the specific option. + */ +@Target({ElementType.METHOD, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface Description { + String value(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DirectPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DirectPipelineOptions.java new file mode 100644 index 000000000000..0867740fabd0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/DirectPipelineOptions.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +/** + * Options that can be used to configure the {@link DirectPipeline}. + */ +public interface DirectPipelineOptions extends + ApplicationNameOptions, BigQueryOptions, GcsOptions, GcpOptions, + PipelineOptions, StreamingOptions { + + /** + * The random seed to use for pseudorandom behaviors in the {@link DirectPipelineRunner}. + * If not explicitly specified, a random seed will be generated. + */ + @JsonIgnore + @Description("The random seed to use for pseudorandom behaviors in the DirectPipelineRunner." + + " If not explicitly specified, a random seed will be generated.") + Long getDirectPipelineRunnerRandomSeed(); + void setDirectPipelineRunnerRandomSeed(Long value); + + /** + * Controls whether the runner should ensure that all of the elements of + * the pipeline, such as DoFns, can be serialized. + */ + @JsonIgnore + @Description("Controls whether the runner should ensure that all of the elements of the " + + "pipeline, such as DoFns, can be serialized.") + @Default.Boolean(true) + boolean isTestSerializability(); + void setTestSerializability(boolean testSerializability); + + /** + * Controls whether the runner should ensure that all of the elements of + * every {@link PCollection} can be encoded using the appropriate + * {@link Coder}. + */ + @JsonIgnore + @Description("Controls whether the runner should ensure that all of the elements of every " + + "PCollection can be encoded using the appropriate Coder.") + @Default.Boolean(true) + boolean isTestEncodability(); + void setTestEncodability(boolean testEncodability); + + /** + * Controls whether the runner should randomize the order of each + * {@link PCollection}. + */ + @JsonIgnore + @Description("Controls whether the runner should randomize the order of each PCollection.") + @Default.Boolean(true) + boolean isTestUnorderedness(); + void setTestUnorderedness(boolean testUnorderedness); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcpOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcpOptions.java new file mode 100644 index 000000000000..7b70f4c31a1d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcpOptions.java @@ -0,0 +1,291 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.googleapis.auth.oauth2.GoogleOAuthConstants; +import com.google.cloud.dataflow.sdk.util.CredentialFactory; +import com.google.cloud.dataflow.sdk.util.GcpCredentialFactory; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Files; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.util.Locale; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Options used to configure Google Cloud Platform project and credentials. + * + *

    These options configure which of the following three different mechanisms for obtaining a + * credential are used: + *

      + *
    1. + * It can fetch the + * + * application default credentials. + *
    2. + *
    3. + * The user can specify a client secrets file and go through the OAuth2 + * webflow. The credential will then be cached in the user's home + * directory for reuse. + *
    4. + *
    5. + * The user can specify a file containing a service account private key along + * with the service account name. + *
    6. + *
    + * + *

    The default mechanism is to use the + * + * application default credentials. The other options can be + * used by setting the corresponding properties. + */ +@Description("Options used to configure Google Cloud Platform project and credentials.") +public interface GcpOptions extends GoogleApiDebugOptions, PipelineOptions { + /** + * Project id to use when launching jobs. + */ + @Description("Project id. Required when running a Dataflow in the cloud. " + + "See https://cloud.google.com/storage/docs/projects for further details.") + @Default.InstanceFactory(DefaultProjectFactory.class) + String getProject(); + void setProject(String value); + + /** + * This option controls which file to use when attempting to create the credentials using the + * service account method. + * + *

    This option if specified, needs be combined with the + * {@link GcpOptions#getServiceAccountName() serviceAccountName}. + */ + @JsonIgnore + @Description("Controls which file to use when attempting to create the credentials " + + "using the service account method. This option if specified, needs to be combined with " + + "the serviceAccountName option.") + String getServiceAccountKeyfile(); + void setServiceAccountKeyfile(String value); + + /** + * This option controls which service account to use when attempting to create the credentials + * using the service account method. + * + *

    This option if specified, needs be combined with the + * {@link GcpOptions#getServiceAccountKeyfile() serviceAccountKeyfile}. + */ + @JsonIgnore + @Description("Controls which service account to use when attempting to create the credentials " + + "using the service account method. This option if specified, needs to be combined with " + + "the serviceAccountKeyfile option.") + String getServiceAccountName(); + void setServiceAccountName(String value); + + /** + * This option controls which file to use when attempting to create the credentials + * using the OAuth 2 webflow. After the OAuth2 webflow, the credentials will be stored + * within credentialDir. + */ + @JsonIgnore + @Description("This option controls which file to use when attempting to create the credentials " + + "using the OAuth 2 webflow. After the OAuth2 webflow, the credentials will be stored " + + "within credentialDir.") + String getSecretsFile(); + void setSecretsFile(String value); + + /** + * This option controls which credential store to use when creating the credentials + * using the OAuth 2 webflow. + */ + @Description("This option controls which credential store to use when creating the credentials " + + "using the OAuth 2 webflow.") + @Default.String("cloud_dataflow") + String getCredentialId(); + void setCredentialId(String value); + + /** + * Directory for storing dataflow credentials after execution of the OAuth 2 webflow. Defaults + * to using the $HOME/.store/data-flow directory. + */ + @Description("Directory for storing dataflow credentials after execution of the OAuth 2 webflow. " + + "Defaults to using the $HOME/.store/data-flow directory.") + @Default.InstanceFactory(CredentialDirFactory.class) + String getCredentialDir(); + void setCredentialDir(String value); + + /** + * Returns the default credential directory of ${user.home}/.store/data-flow. + */ + public static class CredentialDirFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + File home = new File(System.getProperty("user.home")); + File store = new File(home, ".store"); + File dataflow = new File(store, "data-flow"); + return dataflow.getPath(); + } + } + + /** + * The class of the credential factory that should be created and used to create + * credentials. If gcpCredential has not been set explicitly, an instance of this class will + * be constructed and used as a credential factory. + */ + @Description("The class of the credential factory that should be created and used to create " + + "credentials. If gcpCredential has not been set explicitly, an instance of this class will " + + "be constructed and used as a credential factory.") + @Default.Class(GcpCredentialFactory.class) + Class getCredentialFactoryClass(); + void setCredentialFactoryClass( + Class credentialFactoryClass); + + /** + * The credential instance that should be used to authenticate against GCP services. + * If no credential has been set explicitly, the default is to use the instance factory + * that constructs a credential based upon the currently set credentialFactoryClass. + */ + @JsonIgnore + @Description("The credential instance that should be used to authenticate against GCP services. " + + "If no credential has been set explicitly, the default is to use the instance factory " + + "that constructs a credential based upon the currently set credentialFactoryClass.") + @Default.InstanceFactory(GcpUserCredentialsFactory.class) + @Hidden + Credential getGcpCredential(); + void setGcpCredential(Credential value); + + /** + * Attempts to infer the default project based upon the environment this application + * is executing within. Currently this only supports getting the default project from gcloud. + */ + public static class DefaultProjectFactory implements DefaultValueFactory { + private static final Logger LOG = LoggerFactory.getLogger(DefaultProjectFactory.class); + + @Override + public String create(PipelineOptions options) { + try { + File configFile; + if (getEnvironment().containsKey("CLOUDSDK_CONFIG")) { + configFile = new File(getEnvironment().get("CLOUDSDK_CONFIG"), "properties"); + } else if (isWindows() && getEnvironment().containsKey("APPDATA")) { + configFile = new File(getEnvironment().get("APPDATA"), "gcloud/properties"); + } else { + // New versions of gcloud use this file + configFile = new File( + System.getProperty("user.home"), + ".config/gcloud/configurations/config_default"); + if (!configFile.exists()) { + // Old versions of gcloud use this file + configFile = new File(System.getProperty("user.home"), ".config/gcloud/properties"); + } + } + String section = null; + Pattern projectPattern = Pattern.compile("^project\\s*=\\s*(.*)$"); + Pattern sectionPattern = Pattern.compile("^\\[(.*)\\]$"); + for (String line : Files.readLines(configFile, StandardCharsets.UTF_8)) { + line = line.trim(); + if (line.isEmpty() || line.startsWith(";")) { + continue; + } + Matcher matcher = sectionPattern.matcher(line); + if (matcher.matches()) { + section = matcher.group(1); + } else if (section == null || section.equals("core")) { + matcher = projectPattern.matcher(line); + if (matcher.matches()) { + String project = matcher.group(1).trim(); + LOG.info("Inferred default GCP project '{}' from gcloud. If this is the incorrect " + + "project, please cancel this Pipeline and specify the command-line " + + "argument --project.", project); + return project; + } + } + } + } catch (IOException expected) { + LOG.debug("Failed to find default project.", expected); + } + // return null if can't determine + return null; + } + + /** + * Returns true if running on the Windows OS. + */ + private static boolean isWindows() { + return System.getProperty("os.name").toLowerCase(Locale.ENGLISH).contains("windows"); + } + + /** + * Used to mock out getting environment variables. + */ + @VisibleForTesting + Map getEnvironment() { + return System.getenv(); + } + } + + /** + * Attempts to load the GCP credentials. See + * {@link CredentialFactory#getCredential()} for more details. + */ + public static class GcpUserCredentialsFactory implements DefaultValueFactory { + @Override + public Credential create(PipelineOptions options) { + GcpOptions gcpOptions = options.as(GcpOptions.class); + try { + CredentialFactory factory = InstanceBuilder.ofType(CredentialFactory.class) + .fromClass(gcpOptions.getCredentialFactoryClass()) + .fromFactoryMethod("fromOptions") + .withArg(PipelineOptions.class, options) + .build(); + return factory.getCredential(); + } catch (IOException | GeneralSecurityException e) { + throw new RuntimeException("Unable to obtain credential", e); + } + } + } + + /** + * The token server URL to use for OAuth 2 authentication. Normally, the default is sufficient, + * but some specialized use cases may want to override this value. + */ + @Description("The token server URL to use for OAuth 2 authentication. Normally, the default " + + "is sufficient, but some specialized use cases may want to override this value.") + @Default.String(GoogleOAuthConstants.TOKEN_SERVER_URL) + @Hidden + String getTokenServerUrl(); + void setTokenServerUrl(String value); + + /** + * The authorization server URL to use for OAuth 2 authentication. Normally, the default is + * sufficient, but some specialized use cases may want to override this value. + */ + @Description("The authorization server URL to use for OAuth 2 authentication. Normally, the " + + "default is sufficient, but some specialized use cases may want to override this value.") + @Default.String(GoogleOAuthConstants.AUTHORIZATION_SERVER_URL) + @Hidden + String getAuthorizationServerEncodedUrl(); + void setAuthorizationServerEncodedUrl(String value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcsOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcsOptions.java new file mode 100644 index 000000000000..d2218075bd6e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GcsOptions.java @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.util.AppEngineEnvironment; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.hadoop.util.AbstractGoogleAsyncWriteChannel; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * Options used to configure Google Cloud Storage. + */ +public interface GcsOptions extends + ApplicationNameOptions, GcpOptions, PipelineOptions { + /** + * The GcsUtil instance that should be used to communicate with Google Cloud Storage. + */ + @JsonIgnore + @Description("The GcsUtil instance that should be used to communicate with Google Cloud Storage.") + @Default.InstanceFactory(GcsUtil.GcsUtilFactory.class) + @Hidden + GcsUtil getGcsUtil(); + void setGcsUtil(GcsUtil value); + + /** + * The ExecutorService instance to use to create threads, can be overridden to specify an + * ExecutorService that is compatible with the users environment. If unset, the + * default is to create an ExecutorService with an unbounded number of threads; this + * is compatible with Google AppEngine. + */ + @JsonIgnore + @Description("The ExecutorService instance to use to create multiple threads. Can be overridden " + + "to specify an ExecutorService that is compatible with the users environment. If unset, " + + "the default is to create an ExecutorService with an unbounded number of threads; this " + + "is compatible with Google AppEngine.") + @Default.InstanceFactory(ExecutorServiceFactory.class) + @Hidden + ExecutorService getExecutorService(); + void setExecutorService(ExecutorService value); + + /** + * GCS endpoint to use. If unspecified, uses the default endpoint. + */ + @JsonIgnore + @Hidden + @Description("The URL for the GCS API.") + String getGcsEndpoint(); + void setGcsEndpoint(String value); + + /** + * The buffer size (in bytes) to use when uploading files to GCS. Please see the documentation for + * {@link AbstractGoogleAsyncWriteChannel#setUploadBufferSize} for more information on the + * restrictions and performance implications of this value. + */ + @Description("The buffer size (in bytes) to use when uploading files to GCS. Please see the " + + "documentation for AbstractGoogleAsyncWriteChannel.setUploadBufferSize for more " + + "information on the restrictions and performance implications of this value.\n\n" + + "https://github.com/GoogleCloudPlatform/bigdata-interop/blob/master/util/src/main/java/" + + "com/google/cloud/hadoop/util/AbstractGoogleAsyncWriteChannel.java") + Integer getGcsUploadBufferSizeBytes(); + void setGcsUploadBufferSizeBytes(Integer bytes); + + /** + * Returns the default {@link ExecutorService} to use within the Dataflow SDK. The + * {@link ExecutorService} is compatible with AppEngine. + */ + public static class ExecutorServiceFactory implements DefaultValueFactory { + @SuppressWarnings("deprecation") // IS_APP_ENGINE is deprecated for internal use only. + @Override + public ExecutorService create(PipelineOptions options) { + ThreadFactoryBuilder threadFactoryBuilder = new ThreadFactoryBuilder(); + threadFactoryBuilder.setThreadFactory(MoreExecutors.platformThreadFactory()); + if (!AppEngineEnvironment.IS_APP_ENGINE) { + // AppEngine doesn't allow modification of threads to be daemon threads. + threadFactoryBuilder.setDaemon(true); + } + /* The SDK requires an unbounded thread pool because a step may create X writers + * each requiring their own thread to perform the writes otherwise a writer may + * block causing deadlock for the step because the writers buffer is full. + * Also, the MapTaskExecutor launches the steps in reverse order and completes + * them in forward order thus requiring enough threads so that each step's writers + * can be active. + */ + return new ThreadPoolExecutor( + 0, Integer.MAX_VALUE, // Allow an unlimited number of re-usable threads. + Long.MAX_VALUE, TimeUnit.NANOSECONDS, // Keep non-core threads alive forever. + new SynchronousQueue(), + threadFactoryBuilder.build()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GoogleApiDebugOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GoogleApiDebugOptions.java new file mode 100644 index 000000000000..eff679b405de --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/GoogleApiDebugOptions.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.api.client.googleapis.services.AbstractGoogleClient; +import com.google.api.client.googleapis.services.AbstractGoogleClientRequest; +import com.google.api.client.googleapis.services.GoogleClientRequestInitializer; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * These options configure debug settings for Google API clients created within the Dataflow SDK. + */ +public interface GoogleApiDebugOptions extends PipelineOptions { + /** + * This option enables tracing of API calls to Google services used within the + * Dataflow SDK. Values are expected in JSON format {"ApiName":"TraceDestination",...} + * where the {@code ApiName} represents the request classes canonical name. The + * {@code TraceDestination} is a logical trace consumer to whom the trace will be reported. + * Typically, "producer" is the right destination to use: this makes API traces available to the + * team offering the API. Note that by enabling this option, the contents of the requests to and + * from Google Cloud services will be made available to Google. For example, by specifying + * {"Dataflow":"producer"}, all calls to the Dataflow service will be made available + * to Google, specifically to the Google Cloud Dataflow team. + */ + @Description("This option enables tracing of API calls to Google services used within the " + + "Dataflow SDK. Values are expected in JSON format {\"ApiName\":\"TraceDestination\",...} " + + "where the ApiName represents the request classes canonical name. The TraceDestination is " + + "a logical trace consumer to whom the trace will be reported. Typically, \"producer\" is " + + "the right destination to use: this makes API traces available to the team offering the " + + "API. Note that by enabling this option, the contents of the requests to and from " + + "Google Cloud services will be made available to Google. For example, by specifying " + + "{\"Dataflow\":\"producer\"}, all calls to the Dataflow service will be made available to " + + "Google, specifically to the Google Cloud Dataflow team.") + GoogleApiTracer getGoogleApiTrace(); + void setGoogleApiTrace(GoogleApiTracer commands); + + /** + * A {@link GoogleClientRequestInitializer} that adds the trace destination to Google API calls. + */ + public static class GoogleApiTracer extends HashMap + implements GoogleClientRequestInitializer { + /** + * Creates a {@link GoogleApiTracer} that sets the trace destination on all + * calls that match the given client type. + */ + public GoogleApiTracer addTraceFor(AbstractGoogleClient client, String traceDestination) { + put(client.getClass().getCanonicalName(), traceDestination); + return this; + } + + /** + * Creates a {@link GoogleApiTracer} that sets the trace {@code traceDestination} on all + * calls that match for the given request type. + */ + public GoogleApiTracer addTraceFor( + AbstractGoogleClientRequest request, String traceDestination) { + put(request.getClass().getCanonicalName(), traceDestination); + return this; + } + + @Override + public void initialize(AbstractGoogleClientRequest request) throws IOException { + for (Map.Entry entry : this.entrySet()) { + if (request.getClass().getCanonicalName().contains(entry.getKey())) { + request.set("$trace", entry.getValue()); + } + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Hidden.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Hidden.java new file mode 100644 index 000000000000..6a487eb2f548 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Hidden.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Methods and/or interfaces annotated with {@code @Hidden} will be suppressed from + * being output when {@code --help} is specified on the command-line. + */ +@Target({ElementType.METHOD, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface Hidden { +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java new file mode 100644 index 000000000000..923033d5dadb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java @@ -0,0 +1,248 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.auto.service.AutoService; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.GoogleApiDebugOptions.GoogleApiTracer; +import com.google.cloud.dataflow.sdk.options.ProxyInvocationHandler.Deserializer; +import com.google.cloud.dataflow.sdk.options.ProxyInvocationHandler.Serializer; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.Context; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import java.lang.reflect.Proxy; +import java.util.ServiceLoader; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * PipelineOptions are used to configure Pipelines. You can extend {@link PipelineOptions} + * to create custom configuration options specific to your {@link Pipeline}, + * for both local execution and execution via a {@link PipelineRunner}. + * + *

    {@link PipelineOptions} and their subinterfaces represent a collection of properties + * which can be manipulated in a type safe manner. {@link PipelineOptions} is backed by a + * dynamic {@link Proxy} which allows for type safe manipulation of properties in an extensible + * fashion through plain old Java interfaces. + * + *

    {@link PipelineOptions} can be created with {@link PipelineOptionsFactory#create()} + * and {@link PipelineOptionsFactory#as(Class)}. They can be created + * from command-line arguments with {@link PipelineOptionsFactory#fromArgs(String[])}. + * They can be converted to another type by invoking {@link PipelineOptions#as(Class)} and + * can be accessed from within a {@link DoFn} by invoking + * {@link Context#getPipelineOptions()}. + * + *

    For example: + *

    {@code
    + * // The most common way to construct PipelineOptions is via command-line argument parsing:
    + * public static void main(String[] args) {
    + *   // Will parse the arguments passed into the application and construct a PipelineOptions
    + *   // Note that --help will print registered options, and --help=PipelineOptionsClassName
    + *   // will print out usage for the specific class.
    + *   PipelineOptions options =
    + *       PipelineOptionsFactory.fromArgs(args).create();
    + *
    + *   Pipeline p = Pipeline.create(options);
    + *   ...
    + *   p.run();
    + * }
    + *
    + * // To create options for the DirectPipeline:
    + * DirectPipelineOptions directPipelineOptions =
    + *     PipelineOptionsFactory.as(DirectPipelineOptions.class);
    + * directPipelineOptions.setStreaming(true);
    + *
    + * // To cast from one type to another using the as(Class) method:
    + * DataflowPipelineOptions dataflowPipelineOptions =
    + *     directPipelineOptions.as(DataflowPipelineOptions.class);
    + *
    + * // Options for the same property are shared between types
    + * // The statement below will print out "true"
    + * System.out.println(dataflowPipelineOptions.isStreaming());
    + *
    + * // Prints out registered options.
    + * PipelineOptionsFactory.printHelp(System.out);
    + *
    + * // Prints out options which are available to be set on DataflowPipelineOptions
    + * PipelineOptionsFactory.printHelp(System.out, DataflowPipelineOptions.class);
    + * }
    + * + *

    Defining Your Own PipelineOptions

    + * + * Defining your own {@link PipelineOptions} is the way for you to make configuration + * options available for both local execution and execution via a {@link PipelineRunner}. + * By having PipelineOptionsFactory as your command-line interpreter, you will provide + * a standardized way for users to interact with your application via the command-line. + * + *

    To define your own {@link PipelineOptions}, you create an interface which + * extends {@link PipelineOptions} and define getter/setter pairs. These + * getter/setter pairs define a collection of + * + * JavaBean properties. + * + *

    For example: + *

    {@code
    + *  // Creates a user defined property called "myProperty"
    + *  public interface MyOptions extends PipelineOptions {
    + *    String getMyProperty();
    + *    void setMyProperty(String value);
    + *  }
    + * }
    + * + *

    Note: Please see the section on Registration below when using custom property types. + * + *

    Restrictions

    + * + * Since PipelineOptions can be "cast" to multiple types dynamically using + * {@link PipelineOptions#as(Class)}, a property must conform to the following set of restrictions: + *
      + *
    • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
    • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
    • Every method must conform to being a getter or setter for a JavaBean. + *
    • The derived interface of {@link PipelineOptions} must be composable with every interface + * part registered with the PipelineOptionsFactory. + *
    • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. + *
    • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for + * this property must be annotated with {@link JsonIgnore @JsonIgnore}. + *
    + * + *

    Annotations For PipelineOptions

    + * + * {@link Description @Description} can be used to annotate an interface or a getter + * with useful information which is output when {@code --help} + * is invoked via {@link PipelineOptionsFactory#fromArgs(String[])}. + * + *

    {@link Default @Default} represents a set of annotations that can be used to annotate getter + * properties on {@link PipelineOptions} with information representing the default value to be + * returned if no value is specified. + * + *

    {@link Hidden @Hidden} hides an option from being listed when {@code --help} + * is invoked via {@link PipelineOptionsFactory#fromArgs(String[])}. + * + *

    {@link Validation @Validation} represents a set of annotations that can be used to annotate + * getter properties on {@link PipelineOptions} with information representing the validation + * criteria to be used when validating with the {@link PipelineOptionsValidator}. Validation + * will be performed if during construction of the {@link PipelineOptions}, + * {@link PipelineOptionsFactory#withValidation()} is invoked. + * + *

    {@link JsonIgnore @JsonIgnore} is used to prevent a property from being serialized and + * available during execution of {@link DoFn}. See the Serialization section below for more + * details. + * + *

    Registration Of PipelineOptions

    + * + * Registration of {@link PipelineOptions} by an application guarantees that the + * {@link PipelineOptions} is composable during execution of their {@link Pipeline} and + * meets the restrictions listed above or will fail during registration. Registration + * also lists the registered {@link PipelineOptions} when {@code --help} + * is invoked via {@link PipelineOptionsFactory#fromArgs(String[])}. + * + *

    Registration can be performed by invoking {@link PipelineOptionsFactory#register} within + * a users application or via automatic registration by creating a {@link ServiceLoader} entry + * and a concrete implementation of the {@link PipelineOptionsRegistrar} interface. + * + *

    It is optional but recommended to use one of the many build time tools such as + * {@link AutoService} to generate the necessary META-INF files automatically. + * + *

    A list of registered options can be fetched from + * {@link PipelineOptionsFactory#getRegisteredOptions()}. + * + *

    Serialization Of PipelineOptions

    + * + * {@link PipelineRunner}s require support for options to be serialized. Each property + * within {@link PipelineOptions} must be able to be serialized using Jackson's + * {@link ObjectMapper} or the getter method for the property annotated with + * {@link JsonIgnore @JsonIgnore}. + * + *

    Jackson supports serialization of many types and supports a useful set of + * annotations to aid in + * serialization of custom types. We point you to the public + * Jackson documentation when attempting + * to add serialization support for your custom types. See {@link GoogleApiTracer} for an + * example using the Jackson annotations to serialize and deserialize a custom type. + * + *

    Note: It is an error to have the same property available in multiple interfaces with only + * some of them being annotated with {@link JsonIgnore @JsonIgnore}. It is also an error to mark a + * setter for a property with {@link JsonIgnore @JsonIgnore}. + */ +@JsonSerialize(using = Serializer.class) +@JsonDeserialize(using = Deserializer.class) +@ThreadSafe +public interface PipelineOptions { + /** + * Transforms this object into an object of type {@code } saving each property + * that has been manipulated. {@code } must extend {@link PipelineOptions}. + * + *

    If {@code } is not registered with the {@link PipelineOptionsFactory}, then we + * attempt to verify that {@code } is composable with every interface that this + * instance of the {@code PipelineOptions} has seen. + * + * @param kls The class of the type to transform to. + * @return An object of type kls. + */ + T as(Class kls); + + /** + * Makes a deep clone of this object, and transforms the cloned object into the specified + * type {@code kls}. See {@link #as} for more information about the conversion. + * + *

    Properties that are marked with {@code @JsonIgnore} will not be cloned. + */ + T cloneAs(Class kls); + + /** + * The pipeline runner that will be used to execute the pipeline. + * For registered runners, the class name can be specified, otherwise the fully + * qualified name needs to be specified. + */ + @Validation.Required + @Description("The pipeline runner that will be used to execute the pipeline. " + + "For registered runners, the class name can be specified, otherwise the fully " + + "qualified name needs to be specified.") + @Default.Class(DirectPipelineRunner.class) + Class> getRunner(); + void setRunner(Class> kls); + + /** + * Enumeration of the possible states for a given check. + */ + public static enum CheckEnabled { + OFF, + WARNING, + ERROR; + } + + /** + * Whether to check for stable unique names on each transform. This is necessary to + * support updating of pipelines. + */ + @Validation.Required + @Description("Whether to check for stable unique names on each transform. This is necessary to " + + "support updating of pipelines.") + @Default.Enum("WARNING") + CheckEnabled getStableUniqueNames(); + void setStableUniqueNames(CheckEnabled enabled); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java new file mode 100644 index 000000000000..e77b89f9a4ec --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java @@ -0,0 +1,1497 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.options.Validation.Required; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.util.common.ReflectHelpers; +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.base.Strings; +import com.google.common.base.Throwables; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Collections2; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.collect.SortedSetMultimap; +import com.google.common.collect.TreeMultimap; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.beans.BeanInfo; +import java.beans.IntrospectionException; +import java.beans.Introspector; +import java.beans.PropertyDescriptor; +import java.io.IOException; +import java.io.PrintStream; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Proxy; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.ServiceLoader; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeSet; + +import javax.annotation.Nullable; + +/** + * Constructs a {@link PipelineOptions} or any derived interface that is composable to any other + * derived interface of {@link PipelineOptions} via the {@link PipelineOptions#as} method. Being + * able to compose one derived interface of {@link PipelineOptions} to another has the following + * restrictions: + *

      + *
    • Any property with the same name must have the same return type for all derived interfaces + * of {@link PipelineOptions}. + *
    • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
    • Every method must conform to being a getter or setter for a JavaBean. + *
    • The derived interface of {@link PipelineOptions} must be composable with every interface + * registered with this factory. + *
    + * + *

    See the JavaBeans + * specification for more details as to what constitutes a property. + */ +public class PipelineOptionsFactory { + /** + * Creates and returns an object that implements {@link PipelineOptions}. + * This sets the {@link ApplicationNameOptions#getAppName() "appName"} to the calling + * {@link Class#getSimpleName() classes simple name}. + * + * @return An object that implements {@link PipelineOptions}. + */ + public static PipelineOptions create() { + return new Builder().as(PipelineOptions.class); + } + + /** + * Creates and returns an object that implements {@code }. + * This sets the {@link ApplicationNameOptions#getAppName() "appName"} to the calling + * {@link Class#getSimpleName() classes simple name}. + * + *

    Note that {@code } must be composable with every registered interface with this factory. + * See {@link PipelineOptionsFactory#validateWellFormed(Class, Set)} for more details. + * + * @return An object that implements {@code }. + */ + public static T as(Class klass) { + return new Builder().as(klass); + } + + /** + * Sets the command line arguments to parse when constructing the {@link PipelineOptions}. + * + *

    Example GNU style command line arguments: + *

    +   *   --project=MyProject (simple property, will set the "project" property to "MyProject")
    +   *   --readOnly=true (for boolean properties, will set the "readOnly" property to "true")
    +   *   --readOnly (shorthand for boolean properties, will set the "readOnly" property to "true")
    +   *   --x=1 --x=2 --x=3 (list style simple property, will set the "x" property to [1, 2, 3])
    +   *   --x=1,2,3 (shorthand list style simple property, will set the "x" property to [1, 2, 3])
    +   *   --complexObject='{"key1":"value1",...} (JSON format for all other complex types)
    +   * 
    + * + *

    Simple properties are able to bound to {@link String}, {@link Class}, enums and Java + * primitives {@code boolean}, {@code byte}, {@code short}, {@code int}, {@code long}, + * {@code float}, {@code double} and their primitive wrapper classes. + * + *

    Simple list style properties are able to be bound to {@code boolean[]}, {@code char[]}, + * {@code short[]}, {@code int[]}, {@code long[]}, {@code float[]}, {@code double[]}, + * {@code Class[]}, enum arrays, {@code String[]}, and {@code List}. + * + *

    JSON format is required for all other types. + * + *

    By default, strict parsing is enabled and arguments must conform to be either + * {@code --booleanArgName} or {@code --argName=argValue}. Strict parsing can be disabled with + * {@link Builder#withoutStrictParsing()}. Empty or null arguments will be ignored whether + * or not strict parsing is enabled. + * + *

    Help information can be output to {@link System#out} by specifying {@code --help} as an + * argument. After help is printed, the application will exit. Specifying only {@code --help} + * will print out the list of + * {@link PipelineOptionsFactory#getRegisteredOptions() registered options} + * by invoking {@link PipelineOptionsFactory#printHelp(PrintStream)}. Specifying + * {@code --help=PipelineOptionsClassName} will print out detailed usage information about the + * specifically requested PipelineOptions by invoking + * {@link PipelineOptionsFactory#printHelp(PrintStream, Class)}. + */ + public static Builder fromArgs(String[] args) { + return new Builder().fromArgs(args); + } + + /** + * After creation we will validate that {@code } conforms to all the + * validation criteria. See + * {@link PipelineOptionsValidator#validate(Class, PipelineOptions)} for more details about + * validation. + */ + public Builder withValidation() { + return new Builder().withValidation(); + } + + /** A fluent {@link PipelineOptions} builder. */ + public static class Builder { + private final String defaultAppName; + private final String[] args; + private final boolean validation; + private final boolean strictParsing; + + // Do not allow direct instantiation + private Builder() { + this(null, false, true); + } + + private Builder(String[] args, boolean validation, + boolean strictParsing) { + this.defaultAppName = findCallersClassName(); + this.args = args; + this.validation = validation; + this.strictParsing = strictParsing; + } + + /** + * Sets the command line arguments to parse when constructing the {@link PipelineOptions}. + * + *

    Example GNU style command line arguments: + *

    +     *   --project=MyProject (simple property, will set the "project" property to "MyProject")
    +     *   --readOnly=true (for boolean properties, will set the "readOnly" property to "true")
    +     *   --readOnly (shorthand for boolean properties, will set the "readOnly" property to "true")
    +     *   --x=1 --x=2 --x=3 (list style simple property, will set the "x" property to [1, 2, 3])
    +     *   --x=1,2,3 (shorthand list style simple property, will set the "x" property to [1, 2, 3])
    +     *   --complexObject='{"key1":"value1",...} (JSON format for all other complex types)
    +     * 
    + * + *

    Simple properties are able to bound to {@link String}, {@link Class}, enums and Java + * primitives {@code boolean}, {@code byte}, {@code short}, {@code int}, {@code long}, + * {@code float}, {@code double} and their primitive wrapper classes. + * + *

    Simple list style properties are able to be bound to {@code boolean[]}, {@code char[]}, + * {@code short[]}, {@code int[]}, {@code long[]}, {@code float[]}, {@code double[]}, + * {@code Class[]}, enum arrays, {@code String[]}, and {@code List}. + * + *

    JSON format is required for all other types. + * + *

    By default, strict parsing is enabled and arguments must conform to be either + * {@code --booleanArgName} or {@code --argName=argValue}. Strict parsing can be disabled with + * {@link Builder#withoutStrictParsing()}. Empty or null arguments will be ignored whether + * or not strict parsing is enabled. + * + *

    Help information can be output to {@link System#out} by specifying {@code --help} as an + * argument. After help is printed, the application will exit. Specifying only {@code --help} + * will print out the list of + * {@link PipelineOptionsFactory#getRegisteredOptions() registered options} + * by invoking {@link PipelineOptionsFactory#printHelp(PrintStream)}. Specifying + * {@code --help=PipelineOptionsClassName} will print out detailed usage information about the + * specifically requested PipelineOptions by invoking + * {@link PipelineOptionsFactory#printHelp(PrintStream, Class)}. + */ + public Builder fromArgs(String[] args) { + Preconditions.checkNotNull(args, "Arguments should not be null."); + return new Builder(args, validation, strictParsing); + } + + /** + * After creation we will validate that {@link PipelineOptions} conforms to all the + * validation criteria from {@code }. See + * {@link PipelineOptionsValidator#validate(Class, PipelineOptions)} for more details about + * validation. + */ + public Builder withValidation() { + return new Builder(args, true, strictParsing); + } + + /** + * During parsing of the arguments, we will skip over improperly formatted and unknown + * arguments. + */ + public Builder withoutStrictParsing() { + return new Builder(args, validation, false); + } + + /** + * Creates and returns an object that implements {@link PipelineOptions} using the values + * configured on this builder during construction. + * + * @return An object that implements {@link PipelineOptions}. + */ + public PipelineOptions create() { + return as(PipelineOptions.class); + } + + /** + * Creates and returns an object that implements {@code } using the values configured on + * this builder during construction. + * + *

    Note that {@code } must be composable with every registered interface with this + * factory. See {@link PipelineOptionsFactory#validateWellFormed(Class, Set)} for more + * details. + * + * @return An object that implements {@code }. + */ + public T as(Class klass) { + Map initialOptions = Maps.newHashMap(); + + // Attempt to parse the arguments into the set of initial options to use + if (args != null) { + ListMultimap options = parseCommandLine(args, strictParsing); + LOG.debug("Provided Arguments: {}", options); + printHelpUsageAndExitIfNeeded(options, System.out, true /* exit */); + initialOptions = parseObjects(klass, options, strictParsing); + } + + // Create our proxy + ProxyInvocationHandler handler = new ProxyInvocationHandler(initialOptions); + T t = handler.as(klass); + + // Set the application name to the default if none was set. + ApplicationNameOptions appNameOptions = t.as(ApplicationNameOptions.class); + if (appNameOptions.getAppName() == null) { + appNameOptions.setAppName(defaultAppName); + } + + if (validation) { + PipelineOptionsValidator.validate(klass, t); + } + return t; + } + } + + /** + * Determines whether the generic {@code --help} was requested or help was + * requested for a specific class and invokes the appropriate + * {@link PipelineOptionsFactory#printHelp(PrintStream)} and + * {@link PipelineOptionsFactory#printHelp(PrintStream, Class)} variant. + * Prints to the specified {@link PrintStream}, and exits if requested. + * + *

    Visible for testing. + * {@code printStream} and {@code exit} used for testing. + */ + @SuppressWarnings("unchecked") + static boolean printHelpUsageAndExitIfNeeded(ListMultimap options, + PrintStream printStream, boolean exit) { + if (options.containsKey("help")) { + final String helpOption = Iterables.getOnlyElement(options.get("help")); + + // Print the generic help if only --help was specified. + if (Boolean.TRUE.toString().equals(helpOption)) { + printHelp(printStream); + if (exit) { + System.exit(0); + } else { + return true; + } + } + + // Otherwise attempt to print the specific help option. + try { + Class klass = Class.forName(helpOption); + if (!PipelineOptions.class.isAssignableFrom(klass)) { + throw new ClassNotFoundException("PipelineOptions of type " + klass + " not found."); + } + printHelp(printStream, (Class) klass); + } catch (ClassNotFoundException e) { + // If we didn't find an exact match, look for any that match the class name. + Iterable> matches = Iterables.filter( + getRegisteredOptions(), + new Predicate>() { + @Override + public boolean apply(Class input) { + if (helpOption.contains(".")) { + return input.getName().endsWith(helpOption); + } else { + return input.getSimpleName().equals(helpOption); + } + } + }); + try { + printHelp(printStream, Iterables.getOnlyElement(matches)); + } catch (NoSuchElementException exception) { + printStream.format("Unable to find option %s.%n", helpOption); + printHelp(printStream); + } catch (IllegalArgumentException exception) { + printStream.format("Multiple matches found for %s: %s.%n", helpOption, + Iterables.transform(matches, ReflectHelpers.CLASS_NAME)); + printHelp(printStream); + } + } + if (exit) { + System.exit(0); + } else { + return true; + } + } + return false; + } + + /** + * Returns the simple name of the calling class using the current threads stack. + */ + private static String findCallersClassName() { + Iterator elements = + Iterators.forArray(Thread.currentThread().getStackTrace()); + // First find the PipelineOptionsFactory/Builder class in the stack trace. + while (elements.hasNext()) { + StackTraceElement next = elements.next(); + if (PIPELINE_OPTIONS_FACTORY_CLASSES.contains(next.getClassName())) { + break; + } + } + // Then find the first instance after that is not the PipelineOptionsFactory/Builder class. + while (elements.hasNext()) { + StackTraceElement next = elements.next(); + if (!PIPELINE_OPTIONS_FACTORY_CLASSES.contains(next.getClassName())) { + try { + return Class.forName(next.getClassName()).getSimpleName(); + } catch (ClassNotFoundException e) { + break; + } + } + } + + return "unknown"; + } + + /** + * Stores the generated proxyClass and its respective {@link BeanInfo} object. + * + * @param The type of the proxyClass. + */ + static class Registration { + private final Class proxyClass; + private final List propertyDescriptors; + + public Registration(Class proxyClass, List beanInfo) { + this.proxyClass = proxyClass; + this.propertyDescriptors = beanInfo; + } + + List getPropertyDescriptors() { + return propertyDescriptors; + } + + Class getProxyClass() { + return proxyClass; + } + } + + private static final Set> SIMPLE_TYPES = ImmutableSet.>builder() + .add(boolean.class) + .add(Boolean.class) + .add(char.class) + .add(Character.class) + .add(short.class) + .add(Short.class) + .add(int.class) + .add(Integer.class) + .add(long.class) + .add(Long.class) + .add(float.class) + .add(Float.class) + .add(double.class) + .add(Double.class) + .add(String.class) + .add(Class.class).build(); + private static final Logger LOG = LoggerFactory.getLogger(PipelineOptionsFactory.class); + @SuppressWarnings("rawtypes") + private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final Map>> SUPPORTED_PIPELINE_RUNNERS; + + /** Classes that are used as the boundary in the stack trace to find the callers class name. */ + private static final Set PIPELINE_OPTIONS_FACTORY_CLASSES = ImmutableSet.of( + PipelineOptionsFactory.class.getName(), + Builder.class.getName()); + + /** Methods that are ignored when validating the proxy class. */ + private static final Set IGNORED_METHODS; + + /** The set of options that have been registered and visible to the user. */ + private static final Set> REGISTERED_OPTIONS = + Sets.newConcurrentHashSet(); + + /** A cache storing a mapping from a given interface to its registration record. */ + private static final Map, Registration> INTERFACE_CACHE = + Maps.newConcurrentMap(); + + /** A cache storing a mapping from a set of interfaces to its registration record. */ + private static final Map>, Registration> COMBINED_CACHE = + Maps.newConcurrentMap(); + + /** The width at which options should be output. */ + private static final int TERMINAL_WIDTH = 80; + + /** + * Finds the appropriate {@code ClassLoader} to be used by the + * {@link ServiceLoader#load} call, which by default would use the context + * {@code ClassLoader}, which can be null. The fallback is as follows: context + * ClassLoader, class ClassLoader and finaly the system ClassLoader. + */ + static ClassLoader findClassLoader() { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + if (classLoader == null) { + classLoader = PipelineOptionsFactory.class.getClassLoader(); + } + if (classLoader == null) { + classLoader = ClassLoader.getSystemClassLoader(); + } + return classLoader; + } + + static { + try { + IGNORED_METHODS = ImmutableSet.builder() + .add(Object.class.getMethod("getClass")) + .add(Object.class.getMethod("wait")) + .add(Object.class.getMethod("wait", long.class)) + .add(Object.class.getMethod("wait", long.class, int.class)) + .add(Object.class.getMethod("notify")) + .add(Object.class.getMethod("notifyAll")) + .add(Proxy.class.getMethod("getInvocationHandler", Object.class)) + .build(); + } catch (NoSuchMethodException | SecurityException e) { + LOG.error("Unable to find expected method", e); + throw new ExceptionInInitializerError(e); + } + + ClassLoader classLoader = findClassLoader(); + + // Store the list of all available pipeline runners. + ImmutableMap.Builder>> builder = + ImmutableMap.builder(); + Set pipelineRunnerRegistrars = + Sets.newTreeSet(ObjectsClassComparator.INSTANCE); + pipelineRunnerRegistrars.addAll( + Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class, classLoader))); + for (PipelineRunnerRegistrar registrar : pipelineRunnerRegistrars) { + for (Class> klass : registrar.getPipelineRunners()) { + builder.put(klass.getSimpleName(), klass); + } + } + SUPPORTED_PIPELINE_RUNNERS = builder.build(); + + // Load and register the list of all classes that extend PipelineOptions. + register(PipelineOptions.class); + Set pipelineOptionsRegistrars = + Sets.newTreeSet(ObjectsClassComparator.INSTANCE); + pipelineOptionsRegistrars.addAll( + Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class, classLoader))); + for (PipelineOptionsRegistrar registrar : pipelineOptionsRegistrars) { + for (Class klass : registrar.getPipelineOptions()) { + register(klass); + } + } + } + + /** + * This registers the interface with this factory. This interface must conform to the following + * restrictions: + *

      + *
    • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
    • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
    • Every method must conform to being a getter or setter for a JavaBean. + *
    • The derived interface of {@link PipelineOptions} must be composable with every interface + * registered with this factory. + *
    + * + * @param iface The interface object to manually register. + */ + public static synchronized void register(Class iface) { + Preconditions.checkNotNull(iface); + Preconditions.checkArgument(iface.isInterface(), "Only interface types are supported."); + + if (REGISTERED_OPTIONS.contains(iface)) { + return; + } + validateWellFormed(iface, REGISTERED_OPTIONS); + REGISTERED_OPTIONS.add(iface); + } + + /** + * Validates that the interface conforms to the following: + *
      + *
    • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
    • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
    • Every method must conform to being a getter or setter for a JavaBean. + *
    • The derived interface of {@link PipelineOptions} must be composable with every interface + * part of allPipelineOptionsClasses. + *
    • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. + *
    • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for + * this property must be annotated with {@link JsonIgnore @JsonIgnore}. + *
    + * + * @param iface The interface to validate. + * @param validatedPipelineOptionsInterfaces The set of validated pipeline options interfaces to + * validate against. + * @return A registration record containing the proxy class and bean info for iface. + */ + static synchronized Registration validateWellFormed( + Class iface, Set> validatedPipelineOptionsInterfaces) { + Preconditions.checkArgument(iface.isInterface(), "Only interface types are supported."); + + @SuppressWarnings("unchecked") + Set> combinedPipelineOptionsInterfaces = + FluentIterable.from(validatedPipelineOptionsInterfaces).append(iface).toSet(); + // Validate that the view of all currently passed in options classes is well formed. + if (!COMBINED_CACHE.containsKey(combinedPipelineOptionsInterfaces)) { + @SuppressWarnings("unchecked") + Class allProxyClass = + (Class) Proxy.getProxyClass(PipelineOptionsFactory.class.getClassLoader(), + combinedPipelineOptionsInterfaces.toArray(EMPTY_CLASS_ARRAY)); + try { + List propertyDescriptors = + validateClass(iface, validatedPipelineOptionsInterfaces, allProxyClass); + COMBINED_CACHE.put(combinedPipelineOptionsInterfaces, + new Registration(allProxyClass, propertyDescriptors)); + } catch (IntrospectionException e) { + throw Throwables.propagate(e); + } + } + + // Validate that the local view of the class is well formed. + if (!INTERFACE_CACHE.containsKey(iface)) { + @SuppressWarnings({"rawtypes", "unchecked"}) + Class proxyClass = (Class) Proxy.getProxyClass( + PipelineOptionsFactory.class.getClassLoader(), new Class[] {iface}); + try { + List propertyDescriptors = + validateClass(iface, validatedPipelineOptionsInterfaces, proxyClass); + INTERFACE_CACHE.put(iface, + new Registration(proxyClass, propertyDescriptors)); + } catch (IntrospectionException e) { + throw Throwables.propagate(e); + } + } + @SuppressWarnings("unchecked") + Registration result = (Registration) INTERFACE_CACHE.get(iface); + return result; + } + + public static Set> getRegisteredOptions() { + return Collections.unmodifiableSet(REGISTERED_OPTIONS); + } + + /** + * Outputs the set of registered options with the PipelineOptionsFactory + * with a description for each one if available to the output stream. This output + * is pretty printed and meant to be human readable. This method will attempt to + * format its output to be compatible with a terminal window. + */ + public static void printHelp(PrintStream out) { + Preconditions.checkNotNull(out); + out.println("The set of registered options are:"); + Set> sortedOptions = + new TreeSet<>(ClassNameComparator.INSTANCE); + sortedOptions.addAll(REGISTERED_OPTIONS); + for (Class kls : sortedOptions) { + out.format(" %s%n", kls.getName()); + } + out.format("%nUse --help= for detailed help. For example:%n" + + " --help=DataflowPipelineOptions %n" + + " --help=com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions%n"); + } + + /** + * Outputs the set of options available to be set for the passed in {@link PipelineOptions} + * interface. The output is in a human readable format. The format is: + *
    +   * OptionGroup:
    +   *     ... option group description ...
    +   *
    +   *  --option1={@code } or list of valid enum choices
    +   *     Default: value (if available, see {@link Default})
    +   *     ... option description ... (if available, see {@link Description})
    +   *     Required groups (if available, see {@link Required})
    +   *  --option2={@code } or list of valid enum choices
    +   *     Default: value (if available, see {@link Default})
    +   *     ... option description ... (if available, see {@link Description})
    +   *     Required groups (if available, see {@link Required})
    +   * 
    + * This method will attempt to format its output to be compatible with a terminal window. + */ + public static void printHelp(PrintStream out, Class iface) { + Preconditions.checkNotNull(out); + Preconditions.checkNotNull(iface); + validateWellFormed(iface, REGISTERED_OPTIONS); + + Iterable methods = ReflectHelpers.getClosureOfMethodsOnInterface(iface); + ListMultimap, Method> ifaceToMethods = ArrayListMultimap.create(); + for (Method method : methods) { + // Process only methods that are not marked as hidden. + if (method.getAnnotation(Hidden.class) == null) { + ifaceToMethods.put(method.getDeclaringClass(), method); + } + } + SortedSet> ifaces = new TreeSet<>(ClassNameComparator.INSTANCE); + // Keep interfaces that are not marked as hidden. + ifaces.addAll(Collections2.filter(ifaceToMethods.keySet(), new Predicate>() { + @Override + public boolean apply(Class input) { + return input.getAnnotation(Hidden.class) == null; + } + })); + for (Class currentIface : ifaces) { + Map propertyNamesToGetters = + getPropertyNamesToGetters(ifaceToMethods.get(currentIface)); + + // Don't output anything if there are no defined options + if (propertyNamesToGetters.isEmpty()) { + continue; + } + SortedSetMultimap requiredGroupNameToProperties = + getRequiredGroupNamesToProperties(propertyNamesToGetters); + + out.format("%s:%n", currentIface.getName()); + prettyPrintDescription(out, currentIface.getAnnotation(Description.class)); + + out.println(); + + List lists = Lists.newArrayList(propertyNamesToGetters.keySet()); + Collections.sort(lists, String.CASE_INSENSITIVE_ORDER); + for (String propertyName : lists) { + Method method = propertyNamesToGetters.get(propertyName); + String printableType = method.getReturnType().getSimpleName(); + if (method.getReturnType().isEnum()) { + printableType = Joiner.on(" | ").join(method.getReturnType().getEnumConstants()); + } + out.format(" --%s=<%s>%n", propertyName, printableType); + Optional defaultValue = getDefaultValueFromAnnotation(method); + if (defaultValue.isPresent()) { + out.format(" Default: %s%n", defaultValue.get()); + } + prettyPrintDescription(out, method.getAnnotation(Description.class)); + prettyPrintRequiredGroups(out, method.getAnnotation(Validation.Required.class), + requiredGroupNameToProperties); + } + out.println(); + } + } + + /** + * Output the requirement groups that the property is a member of, including all properties that + * satisfy the group requirement, breaking up long lines on white space characters and attempting + * to honor a line limit of {@code TERMINAL_WIDTH}. + */ + private static void prettyPrintRequiredGroups(PrintStream out, Required annotation, + SortedSetMultimap requiredGroupNameToProperties) { + if (annotation == null || annotation.groups() == null) { + return; + } + for (String group : annotation.groups()) { + SortedSet groupMembers = requiredGroupNameToProperties.get(group); + String requirement; + if (groupMembers.size() == 1) { + requirement = Iterables.getOnlyElement(groupMembers) + " is required."; + } else { + requirement = "At least one of " + groupMembers + " is required"; + } + terminalPrettyPrint(out, requirement.split("\\s+")); + } + } + + /** + * Outputs the value of the description, breaking up long lines on white space characters and + * attempting to honor a line limit of {@code TERMINAL_WIDTH}. + */ + private static void prettyPrintDescription(PrintStream out, Description description) { + if (description == null || description.value() == null) { + return; + } + + String[] words = description.value().split("\\s+"); + terminalPrettyPrint(out, words); + } + + private static void terminalPrettyPrint(PrintStream out, String[] words) { + final String spacing = " "; + + if (words.length == 0) { + return; + } + + out.print(spacing); + int lineLength = spacing.length(); + for (int i = 0; i < words.length; ++i) { + out.print(" "); + out.print(words[i]); + lineLength += 1 + words[i].length(); + + // If the next word takes us over the terminal width, then goto the next line. + if (i + 1 != words.length && words[i + 1].length() + lineLength + 1 > TERMINAL_WIDTH) { + out.println(); + out.print(spacing); + lineLength = spacing.length(); + } + } + out.println(); + } + + /** + * Returns a string representation of the {@link Default} value on the passed in method. + */ + private static Optional getDefaultValueFromAnnotation(Method method) { + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof Default.Class) { + return Optional.of(((Default.Class) annotation).value().getSimpleName()); + } else if (annotation instanceof Default.String) { + return Optional.of(((Default.String) annotation).value()); + } else if (annotation instanceof Default.Boolean) { + return Optional.of(Boolean.toString(((Default.Boolean) annotation).value())); + } else if (annotation instanceof Default.Character) { + return Optional.of(Character.toString(((Default.Character) annotation).value())); + } else if (annotation instanceof Default.Byte) { + return Optional.of(Byte.toString(((Default.Byte) annotation).value())); + } else if (annotation instanceof Default.Short) { + return Optional.of(Short.toString(((Default.Short) annotation).value())); + } else if (annotation instanceof Default.Integer) { + return Optional.of(Integer.toString(((Default.Integer) annotation).value())); + } else if (annotation instanceof Default.Long) { + return Optional.of(Long.toString(((Default.Long) annotation).value())); + } else if (annotation instanceof Default.Float) { + return Optional.of(Float.toString(((Default.Float) annotation).value())); + } else if (annotation instanceof Default.Double) { + return Optional.of(Double.toString(((Default.Double) annotation).value())); + } else if (annotation instanceof Default.Enum) { + return Optional.of(((Default.Enum) annotation).value()); + } else if (annotation instanceof Default.InstanceFactory) { + return Optional.of(((Default.InstanceFactory) annotation).value().getSimpleName()); + } + } + return Optional.absent(); + } + + static Map>> getRegisteredRunners() { + return SUPPORTED_PIPELINE_RUNNERS; + } + + static List getPropertyDescriptors( + Set> interfaces) { + return COMBINED_CACHE.get(interfaces).getPropertyDescriptors(); + } + + /** + * Creates a set of Dataflow worker harness options based of a set of known system + * properties. This is meant to only be used from the Dataflow worker harness as a method to + * bootstrap the worker harness. + * + *

    For internal use only. + * + * @return A {@link DataflowWorkerHarnessOptions} object configured for the + * Dataflow worker harness. + */ + public static DataflowWorkerHarnessOptions createFromSystemPropertiesInternal() + throws IOException { + return createFromSystemProperties(); + } + + /** + * Creates a set of {@link DataflowWorkerHarnessOptions} based of a set of known system + * properties. This is meant to only be used from the Dataflow worker harness as a method to + * bootstrap the worker harness. + * + * @return A {@link DataflowWorkerHarnessOptions} object configured for the + * Dataflow worker harness. + * @deprecated for internal use only + */ + @Deprecated + public static DataflowWorkerHarnessOptions createFromSystemProperties() throws IOException { + ObjectMapper objectMapper = new ObjectMapper(); + DataflowWorkerHarnessOptions options; + if (System.getProperties().containsKey("sdk_pipeline_options")) { + String serializedOptions = System.getProperty("sdk_pipeline_options"); + LOG.info("Worker harness starting with: " + serializedOptions); + options = objectMapper.readValue(serializedOptions, PipelineOptions.class) + .as(DataflowWorkerHarnessOptions.class); + } else { + options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); + } + + // These values will not be known at job submission time and must be provided. + if (System.getProperties().containsKey("worker_id")) { + options.setWorkerId(System.getProperty("worker_id")); + } + if (System.getProperties().containsKey("job_id")) { + options.setJobId(System.getProperty("job_id")); + } + + return options; + } + + /** + * This method is meant to emulate the behavior of {@link Introspector#getBeanInfo(Class, int)} + * to construct the list of {@link PropertyDescriptor}. + * + *

    TODO: Swap back to using Introspector once the proxy class issue with AppEngine is + * resolved. + */ + private static List getPropertyDescriptors(Class beanClass) + throws IntrospectionException { + // The sorting is important to make this method stable. + SortedSet methods = Sets.newTreeSet(MethodComparator.INSTANCE); + methods.addAll(Arrays.asList(beanClass.getMethods())); + SortedMap propertyNamesToGetters = getPropertyNamesToGetters(methods); + List descriptors = Lists.newArrayList(); + + List mismatches = new ArrayList<>(); + /* + * Add all the getter/setter pairs to the list of descriptors removing the getter once + * it has been paired up. + */ + for (Method method : methods) { + String methodName = method.getName(); + if (!methodName.startsWith("set") + || method.getParameterTypes().length != 1 + || method.getReturnType() != void.class) { + continue; + } + String propertyName = Introspector.decapitalize(methodName.substring(3)); + Method getterMethod = propertyNamesToGetters.remove(propertyName); + + // Validate that the getter and setter property types are the same. + if (getterMethod != null) { + Class getterPropertyType = getterMethod.getReturnType(); + Class setterPropertyType = method.getParameterTypes()[0]; + if (getterPropertyType != setterPropertyType) { + TypeMismatch mismatch = new TypeMismatch(); + mismatch.propertyName = propertyName; + mismatch.getterPropertyType = getterPropertyType; + mismatch.setterPropertyType = setterPropertyType; + mismatches.add(mismatch); + continue; + } + } + + descriptors.add(new PropertyDescriptor( + propertyName, getterMethod, method)); + } + throwForTypeMismatches(mismatches); + + // Add the remaining getters with missing setters. + for (Map.Entry getterToMethod : propertyNamesToGetters.entrySet()) { + descriptors.add(new PropertyDescriptor( + getterToMethod.getKey(), getterToMethod.getValue(), null)); + } + return descriptors; + } + + private static class TypeMismatch { + private String propertyName; + private Class getterPropertyType; + private Class setterPropertyType; + } + + private static void throwForTypeMismatches(List mismatches) { + if (mismatches.size() == 1) { + TypeMismatch mismatch = mismatches.get(0); + throw new IllegalArgumentException(String.format( + "Type mismatch between getter and setter methods for property [%s]. " + + "Getter is of type [%s] whereas setter is of type [%s].", + mismatch.propertyName, + mismatch.getterPropertyType.getName(), + mismatch.setterPropertyType.getName())); + } else if (mismatches.size() > 1) { + StringBuilder builder = new StringBuilder( + String.format("Type mismatches between getters and setters detected:")); + for (TypeMismatch mismatch : mismatches) { + builder.append(String.format( + "%n - Property [%s]: Getter is of type [%s] whereas setter is of type [%s].", + mismatch.propertyName, + mismatch.getterPropertyType.getName(), + mismatch.setterPropertyType.getName())); + } + throw new IllegalArgumentException(builder.toString()); + } + } + + /** + * Returns a map of the property name to the getter method it represents. + * If there are duplicate methods with the same bean name, then it is indeterminate + * as to which method will be returned. + */ + private static SortedMap getPropertyNamesToGetters(Iterable methods) { + SortedMap propertyNamesToGetters = Maps.newTreeMap(); + for (Method method : methods) { + String methodName = method.getName(); + if ((!methodName.startsWith("get") + && !methodName.startsWith("is")) + || method.getParameterTypes().length != 0 + || method.getReturnType() == void.class) { + continue; + } + String propertyName = Introspector.decapitalize( + methodName.startsWith("is") ? methodName.substring(2) : methodName.substring(3)); + propertyNamesToGetters.put(propertyName, method); + } + return propertyNamesToGetters; + } + + /** + * Returns a map of required groups of arguments to the properties that satisfy the requirement. + */ + private static SortedSetMultimap getRequiredGroupNamesToProperties( + Map propertyNamesToGetters) { + SortedSetMultimap result = TreeMultimap.create(); + for (Map.Entry propertyEntry : propertyNamesToGetters.entrySet()) { + Required requiredAnnotation = + propertyEntry.getValue().getAnnotation(Validation.Required.class); + if (requiredAnnotation != null) { + for (String groupName : requiredAnnotation.groups()) { + result.put(groupName, propertyEntry.getKey()); + } + } + } + return result; + } + + /** + * Validates that a given class conforms to the following properties: + *

      + *
    • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
    • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
    • Every method must conform to being a getter or setter for a JavaBean. + *
    • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. + *
    • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for + * this property must be annotated with {@link JsonIgnore @JsonIgnore}. + *
    + * + * @param iface The interface to validate. + * @param validatedPipelineOptionsInterfaces The set of validated pipeline options interfaces to + * validate against. + * @param klass The proxy class representing the interface. + * @return A list of {@link PropertyDescriptor}s representing all valid bean properties of + * {@code iface}. + * @throws IntrospectionException if invalid property descriptors. + */ + private static List validateClass(Class iface, + Set> validatedPipelineOptionsInterfaces, + Class klass) throws IntrospectionException { + Set methods = Sets.newHashSet(IGNORED_METHODS); + // Ignore static methods, "equals", "hashCode", "toString" and "as" on the generated class. + for (Method method : klass.getMethods()) { + if (Modifier.isStatic(method.getModifiers())) { + methods.add(method); + } + } + try { + methods.add(klass.getMethod("equals", Object.class)); + methods.add(klass.getMethod("hashCode")); + methods.add(klass.getMethod("toString")); + methods.add(klass.getMethod("as", Class.class)); + methods.add(klass.getMethod("cloneAs", Class.class)); + } catch (NoSuchMethodException | SecurityException e) { + throw Throwables.propagate(e); + } + + // Verify that there are no methods with the same name with two different return types. + Iterable interfaceMethods = FluentIterable + .from(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) + .toSortedSet(MethodComparator.INSTANCE); + SortedSetMultimap methodNameToMethodMap = + TreeMultimap.create(MethodNameComparator.INSTANCE, MethodComparator.INSTANCE); + for (Method method : interfaceMethods) { + methodNameToMethodMap.put(method, method); + } + List multipleDefinitions = Lists.newArrayList(); + for (Map.Entry> entry + : methodNameToMethodMap.asMap().entrySet()) { + Set> returnTypes = FluentIterable.from(entry.getValue()) + .transform(ReturnTypeFetchingFunction.INSTANCE).toSet(); + SortedSet collidingMethods = FluentIterable.from(entry.getValue()) + .toSortedSet(MethodComparator.INSTANCE); + if (returnTypes.size() > 1) { + MultipleDefinitions defs = new MultipleDefinitions(); + defs.method = entry.getKey(); + defs.collidingMethods = collidingMethods; + multipleDefinitions.add(defs); + } + } + throwForMultipleDefinitions(iface, multipleDefinitions); + + // Verify that there is no getter with a mixed @JsonIgnore annotation and verify + // that no setter has @JsonIgnore. + Iterable allInterfaceMethods = FluentIterable + .from(ReflectHelpers.getClosureOfMethodsOnInterfaces(validatedPipelineOptionsInterfaces)) + .append(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) + .toSortedSet(MethodComparator.INSTANCE); + SortedSetMultimap methodNameToAllMethodMap = + TreeMultimap.create(MethodNameComparator.INSTANCE, MethodComparator.INSTANCE); + for (Method method : allInterfaceMethods) { + methodNameToAllMethodMap.put(method, method); + } + + List descriptors = getPropertyDescriptors(klass); + + List incompletelyIgnoredGetters = new ArrayList<>(); + List ignoredSetters = new ArrayList<>(); + + for (PropertyDescriptor descriptor : descriptors) { + if (descriptor.getReadMethod() == null + || descriptor.getWriteMethod() == null + || IGNORED_METHODS.contains(descriptor.getReadMethod()) + || IGNORED_METHODS.contains(descriptor.getWriteMethod())) { + continue; + } + SortedSet getters = methodNameToAllMethodMap.get(descriptor.getReadMethod()); + SortedSet gettersWithJsonIgnore = Sets.filter(getters, JsonIgnorePredicate.INSTANCE); + + Iterable getterClassNames = FluentIterable.from(getters) + .transform(MethodToDeclaringClassFunction.INSTANCE) + .transform(ReflectHelpers.CLASS_NAME); + Iterable gettersWithJsonIgnoreClassNames = FluentIterable.from(gettersWithJsonIgnore) + .transform(MethodToDeclaringClassFunction.INSTANCE) + .transform(ReflectHelpers.CLASS_NAME); + + if (!(gettersWithJsonIgnore.isEmpty() || getters.size() == gettersWithJsonIgnore.size())) { + InconsistentlyIgnoredGetters err = new InconsistentlyIgnoredGetters(); + err.descriptor = descriptor; + err.getterClassNames = getterClassNames; + err.gettersWithJsonIgnoreClassNames = gettersWithJsonIgnoreClassNames; + incompletelyIgnoredGetters.add(err); + } + if (!incompletelyIgnoredGetters.isEmpty()) { + continue; + } + + SortedSet settersWithJsonIgnore = + Sets.filter(methodNameToAllMethodMap.get(descriptor.getWriteMethod()), + JsonIgnorePredicate.INSTANCE); + + Iterable settersWithJsonIgnoreClassNames = FluentIterable.from(settersWithJsonIgnore) + .transform(MethodToDeclaringClassFunction.INSTANCE) + .transform(ReflectHelpers.CLASS_NAME); + + if (!settersWithJsonIgnore.isEmpty()) { + IgnoredSetter ignored = new IgnoredSetter(); + ignored.descriptor = descriptor; + ignored.settersWithJsonIgnoreClassNames = settersWithJsonIgnoreClassNames; + ignoredSetters.add(ignored); + } + } + throwForGettersWithInconsistentJsonIgnore(incompletelyIgnoredGetters); + throwForSettersWithJsonIgnore(ignoredSetters); + + List missingBeanMethods = new ArrayList<>(); + // Verify that each property has a matching read and write method. + for (PropertyDescriptor propertyDescriptor : descriptors) { + if (!(IGNORED_METHODS.contains(propertyDescriptor.getWriteMethod()) + || propertyDescriptor.getReadMethod() != null)) { + MissingBeanMethod method = new MissingBeanMethod(); + method.property = propertyDescriptor; + method.methodType = "getter"; + missingBeanMethods.add(method); + continue; + } + if (!(IGNORED_METHODS.contains(propertyDescriptor.getReadMethod()) + || propertyDescriptor.getWriteMethod() != null)) { + MissingBeanMethod method = new MissingBeanMethod(); + method.property = propertyDescriptor; + method.methodType = "setter"; + missingBeanMethods.add(method); + continue; + } + methods.add(propertyDescriptor.getReadMethod()); + methods.add(propertyDescriptor.getWriteMethod()); + } + throwForMissingBeanMethod(iface, missingBeanMethods); + + // Verify that no additional methods are on an interface that aren't a bean property. + SortedSet unknownMethods = new TreeSet<>(MethodComparator.INSTANCE); + unknownMethods.addAll(Sets.difference(Sets.newHashSet(klass.getMethods()), methods)); + Preconditions.checkArgument(unknownMethods.isEmpty(), + "Methods %s on [%s] do not conform to being bean properties.", + FluentIterable.from(unknownMethods).transform(ReflectHelpers.METHOD_FORMATTER), + iface.getName()); + + return descriptors; + } + + private static class MultipleDefinitions { + private Method method; + private SortedSet collidingMethods; + } + + private static void throwForMultipleDefinitions( + Class iface, List definitions) { + if (definitions.size() == 1) { + MultipleDefinitions errDef = definitions.get(0); + throw new IllegalArgumentException(String.format( + "Method [%s] has multiple definitions %s with different return types for [%s].", + errDef.method.getName(), errDef.collidingMethods, iface.getName())); + } else if (definitions.size() > 1) { + StringBuilder errorBuilder = new StringBuilder(String.format( + "Interface [%s] has Methods with multiple definitions with different return types:", + iface.getName())); + for (MultipleDefinitions errDef : definitions) { + errorBuilder.append(String.format( + "%n - Method [%s] has multiple definitions %s", + errDef.method.getName(), + errDef.collidingMethods)); + } + throw new IllegalArgumentException(errorBuilder.toString()); + } + } + + private static class InconsistentlyIgnoredGetters { + PropertyDescriptor descriptor; + Iterable getterClassNames; + Iterable gettersWithJsonIgnoreClassNames; + } + + private static void throwForGettersWithInconsistentJsonIgnore( + List getters) { + if (getters.size() == 1) { + InconsistentlyIgnoredGetters getter = getters.get(0); + throw new IllegalArgumentException(String.format( + "Expected getter for property [%s] to be marked with @JsonIgnore on all %s, " + + "found only on %s", + getter.descriptor.getName(), getter.getterClassNames, + getter.gettersWithJsonIgnoreClassNames)); + } else if (getters.size() > 1) { + StringBuilder errorBuilder = + new StringBuilder("Property getters are inconsistently marked with @JsonIgnore:"); + for (InconsistentlyIgnoredGetters getter : getters) { + errorBuilder.append( + String.format("%n - Expected for property [%s] to be marked on all %s, " + + "found only on %s", + getter.descriptor.getName(), getter.getterClassNames, + getter.gettersWithJsonIgnoreClassNames)); + } + throw new IllegalArgumentException(errorBuilder.toString()); + } + } + + private static class IgnoredSetter { + PropertyDescriptor descriptor; + Iterable settersWithJsonIgnoreClassNames; + } + + private static void throwForSettersWithJsonIgnore(List setters) { + if (setters.size() == 1) { + IgnoredSetter setter = setters.get(0); + throw new IllegalArgumentException( + String.format("Expected setter for property [%s] to not be marked with @JsonIgnore on %s", + setter.descriptor.getName(), setter.settersWithJsonIgnoreClassNames)); + } else if (setters.size() > 1) { + StringBuilder builder = new StringBuilder("Found setters marked with @JsonIgnore:"); + for (IgnoredSetter setter : setters) { + builder.append( + String.format("%n - Setter for property [%s] should not be marked with @JsonIgnore " + + "on %s", + setter.descriptor.getName(), setter.settersWithJsonIgnoreClassNames)); + } + throw new IllegalArgumentException(builder.toString()); + } + } + + private static class MissingBeanMethod { + String methodType; + PropertyDescriptor property; + } + + private static void throwForMissingBeanMethod( + Class iface, List missingBeanMethods) { + if (missingBeanMethods.size() == 1) { + MissingBeanMethod missingBeanMethod = missingBeanMethods.get(0); + throw new IllegalArgumentException( + String.format("Expected %s for property [%s] of type [%s] on [%s].", + missingBeanMethod.methodType, missingBeanMethod.property.getName(), + missingBeanMethod.property.getPropertyType().getName(), iface.getName())); + } else if (missingBeanMethods.size() > 1) { + StringBuilder builder = new StringBuilder(String.format( + "Found missing property methods on [%s]:", iface.getName())); + for (MissingBeanMethod method : missingBeanMethods) { + builder.append( + String.format("%n - Expected %s for property [%s] of type [%s]", method.methodType, + method.property.getName(), method.property.getPropertyType().getName())); + } + throw new IllegalArgumentException(builder.toString()); + } + } + + /** A {@link Comparator} that uses the classes name to compare them. */ + private static class ClassNameComparator implements Comparator> { + static final ClassNameComparator INSTANCE = new ClassNameComparator(); + @Override + public int compare(Class o1, Class o2) { + return o1.getName().compareTo(o2.getName()); + } + } + + /** A {@link Comparator} that uses the object's classes canonical name to compare them. */ + private static class ObjectsClassComparator implements Comparator { + static final ObjectsClassComparator INSTANCE = new ObjectsClassComparator(); + @Override + public int compare(Object o1, Object o2) { + return o1.getClass().getCanonicalName().compareTo(o2.getClass().getCanonicalName()); + } + } + + /** A {@link Comparator} that uses the generic method signature to sort them. */ + private static class MethodComparator implements Comparator { + static final MethodComparator INSTANCE = new MethodComparator(); + @Override + public int compare(Method o1, Method o2) { + return o1.toGenericString().compareTo(o2.toGenericString()); + } + } + + /** A {@link Comparator} that uses the methods name to compare them. */ + static class MethodNameComparator implements Comparator { + static final MethodNameComparator INSTANCE = new MethodNameComparator(); + @Override + public int compare(Method o1, Method o2) { + return o1.getName().compareTo(o2.getName()); + } + } + + /** A {@link Function} that gets the method's return type. */ + private static class ReturnTypeFetchingFunction implements Function> { + static final ReturnTypeFetchingFunction INSTANCE = new ReturnTypeFetchingFunction(); + @Override + public Class apply(Method input) { + return input.getReturnType(); + } + } + + /** A {@link Function} with returns the declaring class for the method. */ + private static class MethodToDeclaringClassFunction implements Function> { + static final MethodToDeclaringClassFunction INSTANCE = new MethodToDeclaringClassFunction(); + @Override + public Class apply(Method input) { + return input.getDeclaringClass(); + } + } + + /** + * A {@link Predicate} that returns true if the method is annotated with + * {@link JsonIgnore @JsonIgnore}. + */ + static class JsonIgnorePredicate implements Predicate { + static final JsonIgnorePredicate INSTANCE = new JsonIgnorePredicate(); + @Override + public boolean apply(Method input) { + return input.isAnnotationPresent(JsonIgnore.class); + } + } + + /** + * Splits string arguments based upon expected pattern of --argName=value. + * + *

    Example GNU style command line arguments: + * + *

    +   *   --project=MyProject (simple property, will set the "project" property to "MyProject")
    +   *   --readOnly=true (for boolean properties, will set the "readOnly" property to "true")
    +   *   --readOnly (shorthand for boolean properties, will set the "readOnly" property to "true")
    +   *   --x=1 --x=2 --x=3 (list style simple property, will set the "x" property to [1, 2, 3])
    +   *   --x=1,2,3 (shorthand list style simple property, will set the "x" property to [1, 2, 3])
    +   *   --complexObject='{"key1":"value1",...} (JSON format for all other complex types)
    +   * 
    + * + *

    Simple properties are able to bound to {@link String}, {@link Class}, enums and Java + * primitives {@code boolean}, {@code byte}, {@code short}, {@code int}, {@code long}, + * {@code float}, {@code double} and their primitive wrapper classes. + * + *

    Simple list style properties are able to be bound to {@code boolean[]}, {@code char[]}, + * {@code short[]}, {@code int[]}, {@code long[]}, {@code float[]}, {@code double[]}, + * {@code Class[]}, enum arrays, {@code String[]}, and {@code List}. + * + *

    JSON format is required for all other types. + * + *

    If strict parsing is enabled, options must start with '--', and not have an empty argument + * name or value based upon the positioning of the '='. Empty or null arguments will be ignored + * whether or not strict parsing is enabled. + */ + private static ListMultimap parseCommandLine( + String[] args, boolean strictParsing) { + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + for (String arg : args) { + if (Strings.isNullOrEmpty(arg)) { + continue; + } + try { + Preconditions.checkArgument(arg.startsWith("--"), + "Argument '%s' does not begin with '--'", arg); + int index = arg.indexOf("="); + // Make sure that '=' isn't the first character after '--' or the last character + Preconditions.checkArgument(index != 2, + "Argument '%s' starts with '--=', empty argument name not allowed", arg); + if (index > 0) { + builder.put(arg.substring(2, index), arg.substring(index + 1, arg.length())); + } else { + builder.put(arg.substring(2), "true"); + } + } catch (IllegalArgumentException e) { + if (strictParsing) { + throw e; + } else { + LOG.warn("Strict parsing is disabled, ignoring option '{}' because {}", + arg, e.getMessage()); + } + } + } + return builder.build(); + } + + /** + * Using the parsed string arguments, we convert the strings to the expected + * return type of the methods that are found on the passed-in class. + * + *

    For any return type that is expected to be an array or a collection, we further + * split up each string on ','. + * + *

    We special case the "runner" option. It is mapped to the class of the {@link PipelineRunner} + * based off of the {@link PipelineRunner}s simple class name or fully qualified class name. + * + *

    If strict parsing is enabled, unknown options or options that cannot be converted to + * the expected java type using an {@link ObjectMapper} will be ignored. + */ + private static Map parseObjects( + Class klass, ListMultimap options, boolean strictParsing) { + Map propertyNamesToGetters = Maps.newHashMap(); + PipelineOptionsFactory.validateWellFormed(klass, REGISTERED_OPTIONS); + @SuppressWarnings("unchecked") + Iterable propertyDescriptors = + PipelineOptionsFactory.getPropertyDescriptors( + FluentIterable.from(getRegisteredOptions()).append(klass).toSet()); + for (PropertyDescriptor descriptor : propertyDescriptors) { + propertyNamesToGetters.put(descriptor.getName(), descriptor.getReadMethod()); + } + Map convertedOptions = Maps.newHashMap(); + for (final Map.Entry> entry : options.asMap().entrySet()) { + try { + // Search for close matches for missing properties. + // Either off by one or off by two character errors. + if (!propertyNamesToGetters.containsKey(entry.getKey())) { + SortedSet closestMatches = new TreeSet( + Sets.filter(propertyNamesToGetters.keySet(), new Predicate() { + @Override + public boolean apply(@Nullable String input) { + return StringUtils.getLevenshteinDistance(entry.getKey(), input) <= 2; + } + })); + switch (closestMatches.size()) { + case 0: + throw new IllegalArgumentException( + String.format("Class %s missing a property named '%s'.", + klass, entry.getKey())); + case 1: + throw new IllegalArgumentException( + String.format("Class %s missing a property named '%s'. Did you mean '%s'?", + klass, entry.getKey(), Iterables.getOnlyElement(closestMatches))); + default: + throw new IllegalArgumentException( + String.format("Class %s missing a property named '%s'. Did you mean one of %s?", + klass, entry.getKey(), closestMatches)); + } + } + + Method method = propertyNamesToGetters.get(entry.getKey()); + // Only allow empty argument values for String, String Array, and Collection. + Class returnType = method.getReturnType(); + JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); + if ("runner".equals(entry.getKey())) { + String runner = Iterables.getOnlyElement(entry.getValue()); + Preconditions.checkArgument(SUPPORTED_PIPELINE_RUNNERS.containsKey(runner), + "Unknown 'runner' specified '%s', supported pipeline runners %s", + runner, Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + } else if ((returnType.isArray() && (SIMPLE_TYPES.contains(returnType.getComponentType()) + || returnType.getComponentType().isEnum())) + || Collection.class.isAssignableFrom(returnType)) { + // Split any strings with "," + List values = FluentIterable.from(entry.getValue()) + .transformAndConcat(new Function>() { + @Override + public Iterable apply(String input) { + return Arrays.asList(input.split(",")); + } + }).toList(); + + if (returnType.isArray() && !returnType.getComponentType().equals(String.class)) { + for (String value : values) { + Preconditions.checkArgument(!value.isEmpty(), + "Empty argument value is only allowed for String, String Array, and Collection," + + " but received: " + returnType); + } + } + convertedOptions.put(entry.getKey(), MAPPER.convertValue(values, type)); + } else if (SIMPLE_TYPES.contains(returnType) || returnType.isEnum()) { + String value = Iterables.getOnlyElement(entry.getValue()); + Preconditions.checkArgument(returnType.equals(String.class) || !value.isEmpty(), + "Empty argument value is only allowed for String, String Array, and Collection," + + " but received: " + returnType); + convertedOptions.put(entry.getKey(), MAPPER.convertValue(value, type)); + } else { + String value = Iterables.getOnlyElement(entry.getValue()); + Preconditions.checkArgument(returnType.equals(String.class) || !value.isEmpty(), + "Empty argument value is only allowed for String, String Array, and Collection," + + " but received: " + returnType); + try { + convertedOptions.put(entry.getKey(), MAPPER.readValue(value, type)); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse JSON value " + value, e); + } + } + } catch (IllegalArgumentException e) { + if (strictParsing) { + throw e; + } else { + LOG.warn("Strict parsing is disabled, ignoring option '{}' with value '{}' because {}", + entry.getKey(), entry.getValue(), e.getMessage()); + } + } + } + return convertedOptions; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsRegistrar.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsRegistrar.java new file mode 100644 index 000000000000..1678541bef0c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsRegistrar.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.auto.service.AutoService; + +import java.util.ServiceLoader; + +/** + * {@link PipelineOptions} creators have the ability to automatically have their + * {@link PipelineOptions} registered with this SDK by creating a {@link ServiceLoader} entry + * and a concrete implementation of this interface. + * + *

    Note that automatic registration of any {@link PipelineOptions} requires users + * conform to the limitations discussed on {@link PipelineOptionsFactory#register(Class)}. + * + *

    It is optional but recommended to use one of the many build time tools such as + * {@link AutoService} to generate the necessary META-INF files automatically. + */ +public interface PipelineOptionsRegistrar { + Iterable> getPipelineOptions(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidator.java new file mode 100644 index 000000000000..b5612c40a325 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidator.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.options.Validation.Required; +import com.google.cloud.dataflow.sdk.util.common.ReflectHelpers; +import com.google.common.base.Preconditions; +import com.google.common.collect.Collections2; +import com.google.common.collect.Ordering; +import com.google.common.collect.SortedSetMultimap; +import com.google.common.collect.TreeMultimap; + +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.Collection; + +/** + * Validates that the {@link PipelineOptions} conforms to all the {@link Validation} criteria. + */ +public class PipelineOptionsValidator { + /** + * Validates that the passed {@link PipelineOptions} conforms to all the validation criteria from + * the passed in interface. + * + *

    Note that the interface requested must conform to the validation criteria specified on + * {@link PipelineOptions#as(Class)}. + * + * @param klass The interface to fetch validation criteria from. + * @param options The {@link PipelineOptions} to validate. + * @return The type + */ + public static T validate(Class klass, PipelineOptions options) { + Preconditions.checkNotNull(klass); + Preconditions.checkNotNull(options); + Preconditions.checkArgument(Proxy.isProxyClass(options.getClass())); + Preconditions.checkArgument(Proxy.getInvocationHandler(options) + instanceof ProxyInvocationHandler); + + // Ensure the methods for T are registered on the ProxyInvocationHandler + T asClassOptions = options.as(klass); + + ProxyInvocationHandler handler = + (ProxyInvocationHandler) Proxy.getInvocationHandler(asClassOptions); + + SortedSetMultimap requiredGroups = TreeMultimap.create( + Ordering.natural(), PipelineOptionsFactory.MethodNameComparator.INSTANCE); + for (Method method : ReflectHelpers.getClosureOfMethodsOnInterface(klass)) { + Required requiredAnnotation = method.getAnnotation(Validation.Required.class); + if (requiredAnnotation != null) { + if (requiredAnnotation.groups().length > 0) { + for (String requiredGroup : requiredAnnotation.groups()) { + requiredGroups.put(requiredGroup, method); + } + } else { + Preconditions.checkArgument(handler.invoke(asClassOptions, method, null) != null, + "Missing required value for [" + method + ", \"" + getDescription(method) + "\"]. "); + } + } + } + + for (String requiredGroup : requiredGroups.keySet()) { + if (!verifyGroup(handler, asClassOptions, requiredGroups.get(requiredGroup))) { + throw new IllegalArgumentException("Missing required value for group [" + requiredGroup + + "]. At least one of the following properties " + + Collections2.transform( + requiredGroups.get(requiredGroup), ReflectHelpers.METHOD_FORMATTER) + + " required. Run with --help=" + klass.getSimpleName() + " for more information."); + } + } + + return asClassOptions; + } + + private static boolean verifyGroup(ProxyInvocationHandler handler, PipelineOptions options, + Collection requiredGroup) { + for (Method m : requiredGroup) { + if (handler.invoke(options, m, null) != null) { + return true; + } + } + return false; + } + + private static String getDescription(Method method) { + Description description = method.getAnnotation(Description.class); + return description == null ? "" : description.value(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandler.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandler.java new file mode 100644 index 000000000000..527f712ca49f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandler.java @@ -0,0 +1,441 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory.JsonIgnorePredicate; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory.Registration; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.common.ReflectHelpers; +import com.google.common.base.Defaults; +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ClassToInstanceMap; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.collect.MutableClassToInstanceMap; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.node.ObjectNode; + +import java.beans.PropertyDescriptor; +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.lang.reflect.Type; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * Represents and {@link InvocationHandler} for a {@link Proxy}. The invocation handler uses bean + * introspection of the proxy class to store and retrieve values based off of the property name. + * + *

    Unset properties use the {@code @Default} metadata on the getter to return values. If there + * is no {@code @Default} annotation on the getter, then a default as + * per the Java Language Specification for the expected return type is returned. + * + *

    In addition to the getter/setter pairs, this proxy invocation handler supports + * {@link Object#equals(Object)}, {@link Object#hashCode()}, {@link Object#toString()} and + * {@link PipelineOptions#as(Class)}. + */ +@ThreadSafe +class ProxyInvocationHandler implements InvocationHandler { + private static final ObjectMapper MAPPER = new ObjectMapper(); + /** + * No two instances of this class are considered equivalent hence we generate a random hash code + * between 0 and {@link Integer#MAX_VALUE}. + */ + private final int hashCode = (int) (Math.random() * Integer.MAX_VALUE); + private final Set> knownInterfaces; + private final ClassToInstanceMap interfaceToProxyCache; + private final Map options; + private final Map jsonOptions; + private final Map gettersToPropertyNames; + private final Map settersToPropertyNames; + + ProxyInvocationHandler(Map options) { + this(options, Maps.newHashMap()); + } + + private ProxyInvocationHandler(Map options, Map jsonOptions) { + this.options = options; + this.jsonOptions = jsonOptions; + this.knownInterfaces = new HashSet<>(PipelineOptionsFactory.getRegisteredOptions()); + gettersToPropertyNames = Maps.newHashMap(); + settersToPropertyNames = Maps.newHashMap(); + interfaceToProxyCache = MutableClassToInstanceMap.create(); + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + if (args == null && "toString".equals(method.getName())) { + return toString(); + } else if (args != null && args.length == 1 && "equals".equals(method.getName())) { + return equals(args[0]); + } else if (args == null && "hashCode".equals(method.getName())) { + return hashCode(); + } else if (args != null && "as".equals(method.getName()) && args[0] instanceof Class) { + @SuppressWarnings("unchecked") + Class clazz = (Class) args[0]; + return as(clazz); + } else if (args != null && "cloneAs".equals(method.getName()) && args[0] instanceof Class) { + @SuppressWarnings("unchecked") + Class clazz = (Class) args[0]; + return cloneAs(proxy, clazz); + } + String methodName = method.getName(); + synchronized (this) { + if (gettersToPropertyNames.keySet().contains(methodName)) { + String propertyName = gettersToPropertyNames.get(methodName); + if (!options.containsKey(propertyName)) { + // Lazy bind the default to the method. + Object value = jsonOptions.containsKey(propertyName) + ? getValueFromJson(propertyName, method) + : getDefault((PipelineOptions) proxy, method); + options.put(propertyName, value); + } + return options.get(propertyName); + } else if (settersToPropertyNames.containsKey(methodName)) { + options.put(settersToPropertyNames.get(methodName), args[0]); + return Void.TYPE; + } + } + throw new RuntimeException("Unknown method [" + method + "] invoked with args [" + + Arrays.toString(args) + "]."); + } + + /** + * Backing implementation for {@link PipelineOptions#as(Class)}. + * + * @param iface The interface that the returned object needs to implement. + * @return An object that implements the interface . + */ + synchronized T as(Class iface) { + Preconditions.checkNotNull(iface); + Preconditions.checkArgument(iface.isInterface()); + if (!interfaceToProxyCache.containsKey(iface)) { + Registration registration = + PipelineOptionsFactory.validateWellFormed(iface, knownInterfaces); + List propertyDescriptors = registration.getPropertyDescriptors(); + Class proxyClass = registration.getProxyClass(); + gettersToPropertyNames.putAll(generateGettersToPropertyNames(propertyDescriptors)); + settersToPropertyNames.putAll(generateSettersToPropertyNames(propertyDescriptors)); + knownInterfaces.add(iface); + interfaceToProxyCache.putInstance(iface, + InstanceBuilder.ofType(proxyClass) + .fromClass(proxyClass) + .withArg(InvocationHandler.class, this) + .build()); + } + return interfaceToProxyCache.getInstance(iface); + } + + /** + * Backing implementation for {@link PipelineOptions#cloneAs(Class)}. + * + * @return A copy of the PipelineOptions. + */ + synchronized T cloneAs(Object proxy, Class iface) { + PipelineOptions clonedOptions; + try { + clonedOptions = MAPPER.readValue(MAPPER.writeValueAsBytes(proxy), PipelineOptions.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to serialize the pipeline options to JSON.", e); + } + for (Class knownIface : knownInterfaces) { + clonedOptions.as(knownIface); + } + return clonedOptions.as(iface); + } + + /** + * Returns true if the other object is a ProxyInvocationHandler or is a Proxy object and has the + * same ProxyInvocationHandler as this. + * + * @param obj The object to compare against this. + * @return true iff the other object is a ProxyInvocationHandler or is a Proxy object and has the + * same ProxyInvocationHandler as this. + */ + @Override + public boolean equals(Object obj) { + return obj != null && ((obj instanceof ProxyInvocationHandler && this == obj) + || (Proxy.isProxyClass(obj.getClass()) && this == Proxy.getInvocationHandler(obj))); + } + + /** + * Each instance of this ProxyInvocationHandler is unique and has a random hash code. + * + * @return A hash code that was generated randomly. + */ + @Override + public int hashCode() { + return hashCode; + } + + /** + * This will output all the currently set values. This is a relatively costly function + * as it will call {@code toString()} on each object that has been set and format + * the results in a readable format. + * + * @return A pretty printed string representation of this. + */ + @Override + public synchronized String toString() { + SortedMap sortedOptions = new TreeMap<>(); + // Add the options that we received from deserialization + sortedOptions.putAll(jsonOptions); + // Override with any programmatically set options. + sortedOptions.putAll(options); + + StringBuilder b = new StringBuilder(); + b.append("Current Settings:\n"); + for (Map.Entry entry : sortedOptions.entrySet()) { + b.append(" " + entry.getKey() + ": " + entry.getValue() + "\n"); + } + return b.toString(); + } + + /** + * Uses a Jackson {@link ObjectMapper} to attempt type conversion. + * + * @param method The method whose return type you would like to return. + * @param propertyName The name of the property that is being returned. + * @return An object matching the return type of the method passed in. + */ + private Object getValueFromJson(String propertyName, Method method) { + try { + JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); + JsonNode jsonNode = jsonOptions.get(propertyName); + return MAPPER.readValue(jsonNode.toString(), type); + } catch (IOException e) { + throw new RuntimeException("Unable to parse representation", e); + } + } + + /** + * Returns a default value for the method based upon {@code @Default} metadata on the getter + * to return values. If there is no {@code @Default} annotation on the getter, then a default as + * per the Java Language Specification for the expected return type is returned. + * + * @param proxy The proxy object for which we are attempting to get the default. + * @param method The getter method that was invoked. + * @return The default value from an {@link Default} annotation if present, otherwise a default + * value as per the Java Language Specification. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + private Object getDefault(PipelineOptions proxy, Method method) { + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof Default.Class) { + return ((Default.Class) annotation).value(); + } else if (annotation instanceof Default.String) { + return ((Default.String) annotation).value(); + } else if (annotation instanceof Default.Boolean) { + return ((Default.Boolean) annotation).value(); + } else if (annotation instanceof Default.Character) { + return ((Default.Character) annotation).value(); + } else if (annotation instanceof Default.Byte) { + return ((Default.Byte) annotation).value(); + } else if (annotation instanceof Default.Short) { + return ((Default.Short) annotation).value(); + } else if (annotation instanceof Default.Integer) { + return ((Default.Integer) annotation).value(); + } else if (annotation instanceof Default.Long) { + return ((Default.Long) annotation).value(); + } else if (annotation instanceof Default.Float) { + return ((Default.Float) annotation).value(); + } else if (annotation instanceof Default.Double) { + return ((Default.Double) annotation).value(); + } else if (annotation instanceof Default.Enum) { + return Enum.valueOf((Class) method.getReturnType(), + ((Default.Enum) annotation).value()); + } else if (annotation instanceof Default.InstanceFactory) { + return InstanceBuilder.ofType(((Default.InstanceFactory) annotation).value()) + .build() + .create(proxy); + } + } + + /* + * We need to make sure that we return something appropriate for the return type. Thus we return + * a default value as defined by the JLS. + */ + return Defaults.defaultValue(method.getReturnType()); + } + + /** + * Returns a map from the getters method name to the name of the property based upon the passed in + * {@link PropertyDescriptor}s property descriptors. + * + * @param propertyDescriptors A list of {@link PropertyDescriptor}s to use when generating the + * map. + * @return A map of getter method name to property name. + */ + private static Map generateGettersToPropertyNames( + List propertyDescriptors) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (PropertyDescriptor descriptor : propertyDescriptors) { + if (descriptor.getReadMethod() != null) { + builder.put(descriptor.getReadMethod().getName(), descriptor.getName()); + } + } + return builder.build(); + } + + /** + * Returns a map from the setters method name to its matching getters method name based upon the + * passed in {@link PropertyDescriptor}s property descriptors. + * + * @param propertyDescriptors A list of {@link PropertyDescriptor}s to use when generating the + * map. + * @return A map of setter method name to getter method name. + */ + private static Map generateSettersToPropertyNames( + List propertyDescriptors) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (PropertyDescriptor descriptor : propertyDescriptors) { + if (descriptor.getWriteMethod() != null) { + builder.put(descriptor.getWriteMethod().getName(), descriptor.getName()); + } + } + return builder.build(); + } + + static class Serializer extends JsonSerializer { + @Override + public void serialize(PipelineOptions value, JsonGenerator jgen, SerializerProvider provider) + throws IOException, JsonProcessingException { + ProxyInvocationHandler handler = (ProxyInvocationHandler) Proxy.getInvocationHandler(value); + synchronized (handler) { + // We first filter out any properties that have been modified since + // the last serialization of this PipelineOptions and then verify that + // they are all serializable. + Map filteredOptions = Maps.newHashMap(handler.options); + removeIgnoredOptions(handler.knownInterfaces, filteredOptions); + ensureSerializable(handler.knownInterfaces, filteredOptions); + + // Now we create the map of serializable options by taking the original + // set of serialized options (if any) and updating them with any properties + // instances that have been modified since the previous serialization. + Map serializableOptions = + Maps.newHashMap(handler.jsonOptions); + serializableOptions.putAll(filteredOptions); + jgen.writeStartObject(); + jgen.writeFieldName("options"); + jgen.writeObject(serializableOptions); + jgen.writeEndObject(); + } + } + + /** + * We remove all properties within the passed in options where there getter is annotated with + * {@link JsonIgnore @JsonIgnore} from the passed in options using the passed in interfaces. + */ + private void removeIgnoredOptions( + Set> interfaces, Map options) { + // Find all the method names that are annotated with JSON ignore. + Set jsonIgnoreMethodNames = FluentIterable.from( + ReflectHelpers.getClosureOfMethodsOnInterfaces(interfaces)) + .filter(JsonIgnorePredicate.INSTANCE).transform(new Function() { + @Override + public String apply(Method input) { + return input.getName(); + } + }).toSet(); + + // Remove all options that have the same method name as the descriptor. + for (PropertyDescriptor descriptor + : PipelineOptionsFactory.getPropertyDescriptors(interfaces)) { + if (jsonIgnoreMethodNames.contains(descriptor.getReadMethod().getName())) { + options.remove(descriptor.getName()); + } + } + } + + /** + * We use an {@link ObjectMapper} to verify that the passed in options are serializable + * and deserializable. + */ + private void ensureSerializable(Set> interfaces, + Map options) throws IOException { + // Construct a map from property name to the return type of the getter. + Map propertyToReturnType = Maps.newHashMap(); + for (PropertyDescriptor descriptor + : PipelineOptionsFactory.getPropertyDescriptors(interfaces)) { + if (descriptor.getReadMethod() != null) { + propertyToReturnType.put(descriptor.getName(), + descriptor.getReadMethod().getGenericReturnType()); + } + } + + // Attempt to serialize and deserialize each property. + for (Map.Entry entry : options.entrySet()) { + try { + String serializedValue = MAPPER.writeValueAsString(entry.getValue()); + JavaType type = MAPPER.getTypeFactory() + .constructType(propertyToReturnType.get(entry.getKey())); + MAPPER.readValue(serializedValue, type); + } catch (Exception e) { + throw new IOException(String.format( + "Failed to serialize and deserialize property '%s' with value '%s'", + entry.getKey(), entry.getValue()), e); + } + } + } + } + + static class Deserializer extends JsonDeserializer { + @Override + public PipelineOptions deserialize(JsonParser jp, DeserializationContext ctxt) + throws IOException, JsonProcessingException { + ObjectNode objectNode = (ObjectNode) jp.readValueAsTree(); + ObjectNode optionsNode = (ObjectNode) objectNode.get("options"); + + Map fields = Maps.newHashMap(); + for (Iterator> iterator = optionsNode.fields(); + iterator.hasNext(); ) { + Map.Entry field = iterator.next(); + fields.put(field.getKey(), field.getValue()); + } + PipelineOptions options = + new ProxyInvocationHandler(Maps.newHashMap(), fields) + .as(PipelineOptions.class); + return options; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/StreamingOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/StreamingOptions.java new file mode 100644 index 000000000000..9563c589046c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/StreamingOptions.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +/** + * Options used to configure streaming. + */ +public interface StreamingOptions extends + ApplicationNameOptions, GcpOptions, PipelineOptions { + /** + * Set to true if running a streaming pipeline. + */ + @Description("Set to true if running a streaming pipeline.") + boolean isStreaming(); + void setStreaming(boolean value); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Validation.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Validation.java new file mode 100644 index 000000000000..20034f83e2d7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/Validation.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * {@link Validation} represents a set of annotations that can be used to annotate getter + * properties on {@link PipelineOptions} with information representing the validation criteria to + * be used when validating with the {@link PipelineOptionsValidator}. + */ +public @interface Validation { + /** + * This criteria specifies that the value must be not null. Note that this annotation + * should only be applied to methods that return nullable objects. + */ + @Target(value = ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + public @interface Required { + /** + * The groups that the annotated attribute is a member of. A member can be in 0 or more groups. + * Members not in any groups are considered to be in a group consisting exclusively of + * themselves. At least one member of a group must be non-null if the options are to be valid. + */ + String[] groups() default {}; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/package-info.java new file mode 100644 index 000000000000..cef995f11591 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.options.PipelineOptions} for + * configuring pipeline execution. + * + *

    {@link com.google.cloud.dataflow.sdk.options.PipelineOptions} encapsulates the various + * parameters that describe how a pipeline should be run. {@code PipelineOptions} are created + * using a {@link com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory}. + */ +package com.google.cloud.dataflow.sdk.options; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/package-info.java new file mode 100644 index 000000000000..5567f038ece3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/package-info.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Provides a simple, powerful model for building both batch and + * streaming parallel data processing + * {@link com.google.cloud.dataflow.sdk.Pipeline}s. + * + *

    To use the Google Cloud Dataflow SDK, you build a + * {@link com.google.cloud.dataflow.sdk.Pipeline}, which manages a graph of + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s + * and the {@link com.google.cloud.dataflow.sdk.values.PCollection}s that + * the PTransforms consume and produce. + * + *

    Each Pipeline has a + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} to specify + * where and how it should run after pipeline construction is complete. + * + */ +package com.google.cloud.dataflow.sdk; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorPipelineExtractor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorPipelineExtractor.java new file mode 100644 index 000000000000..ab87f2ea22e2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorPipelineExtractor.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AggregatorRetriever; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.SetMultimap; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; + +/** + * Retrieves {@link Aggregator Aggregators} at each {@link ParDo} and returns a {@link Map} of + * {@link Aggregator} to the {@link PTransform PTransforms} in which it is present. + */ +public class AggregatorPipelineExtractor { + private final Pipeline pipeline; + + /** + * Creates an {@code AggregatorPipelineExtractor} for the given {@link Pipeline}. + */ + public AggregatorPipelineExtractor(Pipeline pipeline) { + this.pipeline = pipeline; + } + + /** + * Returns a {@link Map} between each {@link Aggregator} in the {@link Pipeline} to the {@link + * PTransform PTransforms} in which it is used. + */ + public Map, Collection>> getAggregatorSteps() { + HashMultimap, PTransform> aggregatorSteps = HashMultimap.create(); + pipeline.traverseTopologically(new AggregatorVisitor(aggregatorSteps)); + return aggregatorSteps.asMap(); + } + + private static class AggregatorVisitor implements PipelineVisitor { + private final SetMultimap, PTransform> aggregatorSteps; + + public AggregatorVisitor(SetMultimap, PTransform> aggregatorSteps) { + this.aggregatorSteps = aggregatorSteps; + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) {} + + @Override + public void leaveCompositeTransform(TransformTreeNode node) {} + + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + addStepToAggregators(transform, getAggregators(transform)); + } + + private Collection> getAggregators(PTransform transform) { + if (transform != null) { + if (transform instanceof ParDo.Bound) { + return AggregatorRetriever.getAggregators(((ParDo.Bound) transform).getFn()); + } else if (transform instanceof ParDo.BoundMulti) { + return AggregatorRetriever.getAggregators(((ParDo.BoundMulti) transform).getFn()); + } + } + return Collections.emptyList(); + } + + private void addStepToAggregators( + PTransform transform, Collection> aggregators) { + for (Aggregator aggregator : aggregators) { + aggregatorSteps.put(aggregator, transform); + } + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) {} + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorRetrievalException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorRetrievalException.java new file mode 100644 index 000000000000..90162aded2ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorRetrievalException.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; + +/** + * Signals that an exception has occurred while retrieving {@link Aggregator}s. + */ +public class AggregatorRetrievalException extends Exception { + /** + * Constructs a new {@code AggregatorRetrievalException} with the specified detail message and + * cause. + */ + public AggregatorRetrievalException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorValues.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorValues.java new file mode 100644 index 000000000000..21f02821c5dd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/AggregatorValues.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; + +import java.util.Collection; +import java.util.Map; + +/** + * A collection of values associated with an {@link Aggregator}. Aggregators declared in a + * {@link DoFn} are emitted on a per-{@code DoFn}-application basis. + * + * @param the output type of the aggregator + */ +public abstract class AggregatorValues { + /** + * Get the values of the {@link Aggregator} at all steps it was used. + */ + public Collection getValues() { + return getValuesAtSteps().values(); + } + + /** + * Get the values of the {@link Aggregator} by the user name at each step it was used. + */ + public abstract Map getValuesAtSteps(); + + /** + * Get the total value of this {@link Aggregator} by applying the specified {@link CombineFn}. + */ + public T getTotalValue(CombineFn combineFn) { + return combineFn.apply(getValues()); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunner.java new file mode 100644 index 000000000000..95e3dfeb91f3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunner.java @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.BlockingDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +/** + * A {@link PipelineRunner} that's like {@link DataflowPipelineRunner} + * but that waits for the launched job to finish. + * + *

    Prints out job status updates and console messages while it waits. + * + *

    Returns the final job state, or throws an exception if the job + * fails or cannot be monitored. + * + *

    Permissions

    + * When reading from a Dataflow source or writing to a Dataflow sink using + * {@code BlockingDataflowPipelineRunner}, the Google cloud services account and the Google compute + * engine service account of the GCP project running the Dataflow Job will need access to the + * corresponding source/sink. + * + *

    Please see Google Cloud + * Dataflow Security and Permissions for more details. + */ +public class BlockingDataflowPipelineRunner extends + PipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(BlockingDataflowPipelineRunner.class); + + // Defaults to an infinite wait period. + // TODO: make this configurable after removal of option map. + private static final long BUILTIN_JOB_TIMEOUT_SEC = -1L; + + private final DataflowPipelineRunner dataflowPipelineRunner; + private final BlockingDataflowPipelineOptions options; + + protected BlockingDataflowPipelineRunner( + DataflowPipelineRunner internalRunner, + BlockingDataflowPipelineOptions options) { + this.dataflowPipelineRunner = internalRunner; + this.options = options; + } + + /** + * Constructs a runner from the provided options. + */ + public static BlockingDataflowPipelineRunner fromOptions( + PipelineOptions options) { + BlockingDataflowPipelineOptions dataflowOptions = + PipelineOptionsValidator.validate(BlockingDataflowPipelineOptions.class, options); + DataflowPipelineRunner dataflowPipelineRunner = + DataflowPipelineRunner.fromOptions(dataflowOptions); + + return new BlockingDataflowPipelineRunner(dataflowPipelineRunner, dataflowOptions); + } + + /** + * {@inheritDoc} + * + * @throws DataflowJobExecutionException if there is an exception during job execution. + * @throws DataflowServiceException if there is an exception retrieving information about the job. + */ + @Override + public DataflowPipelineJob run(Pipeline p) { + final DataflowPipelineJob job = dataflowPipelineRunner.run(p); + + // We ignore the potential race condition here (Ctrl-C after job submission but before the + // shutdown hook is registered). Even if we tried to do something smarter (eg., SettableFuture) + // the run method (which produces the job) could fail or be Ctrl-C'd before it had returned a + // job. The display of the command to cancel the job is best-effort anyways -- RPC's could fail, + // etc. If the user wants to verify the job was cancelled they should look at the job status. + Thread shutdownHook = new Thread() { + @Override + public void run() { + LOG.warn("Job is already running in Google Cloud Platform, Ctrl-C will not cancel it.\n" + + "To cancel the job in the cloud, run:\n> {}", + MonitoringUtil.getGcloudCancelCommand(options, job.getJobId())); + } + }; + + try { + Runtime.getRuntime().addShutdownHook(shutdownHook); + + @Nullable + State result; + try { + result = job.waitToFinish( + BUILTIN_JOB_TIMEOUT_SEC, TimeUnit.SECONDS, + new MonitoringUtil.PrintHandler(options.getJobMessageOutput())); + } catch (IOException | InterruptedException ex) { + LOG.debug("Exception caught while retrieving status for job {}", job.getJobId(), ex); + throw new DataflowServiceException( + job, "Exception caught while retrieving status for job " + job.getJobId(), ex); + } + + if (result == null) { + throw new DataflowServiceException( + job, "Timed out while retrieving status for job " + job.getJobId()); + } + + LOG.info("Job finished with status {}", result); + if (!result.isTerminal()) { + throw new IllegalStateException("Expected terminal state for job " + job.getJobId() + + ", got " + result); + } + + if (result == State.DONE) { + return job; + } else if (result == State.UPDATED) { + DataflowPipelineJob newJob = job.getReplacedByJob(); + LOG.info("Job {} has been updated and is running as the new job with id {}." + + "To access the updated job on the Dataflow monitoring console, please navigate to {}", + job.getJobId(), + newJob.getJobId(), + MonitoringUtil.getJobMonitoringPageURL(newJob.getProjectId(), newJob.getJobId())); + throw new DataflowJobUpdatedException( + job, + String.format("Job %s updated; new job is %s.", job.getJobId(), newJob.getJobId()), + newJob); + } else if (result == State.CANCELLED) { + String message = String.format("Job %s cancelled by user", job.getJobId()); + LOG.info(message); + throw new DataflowJobCancelledException(job, message); + } else { + throw new DataflowJobExecutionException(job, "Job " + job.getJobId() + + " failed with status " + result); + } + } finally { + Runtime.getRuntime().removeShutdownHook(shutdownHook); + } + } + + @Override + public OutputT apply( + PTransform transform, InputT input) { + return dataflowPipelineRunner.apply(transform, input); + } + + /** + * Sets callbacks to invoke during execution. See {@link DataflowPipelineRunnerHooks}. + */ + @Experimental + public void setHooks(DataflowPipelineRunnerHooks hooks) { + this.dataflowPipelineRunner.setHooks(hooks); + } + + @Override + public String toString() { + return "BlockingDataflowPipelineRunner#" + options.getJobName(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobAlreadyExistsException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobAlreadyExistsException.java new file mode 100644 index 000000000000..1547f73efe4c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobAlreadyExistsException.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +/** + * An exception that is thrown if the unique job name constraint of the Dataflow + * service is broken because an existing job with the same job name is currently active. + * The {@link DataflowPipelineJob} contained within this exception contains information + * about the pre-existing job. + */ +public class DataflowJobAlreadyExistsException extends DataflowJobException { + /** + * Create a new {@code DataflowJobAlreadyExistsException} with the specified {@link + * DataflowPipelineJob} and message. + */ + public DataflowJobAlreadyExistsException( + DataflowPipelineJob job, String message) { + super(job, message, null); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobAlreadyUpdatedException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobAlreadyUpdatedException.java new file mode 100644 index 000000000000..d4ae4f514df3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobAlreadyUpdatedException.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +/** + * An exception that is thrown if the existing job has already been updated within the Dataflow + * service and is no longer able to be updated. The {@link DataflowPipelineJob} contained within + * this exception contains information about the pre-existing updated job. + */ +public class DataflowJobAlreadyUpdatedException extends DataflowJobException { + /** + * Create a new {@code DataflowJobAlreadyUpdatedException} with the specified {@link + * DataflowPipelineJob} and message. + */ + public DataflowJobAlreadyUpdatedException( + DataflowPipelineJob job, String message) { + super(job, message, null); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobCancelledException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobCancelledException.java new file mode 100644 index 000000000000..0d31726ee9e2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobCancelledException.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +/** + * Signals that a job run by a {@link BlockingDataflowPipelineRunner} was updated during execution. + */ +public class DataflowJobCancelledException extends DataflowJobException { + /** + * Create a new {@code DataflowJobAlreadyUpdatedException} with the specified {@link + * DataflowPipelineJob} and message. + */ + public DataflowJobCancelledException(DataflowPipelineJob job, String message) { + super(job, message, null); + } + + /** + * Create a new {@code DataflowJobAlreadyUpdatedException} with the specified {@link + * DataflowPipelineJob}, message, and cause. + */ + public DataflowJobCancelledException(DataflowPipelineJob job, String message, Throwable cause) { + super(job, message, cause); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobException.java new file mode 100644 index 000000000000..9e305d565c6f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobException.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * A {@link RuntimeException} that contains information about a {@link DataflowPipelineJob}. + */ +public abstract class DataflowJobException extends RuntimeException { + private final DataflowPipelineJob job; + + DataflowJobException(DataflowPipelineJob job, String message, @Nullable Throwable cause) { + super(message, cause); + this.job = Objects.requireNonNull(job); + } + + /** + * Returns the failed job. + */ + public DataflowPipelineJob getJob() { + return job; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobExecutionException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobExecutionException.java new file mode 100644 index 000000000000..ae6df0fa657a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobExecutionException.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import javax.annotation.Nullable; + +/** + * Signals that a job run by a {@link BlockingDataflowPipelineRunner} fails during execution, and + * provides access to the failed job. + */ +public class DataflowJobExecutionException extends DataflowJobException { + DataflowJobExecutionException(DataflowPipelineJob job, String message) { + this(job, message, null); + } + + DataflowJobExecutionException( + DataflowPipelineJob job, String message, @Nullable Throwable cause) { + super(job, message, cause); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobUpdatedException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobUpdatedException.java new file mode 100644 index 000000000000..1becdd7c1036 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowJobUpdatedException.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +/** + * Signals that a job run by a {@link BlockingDataflowPipelineRunner} was updated during execution. + */ +public class DataflowJobUpdatedException extends DataflowJobException { + private DataflowPipelineJob replacedByJob; + + /** + * Create a new {@code DataflowJobUpdatedException} with the specified original {@link + * DataflowPipelineJob}, message, and replacement {@link DataflowPipelineJob}. + */ + public DataflowJobUpdatedException( + DataflowPipelineJob job, String message, DataflowPipelineJob replacedByJob) { + this(job, message, replacedByJob, null); + } + + /** + * Create a new {@code DataflowJobUpdatedException} with the specified original {@link + * DataflowPipelineJob}, message, replacement {@link DataflowPipelineJob}, and cause. + */ + public DataflowJobUpdatedException( + DataflowPipelineJob job, String message, DataflowPipelineJob replacedByJob, Throwable cause) { + super(job, message, cause); + this.replacedByJob = replacedByJob; + } + + /** + * The new job that replaces the job terminated with this exception. + */ + public DataflowPipelineJob getReplacedByJob() { + return replacedByJob; + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipeline.java new file mode 100644 index 000000000000..5a78624f3c4e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipeline.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * A {@link DataflowPipeline} is a {@link Pipeline} that returns a + * {@link DataflowPipelineJob} when it is + * {@link com.google.cloud.dataflow.sdk.Pipeline#run()}. + * + *

    This is not intended for use by users of Cloud Dataflow. + * Instead, use {@link Pipeline#create(PipelineOptions)} to initialize a + * {@link Pipeline}. + */ +public class DataflowPipeline extends Pipeline { + + /** + * Creates and returns a new {@link DataflowPipeline} instance for tests. + */ + public static DataflowPipeline create(DataflowPipelineOptions options) { + return new DataflowPipeline(options); + } + + private DataflowPipeline(DataflowPipelineOptions options) { + super(DataflowPipelineRunner.fromOptions(options), options); + } + + @Override + public DataflowPipelineJob run() { + return (DataflowPipelineJob) super.run(); + } + + @Override + public DataflowPipelineRunner getRunner() { + return (DataflowPipelineRunner) super.getRunner(); + } + + @Override + public String toString() { + return "DataflowPipeline#" + getOptions().as(DataflowPipelineOptions.class).getJobName(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJob.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJob.java new file mode 100644 index 000000000000..e9f134c8489b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJob.java @@ -0,0 +1,389 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.JobMetrics; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.runners.dataflow.DataflowAggregatorTransforms; +import com.google.cloud.dataflow.sdk.runners.dataflow.DataflowMetricUpdateExtractor; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.util.AttemptAndTimeBoundedExponentialBackOff; +import com.google.cloud.dataflow.sdk.util.AttemptBoundedExponentialBackOff; +import com.google.cloud.dataflow.sdk.util.MapAggregatorValues; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +/** + * A DataflowPipelineJob represents a job submitted to Dataflow using + * {@link DataflowPipelineRunner}. + */ +public class DataflowPipelineJob implements PipelineResult { + private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineJob.class); + + /** + * The id for the job. + */ + private String jobId; + + /** + * Google cloud project to associate this pipeline with. + */ + private String projectId; + + /** + * Client for the Dataflow service. This can be used to query the service + * for information about the job. + */ + private Dataflow dataflowClient; + + /** + * The state the job terminated in or {@code null} if the job has not terminated. + */ + @Nullable + private State terminalState = null; + + /** + * The job that replaced this one or {@code null} if the job has not been replaced. + */ + @Nullable + private DataflowPipelineJob replacedByJob = null; + + private DataflowAggregatorTransforms aggregatorTransforms; + + /** + * The Metric Updates retrieved after the job was in a terminal state. + */ + private List terminalMetricUpdates; + + /** + * The polling interval for job status and messages information. + */ + static final long MESSAGES_POLLING_INTERVAL = TimeUnit.SECONDS.toMillis(2); + static final long STATUS_POLLING_INTERVAL = TimeUnit.SECONDS.toMillis(2); + + /** + * The amount of polling attempts for job status and messages information. + */ + static final int MESSAGES_POLLING_ATTEMPTS = 10; + static final int STATUS_POLLING_ATTEMPTS = 5; + + /** + * Constructs the job. + * + * @param projectId the project id + * @param jobId the job id + * @param dataflowClient the client for the Dataflow Service + */ + public DataflowPipelineJob(String projectId, String jobId, Dataflow dataflowClient, + DataflowAggregatorTransforms aggregatorTransforms) { + this.projectId = projectId; + this.jobId = jobId; + this.dataflowClient = dataflowClient; + this.aggregatorTransforms = aggregatorTransforms; + } + + /** + * Get the id of this job. + */ + public String getJobId() { + return jobId; + } + + /** + * Get the project this job exists in. + */ + public String getProjectId() { + return projectId; + } + + /** + * Returns a new {@link DataflowPipelineJob} for the job that replaced this one, if applicable. + * + * @throws IllegalStateException if called before the job has terminated or if the job terminated + * but was not updated + */ + public DataflowPipelineJob getReplacedByJob() { + if (terminalState == null) { + throw new IllegalStateException("getReplacedByJob() called before job terminated"); + } + if (replacedByJob == null) { + throw new IllegalStateException("getReplacedByJob() called for job that was not replaced"); + } + return replacedByJob; + } + + /** + * Get the Cloud Dataflow API Client used by this job. + */ + public Dataflow getDataflowClient() { + return dataflowClient; + } + + /** + * Waits for the job to finish and return the final status. + * + * @param timeToWait The time to wait in units timeUnit for the job to finish. + * Provide a value less than 1 ms for an infinite wait. + * @param timeUnit The unit of time for timeToWait. + * @param messageHandler If non null this handler will be invoked for each + * batch of messages received. + * @return The final state of the job or null on timeout or if the + * thread is interrupted. + * @throws IOException If there is a persistent problem getting job + * information. + * @throws InterruptedException + */ + @Nullable + public State waitToFinish( + long timeToWait, + TimeUnit timeUnit, + MonitoringUtil.JobMessagesHandler messageHandler) + throws IOException, InterruptedException { + return waitToFinish(timeToWait, timeUnit, messageHandler, Sleeper.DEFAULT, NanoClock.SYSTEM); + } + + /** + * Wait for the job to finish and return the final status. + * + * @param timeToWait The time to wait in units timeUnit for the job to finish. + * Provide a value less than 1 ms for an infinite wait. + * @param timeUnit The unit of time for timeToWait. + * @param messageHandler If non null this handler will be invoked for each + * batch of messages received. + * @param sleeper A sleeper to use to sleep between attempts. + * @param nanoClock A nanoClock used to time the total time taken. + * @return The final state of the job or null on timeout or if the + * thread is interrupted. + * @throws IOException If there is a persistent problem getting job + * information. + * @throws InterruptedException + */ + @Nullable + @VisibleForTesting + State waitToFinish( + long timeToWait, + TimeUnit timeUnit, + MonitoringUtil.JobMessagesHandler messageHandler, + Sleeper sleeper, + NanoClock nanoClock) + throws IOException, InterruptedException { + MonitoringUtil monitor = new MonitoringUtil(projectId, dataflowClient); + + long lastTimestamp = 0; + BackOff backoff = + timeUnit.toMillis(timeToWait) > 0 + ? new AttemptAndTimeBoundedExponentialBackOff( + MESSAGES_POLLING_ATTEMPTS, + MESSAGES_POLLING_INTERVAL, + timeUnit.toMillis(timeToWait), + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ATTEMPTS, + nanoClock) + : new AttemptBoundedExponentialBackOff( + MESSAGES_POLLING_ATTEMPTS, MESSAGES_POLLING_INTERVAL); + State state; + do { + // Get the state of the job before listing messages. This ensures we always fetch job + // messages after the job finishes to ensure we have all them. + state = getStateWithRetries(1, sleeper); + boolean hasError = state == State.UNKNOWN; + + if (messageHandler != null && !hasError) { + // Process all the job messages that have accumulated so far. + try { + List allMessages = monitor.getJobMessages( + jobId, lastTimestamp); + + if (!allMessages.isEmpty()) { + lastTimestamp = + fromCloudTime(allMessages.get(allMessages.size() - 1).getTime()).getMillis(); + messageHandler.process(allMessages); + } + } catch (GoogleJsonResponseException | SocketTimeoutException e) { + hasError = true; + LOG.warn("There were problems getting current job messages: {}.", e.getMessage()); + LOG.debug("Exception information:", e); + } + } + + if (!hasError) { + backoff.reset(); + // Check if the job is done. + if (state.isTerminal()) { + return state; + } + } + } while(BackOffUtils.next(sleeper, backoff)); + LOG.warn("No terminal state was returned. State value {}", state); + return null; // Timed out. + } + + /** + * Cancels the job. + * @throws IOException if there is a problem executing the cancel request. + */ + public void cancel() throws IOException { + Job content = new Job(); + content.setProjectId(projectId); + content.setId(jobId); + content.setRequestedState("JOB_STATE_CANCELLED"); + dataflowClient.projects().jobs() + .update(projectId, jobId, content) + .execute(); + } + + @Override + public State getState() { + if (terminalState != null) { + return terminalState; + } + + return getStateWithRetries(STATUS_POLLING_ATTEMPTS, Sleeper.DEFAULT); + } + + /** + * Attempts to get the state. Uses exponential backoff on failure up to the maximum number + * of passed in attempts. + * + * @param attempts The amount of attempts to make. + * @param sleeper Object used to do the sleeps between attempts. + * @return The state of the job or State.UNKNOWN in case of failure. + */ + @VisibleForTesting + State getStateWithRetries(int attempts, Sleeper sleeper) { + if (terminalState != null) { + return terminalState; + } + try { + Job job = getJobWithRetries(attempts, sleeper); + return MonitoringUtil.toState(job.getCurrentState()); + } catch (IOException exn) { + // The only IOException that getJobWithRetries is permitted to throw is the final IOException + // that caused the failure of retry. Other exceptions are wrapped in an unchecked exceptions + // and will propagate. + return State.UNKNOWN; + } + } + + /** + * Attempts to get the underlying {@link Job}. Uses exponential backoff on failure up to the + * maximum number of passed in attempts. + * + * @param attempts The amount of attempts to make. + * @param sleeper Object used to do the sleeps between attempts. + * @return The underlying {@link Job} object. + * @throws IOException When the maximum number of retries is exhausted, the last exception is + * thrown. + */ + @VisibleForTesting + Job getJobWithRetries(int attempts, Sleeper sleeper) throws IOException { + AttemptBoundedExponentialBackOff backoff = + new AttemptBoundedExponentialBackOff(attempts, STATUS_POLLING_INTERVAL); + + // Retry loop ends in return or throw + while (true) { + try { + Job job = dataflowClient + .projects() + .jobs() + .get(projectId, jobId) + .execute(); + State currentState = MonitoringUtil.toState(job.getCurrentState()); + if (currentState.isTerminal()) { + terminalState = currentState; + replacedByJob = new DataflowPipelineJob( + getProjectId(), job.getReplacedByJobId(), dataflowClient, aggregatorTransforms); + } + return job; + } catch (IOException exn) { + LOG.warn("There were problems getting current job status: {}.", exn.getMessage()); + LOG.debug("Exception information:", exn); + + if (!nextBackOff(sleeper, backoff)) { + throw exn; + } + } + } + } + + /** + * Identical to {@link BackOffUtils#next} but without checked exceptions. + */ + private boolean nextBackOff(Sleeper sleeper, BackOff backoff) { + try { + return BackOffUtils.next(sleeper, backoff); + } catch (InterruptedException | IOException e) { + throw Throwables.propagate(e); + } + } + + @Override + public AggregatorValues getAggregatorValues(Aggregator aggregator) + throws AggregatorRetrievalException { + try { + return new MapAggregatorValues<>(fromMetricUpdates(aggregator)); + } catch (IOException e) { + throw new AggregatorRetrievalException( + "IOException when retrieving Aggregator values for Aggregator " + aggregator, e); + } + } + + private Map fromMetricUpdates(Aggregator aggregator) + throws IOException { + if (aggregatorTransforms.contains(aggregator)) { + List metricUpdates; + if (terminalMetricUpdates != null) { + metricUpdates = terminalMetricUpdates; + } else { + boolean terminal = getState().isTerminal(); + JobMetrics jobMetrics = + dataflowClient.projects().jobs().getMetrics(projectId, jobId).execute(); + metricUpdates = jobMetrics.getMetrics(); + if (terminal && jobMetrics.getMetrics() != null) { + terminalMetricUpdates = metricUpdates; + } + } + + return DataflowMetricUpdateExtractor.fromMetricUpdates( + aggregator, aggregatorTransforms, metricUpdates); + } else { + throw new IllegalArgumentException( + "Aggregator " + aggregator + " is not used in this pipeline"); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRegistrar.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRegistrar.java new file mode 100644 index 000000000000..0e4d4e96b928 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRegistrar.java @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.auto.service.AutoService; +import com.google.cloud.dataflow.sdk.options.BlockingDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar; +import com.google.common.collect.ImmutableList; + +/** + * Contains the {@link PipelineOptionsRegistrar} and {@link PipelineRunnerRegistrar} for + * the {@link DataflowPipeline}. + */ +public class DataflowPipelineRegistrar { + private DataflowPipelineRegistrar() { } + + /** + * Register the {@link DataflowPipelineOptions} and {@link BlockingDataflowPipelineOptions}. + */ + @AutoService(PipelineOptionsRegistrar.class) + public static class Options implements PipelineOptionsRegistrar { + @Override + public Iterable> getPipelineOptions() { + return ImmutableList.>of( + DataflowPipelineOptions.class, + BlockingDataflowPipelineOptions.class); + } + } + + /** + * Register the {@link DataflowPipelineRunner} and {@link BlockingDataflowPipelineRunner}. + */ + @AutoService(PipelineRunnerRegistrar.class) + public static class Runner implements PipelineRunnerRegistrar { + @Override + public Iterable>> getPipelineRunners() { + return ImmutableList.>>of( + DataflowPipelineRunner.class, + BlockingDataflowPipelineRunner.class); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java new file mode 100644 index 000000000000..6eb6c2f7ad96 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java @@ -0,0 +1,2947 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.StringUtils.approximatePTransformName; +import static com.google.cloud.dataflow.sdk.util.StringUtils.approximateSimpleName; +import static com.google.cloud.dataflow.sdk.util.WindowedValue.valueInEmptyWindows; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.ListJobsResponse; +import com.google.api.services.dataflow.model.WorkerPool; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.FileBasedSink; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.io.Write; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineDebugOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineWorkerPoolOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.options.StreamingOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.JobSpecification; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.runners.dataflow.AssignWindows; +import com.google.cloud.dataflow.sdk.runners.dataflow.DataflowAggregatorTransforms; +import com.google.cloud.dataflow.sdk.runners.dataflow.PubsubIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.ReadTranslator; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.IsmRecord; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.IsmRecordCoder; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.MetadataKeyCoder; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterPane; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.DataflowReleaseInfo; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.PCollectionViews; +import com.google.cloud.dataflow.sdk.util.PathValidator; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Reshuffle; +import com.google.cloud.dataflow.sdk.util.SystemDoFnInternal; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.ValueWithRecordId; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import com.google.common.base.Utf8; +import com.google.common.collect.ForwardingMap; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.DateTimeUtils; +import org.joda.time.DateTimeZone; +import org.joda.time.Duration; +import org.joda.time.format.DateTimeFormat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintWriter; +import java.io.Serializable; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +/** + * A {@link PipelineRunner} that executes the operations in the + * pipeline by first translating them to the Dataflow representation + * using the {@link DataflowPipelineTranslator} and then submitting + * them to a Dataflow service for execution. + * + *

    Permissions

    + * When reading from a Dataflow source or writing to a Dataflow sink using + * {@code DataflowPipelineRunner}, the Google cloudservices account and the Google compute engine + * service account of the GCP project running the Dataflow Job will need access to the corresponding + * source/sink. + * + *

    Please see Google Cloud + * Dataflow Security and Permissions for more details. + */ +public class DataflowPipelineRunner extends PipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineRunner.class); + + /** Provided configuration options. */ + private final DataflowPipelineOptions options; + + /** Client for the Dataflow service. This is used to actually submit jobs. */ + private final Dataflow dataflowClient; + + /** Translator for this DataflowPipelineRunner, based on options. */ + private final DataflowPipelineTranslator translator; + + /** Custom transforms implementations. */ + private final Map, Class> overrides; + + /** A set of user defined functions to invoke at different points in execution. */ + private DataflowPipelineRunnerHooks hooks; + + // Environment version information. + private static final String ENVIRONMENT_MAJOR_VERSION = "4"; + + // Default Docker container images that execute Dataflow worker harness, residing in Google + // Container Registry, separately for Batch and Streaming. + public static final String BATCH_WORKER_HARNESS_CONTAINER_IMAGE + = "dataflow.gcr.io/v1beta3/java-batch:github-20160225-00"; + public static final String STREAMING_WORKER_HARNESS_CONTAINER_IMAGE + = "dataflow.gcr.io/v1beta3/java-streaming:github-20160225-00"; + + // The limit of CreateJob request size. + private static final int CREATE_JOB_REQUEST_LIMIT_BYTES = 10 * 1024 * 1024; + + private final Set> pcollectionsRequiringIndexedFormat; + + /** + * Project IDs must contain lowercase letters, digits, or dashes. + * IDs must start with a letter and may not end with a dash. + * This regex isn't exact - this allows for patterns that would be rejected by + * the service, but this is sufficient for basic validation of project IDs. + */ + public static final String PROJECT_ID_REGEXP = "[a-z][-a-z0-9:.]+[a-z0-9]"; + + /** + * Construct a runner from the provided options. + * + * @param options Properties that configure the runner. + * @return The newly created runner. + */ + public static DataflowPipelineRunner fromOptions(PipelineOptions options) { + // (Re-)register standard IO factories. Clobbers any prior credentials. + IOChannelUtils.registerStandardIOFactories(options); + + DataflowPipelineOptions dataflowOptions = + PipelineOptionsValidator.validate(DataflowPipelineOptions.class, options); + ArrayList missing = new ArrayList<>(); + + if (dataflowOptions.getAppName() == null) { + missing.add("appName"); + } + if (missing.size() > 0) { + throw new IllegalArgumentException( + "Missing required values: " + Joiner.on(',').join(missing)); + } + + PathValidator validator = dataflowOptions.getPathValidator(); + if (dataflowOptions.getStagingLocation() != null) { + validator.validateOutputFilePrefixSupported(dataflowOptions.getStagingLocation()); + } + if (dataflowOptions.getTempLocation() != null) { + validator.validateOutputFilePrefixSupported(dataflowOptions.getTempLocation()); + } + if (Strings.isNullOrEmpty(dataflowOptions.getTempLocation())) { + dataflowOptions.setTempLocation(dataflowOptions.getStagingLocation()); + } else if (Strings.isNullOrEmpty(dataflowOptions.getStagingLocation())) { + try { + dataflowOptions.setStagingLocation( + IOChannelUtils.resolve(dataflowOptions.getTempLocation(), "staging")); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to resolve PipelineOptions.stagingLocation " + + "from PipelineOptions.tempLocation. Please set the staging location explicitly.", e); + } + } + + if (dataflowOptions.getFilesToStage() == null) { + dataflowOptions.setFilesToStage(detectClassPathResourcesToStage( + DataflowPipelineRunner.class.getClassLoader())); + LOG.info("PipelineOptions.filesToStage was not specified. " + + "Defaulting to files from the classpath: will stage {} files. " + + "Enable logging at DEBUG level to see which files will be staged.", + dataflowOptions.getFilesToStage().size()); + LOG.debug("Classpath elements: {}", dataflowOptions.getFilesToStage()); + } + + // Verify jobName according to service requirements. + String jobName = dataflowOptions.getJobName().toLowerCase(); + Preconditions.checkArgument( + jobName.matches("[a-z]([-a-z0-9]*[a-z0-9])?"), + "JobName invalid; the name must consist of only the characters " + + "[-a-z0-9], starting with a letter and ending with a letter " + + "or number"); + + // Verify project + String project = dataflowOptions.getProject(); + if (project.matches("[0-9]*")) { + throw new IllegalArgumentException("Project ID '" + project + + "' invalid. Please make sure you specified the Project ID, not project number."); + } else if (!project.matches(PROJECT_ID_REGEXP)) { + throw new IllegalArgumentException("Project ID '" + project + + "' invalid. Please make sure you specified the Project ID, not project description."); + } + + DataflowPipelineDebugOptions debugOptions = + dataflowOptions.as(DataflowPipelineDebugOptions.class); + // Verify the number of worker threads is a valid value + if (debugOptions.getNumberOfWorkerHarnessThreads() < 0) { + throw new IllegalArgumentException("Number of worker harness threads '" + + debugOptions.getNumberOfWorkerHarnessThreads() + + "' invalid. Please make sure the value is non-negative."); + } + + return new DataflowPipelineRunner(dataflowOptions); + } + + @VisibleForTesting protected DataflowPipelineRunner(DataflowPipelineOptions options) { + this.options = options; + this.dataflowClient = options.getDataflowClient(); + this.translator = DataflowPipelineTranslator.fromOptions(options); + this.pcollectionsRequiringIndexedFormat = new HashSet<>(); + this.ptransformViewsWithNonDeterministicKeyCoders = new HashSet<>(); + + if (options.isStreaming()) { + overrides = ImmutableMap., Class>builder() + .put(Combine.GloballyAsSingletonView.class, StreamingCombineGloballyAsSingletonView.class) + .put(Create.Values.class, StreamingCreate.class) + .put(View.AsMap.class, StreamingViewAsMap.class) + .put(View.AsMultimap.class, StreamingViewAsMultimap.class) + .put(View.AsSingleton.class, StreamingViewAsSingleton.class) + .put(View.AsList.class, StreamingViewAsList.class) + .put(View.AsIterable.class, StreamingViewAsIterable.class) + .put(Write.Bound.class, StreamingWrite.class) + .put(PubsubIO.Write.Bound.class, StreamingPubsubIOWrite.class) + .put(Read.Unbounded.class, StreamingUnboundedRead.class) + .put(Read.Bounded.class, UnsupportedIO.class) + .put(AvroIO.Read.Bound.class, UnsupportedIO.class) + .put(AvroIO.Write.Bound.class, UnsupportedIO.class) + .put(BigQueryIO.Read.Bound.class, UnsupportedIO.class) + .put(TextIO.Read.Bound.class, UnsupportedIO.class) + .put(TextIO.Write.Bound.class, UnsupportedIO.class) + .put(Window.Bound.class, AssignWindows.class) + .build(); + } else { + ImmutableMap.Builder, Class> builder = ImmutableMap., Class>builder(); + builder.put(Read.Unbounded.class, UnsupportedIO.class); + builder.put(Window.Bound.class, AssignWindows.class); + builder.put(Write.Bound.class, BatchWrite.class); + builder.put(AvroIO.Write.Bound.class, BatchAvroIOWrite.class); + builder.put(TextIO.Write.Bound.class, BatchTextIOWrite.class); + if (options.getExperiments() == null + || !options.getExperiments().contains("disable_ism_side_input")) { + builder.put(View.AsMap.class, BatchViewAsMap.class); + builder.put(View.AsMultimap.class, BatchViewAsMultimap.class); + builder.put(View.AsSingleton.class, BatchViewAsSingleton.class); + builder.put(View.AsList.class, BatchViewAsList.class); + builder.put(View.AsIterable.class, BatchViewAsIterable.class); + } + overrides = builder.build(); + } + } + + /** + * Applies the given transform to the input. For transforms with customized definitions + * for the Dataflow pipeline runner, the application is intercepted and modified here. + */ + @Override + public OutputT apply( + PTransform transform, InputT input) { + + if (Combine.GroupedValues.class.equals(transform.getClass()) + || GroupByKey.class.equals(transform.getClass())) { + + // For both Dataflow runners (streaming and batch), GroupByKey and GroupedValues are + // primitives. Returning a primitive output instead of the expanded definition + // signals to the translator that translation is necessary. + @SuppressWarnings("unchecked") + PCollection pc = (PCollection) input; + @SuppressWarnings("unchecked") + OutputT outputT = (OutputT) PCollection.createPrimitiveOutputInternal( + pc.getPipeline(), + transform instanceof GroupByKey + ? ((GroupByKey) transform).updateWindowingStrategy(pc.getWindowingStrategy()) + : pc.getWindowingStrategy(), + pc.isBounded()); + return outputT; + } else if (Window.Bound.class.equals(transform.getClass())) { + /* + * TODO: make this the generic way overrides are applied (using super.apply() rather than + * Pipeline.applyTransform(); this allows the apply method to be replaced without inserting + * additional nodes into the graph. + */ + // casting to wildcard + @SuppressWarnings("unchecked") + OutputT windowed = (OutputT) applyWindow((Window.Bound) transform, (PCollection) input); + return windowed; + } else if (Flatten.FlattenPCollectionList.class.equals(transform.getClass()) + && ((PCollectionList) input).size() == 0) { + return (OutputT) Pipeline.applyTransform(input, Create.of()); + } else if (overrides.containsKey(transform.getClass())) { + // It is the responsibility of whoever constructs overrides to ensure this is type safe. + @SuppressWarnings("unchecked") + Class> transformClass = + (Class>) transform.getClass(); + + @SuppressWarnings("unchecked") + Class> customTransformClass = + (Class>) overrides.get(transform.getClass()); + + PTransform customTransform = + InstanceBuilder.ofType(customTransformClass) + .withArg(DataflowPipelineRunner.class, this) + .withArg(transformClass, transform) + .build(); + + return Pipeline.applyTransform(input, customTransform); + } else { + return super.apply(transform, input); + } + } + + private PCollection applyWindow( + Window.Bound intitialTransform, PCollection initialInput) { + // types are matched at compile time + @SuppressWarnings("unchecked") + Window.Bound transform = (Window.Bound) intitialTransform; + @SuppressWarnings("unchecked") + PCollection input = (PCollection) initialInput; + return super.apply(new AssignWindows<>(transform), input); + } + + @Override + public DataflowPipelineJob run(Pipeline pipeline) { + logWarningIfPCollectionViewHasNonDeterministicKeyCoder(pipeline); + + LOG.info("Executing pipeline on the Dataflow Service, which will have billing implications " + + "related to Google Compute Engine usage and other Google Cloud Services."); + + List packages = options.getStager().stageFiles(); + JobSpecification jobSpecification = + translator.translate(pipeline, this, packages); + Job newJob = jobSpecification.getJob(); + + // Set a unique client_request_id in the CreateJob request. + // This is used to ensure idempotence of job creation across retried + // attempts to create a job. Specifically, if the service returns a job with + // a different client_request_id, it means the returned one is a different + // job previously created with the same job name, and that the job creation + // has been effectively rejected. The SDK should return + // Error::Already_Exists to user in that case. + int randomNum = new Random().nextInt(9000) + 1000; + String requestId = DateTimeFormat.forPattern("YYYYMMddHHmmssmmm").withZone(DateTimeZone.UTC) + .print(DateTimeUtils.currentTimeMillis()) + "_" + randomNum; + newJob.setClientRequestId(requestId); + + String version = DataflowReleaseInfo.getReleaseInfo().getVersion(); + System.out.println("Dataflow SDK version: " + version); + + newJob.getEnvironment().setUserAgent(DataflowReleaseInfo.getReleaseInfo()); + // The Dataflow Service may write to the temporary directory directly, so + // must be verified. + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + if (!Strings.isNullOrEmpty(options.getTempLocation())) { + newJob.getEnvironment().setTempStoragePrefix( + dataflowOptions.getPathValidator().verifyPath(options.getTempLocation())); + } + newJob.getEnvironment().setDataset(options.getTempDatasetId()); + newJob.getEnvironment().setExperiments(options.getExperiments()); + + // Set the Docker container image that executes Dataflow worker harness, residing in Google + // Container Registry. Translator is guaranteed to create a worker pool prior to this point. + String workerHarnessContainerImage = + options.as(DataflowPipelineWorkerPoolOptions.class) + .getWorkerHarnessContainerImage(); + for (WorkerPool workerPool : newJob.getEnvironment().getWorkerPools()) { + workerPool.setWorkerHarnessContainerImage(workerHarnessContainerImage); + } + + // Requirements about the service. + Map environmentVersion = new HashMap<>(); + environmentVersion.put(PropertyNames.ENVIRONMENT_VERSION_MAJOR_KEY, ENVIRONMENT_MAJOR_VERSION); + newJob.getEnvironment().setVersion(environmentVersion); + // Default jobType is JAVA_BATCH_AUTOSCALING: A Java job with workers that the job can + // autoscale if specified. + String jobType = "JAVA_BATCH_AUTOSCALING"; + + if (options.isStreaming()) { + jobType = "STREAMING"; + } + environmentVersion.put(PropertyNames.ENVIRONMENT_VERSION_JOB_TYPE_KEY, jobType); + + if (hooks != null) { + hooks.modifyEnvironmentBeforeSubmission(newJob.getEnvironment()); + } + + if (!Strings.isNullOrEmpty(options.getDataflowJobFile())) { + try (PrintWriter printWriter = new PrintWriter( + new File(options.getDataflowJobFile()))) { + String workSpecJson = DataflowPipelineTranslator.jobToString(newJob); + printWriter.print(workSpecJson); + LOG.info("Printed workflow specification to {}", options.getDataflowJobFile()); + } catch (IllegalStateException ex) { + LOG.warn("Cannot translate workflow spec to json for debug."); + } catch (FileNotFoundException ex) { + LOG.warn("Cannot create workflow spec output file."); + } + } + + String jobIdToUpdate = null; + if (options.getUpdate()) { + jobIdToUpdate = getJobIdFromName(options.getJobName()); + newJob.setTransformNameMapping(options.getTransformNameMapping()); + newJob.setReplaceJobId(jobIdToUpdate); + } + Job jobResult; + try { + jobResult = dataflowClient + .projects() + .jobs() + .create(options.getProject(), newJob) + .execute(); + } catch (GoogleJsonResponseException e) { + String errorMessages = "Unexpected errors"; + if (e.getDetails() != null) { + if (Utf8.encodedLength(newJob.toString()) >= CREATE_JOB_REQUEST_LIMIT_BYTES) { + errorMessages = "The size of the serialized JSON representation of the pipeline " + + "exceeds the allowable limit. " + + "For more information, please check the FAQ link below:\n" + + "https://cloud.google.com/dataflow/faq"; + } else { + errorMessages = e.getDetails().getMessage(); + } + } + throw new RuntimeException("Failed to create a workflow job: " + errorMessages, e); + } catch (IOException e) { + throw new RuntimeException("Failed to create a workflow job", e); + } + + // Obtain all of the extractors from the PTransforms used in the pipeline so the + // DataflowPipelineJob has access to them. + AggregatorPipelineExtractor aggregatorExtractor = new AggregatorPipelineExtractor(pipeline); + Map, Collection>> aggregatorSteps = + aggregatorExtractor.getAggregatorSteps(); + + DataflowAggregatorTransforms aggregatorTransforms = + new DataflowAggregatorTransforms(aggregatorSteps, jobSpecification.getStepNames()); + + // Use a raw client for post-launch monitoring, as status calls may fail + // regularly and need not be retried automatically. + DataflowPipelineJob dataflowPipelineJob = + new DataflowPipelineJob(options.getProject(), jobResult.getId(), + Transport.newRawDataflowClient(options).build(), aggregatorTransforms); + + // If the service returned client request id, the SDK needs to compare it + // with the original id generated in the request, if they are not the same + // (i.e., the returned job is not created by this request), throw + // DataflowJobAlreadyExistsException or DataflowJobAlreadyUpdatedExcetpion + // depending on whether this is a reload or not. + if (jobResult.getClientRequestId() != null && !jobResult.getClientRequestId().isEmpty() + && !jobResult.getClientRequestId().equals(requestId)) { + // If updating a job. + if (options.getUpdate()) { + throw new DataflowJobAlreadyUpdatedException(dataflowPipelineJob, + String.format("The job named %s with id: %s has already been updated into job id: %s " + + "and cannot be updated again.", + newJob.getName(), jobIdToUpdate, jobResult.getId())); + } else { + throw new DataflowJobAlreadyExistsException(dataflowPipelineJob, + String.format("There is already an active job named %s with id: %s. If you want " + + "to submit a second job, try again by setting a different name using --jobName.", + newJob.getName(), jobResult.getId())); + } + } + + LOG.info("To access the Dataflow monitoring console, please navigate to {}", + MonitoringUtil.getJobMonitoringPageURL(options.getProject(), jobResult.getId())); + System.out.println("Submitted job: " + jobResult.getId()); + + LOG.info("To cancel the job using the 'gcloud' tool, run:\n> {}", + MonitoringUtil.getGcloudCancelCommand(options, jobResult.getId())); + + return dataflowPipelineJob; + } + + /** + * Returns the DataflowPipelineTranslator associated with this object. + */ + public DataflowPipelineTranslator getTranslator() { + return translator; + } + + /** + * Sets callbacks to invoke during execution see {@code DataflowPipelineRunnerHooks}. + */ + @Experimental + public void setHooks(DataflowPipelineRunnerHooks hooks) { + this.hooks = hooks; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Outputs a warning about PCollection views without deterministic key coders. */ + private void logWarningIfPCollectionViewHasNonDeterministicKeyCoder(Pipeline pipeline) { + // We need to wait till this point to determine the names of the transforms since only + // at this time do we know the hierarchy of the transforms otherwise we could + // have just recorded the full names during apply time. + if (!ptransformViewsWithNonDeterministicKeyCoders.isEmpty()) { + final SortedSet ptransformViewNamesWithNonDeterministicKeyCoders = new TreeSet<>(); + pipeline.traverseTopologically(new PipelineVisitor() { + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + } + + @Override + public void visitTransform(TransformTreeNode node) { + if (ptransformViewsWithNonDeterministicKeyCoders.contains(node.getTransform())) { + ptransformViewNamesWithNonDeterministicKeyCoders.add(node.getFullName()); + } + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + if (ptransformViewsWithNonDeterministicKeyCoders.contains(node.getTransform())) { + ptransformViewNamesWithNonDeterministicKeyCoders.add(node.getFullName()); + } + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + }); + + LOG.warn("Unable to use indexed implementation for View.AsMap and View.AsMultimap for {} " + + "because the key coder is not deterministic. Falling back to singleton implementation " + + "which may cause memory and/or performance problems. Future major versions of " + + "Dataflow will require deterministic key coders.", + ptransformViewNamesWithNonDeterministicKeyCoders); + } + } + + /** + * Returns true if the passed in {@link PCollection} needs to be materialiazed using + * an indexed format. + */ + boolean doesPCollectionRequireIndexedFormat(PCollection pcol) { + return pcollectionsRequiringIndexedFormat.contains(pcol); + } + + /** + * Marks the passed in {@link PCollection} as requiring to be materialized using + * an indexed format. + */ + private void addPCollectionRequiringIndexedFormat(PCollection pcol) { + pcollectionsRequiringIndexedFormat.add(pcol); + } + + /** A set of {@link View}s with non-deterministic key coders. */ + Set> ptransformViewsWithNonDeterministicKeyCoders; + + /** + * Records that the {@link PTransform} requires a deterministic key coder. + */ + private void recordViewUsesNonDeterministicKeyCoder(PTransform ptransform) { + ptransformViewsWithNonDeterministicKeyCoders.add(ptransform); + } + + /** + * A {@link GroupByKey} transform for the {@link DataflowPipelineRunner} which sorts + * values using the secondary key {@code K2}. + * + *

    The {@link PCollection} created created by this {@link PTransform} will have values in + * the empty window. Care must be taken *afterwards* to either re-window + * (using {@link Window#into}) or only use {@link PTransform}s that do not depend on the + * values being within a window. + */ + static class GroupByKeyAndSortValuesOnly + extends PTransform>>, PCollection>>>> { + private GroupByKeyAndSortValuesOnly() { + } + + @Override + public PCollection>>> apply(PCollection>> input) { + PCollection>>> rval = + PCollection.>>>createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + IsBounded.BOUNDED); + + @SuppressWarnings({"unchecked", "rawtypes"}) + KvCoder> inputCoder = (KvCoder) input.getCoder(); + rval.setCoder( + KvCoder.of(inputCoder.getKeyCoder(), + IterableCoder.of(inputCoder.getValueCoder()))); + return rval; + } + } + + /** + * A {@link PTransform} that groups the values by a hash of the window's byte representation + * and sorts the values using the windows byte representation. + */ + private static class GroupByWindowHashAsKeyAndWindowAsSortKey extends + PTransform, PCollection>>>>> { + + /** + * A {@link DoFn} that for each element outputs a {@code KV} structure suitable for + * grouping by the hash of the window's byte representation and sorting the grouped values + * using the window's byte representation. + */ + @SystemDoFnInternal + private static class UseWindowHashAsKeyAndWindowAsSortKeyDoFn + extends DoFn>>> implements DoFn.RequiresWindowAccess { + + private final IsmRecordCoder ismCoderForHash; + private UseWindowHashAsKeyAndWindowAsSortKeyDoFn(IsmRecordCoder ismCoderForHash) { + this.ismCoderForHash = ismCoderForHash; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + @SuppressWarnings("unchecked") + W window = (W) c.window(); + c.output( + KV.of(ismCoderForHash.hash(ImmutableList.of(window)), + KV.of(window, + WindowedValue.of( + c.element(), + c.timestamp(), + c.window(), + c.pane())))); + } + } + + private final IsmRecordCoder ismCoderForHash; + private GroupByWindowHashAsKeyAndWindowAsSortKey(IsmRecordCoder ismCoderForHash) { + this.ismCoderForHash = ismCoderForHash; + } + + @Override + public PCollection>>>> apply(PCollection input) { + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + PCollection>>> rval = + input.apply(ParDo.of( + new UseWindowHashAsKeyAndWindowAsSortKeyDoFn(ismCoderForHash))); + rval.setCoder( + KvCoder.of( + VarIntCoder.of(), + KvCoder.of(windowCoder, + FullWindowedValueCoder.of(input.getCoder(), windowCoder)))); + return rval.apply(new GroupByKeyAndSortValuesOnly>()); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsSingleton View.AsSingleton} for the + * Dataflow runner in batch mode. + * + *

    Creates a set of files in the {@link IsmFormat} sharded by the hash of the windows + * byte representation and with records having: + *

      + *
    • Key 1: Window
    • + *
    • Value: Windowed value
    • + *
    + */ + static class BatchViewAsSingleton + extends PTransform, PCollectionView> { + + /** + * A {@link DoFn} that outputs {@link IsmRecord}s. These records are structured as follows: + *
      + *
    • Key 1: Window + *
    • Value: Windowed value + *
    + */ + static class IsmRecordForSingularValuePerWindowDoFn + extends DoFn>>>, + IsmRecord>> { + + @Override + public void processElement(ProcessContext c) throws Exception { + Iterator>> iterator = c.element().getValue().iterator(); + while (iterator.hasNext()) { + KV> next = iterator.next(); + c.output( + IsmRecord.of( + ImmutableList.of(next.getKey()), next.getValue())); + } + } + } + + private final DataflowPipelineRunner runner; + private final View.AsSingleton transform; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchViewAsSingleton(DataflowPipelineRunner runner, View.AsSingleton transform) { + this.runner = runner; + this.transform = transform; + } + + @Override + public PCollectionView apply(PCollection input) { + return BatchViewAsSingleton.applyForSingleton( + runner, + input, + new IsmRecordForSingularValuePerWindowDoFn(), + transform.hasDefaultValue(), + transform.defaultValue(), + input.getCoder()); + } + + static PCollectionView + applyForSingleton( + DataflowPipelineRunner runner, + PCollection input, + DoFn>>>, + IsmRecord>> doFn, + boolean hasDefault, + FinalT defaultValue, + Coder defaultValueCoder) { + + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + + @SuppressWarnings({"rawtypes", "unchecked"}) + PCollectionView view = PCollectionViews.singletonView( + input.getPipeline(), + (WindowingStrategy) input.getWindowingStrategy(), + hasDefault, + defaultValue, + defaultValueCoder); + + IsmRecordCoder> ismCoder = + coderForSingleton(windowCoder, defaultValueCoder); + + PCollection>> reifiedPerWindowAndSorted = input + .apply(new GroupByWindowHashAsKeyAndWindowAsSortKey(ismCoder)) + .apply(ParDo.of(doFn)); + reifiedPerWindowAndSorted.setCoder(ismCoder); + + runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); + return reifiedPerWindowAndSorted.apply( + CreatePCollectionView.>, ViewT>of(view)); + } + + @Override + protected String getKindString() { + return "BatchViewAsSingleton"; + } + + static IsmRecordCoder> coderForSingleton( + Coder windowCoder, Coder valueCoder) { + return IsmRecordCoder.of( + 1, // We hash using only the window + 0, // There are no metadata records + ImmutableList.>of(windowCoder), + FullWindowedValueCoder.of(valueCoder, windowCoder)); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsIterable View.AsIterable} for the + * Dataflow runner in batch mode. + * + *

    Creates a set of {@code Ism} files sharded by the hash of the windows byte representation + * and with records having: + *

      + *
    • Key 1: Window
    • + *
    • Key 2: Index offset within window
    • + *
    • Value: Windowed value
    • + *
    + */ + static class BatchViewAsIterable + extends PTransform, PCollectionView>> { + + private final DataflowPipelineRunner runner; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchViewAsIterable(DataflowPipelineRunner runner, View.AsIterable transform) { + this.runner = runner; + } + + @Override + public PCollectionView> apply(PCollection input) { + PCollectionView> view = PCollectionViews.iterableView( + input.getPipeline(), input.getWindowingStrategy(), input.getCoder()); + return BatchViewAsList.applyForIterableLike(runner, input, view); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsList View.AsList} for the + * Dataflow runner in batch mode. + * + *

    Creates a set of {@code Ism} files sharded by the hash of the window's byte representation + * and with records having: + *

      + *
    • Key 1: Window
    • + *
    • Key 2: Index offset within window
    • + *
    • Value: Windowed value
    • + *
    + */ + static class BatchViewAsList + extends PTransform, PCollectionView>> { + /** + * A {@link DoFn} which creates {@link IsmRecord}s assuming that each element is within the + * global window. Each {@link IsmRecord} has + *
      + *
    • Key 1: Global window
    • + *
    • Key 2: Index offset within window
    • + *
    • Value: Windowed value
    • + *
    + */ + @SystemDoFnInternal + static class ToIsmRecordForGlobalWindowDoFn + extends DoFn>> { + + long indexInBundle; + @Override + public void startBundle(Context c) throws Exception { + indexInBundle = 0; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(IsmRecord.of( + ImmutableList.of(GlobalWindow.INSTANCE, indexInBundle), + WindowedValue.of( + c.element(), + c.timestamp(), + GlobalWindow.INSTANCE, + c.pane()))); + indexInBundle += 1; + } + } + + /** + * A {@link DoFn} which creates {@link IsmRecord}s comparing successive elements windows + * to locate the window boundaries. The {@link IsmRecord} has: + *
      + *
    • Key 1: Window
    • + *
    • Key 2: Index offset within window
    • + *
    • Value: Windowed value
    • + *
    + */ + @SystemDoFnInternal + static class ToIsmRecordForNonGlobalWindowDoFn + extends DoFn>>>, + IsmRecord>> { + + private final Coder windowCoder; + ToIsmRecordForNonGlobalWindowDoFn(Coder windowCoder) { + this.windowCoder = windowCoder; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + long elementsInWindow = 0; + Optional previousWindowStructuralValue = Optional.absent(); + for (KV> value : c.element().getValue()) { + Object currentWindowStructuralValue = windowCoder.structuralValue(value.getKey()); + // Compare to see if this is a new window so we can reset the index counter i + if (previousWindowStructuralValue.isPresent() + && !previousWindowStructuralValue.get().equals(currentWindowStructuralValue)) { + // Reset i since we have a new window. + elementsInWindow = 0; + } + c.output(IsmRecord.of( + ImmutableList.of(value.getKey(), elementsInWindow), + value.getValue())); + previousWindowStructuralValue = Optional.of(currentWindowStructuralValue); + elementsInWindow += 1; + } + } + } + + private final DataflowPipelineRunner runner; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchViewAsList(DataflowPipelineRunner runner, View.AsList transform) { + this.runner = runner; + } + + @Override + public PCollectionView> apply(PCollection input) { + PCollectionView> view = PCollectionViews.listView( + input.getPipeline(), input.getWindowingStrategy(), input.getCoder()); + return applyForIterableLike(runner, input, view); + } + + static PCollectionView applyForIterableLike( + DataflowPipelineRunner runner, + PCollection input, + PCollectionView view) { + + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + + IsmRecordCoder> ismCoder = coderForListLike(windowCoder, input.getCoder()); + + // If we are working in the global window, we do not need to do a GBK using the window + // as the key since all the elements of the input PCollection are already such. + // We just reify the windowed value while converting them to IsmRecords and generating + // an index based upon where we are within the bundle. Each bundle + // maps to one file exactly. + if (input.getWindowingStrategy().getWindowFn() instanceof GlobalWindows) { + PCollection>> reifiedPerWindowAndSorted = + input.apply(ParDo.of(new ToIsmRecordForGlobalWindowDoFn())); + reifiedPerWindowAndSorted.setCoder(ismCoder); + + runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); + return reifiedPerWindowAndSorted.apply( + CreatePCollectionView.>, ViewT>of(view)); + } + + PCollection>> reifiedPerWindowAndSorted = input + .apply(new GroupByWindowHashAsKeyAndWindowAsSortKey(ismCoder)) + .apply(ParDo.of(new ToIsmRecordForNonGlobalWindowDoFn(windowCoder))); + reifiedPerWindowAndSorted.setCoder(ismCoder); + + runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); + return reifiedPerWindowAndSorted.apply( + CreatePCollectionView.>, ViewT>of(view)); + } + + @Override + protected String getKindString() { + return "BatchViewAsList"; + } + + static IsmRecordCoder> coderForListLike( + Coder windowCoder, Coder valueCoder) { + // TODO: swap to use a variable length long coder which has values which compare + // the same as their byte representation compare lexicographically within the key coder + return IsmRecordCoder.of( + 1, // We hash using only the window + 0, // There are no metadata records + ImmutableList.of(windowCoder, BigEndianLongCoder.of()), + FullWindowedValueCoder.of(valueCoder, windowCoder)); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsMap View.AsMap} for the + * Dataflow runner in batch mode. + * + *

    Creates a set of {@code Ism} files sharded by the hash of the key's byte + * representation. Each record is structured as follows: + *

      + *
    • Key 1: User key K
    • + *
    • Key 2: Window
    • + *
    • Key 3: 0L (constant)
    • + *
    • Value: Windowed value
    • + *
    + * + *

    Alongside the data records, there are the following metadata records: + *

      + *
    • Key 1: Metadata Key
    • + *
    • Key 2: Window
    • + *
    • Key 3: Index [0, size of map]
    • + *
    • Value: variable length long byte representation of size of map if index is 0, + * otherwise the byte representation of a key
    • + *
    + * The {@code [META, Window, 0]} record stores the number of unique keys per window, while + * {@code [META, Window, i]} for {@code i} in {@code [1, size of map]} stores a the users key. + * This allows for one to access the size of the map by looking at {@code [META, Window, 0]} + * and iterate over all the keys by accessing {@code [META, Window, i]} for {@code i} in + * {@code [1, size of map]}. + * + *

    Note that in the case of a non-deterministic key coder, we fallback to using + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsSingleton View.AsSingleton} printing + * a warning to users to specify a deterministic key coder. + */ + static class BatchViewAsMap + extends PTransform>, PCollectionView>> { + + /** + * A {@link DoFn} which groups elements by window boundaries. For each group, + * the group of elements is transformed into a {@link TransformedMap}. + * The transformed {@code Map} is backed by a {@code Map>} + * and contains a function {@code WindowedValue -> V}. + * + *

    Outputs {@link IsmRecord}s having: + *

      + *
    • Key 1: Window
    • + *
    • Value: Transformed map containing a transform that removes the encapsulation + * of the window around each value, + * {@code Map> -> Map}.
    • + *
    + */ + static class ToMapDoFn + extends DoFn>>>>, + IsmRecord, + V>>>> { + + private final Coder windowCoder; + ToMapDoFn(Coder windowCoder) { + this.windowCoder = windowCoder; + } + + @Override + public void processElement(ProcessContext c) + throws Exception { + Optional previousWindowStructuralValue = Optional.absent(); + Optional previousWindow = Optional.absent(); + Map> map = new HashMap<>(); + for (KV>> kv : c.element().getValue()) { + Object currentWindowStructuralValue = windowCoder.structuralValue(kv.getKey()); + if (previousWindowStructuralValue.isPresent() + && !previousWindowStructuralValue.get().equals(currentWindowStructuralValue)) { + // Construct the transformed map containing all the elements since we + // are at a window boundary. + c.output(IsmRecord.of( + ImmutableList.of(previousWindow.get()), + valueInEmptyWindows(new TransformedMap<>(WindowedValueToValue.of(), map)))); + map = new HashMap<>(); + } + + // Verify that the user isn't trying to insert the same key multiple times. + checkState(!map.containsKey(kv.getValue().getValue().getKey()), + "Multiple values [%s, %s] found for single key [%s] within window [%s].", + map.get(kv.getValue().getValue().getKey()), + kv.getValue().getValue().getValue(), + kv.getKey()); + map.put(kv.getValue().getValue().getKey(), + kv.getValue().withValue(kv.getValue().getValue().getValue())); + previousWindowStructuralValue = Optional.of(currentWindowStructuralValue); + previousWindow = Optional.of(kv.getKey()); + } + + // The last value for this hash is guaranteed to be at a window boundary + // so we output a transformed map containing all the elements since the last + // window boundary. + c.output(IsmRecord.of( + ImmutableList.of(previousWindow.get()), + valueInEmptyWindows(new TransformedMap<>(WindowedValueToValue.of(), map)))); + } + } + + private final DataflowPipelineRunner runner; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchViewAsMap(DataflowPipelineRunner runner, View.AsMap transform) { + this.runner = runner; + } + + @Override + public PCollectionView> apply(PCollection> input) { + return this.applyInternal(input); + } + + private PCollectionView> + applyInternal(PCollection> input) { + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + try { + PCollectionView> view = PCollectionViews.mapView( + input.getPipeline(), input.getWindowingStrategy(), inputCoder); + return BatchViewAsMultimap.applyForMapLike(runner, input, view, true /* unique keys */); + } catch (NonDeterministicException e) { + runner.recordViewUsesNonDeterministicKeyCoder(this); + + // Since the key coder is not deterministic, we convert the map into a singleton + // and return a singleton view equivalent. + return applyForSingletonFallback(input); + } + } + + @Override + protected String getKindString() { + return "BatchViewAsMap"; + } + + /** Transforms the input {@link PCollection} into a singleton {@link Map} per window. */ + private PCollectionView> + applyForSingletonFallback(PCollection> input) { + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + Coder, V>> transformCoder = + (Coder) SerializableCoder.of(WindowedValueToValue.class); + + Coder, V>> finalValueCoder = + TransformedMapCoder.of( + transformCoder, + MapCoder.of( + inputCoder.getKeyCoder(), + FullWindowedValueCoder.of(inputCoder.getValueCoder(), windowCoder))); + + TransformedMap, V> defaultValue = new TransformedMap<>( + WindowedValueToValue.of(), + ImmutableMap.>of()); + + return BatchViewAsSingleton., + TransformedMap, V>, + Map, + W> applyForSingleton( + runner, + input, + new ToMapDoFn(windowCoder), + true, + defaultValue, + finalValueCoder); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsMultimap View.AsMultimap} for the + * Dataflow runner in batch mode. + * + *

    Creates a set of {@code Ism} files sharded by the hash of the key's byte + * representation. Each record is structured as follows: + *

      + *
    • Key 1: User key K
    • + *
    • Key 2: Window
    • + *
    • Key 3: Index offset for a given key and window.
    • + *
    • Value: Windowed value
    • + *
    + * + *

    Alongside the data records, there are the following metadata records: + *

      + *
    • Key 1: Metadata Key
    • + *
    • Key 2: Window
    • + *
    • Key 3: Index [0, size of map]
    • + *
    • Value: variable length long byte representation of size of map if index is 0, + * otherwise the byte representation of a key
    • + *
    + * The {@code [META, Window, 0]} record stores the number of unique keys per window, while + * {@code [META, Window, i]} for {@code i} in {@code [1, size of map]} stores a the users key. + * This allows for one to access the size of the map by looking at {@code [META, Window, 0]} + * and iterate over all the keys by accessing {@code [META, Window, i]} for {@code i} in + * {@code [1, size of map]}. + * + *

    Note that in the case of a non-deterministic key coder, we fallback to using + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsSingleton View.AsSingleton} printing + * a warning to users to specify a deterministic key coder. + */ + static class BatchViewAsMultimap + extends PTransform>, PCollectionView>>> { + /** + * A {@link PTransform} that groups elements by the hash of window's byte representation + * if the input {@link PCollection} is not within the global window. Otherwise by the hash + * of the window and key's byte representation. This {@link PTransform} also sorts + * the values by the combination of the window and key's byte representations. + */ + private static class GroupByKeyHashAndSortByKeyAndWindow + extends PTransform>, + PCollection, WindowedValue>>>>> { + + @SystemDoFnInternal + private static class GroupByKeyHashAndSortByKeyAndWindowDoFn + extends DoFn, KV, WindowedValue>>> + implements DoFn.RequiresWindowAccess { + + private final IsmRecordCoder coder; + private GroupByKeyHashAndSortByKeyAndWindowDoFn(IsmRecordCoder coder) { + this.coder = coder; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + @SuppressWarnings("unchecked") + W window = (W) c.window(); + + c.output( + KV.of(coder.hash(ImmutableList.of(c.element().getKey())), + KV.of(KV.of(c.element().getKey(), window), + WindowedValue.of( + c.element().getValue(), + c.timestamp(), + (BoundedWindow) window, + c.pane())))); + } + } + + private final IsmRecordCoder coder; + public GroupByKeyHashAndSortByKeyAndWindow(IsmRecordCoder coder) { + this.coder = coder; + } + + @Override + public PCollection, WindowedValue>>>> + apply(PCollection> input) { + + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + @SuppressWarnings("unchecked") + KvCoder inputCoder = (KvCoder) input.getCoder(); + + PCollection, WindowedValue>>> keyedByHash; + keyedByHash = input.apply( + ParDo.of(new GroupByKeyHashAndSortByKeyAndWindowDoFn(coder))); + keyedByHash.setCoder( + KvCoder.of( + VarIntCoder.of(), + KvCoder.of(KvCoder.of(inputCoder.getKeyCoder(), windowCoder), + FullWindowedValueCoder.of(inputCoder.getValueCoder(), windowCoder)))); + + return keyedByHash.apply( + new GroupByKeyAndSortValuesOnly, WindowedValue>()); + } + } + + /** + * A {@link DoFn} which creates {@link IsmRecord}s comparing successive elements windows + * and keys to locate window and key boundaries. The main output {@link IsmRecord}s have: + *

      + *
    • Key 1: Window
    • + *
    • Key 2: User key K
    • + *
    • Key 3: Index offset for a given key and window.
    • + *
    • Value: Windowed value
    • + *
    + * + *

    Additionally, we output all the unique keys per window seen to {@code outputForEntrySet} + * and the unique key count per window to {@code outputForSize}. + * + *

    Finally, if this DoFn has been requested to perform unique key checking, it will + * throw an {@link IllegalStateException} if more than one key per window is found. + */ + static class ToIsmRecordForMapLikeDoFn + extends DoFn, WindowedValue>>>, + IsmRecord>> { + + private final TupleTag>> outputForSize; + private final TupleTag>> outputForEntrySet; + private final Coder windowCoder; + private final Coder keyCoder; + private final IsmRecordCoder> ismCoder; + private final boolean uniqueKeysExpected; + ToIsmRecordForMapLikeDoFn( + TupleTag>> outputForSize, + TupleTag>> outputForEntrySet, + Coder windowCoder, + Coder keyCoder, + IsmRecordCoder> ismCoder, + boolean uniqueKeysExpected) { + this.outputForSize = outputForSize; + this.outputForEntrySet = outputForEntrySet; + this.windowCoder = windowCoder; + this.keyCoder = keyCoder; + this.ismCoder = ismCoder; + this.uniqueKeysExpected = uniqueKeysExpected; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + long currentKeyIndex = 0; + // We use one based indexing while counting + long currentUniqueKeyCounter = 1; + Iterator, WindowedValue>> iterator = c.element().getValue().iterator(); + + KV, WindowedValue> currentValue = iterator.next(); + Object currentKeyStructuralValue = + keyCoder.structuralValue(currentValue.getKey().getKey()); + Object currentWindowStructuralValue = + windowCoder.structuralValue(currentValue.getKey().getValue()); + + while (iterator.hasNext()) { + KV, WindowedValue> nextValue = iterator.next(); + Object nextKeyStructuralValue = + keyCoder.structuralValue(nextValue.getKey().getKey()); + Object nextWindowStructuralValue = + windowCoder.structuralValue(nextValue.getKey().getValue()); + + outputDataRecord(c, currentValue, currentKeyIndex); + + final long nextKeyIndex; + final long nextUniqueKeyCounter; + + // Check to see if its a new window + if (!currentWindowStructuralValue.equals(nextWindowStructuralValue)) { + // The next value is a new window, so we output for size the number of unique keys + // seen and the last key of the window. We also reset the next key index the unique + // key counter. + outputMetadataRecordForSize(c, currentValue, currentUniqueKeyCounter); + outputMetadataRecordForEntrySet(c, currentValue); + + nextKeyIndex = 0; + nextUniqueKeyCounter = 1; + } else if (!currentKeyStructuralValue.equals(nextKeyStructuralValue)){ + // It is a new key within the same window so output the key for the entry set, + // reset the key index and increase the count of unique keys seen within this window. + outputMetadataRecordForEntrySet(c, currentValue); + + nextKeyIndex = 0; + nextUniqueKeyCounter = currentUniqueKeyCounter + 1; + } else if (!uniqueKeysExpected) { + // It is not a new key so we don't have to output the number of elements in this + // window or increase the unique key counter. All we do is increase the key index. + + nextKeyIndex = currentKeyIndex + 1; + nextUniqueKeyCounter = currentUniqueKeyCounter; + } else { + throw new IllegalStateException(String.format( + "Unique keys are expected but found key %s with values %s and %s in window %s.", + currentValue.getKey().getKey(), + currentValue.getValue().getValue(), + nextValue.getValue().getValue(), + currentValue.getKey().getValue())); + } + + currentValue = nextValue; + currentWindowStructuralValue = nextWindowStructuralValue; + currentKeyStructuralValue = nextKeyStructuralValue; + currentKeyIndex = nextKeyIndex; + currentUniqueKeyCounter = nextUniqueKeyCounter; + } + + outputDataRecord(c, currentValue, currentKeyIndex); + outputMetadataRecordForSize(c, currentValue, currentUniqueKeyCounter); + // The last value for this hash is guaranteed to be at a window boundary + // so we output a record with the number of unique keys seen. + outputMetadataRecordForEntrySet(c, currentValue); + } + + /** This outputs the data record. */ + private void outputDataRecord( + ProcessContext c, KV, WindowedValue> value, long keyIndex) { + IsmRecord> ismRecord = IsmRecord.of( + ImmutableList.of( + value.getKey().getKey(), + value.getKey().getValue(), + keyIndex), + value.getValue()); + c.output(ismRecord); + } + + /** + * This outputs records which will be used to compute the number of keys for a given window. + */ + private void outputMetadataRecordForSize( + ProcessContext c, KV, WindowedValue> value, long uniqueKeyCount) { + c.sideOutput(outputForSize, + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), + value.getKey().getValue())), + KV.of(value.getKey().getValue(), uniqueKeyCount))); + } + + /** This outputs records which will be used to construct the entry set. */ + private void outputMetadataRecordForEntrySet( + ProcessContext c, KV, WindowedValue> value) { + c.sideOutput(outputForEntrySet, + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), + value.getKey().getValue())), + KV.of(value.getKey().getValue(), value.getKey().getKey()))); + } + } + + /** + * A {@link DoFn} which outputs a metadata {@link IsmRecord} per window of: + *

      + *
    • Key 1: META key
    • + *
    • Key 2: window
    • + *
    • Key 3: 0L (constant)
    • + *
    • Value: sum of values for window
    • + *
    + * + *

    This {@link DoFn} is meant to be used to compute the number of unique keys + * per window for map and multimap side inputs. + */ + static class ToIsmMetadataRecordForSizeDoFn + extends DoFn>>, IsmRecord>> { + private final Coder windowCoder; + ToIsmMetadataRecordForSizeDoFn(Coder windowCoder) { + this.windowCoder = windowCoder; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + Iterator> iterator = c.element().getValue().iterator(); + KV currentValue = iterator.next(); + Object currentWindowStructuralValue = windowCoder.structuralValue(currentValue.getKey()); + long size = 0; + while (iterator.hasNext()) { + KV nextValue = iterator.next(); + Object nextWindowStructuralValue = windowCoder.structuralValue(nextValue.getKey()); + + size += currentValue.getValue(); + if (!currentWindowStructuralValue.equals(nextWindowStructuralValue)) { + c.output(IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), currentValue.getKey(), 0L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), size))); + size = 0; + } + + currentValue = nextValue; + currentWindowStructuralValue = nextWindowStructuralValue; + } + + size += currentValue.getValue(); + // Output the final value since it is guaranteed to be on a window boundary. + c.output(IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), currentValue.getKey(), 0L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), size))); + } + } + + /** + * A {@link DoFn} which outputs a metadata {@link IsmRecord} per window and key pair of: + *

      + *
    • Key 1: META key
    • + *
    • Key 2: window
    • + *
    • Key 3: index offset (1-based index)
    • + *
    • Value: key
    • + *
    + * + *

    This {@link DoFn} is meant to be used to output index to key records + * per window for map and multimap side inputs. + */ + static class ToIsmMetadataRecordForKeyDoFn + extends DoFn>>, IsmRecord>> { + + private final Coder keyCoder; + private final Coder windowCoder; + ToIsmMetadataRecordForKeyDoFn(Coder keyCoder, Coder windowCoder) { + this.keyCoder = keyCoder; + this.windowCoder = windowCoder; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + Iterator> iterator = c.element().getValue().iterator(); + KV currentValue = iterator.next(); + Object currentWindowStructuralValue = windowCoder.structuralValue(currentValue.getKey()); + long elementsInWindow = 1; + while (iterator.hasNext()) { + KV nextValue = iterator.next(); + Object nextWindowStructuralValue = windowCoder.structuralValue(nextValue.getKey()); + + c.output(IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), currentValue.getKey(), elementsInWindow), + CoderUtils.encodeToByteArray(keyCoder, currentValue.getValue()))); + elementsInWindow += 1; + + if (!currentWindowStructuralValue.equals(nextWindowStructuralValue)) { + elementsInWindow = 1; + } + + currentValue = nextValue; + currentWindowStructuralValue = nextWindowStructuralValue; + } + + // Output the final value since it is guaranteed to be on a window boundary. + c.output(IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), currentValue.getKey(), elementsInWindow), + CoderUtils.encodeToByteArray(keyCoder, currentValue.getValue()))); + } + } + + /** + * A {@link DoFn} which partitions sets of elements by window boundaries. Within each + * partition, the set of elements is transformed into a {@link TransformedMap}. + * The transformed {@code Map>} is backed by a + * {@code Map>>} and contains a function + * {@code Iterable> -> Iterable}. + * + *

    Outputs {@link IsmRecord}s having: + *

      + *
    • Key 1: Window
    • + *
    • Value: Transformed map containing a transform that removes the encapsulation + * of the window around each value, + * {@code Map>> -> Map>}.
    • + *
    + */ + static class ToMultimapDoFn + extends DoFn>>>>, + IsmRecord>, + Iterable>>>> { + + private final Coder windowCoder; + ToMultimapDoFn(Coder windowCoder) { + this.windowCoder = windowCoder; + } + + @Override + public void processElement(ProcessContext c) + throws Exception { + Optional previousWindowStructuralValue = Optional.absent(); + Optional previousWindow = Optional.absent(); + Multimap> multimap = HashMultimap.create(); + for (KV>> kv : c.element().getValue()) { + Object currentWindowStructuralValue = windowCoder.structuralValue(kv.getKey()); + if (previousWindowStructuralValue.isPresent() + && !previousWindowStructuralValue.get().equals(currentWindowStructuralValue)) { + // Construct the transformed map containing all the elements since we + // are at a window boundary. + @SuppressWarnings({"unchecked", "rawtypes"}) + Map>> resultMap = (Map) multimap.asMap(); + c.output(IsmRecord.>, + Iterable>>>of( + ImmutableList.of(previousWindow.get()), + valueInEmptyWindows( + new TransformedMap<>( + IterableWithWindowedValuesToIterable.of(), resultMap)))); + multimap = HashMultimap.create(); + } + + multimap.put(kv.getValue().getValue().getKey(), + kv.getValue().withValue(kv.getValue().getValue().getValue())); + previousWindowStructuralValue = Optional.of(currentWindowStructuralValue); + previousWindow = Optional.of(kv.getKey()); + } + + // The last value for this hash is guaranteed to be at a window boundary + // so we output a transformed map containing all the elements since the last + // window boundary. + @SuppressWarnings({"unchecked", "rawtypes"}) + Map>> resultMap = (Map) multimap.asMap(); + c.output(IsmRecord.>, + Iterable>>>of( + ImmutableList.of(previousWindow.get()), + valueInEmptyWindows( + new TransformedMap<>(IterableWithWindowedValuesToIterable.of(), resultMap)))); + } + } + + private final DataflowPipelineRunner runner; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchViewAsMultimap(DataflowPipelineRunner runner, View.AsMultimap transform) { + this.runner = runner; + } + + @Override + public PCollectionView>> apply(PCollection> input) { + return this.applyInternal(input); + } + + private PCollectionView>> + applyInternal(PCollection> input) { + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + try { + PCollectionView>> view = PCollectionViews.multimapView( + input.getPipeline(), input.getWindowingStrategy(), inputCoder); + + return applyForMapLike(runner, input, view, false /* unique keys not expected */); + } catch (NonDeterministicException e) { + runner.recordViewUsesNonDeterministicKeyCoder(this); + + // Since the key coder is not deterministic, we convert the map into a singleton + // and return a singleton view equivalent. + return applyForSingletonFallback(input); + } + } + + /** Transforms the input {@link PCollection} into a singleton {@link Map} per window. */ + private PCollectionView>> + applyForSingletonFallback(PCollection> input) { + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + Coder>, Iterable>> transformCoder = + (Coder) SerializableCoder.of(IterableWithWindowedValuesToIterable.class); + + Coder>, Iterable>> finalValueCoder = + TransformedMapCoder.of( + transformCoder, + MapCoder.of( + inputCoder.getKeyCoder(), + IterableCoder.of( + FullWindowedValueCoder.of(inputCoder.getValueCoder(), windowCoder)))); + + TransformedMap>, Iterable> defaultValue = + new TransformedMap<>( + IterableWithWindowedValuesToIterable.of(), + ImmutableMap.>>of()); + + return BatchViewAsSingleton., + TransformedMap>, Iterable>, + Map>, + W> applyForSingleton( + runner, + input, + new ToMultimapDoFn(windowCoder), + true, + defaultValue, + finalValueCoder); + } + + private static PCollectionView applyForMapLike( + DataflowPipelineRunner runner, + PCollection> input, + PCollectionView view, + boolean uniqueKeysExpected) throws NonDeterministicException { + + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) + input.getWindowingStrategy().getWindowFn().windowCoder(); + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + + // If our key coder is deterministic, we can use the key portion of each KV + // part of a composite key containing the window , key and index. + inputCoder.getKeyCoder().verifyDeterministic(); + + IsmRecordCoder> ismCoder = + coderForMapLike(windowCoder, inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + + // Create the various output tags representing the main output containing the data stream + // and the side outputs containing the metadata about the size and entry set. + TupleTag>> mainOutputTag = new TupleTag<>(); + TupleTag>> outputForSizeTag = new TupleTag<>(); + TupleTag>> outputForEntrySetTag = new TupleTag<>(); + + // Process all the elements grouped by key hash, and sorted by key and then window + // outputting to all the outputs defined above. + PCollectionTuple outputTuple = input + .apply("GBKaSVForData", new GroupByKeyHashAndSortByKeyAndWindow(ismCoder)) + .apply(ParDo.of(new ToIsmRecordForMapLikeDoFn( + outputForSizeTag, outputForEntrySetTag, + windowCoder, inputCoder.getKeyCoder(), ismCoder, uniqueKeysExpected)) + .withOutputTags(mainOutputTag, + TupleTagList.of( + ImmutableList.>of(outputForSizeTag, + outputForEntrySetTag)))); + + // Set the coder on the main data output. + PCollection>> perHashWithReifiedWindows = + outputTuple.get(mainOutputTag); + perHashWithReifiedWindows.setCoder(ismCoder); + + // Set the coder on the metadata output for size and process the entries + // producing a [META, Window, 0L] record per window storing the number of unique keys + // for each window. + PCollection>> outputForSize = outputTuple.get(outputForSizeTag); + outputForSize.setCoder( + KvCoder.of(VarIntCoder.of(), + KvCoder.of(windowCoder, VarLongCoder.of()))); + PCollection>> windowMapSizeMetadata = outputForSize + .apply("GBKaSVForSize", new GroupByKeyAndSortValuesOnly()) + .apply(ParDo.of(new ToIsmMetadataRecordForSizeDoFn(windowCoder))); + windowMapSizeMetadata.setCoder(ismCoder); + + // Set the coder on the metadata output destined to build the entry set and process the + // entries producing a [META, Window, Index] record per window key pair storing the key. + PCollection>> outputForEntrySet = + outputTuple.get(outputForEntrySetTag); + outputForEntrySet.setCoder( + KvCoder.of(VarIntCoder.of(), + KvCoder.of(windowCoder, inputCoder.getKeyCoder()))); + PCollection>> windowMapKeysMetadata = outputForEntrySet + .apply("GBKaSVForKeys", new GroupByKeyAndSortValuesOnly()) + .apply(ParDo.of( + new ToIsmMetadataRecordForKeyDoFn(inputCoder.getKeyCoder(), windowCoder))); + windowMapKeysMetadata.setCoder(ismCoder); + + // Set that all these outputs should be materialized using an indexed format. + runner.addPCollectionRequiringIndexedFormat(perHashWithReifiedWindows); + runner.addPCollectionRequiringIndexedFormat(windowMapSizeMetadata); + runner.addPCollectionRequiringIndexedFormat(windowMapKeysMetadata); + + PCollectionList>> outputs = + PCollectionList.of(ImmutableList.of( + perHashWithReifiedWindows, windowMapSizeMetadata, windowMapKeysMetadata)); + + return Pipeline.applyTransform(outputs, + Flatten.>>pCollections()) + .apply(CreatePCollectionView.>, + ViewT>of(view)); + } + + @Override + protected String getKindString() { + return "BatchViewAsMultimap"; + } + + static IsmRecordCoder> coderForMapLike( + Coder windowCoder, Coder keyCoder, Coder valueCoder) { + // TODO: swap to use a variable length long coder which has values which compare + // the same as their byte representation compare lexicographically within the key coder + return IsmRecordCoder.of( + 1, // We use only the key for hashing when producing value records + 2, // Since the key is not present, we add the window to the hash when + // producing metadata records + ImmutableList.of( + MetadataKeyCoder.of(keyCoder), + windowCoder, + BigEndianLongCoder.of()), + FullWindowedValueCoder.of(valueCoder, windowCoder)); + } + } + + /** + * A {@code Map} backed by a {@code Map} and a function that transforms + * {@code V1 -> V2}. + */ + static class TransformedMap + extends ForwardingMap { + private final Function transform; + private final Map originalMap; + private final Map transformedMap; + + private TransformedMap(Function transform, Map originalMap) { + this.transform = transform; + this.originalMap = Collections.unmodifiableMap(originalMap); + this.transformedMap = Maps.transformValues(originalMap, transform); + } + + @Override + protected Map delegate() { + return transformedMap; + } + } + + /** + * A {@link Coder} for {@link TransformedMap}s. + */ + static class TransformedMapCoder + extends StandardCoder> { + private final Coder> transformCoder; + private final Coder> originalMapCoder; + + private TransformedMapCoder( + Coder> transformCoder, Coder> originalMapCoder) { + this.transformCoder = transformCoder; + this.originalMapCoder = originalMapCoder; + } + + public static TransformedMapCoder of( + Coder> transformCoder, Coder> originalMapCoder) { + return new TransformedMapCoder<>(transformCoder, originalMapCoder); + } + + @JsonCreator + public static TransformedMapCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + @SuppressWarnings("unchecked") + Coder> transformCoder = (Coder>) components.get(0); + @SuppressWarnings("unchecked") + Coder> originalMapCoder = (Coder>) components.get(1); + return of(transformCoder, originalMapCoder); + } + + @Override + public void encode(TransformedMap value, OutputStream outStream, + Coder.Context context) throws CoderException, IOException { + transformCoder.encode(value.transform, outStream, context.nested()); + originalMapCoder.encode(value.originalMap, outStream, context.nested()); + } + + @Override + public TransformedMap decode( + InputStream inStream, Coder.Context context) throws CoderException, IOException { + return new TransformedMap<>( + transformCoder.decode(inStream, context.nested()), + originalMapCoder.decode(inStream, context.nested())); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(transformCoder, originalMapCoder); + } + + @Override + public void verifyDeterministic() + throws com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException { + verifyDeterministic("Expected transform coder to be deterministic.", transformCoder); + verifyDeterministic("Expected map coder to be deterministic.", originalMapCoder); + } + } + + /** + * A {@link Function} which converts {@code WindowedValue} to {@code V}. + */ + private static class WindowedValueToValue implements + Function, V>, Serializable { + private static final WindowedValueToValue INSTANCE = new WindowedValueToValue<>(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static WindowedValueToValue of() { + return (WindowedValueToValue) INSTANCE; + } + + @Override + public V apply(WindowedValue input) { + return input.getValue(); + } + } + + /** + * A {@link Function} which converts {@code Iterable>} to {@code Iterable}. + */ + private static class IterableWithWindowedValuesToIterable implements + Function>, Iterable>, Serializable { + private static final IterableWithWindowedValuesToIterable INSTANCE = + new IterableWithWindowedValuesToIterable<>(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static IterableWithWindowedValuesToIterable of() { + return (IterableWithWindowedValuesToIterable) INSTANCE; + } + + @Override + public Iterable apply(Iterable> input) { + return Iterables.transform(input, WindowedValueToValue.of()); + } + } + + /** + * A {@link PTransform} that uses shuffle to create a fusion break. This allows pushing + * parallelism limits such as sharding controls further down the pipeline. + */ + private static class ReshardForWrite extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + return input + // TODO: This would need to be adapted to write per-window shards. + .apply( + Window.into(new GlobalWindows()) + .triggering(DefaultTrigger.of()) + .discardingFiredPanes()) + .apply( + "RandomKey", + ParDo.of( + new DoFn>() { + transient long counter, step; + + @Override + public void startBundle(Context c) { + counter = (long) (Math.random() * Long.MAX_VALUE); + step = 1 + 2 * (long) (Math.random() * Long.MAX_VALUE); + } + + @Override + public void processElement(ProcessContext c) { + counter += step; + c.output(KV.of(counter, c.element())); + } + })) + .apply(GroupByKey.create()) + .apply( + "Ungroup", + ParDo.of( + new DoFn>, T>() { + @Override + public void processElement(ProcessContext c) { + for (T item : c.element().getValue()) { + c.output(item); + } + } + })); + } + } + + /** + * Specialized implementation which overrides + * {@link com.google.cloud.dataflow.sdk.io.Write.Bound Write.Bound} to provide Google + * Cloud Dataflow specific path validation of {@link FileBasedSink}s. + */ + private static class BatchWrite extends PTransform, PDone> { + private final DataflowPipelineRunner runner; + private final Write.Bound transform; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchWrite(DataflowPipelineRunner runner, Write.Bound transform) { + this.runner = runner; + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + if (transform.getSink() instanceof FileBasedSink) { + FileBasedSink sink = (FileBasedSink) transform.getSink(); + PathValidator validator = runner.options.getPathValidator(); + validator.validateOutputFilePrefixSupported(sink.getBaseOutputFilename()); + } + return transform.apply(input); + } + } + + /** + * Specialized implementation which overrides + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Write.Bound TextIO.Write.Bound} with + * a native sink instead of a custom sink as workaround until custom sinks + * have support for sharding controls. + */ + private static class BatchTextIOWrite extends PTransform, PDone> { + private final TextIO.Write.Bound transform; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchTextIOWrite(DataflowPipelineRunner runner, TextIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + if (transform.getNumShards() > 0) { + return input + .apply(new ReshardForWrite()) + .apply(new BatchTextIONativeWrite<>(transform)); + } else { + return transform.apply(input); + } + } + } + + /** + * This {@link PTransform} is used by the {@link DataflowPipelineTranslator} as a way + * to provide the native definition of the Text sink. + */ + private static class BatchTextIONativeWrite extends PTransform, PDone> { + private final TextIO.Write.Bound transform; + public BatchTextIONativeWrite(TextIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + return PDone.in(input.getPipeline()); + } + + static { + DataflowPipelineTranslator.registerTransformTranslator( + BatchTextIONativeWrite.class, new BatchTextIONativeWriteTranslator()); + } + } + + /** + * TextIO.Write.Bound support code for the Dataflow backend when applying parallelism limits + * through user requested sharding limits. + */ + private static class BatchTextIONativeWriteTranslator + implements TransformTranslator> { + @SuppressWarnings("unchecked") + @Override + public void translate(@SuppressWarnings("rawtypes") BatchTextIONativeWrite transform, + TranslationContext context) { + translateWriteHelper(transform, transform.transform, context); + } + + private void translateWriteHelper( + BatchTextIONativeWrite transform, + TextIO.Write.Bound originalTransform, + TranslationContext context) { + // Note that the original transform can not be used during add step/add input + // and is only passed in to get properties from it. + + checkState(originalTransform.getNumShards() > 0, + "Native TextSink is expected to only be used when sharding controls are required."); + + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + + // TODO: drop this check when server supports alternative templates. + switch (originalTransform.getShardTemplate()) { + case ShardNameTemplate.INDEX_OF_MAX: + break; // supported by server + case "": + // Empty shard template allowed - forces single output. + Preconditions.checkArgument(originalTransform.getNumShards() <= 1, + "Num shards must be <= 1 when using an empty sharding template"); + break; + default: + throw new UnsupportedOperationException("Shard template " + + originalTransform.getShardTemplate() + + " not yet supported by Dataflow service"); + } + + // TODO: How do we want to specify format and + // format-specific properties? + context.addInput(PropertyNames.FORMAT, "text"); + context.addInput(PropertyNames.FILENAME_PREFIX, originalTransform.getFilenamePrefix()); + context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, + originalTransform.getShardNameTemplate()); + context.addInput(PropertyNames.FILENAME_SUFFIX, originalTransform.getFilenameSuffix()); + context.addInput(PropertyNames.VALIDATE_SINK, originalTransform.needsValidation()); + context.addInput(PropertyNames.NUM_SHARDS, (long) originalTransform.getNumShards()); + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(originalTransform.getCoder())); + + } + } + + /** + * Specialized implementation which overrides + * {@link com.google.cloud.dataflow.sdk.io.AvroIO.Write.Bound AvroIO.Write.Bound} with + * a native sink instead of a custom sink as workaround until custom sinks + * have support for sharding controls. + */ + private static class BatchAvroIOWrite extends PTransform, PDone> { + private final AvroIO.Write.Bound transform; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchAvroIOWrite(DataflowPipelineRunner runner, AvroIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + if (transform.getNumShards() > 0) { + return input + .apply(new ReshardForWrite()) + .apply(new BatchAvroIONativeWrite<>(transform)); + } else { + return transform.apply(input); + } + } + } + + /** + * This {@link PTransform} is used by the {@link DataflowPipelineTranslator} as a way + * to provide the native definition of the Avro sink. + */ + private static class BatchAvroIONativeWrite extends PTransform, PDone> { + private final AvroIO.Write.Bound transform; + public BatchAvroIONativeWrite(AvroIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + return PDone.in(input.getPipeline()); + } + + static { + DataflowPipelineTranslator.registerTransformTranslator( + BatchAvroIONativeWrite.class, new BatchAvroIONativeWriteTranslator()); + } + } + + /** + * AvroIO.Write.Bound support code for the Dataflow backend when applying parallelism limits + * through user requested sharding limits. + */ + private static class BatchAvroIONativeWriteTranslator + implements TransformTranslator> { + @SuppressWarnings("unchecked") + @Override + public void translate(@SuppressWarnings("rawtypes") BatchAvroIONativeWrite transform, + TranslationContext context) { + translateWriteHelper(transform, transform.transform, context); + } + + private void translateWriteHelper( + BatchAvroIONativeWrite transform, + AvroIO.Write.Bound originalTransform, + TranslationContext context) { + // Note that the original transform can not be used during add step/add input + // and is only passed in to get properties from it. + + checkState(originalTransform.getNumShards() > 0, + "Native AvroSink is expected to only be used when sharding controls are required."); + + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + + // TODO: drop this check when server supports alternative templates. + switch (originalTransform.getShardTemplate()) { + case ShardNameTemplate.INDEX_OF_MAX: + break; // supported by server + case "": + // Empty shard template allowed - forces single output. + Preconditions.checkArgument(originalTransform.getNumShards() <= 1, + "Num shards must be <= 1 when using an empty sharding template"); + break; + default: + throw new UnsupportedOperationException("Shard template " + + originalTransform.getShardTemplate() + + " not yet supported by Dataflow service"); + } + + context.addInput(PropertyNames.FORMAT, "avro"); + context.addInput(PropertyNames.FILENAME_PREFIX, originalTransform.getFilenamePrefix()); + context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, originalTransform.getShardTemplate()); + context.addInput(PropertyNames.FILENAME_SUFFIX, originalTransform.getFilenameSuffix()); + context.addInput(PropertyNames.VALIDATE_SINK, originalTransform.needsValidation()); + context.addInput(PropertyNames.NUM_SHARDS, (long) originalTransform.getNumShards()); + context.addEncodingInput( + WindowedValue.getValueOnlyCoder( + AvroCoder.of(originalTransform.getType(), originalTransform.getSchema()))); + } + } + + /** + * Specialized (non-)implementation for + * {@link com.google.cloud.dataflow.sdk.io.Write.Bound Write.Bound} + * for the Dataflow runner in streaming mode. + */ + private static class StreamingWrite extends PTransform, PDone> { + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingWrite(DataflowPipelineRunner runner, Write.Bound transform) { } + + @Override + public PDone apply(PCollection input) { + throw new UnsupportedOperationException( + "The Write transform is not supported by the Dataflow streaming runner."); + } + + @Override + protected String getKindString() { + return "StreamingWrite"; + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.io.PubsubIO.Write PubsubIO.Write} for the + * Dataflow runner in streaming mode. + * + *

    For internal use only. Subject to change at any time. + * + *

    Public so the {@link PubsubIOTranslator} can access. + */ + public static class StreamingPubsubIOWrite extends PTransform, PDone> { + private final PubsubIO.Write.Bound transform; + + /** + * Builds an instance of this class from the overridden transform. + */ + public StreamingPubsubIOWrite( + DataflowPipelineRunner runner, PubsubIO.Write.Bound transform) { + this.transform = transform; + } + + public PubsubIO.Write.Bound getOverriddenTransform() { + return transform; + } + + @Override + public PDone apply(PCollection input) { + return PDone.in(input.getPipeline()); + } + + @Override + protected String getKindString() { + return "StreamingPubsubIOWrite"; + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.io.Read.Unbounded Read.Unbounded} for the + * Dataflow runner in streaming mode. + * + *

    In particular, if an UnboundedSource requires deduplication, then features of WindmillSink + * are leveraged to do the deduplication. + */ + private static class StreamingUnboundedRead extends PTransform> { + private final UnboundedSource source; + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingUnboundedRead(DataflowPipelineRunner runner, Read.Unbounded transform) { + this.source = transform.getSource(); + } + + @Override + protected Coder getDefaultOutputCoder() { + return source.getDefaultOutputCoder(); + } + + @Override + public final PCollection apply(PInput input) { + source.validate(); + + if (source.requiresDeduping()) { + return Pipeline.applyTransform(input, new ReadWithIds(source)) + .apply(new Deduplicate()); + } else { + return Pipeline.applyTransform(input, new ReadWithIds(source)) + .apply(ValueWithRecordId.stripIds()); + } + } + + /** + * {@link PTransform} that reads {@code (record,recordId)} pairs from an + * {@link UnboundedSource}. + */ + private static class ReadWithIds + extends PTransform>> { + private final UnboundedSource source; + + private ReadWithIds(UnboundedSource source) { + this.source = source; + } + + @Override + public final PCollection> apply(PInput input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED); + } + + @Override + protected Coder> getDefaultOutputCoder() { + return ValueWithRecordId.ValueWithRecordIdCoder.of(source.getDefaultOutputCoder()); + } + + public UnboundedSource getSource() { + return source; + } + } + + @Override + public String getKindString() { + return "Read(" + approximateSimpleName(source.getClass()) + ")"; + } + + static { + DataflowPipelineTranslator.registerTransformTranslator( + ReadWithIds.class, new ReadWithIdsTranslator()); + } + + private static class ReadWithIdsTranslator + implements DataflowPipelineTranslator.TransformTranslator> { + @Override + public void translate(ReadWithIds transform, + DataflowPipelineTranslator.TranslationContext context) { + ReadTranslator.translateReadHelper(transform.getSource(), transform, context); + } + } + } + + /** + * Remove values with duplicate ids. + */ + private static class Deduplicate + extends PTransform>, PCollection> { + // Use a finite set of keys to improve bundling. Without this, the key space + // will be the space of ids which is potentially very large, which results in much + // more per-key overhead. + private static final int NUM_RESHARD_KEYS = 10000; + @Override + public PCollection apply(PCollection> input) { + return input + .apply(WithKeys.of(new SerializableFunction, Integer>() { + @Override + public Integer apply(ValueWithRecordId value) { + return Arrays.hashCode(value.getId()) % NUM_RESHARD_KEYS; + } + })) + // Reshuffle will dedup based on ids in ValueWithRecordId by passing the data through + // WindmillSink. + .apply(Reshuffle.>of()) + .apply(ParDo.named("StripIds").of( + new DoFn>, T>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getValue().getValue()); + } + })); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.Create.Values Create.Values} for the + * Dataflow runner in streaming mode. + */ + private static class StreamingCreate extends PTransform> { + private final Create.Values transform; + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingCreate(DataflowPipelineRunner runner, Create.Values transform) { + this.transform = transform; + } + + /** + * {@link DoFn} that outputs a single KV.of(null, null) kick off the {@link GroupByKey} + * in the streaming create implementation. + */ + private static class OutputNullKv extends DoFn> { + @Override + public void processElement(DoFn>.ProcessContext c) throws Exception { + c.output(KV.of((Void) null, (Void) null)); + } + } + + /** + * A {@link DoFn} which outputs the specified elements by first encoding them to bytes using + * the specified {@link Coder} so that they are serialized as part of the {@link DoFn} but + * need not implement {@code Serializable}. + */ + private static class OutputElements extends DoFn { + private final Coder coder; + private final List encodedElements; + + public OutputElements(Iterable elems, Coder coder) { + this.coder = coder; + this.encodedElements = new ArrayList<>(); + for (T t : elems) { + try { + encodedElements.add(CoderUtils.encodeToByteArray(coder, t)); + } catch (CoderException e) { + throw new IllegalArgumentException("Unable to encode value " + t + + " with coder " + coder, e); + } + } + } + + @Override + public void processElement(ProcessContext c) throws IOException { + for (byte[] encodedElement : encodedElements) { + c.output(CoderUtils.decodeFromByteArray(coder, encodedElement)); + } + } + } + + @Override + public PCollection apply(PInput input) { + try { + Coder coder = transform.getDefaultOutputCoder(input); + return Pipeline.applyTransform( + input, PubsubIO.Read.named("StartingSignal").subscription("_starting_signal/")) + .apply(ParDo.of(new OutputNullKv())) + .apply("GlobalSingleton", Window.>into(new GlobalWindows()) + .triggering(AfterPane.elementCountAtLeast(1)) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()) + .apply(GroupByKey.create()) + // Go back to the default windowing strategy, so that our setting allowed lateness + // doesn't count as the user having set it. + .setWindowingStrategyInternal(WindowingStrategy.globalDefault()) + .apply(Window.>>into(new GlobalWindows())) + .apply(ParDo.of(new OutputElements<>(transform.getElements(), coder))) + .setCoder(coder).setIsBoundedInternal(IsBounded.BOUNDED); + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException("Unable to infer a coder and no Coder was specified. " + + "Please set a coder by invoking Create.withCoder() explicitly.", e); + } + } + + @Override + protected String getKindString() { + return "StreamingCreate"; + } + } + + /** + * A specialized {@link DoFn} for writing the contents of a {@link PCollection} + * to a streaming {@link PCollectionView} backend implementation. + */ + private static class StreamingPCollectionViewWriterFn + extends DoFn, T> implements DoFn.RequiresWindowAccess { + private final PCollectionView view; + private final Coder dataCoder; + + public static StreamingPCollectionViewWriterFn create( + PCollectionView view, Coder dataCoder) { + return new StreamingPCollectionViewWriterFn(view, dataCoder); + } + + private StreamingPCollectionViewWriterFn(PCollectionView view, Coder dataCoder) { + this.view = view; + this.dataCoder = dataCoder; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + List> output = new ArrayList<>(); + for (T elem : c.element()) { + output.add(WindowedValue.of(elem, c.timestamp(), c.window(), c.pane())); + } + + c.windowingInternals().writePCollectionViewData( + view.getTagInternal(), output, dataCoder); + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsMap View.AsMap} + * for the Dataflow runner in streaming mode. + */ + private static class StreamingViewAsMap + extends PTransform>, PCollectionView>> { + private final DataflowPipelineRunner runner; + + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingViewAsMap(DataflowPipelineRunner runner, View.AsMap transform) { + this.runner = runner; + } + + @Override + public PCollectionView> apply(PCollection> input) { + PCollectionView> view = + PCollectionViews.mapView( + input.getPipeline(), + input.getWindowingStrategy(), + input.getCoder()); + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + try { + inputCoder.getKeyCoder().verifyDeterministic(); + } catch (NonDeterministicException e) { + runner.recordViewUsesNonDeterministicKeyCoder(this); + } + + return input + .apply(Combine.globally(new Concatenate>()).withoutDefaults()) + .apply(ParDo.of(StreamingPCollectionViewWriterFn.create(view, input.getCoder()))) + .apply(View.CreatePCollectionView., Map>of(view)); + } + + @Override + protected String getKindString() { + return "StreamingViewAsMap"; + } + } + + /** + * Specialized expansion for {@link + * com.google.cloud.dataflow.sdk.transforms.View.AsMultimap View.AsMultimap} for the + * Dataflow runner in streaming mode. + */ + private static class StreamingViewAsMultimap + extends PTransform>, PCollectionView>>> { + private final DataflowPipelineRunner runner; + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingViewAsMultimap(DataflowPipelineRunner runner, View.AsMultimap transform) { + this.runner = runner; + } + + @Override + public PCollectionView>> apply(PCollection> input) { + PCollectionView>> view = + PCollectionViews.multimapView( + input.getPipeline(), + input.getWindowingStrategy(), + input.getCoder()); + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder inputCoder = (KvCoder) input.getCoder(); + try { + inputCoder.getKeyCoder().verifyDeterministic(); + } catch (NonDeterministicException e) { + runner.recordViewUsesNonDeterministicKeyCoder(this); + } + + return input + .apply(Combine.globally(new Concatenate>()).withoutDefaults()) + .apply(ParDo.of(StreamingPCollectionViewWriterFn.create(view, input.getCoder()))) + .apply(View.CreatePCollectionView., Map>>of(view)); + } + + @Override + protected String getKindString() { + return "StreamingViewAsMultimap"; + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsList View.AsList} for the + * Dataflow runner in streaming mode. + */ + private static class StreamingViewAsList + extends PTransform, PCollectionView>> { + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingViewAsList(DataflowPipelineRunner runner, View.AsList transform) {} + + @Override + public PCollectionView> apply(PCollection input) { + PCollectionView> view = + PCollectionViews.listView( + input.getPipeline(), + input.getWindowingStrategy(), + input.getCoder()); + + return input.apply(Combine.globally(new Concatenate()).withoutDefaults()) + .apply(ParDo.of(StreamingPCollectionViewWriterFn.create(view, input.getCoder()))) + .apply(View.CreatePCollectionView.>of(view)); + } + + @Override + protected String getKindString() { + return "StreamingViewAsList"; + } + } + + /** + * Specialized implementation for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsIterable View.AsIterable} for the + * Dataflow runner in streaming mode. + */ + private static class StreamingViewAsIterable + extends PTransform, PCollectionView>> { + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingViewAsIterable(DataflowPipelineRunner runner, View.AsIterable transform) { } + + @Override + public PCollectionView> apply(PCollection input) { + PCollectionView> view = + PCollectionViews.iterableView( + input.getPipeline(), + input.getWindowingStrategy(), + input.getCoder()); + + return input.apply(Combine.globally(new Concatenate()).withoutDefaults()) + .apply(ParDo.of(StreamingPCollectionViewWriterFn.create(view, input.getCoder()))) + .apply(View.CreatePCollectionView.>of(view)); + } + + @Override + protected String getKindString() { + return "StreamingViewAsIterable"; + } + } + + private static class WrapAsList extends DoFn> { + @Override + public void processElement(ProcessContext c) { + c.output(Arrays.asList(c.element())); + } + } + + /** + * Specialized expansion for + * {@link com.google.cloud.dataflow.sdk.transforms.View.AsSingleton View.AsSingleton} for the + * Dataflow runner in streaming mode. + */ + private static class StreamingViewAsSingleton + extends PTransform, PCollectionView> { + private View.AsSingleton transform; + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingViewAsSingleton(DataflowPipelineRunner runner, View.AsSingleton transform) { + this.transform = transform; + } + + @Override + public PCollectionView apply(PCollection input) { + Combine.Globally combine = Combine.globally( + new SingletonCombine<>(transform.hasDefaultValue(), transform.defaultValue())); + if (!transform.hasDefaultValue()) { + combine = combine.withoutDefaults(); + } + return input.apply(combine.asSingletonView()); + } + + @Override + protected String getKindString() { + return "StreamingViewAsSingleton"; + } + + private static class SingletonCombine extends Combine.BinaryCombineFn { + private boolean hasDefaultValue; + private T defaultValue; + + SingletonCombine(boolean hasDefaultValue, T defaultValue) { + this.hasDefaultValue = hasDefaultValue; + this.defaultValue = defaultValue; + } + + @Override + public T apply(T left, T right) { + throw new IllegalArgumentException("PCollection with more than one element " + + "accessed as a singleton view. Consider using Combine.globally().asSingleton() to " + + "combine the PCollection into a single value"); + } + + @Override + public T identity() { + if (hasDefaultValue) { + return defaultValue; + } else { + throw new IllegalArgumentException( + "Empty PCollection accessed as a singleton view. " + + "Consider setting withDefault to provide a default value"); + } + } + } + } + + private static class StreamingCombineGloballyAsSingletonView + extends PTransform, PCollectionView> { + Combine.GloballyAsSingletonView transform; + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public StreamingCombineGloballyAsSingletonView( + DataflowPipelineRunner runner, + Combine.GloballyAsSingletonView transform) { + this.transform = transform; + } + + @Override + public PCollectionView apply(PCollection input) { + PCollection combined = + input.apply(Combine.globally(transform.getCombineFn()) + .withoutDefaults() + .withFanout(transform.getFanout())); + + PCollectionView view = PCollectionViews.singletonView( + combined.getPipeline(), + combined.getWindowingStrategy(), + transform.getInsertDefault(), + transform.getInsertDefault() + ? transform.getCombineFn().defaultValue() : null, + combined.getCoder()); + return combined + .apply(ParDo.of(new WrapAsList())) + .apply(ParDo.of(StreamingPCollectionViewWriterFn.create(view, combined.getCoder()))) + .apply(View.CreatePCollectionView.of(view)); + } + + @Override + protected String getKindString() { + return "StreamingCombineGloballyAsSingletonView"; + } + } + + /** + * Combiner that combines {@code T}s into a single {@code List} containing all inputs. + * + *

    For internal use by {@link StreamingViewAsMap}, {@link StreamingViewAsMultimap}, + * {@link StreamingViewAsList}, {@link StreamingViewAsIterable}. + * They require the input {@link PCollection} fits in memory. + * For a large {@link PCollection} this is expected to crash! + * + * @param the type of elements to concatenate. + */ + private static class Concatenate extends CombineFn, List> { + @Override + public List createAccumulator() { + return new ArrayList(); + } + + @Override + public List addInput(List accumulator, T input) { + accumulator.add(input); + return accumulator; + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + List result = createAccumulator(); + for (List accumulator : accumulators) { + result.addAll(accumulator); + } + return result; + } + + @Override + public List extractOutput(List accumulator) { + return accumulator; + } + + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); + } + + @Override + public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); + } + } + + /** + * Specialized expansion for unsupported IO transforms that throws an error. + */ + private static class UnsupportedIO + extends PTransform { + private PTransform transform; + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, AvroIO.Read.Bound transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, BigQueryIO.Read.Bound transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, TextIO.Read.Bound transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, Read.Bounded transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, Read.Unbounded transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, AvroIO.Write.Bound transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, TextIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public OutputT apply(InputT input) { + String mode = input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming() + ? "streaming" : "batch"; + throw new UnsupportedOperationException( + String.format("The DataflowPipelineRunner in %s mode does not support %s.", + mode, approximatePTransformName(transform.getClass()))); + } + } + + @Override + public String toString() { + return "DataflowPipelineRunner#" + options.getJobName(); + } + + /** + * Attempts to detect all the resources the class loader has access to. This does not recurse + * to class loader parents stopping it from pulling in resources from the system class loader. + * + * @param classLoader The URLClassLoader to use to detect resources to stage. + * @throws IllegalArgumentException If either the class loader is not a URLClassLoader or one + * of the resources the class loader exposes is not a file resource. + * @return A list of absolute paths to the resources the class loader uses. + */ + protected static List detectClassPathResourcesToStage(ClassLoader classLoader) { + if (!(classLoader instanceof URLClassLoader)) { + String message = String.format("Unable to use ClassLoader to detect classpath elements. " + + "Current ClassLoader is %s, only URLClassLoaders are supported.", classLoader); + LOG.error(message); + throw new IllegalArgumentException(message); + } + + List files = new ArrayList<>(); + for (URL url : ((URLClassLoader) classLoader).getURLs()) { + try { + files.add(new File(url.toURI()).getAbsolutePath()); + } catch (IllegalArgumentException | URISyntaxException e) { + String message = String.format("Unable to convert url (%s) to file.", url); + LOG.error(message); + throw new IllegalArgumentException(message, e); + } + } + return files; + } + + /** + * Finds the id for the running job of the given name. + */ + private String getJobIdFromName(String jobName) { + try { + ListJobsResponse listResult; + String token = null; + do { + listResult = dataflowClient.projects().jobs() + .list(options.getProject()) + .setPageToken(token) + .execute(); + token = listResult.getNextPageToken(); + for (Job job : listResult.getJobs()) { + if (job.getName().equals(jobName) + && MonitoringUtil.toState(job.getCurrentState()).equals(State.RUNNING)) { + return job.getId(); + } + } + } while (token != null); + } catch (GoogleJsonResponseException e) { + throw new RuntimeException( + "Got error while looking up jobs: " + + (e.getDetails() != null ? e.getDetails().getMessage() : e), e); + } catch (IOException e) { + throw new RuntimeException("Got error while looking up jobs: ", e); + } + + throw new IllegalArgumentException("Could not find running job named " + jobName); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerHooks.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerHooks.java new file mode 100644 index 000000000000..b9a02935dee3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerHooks.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.api.services.dataflow.model.Environment; +import com.google.cloud.dataflow.sdk.annotations.Experimental; + +/** + * An instance of this class can be passed to the + * {@link DataflowPipelineRunner} to add user defined hooks to be + * invoked at various times during pipeline execution. + */ +@Experimental +public class DataflowPipelineRunnerHooks { + /** + * Allows the user to modify the environment of their job before their job is submitted + * to the service for execution. + * + * @param environment The environment of the job. Users can make change to this instance in order + * to change the environment with which their job executes on the service. + */ + public void modifyEnvironmentBeforeSubmission(Environment environment) {} +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java new file mode 100644 index 000000000000..ae3a40310372 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java @@ -0,0 +1,1086 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; +import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.addList; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addObject; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.api.services.dataflow.model.AutoscalingSettings; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Disk; +import com.google.api.services.dataflow.model.Environment; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.Step; +import com.google.api.services.dataflow.model.WorkerPool; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.StreamingOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.GroupByKeyAndSortValuesOnly; +import com.google.cloud.dataflow.sdk.runners.dataflow.BigQueryIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.PubsubIOTranslator; +import com.google.cloud.dataflow.sdk.runners.dataflow.ReadTranslator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.AppliedCombineFn; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.DoFnInfo; +import com.google.cloud.dataflow.sdk.util.OutputReference; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * {@link DataflowPipelineTranslator} knows how to translate {@link Pipeline} objects + * into Cloud Dataflow Service API {@link Job}s. + */ +@SuppressWarnings({"rawtypes", "unchecked"}) +public class DataflowPipelineTranslator { + // Must be kept in sync with their internal counterparts. + private static final Logger LOG = LoggerFactory.getLogger(DataflowPipelineTranslator.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + + /** + * A map from {@link PTransform} subclass to the corresponding + * {@link TransformTranslator} to use to translate that transform. + * + *

    A static map that contains system-wide defaults. + */ + private static Map transformTranslators = + new HashMap<>(); + + /** Provided configuration options. */ + private final DataflowPipelineOptions options; + + /** + * Constructs a translator from the provided options. + * + * @param options Properties that configure the translator. + * + * @return The newly created translator. + */ + public static DataflowPipelineTranslator fromOptions( + DataflowPipelineOptions options) { + return new DataflowPipelineTranslator(options); + } + + private DataflowPipelineTranslator(DataflowPipelineOptions options) { + this.options = options; + } + + /** + * Translates a {@link Pipeline} into a {@code JobSpecification}. + */ + public JobSpecification translate( + Pipeline pipeline, + DataflowPipelineRunner runner, + List packages) { + + Translator translator = new Translator(pipeline, runner); + Job result = translator.translate(packages); + return new JobSpecification(result, Collections.unmodifiableMap(translator.stepNames)); + } + + /** + * The result of a job translation. + * + *

    Used to pass the result {@link Job} and any state that was used to construct the job that + * may be of use to other classes (eg the {@link PTransform} to StepName mapping). + */ + public static class JobSpecification { + private final Job job; + private final Map, String> stepNames; + + public JobSpecification(Job job, Map, String> stepNames) { + this.job = job; + this.stepNames = stepNames; + } + + public Job getJob() { + return job; + } + + /** + * Returns the mapping of {@link AppliedPTransform AppliedPTransforms} to the internal step + * name for that {@code AppliedPTransform}. + */ + public Map, String> getStepNames() { + return stepNames; + } + } + + /** + * Renders a {@link Job} as a string. + */ + public static String jobToString(Job job) { + try { + return MAPPER.writerWithDefaultPrettyPrinter().writeValueAsString(job); + } catch (JsonProcessingException exc) { + throw new IllegalStateException("Failed to render Job as String.", exc); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Records that instances of the specified PTransform class + * should be translated by default by the corresponding + * {@link TransformTranslator}. + */ + public static void registerTransformTranslator( + Class transformClass, + TransformTranslator transformTranslator) { + if (transformTranslators.put(transformClass, transformTranslator) != null) { + throw new IllegalArgumentException( + "defining multiple translators for " + transformClass); + } + } + + /** + * Returns the {@link TransformTranslator} to use for instances of the + * specified PTransform class, or null if none registered. + */ + public + TransformTranslator getTransformTranslator(Class transformClass) { + return transformTranslators.get(transformClass); + } + + /** + * A {@link TransformTranslator} knows how to translate + * a particular subclass of {@link PTransform} for the + * Cloud Dataflow service. It does so by + * mutating the {@link TranslationContext}. + */ + public interface TransformTranslator { + public void translate(TransformT transform, + TranslationContext context); + } + + /** + * The interface provided to registered callbacks for interacting + * with the {@link DataflowPipelineRunner}, including reading and writing the + * values of {@link PCollection}s and side inputs ({@link PCollectionView}s). + */ + public interface TranslationContext { + /** + * Returns the configured pipeline options. + */ + DataflowPipelineOptions getPipelineOptions(); + + /** + * Returns the input of the currently being translated transform. + */ + InputT getInput(PTransform transform); + + /** + * Returns the output of the currently being translated transform. + */ + OutputT getOutput(PTransform transform); + + /** + * Returns the full name of the currently being translated transform. + */ + String getFullName(PTransform transform); + + /** + * Adds a step to the Dataflow workflow for the given transform, with + * the given Dataflow step type. + * This step becomes "current" for the purpose of {@link #addInput} and + * {@link #addOutput}. + */ + public void addStep(PTransform transform, String type); + + /** + * Adds a pre-defined step to the Dataflow workflow. The given PTransform should be + * consistent with the Step, in terms of input, output and coder types. + * + *

    This is a low-level operation, when using this method it is up to + * the caller to ensure that names do not collide. + */ + public void addStep(PTransform transform, Step step); + + /** + * Sets the encoding for the current Dataflow step. + */ + public void addEncodingInput(Coder value); + + /** + * Adds an input with the given name and value to the current + * Dataflow step. + */ + public void addInput(String name, Boolean value); + + /** + * Adds an input with the given name and value to the current + * Dataflow step. + */ + public void addInput(String name, String value); + + /** + * Adds an input with the given name and value to the current + * Dataflow step. + */ + public void addInput(String name, Long value); + + /** + * Adds an input with the given name to the previously added Dataflow + * step, coming from the specified input PValue. + */ + public void addInput(String name, PInput value); + + /** + * Adds an input that is a dictionary of strings to objects. + */ + public void addInput(String name, Map elements); + + /** + * Adds an input that is a list of objects. + */ + public void addInput(String name, List> elements); + + /** + * Adds an output with the given name to the previously added + * Dataflow step, producing the specified output {@code PValue}, + * including its {@code Coder} if a {@code TypedPValue}. If the + * {@code PValue} is a {@code PCollection}, wraps its coder inside + * a {@code WindowedValueCoder}. + */ + public void addOutput(String name, PValue value); + + /** + * Adds an output with the given name to the previously added + * Dataflow step, producing the specified output {@code PValue}, + * including its {@code Coder} if a {@code TypedPValue}. If the + * {@code PValue} is a {@code PCollection}, wraps its coder inside + * a {@code ValueOnlyCoder}. + */ + public void addValueOnlyOutput(String name, PValue value); + + /** + * Adds an output with the given name to the previously added + * CollectionToSingleton Dataflow step, consuming the specified + * input {@code PValue} and producing the specified output + * {@code PValue}. This step requires special treatment for its + * output encoding. + */ + public void addCollectionToSingletonOutput(String name, + PValue inputValue, + PValue outputValue); + + /** + * Encode a PValue reference as an output reference. + */ + public OutputReference asOutputReference(PValue value); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Translates a Pipeline into the Dataflow representation. + */ + class Translator implements PipelineVisitor, TranslationContext { + /** The Pipeline to translate. */ + private final Pipeline pipeline; + + /** The runner which will execute the pipeline. */ + private final DataflowPipelineRunner runner; + + /** The Cloud Dataflow Job representation. */ + private final Job job = new Job(); + + /** + * Translator is stateful, as addProperty calls refer to the current step. + */ + private Step currentStep; + + /** + * A Map from AppliedPTransform to their unique Dataflow step names. + */ + private final Map, String> stepNames = new HashMap<>(); + + /** + * A Map from PValues to their output names used by their producer + * Dataflow steps. + */ + private final Map outputNames = new HashMap<>(); + + /** + * A Map from PValues to the Coders used for them. + */ + private final Map> outputCoders = new HashMap<>(); + + /** + * The transform currently being applied. + */ + private AppliedPTransform currentTransform; + + /** + * Constructs a Translator that will translate the specified + * Pipeline into Dataflow objects. + */ + public Translator(Pipeline pipeline, DataflowPipelineRunner runner) { + this.pipeline = pipeline; + this.runner = runner; + } + + /** + * Translates this Translator's pipeline onto its writer. + * @return a Job definition filled in with the type of job, the environment, + * and the job steps. + */ + public Job translate(List packages) { + job.setName(options.getJobName().toLowerCase()); + + Environment environment = new Environment(); + job.setEnvironment(environment); + + try { + environment.setSdkPipelineOptions( + MAPPER.readValue(MAPPER.writeValueAsBytes(options), Map.class)); + } catch (IOException e) { + throw new IllegalArgumentException( + "PipelineOptions specified failed to serialize to JSON.", e); + } + + WorkerPool workerPool = new WorkerPool(); + + if (options.getTeardownPolicy() != null) { + workerPool.setTeardownPolicy(options.getTeardownPolicy().getTeardownPolicyName()); + } + + if (options.isStreaming()) { + job.setType("JOB_TYPE_STREAMING"); + } else { + job.setType("JOB_TYPE_BATCH"); + workerPool.setDiskType(options.getWorkerDiskType()); + } + + if (options.getWorkerMachineType() != null) { + workerPool.setMachineType(options.getWorkerMachineType()); + } + + workerPool.setPackages(packages); + workerPool.setNumWorkers(options.getNumWorkers()); + + if (options.isStreaming()) { + // Use separate data disk for streaming. + Disk disk = new Disk(); + disk.setDiskType(options.getWorkerDiskType()); + workerPool.setDataDisks(Collections.singletonList(disk)); + } + if (!Strings.isNullOrEmpty(options.getZone())) { + workerPool.setZone(options.getZone()); + } + if (!Strings.isNullOrEmpty(options.getNetwork())) { + workerPool.setNetwork(options.getNetwork()); + } + if (options.getDiskSizeGb() > 0) { + workerPool.setDiskSizeGb(options.getDiskSizeGb()); + } + AutoscalingSettings settings = new AutoscalingSettings(); + if (options.getAutoscalingAlgorithm() != null) { + settings.setAlgorithm(options.getAutoscalingAlgorithm().getAlgorithm()); + } + settings.setMaxNumWorkers(options.getMaxNumWorkers()); + workerPool.setAutoscalingSettings(settings); + + List workerPools = new LinkedList<>(); + + workerPools.add(workerPool); + environment.setWorkerPools(workerPools); + + pipeline.traverseTopologically(this); + return job; + } + + @Override + public DataflowPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public InputT getInput(PTransform transform) { + return (InputT) getCurrentTransform(transform).getInput(); + } + + @Override + public OutputT getOutput(PTransform transform) { + return (OutputT) getCurrentTransform(transform).getOutput(); + } + + @Override + public String getFullName(PTransform transform) { + return getCurrentTransform(transform).getFullName(); + } + + private AppliedPTransform getCurrentTransform(PTransform transform) { + checkArgument( + currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + return currentTransform; + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + TransformTranslator translator = + getTransformTranslator(transform.getClass()); + if (translator == null) { + throw new IllegalStateException( + "no translator registered for " + transform); + } + LOG.debug("Translating {}", transform); + currentTransform = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform); + translator.translate(transform, this); + currentTransform = null; + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + LOG.debug("Checking translation of {}", value); + if (value.getProducingTransformInternal() == null) { + throw new RuntimeException( + "internal error: expecting a PValue " + + "to have a producingTransform"); + } + if (!producer.isCompositeNode()) { + // Primitive transforms are the only ones assigned step names. + asOutputReference(value); + } + } + + @Override + public void addStep(PTransform transform, String type) { + String stepName = genStepName(); + if (stepNames.put(getCurrentTransform(transform), stepName) != null) { + throw new IllegalArgumentException( + transform + " already has a name specified"); + } + // Start the next "steps" list item. + List steps = job.getSteps(); + if (steps == null) { + steps = new LinkedList<>(); + job.setSteps(steps); + } + + currentStep = new Step(); + currentStep.setName(stepName); + currentStep.setKind(type); + steps.add(currentStep); + addInput(PropertyNames.USER_NAME, getFullName(transform)); + } + + @Override + public void addStep(PTransform transform, Step original) { + Step step = original.clone(); + String stepName = step.getName(); + if (stepNames.put(getCurrentTransform(transform), stepName) != null) { + throw new IllegalArgumentException(transform + " already has a name specified"); + } + + Map properties = step.getProperties(); + if (properties != null) { + @Nullable List> outputInfoList = null; + try { + // TODO: This should be done via a Structs accessor. + @Nullable List> list = + (List>) properties.get(PropertyNames.OUTPUT_INFO); + outputInfoList = list; + } catch (Exception e) { + throw new RuntimeException("Inconsistent dataflow pipeline translation", e); + } + if (outputInfoList != null && outputInfoList.size() > 0) { + Map firstOutputPort = outputInfoList.get(0); + @Nullable String name; + try { + name = getString(firstOutputPort, PropertyNames.OUTPUT_NAME); + } catch (Exception e) { + name = null; + } + if (name != null) { + registerOutputName(getOutput(transform), name); + } + } + } + + List steps = job.getSteps(); + if (steps == null) { + steps = new LinkedList<>(); + job.setSteps(steps); + } + currentStep = step; + steps.add(step); + } + + @Override + public void addEncodingInput(Coder coder) { + CloudObject encoding = SerializableUtils.ensureSerializable(coder); + addObject(getProperties(), PropertyNames.ENCODING, encoding); + } + + @Override + public void addInput(String name, Boolean value) { + addBoolean(getProperties(), name, value); + } + + @Override + public void addInput(String name, String value) { + addString(getProperties(), name, value); + } + + @Override + public void addInput(String name, Long value) { + addLong(getProperties(), name, value); + } + + @Override + public void addInput(String name, Map elements) { + addDictionary(getProperties(), name, elements); + } + + @Override + public void addInput(String name, List> elements) { + addList(getProperties(), name, elements); + } + + @Override + public void addInput(String name, PInput value) { + if (value instanceof PValue) { + addInput(name, asOutputReference((PValue) value)); + } else { + throw new IllegalStateException("Input must be a PValue"); + } + } + + @Override + public void addOutput(String name, PValue value) { + Coder coder; + if (value instanceof TypedPValue) { + coder = ((TypedPValue) value).getCoder(); + if (value instanceof PCollection) { + // Wrap the PCollection element Coder inside a WindowedValueCoder. + coder = WindowedValue.getFullCoder( + coder, + ((PCollection) value).getWindowingStrategy().getWindowFn().windowCoder()); + } + } else { + // No output coder to encode. + coder = null; + } + addOutput(name, value, coder); + } + + @Override + public void addValueOnlyOutput(String name, PValue value) { + Coder coder; + if (value instanceof TypedPValue) { + coder = ((TypedPValue) value).getCoder(); + if (value instanceof PCollection) { + // Wrap the PCollection element Coder inside a ValueOnly + // WindowedValueCoder. + coder = WindowedValue.getValueOnlyCoder(coder); + } + } else { + // No output coder to encode. + coder = null; + } + addOutput(name, value, coder); + } + + @Override + public void addCollectionToSingletonOutput(String name, + PValue inputValue, + PValue outputValue) { + Coder inputValueCoder = + Preconditions.checkNotNull(outputCoders.get(inputValue)); + // The inputValueCoder for the input PCollection should be some + // WindowedValueCoder of the input PCollection's element + // coder. + Preconditions.checkState( + inputValueCoder instanceof WindowedValue.WindowedValueCoder); + // The outputValueCoder for the output should be an + // IterableCoder of the inputValueCoder. This is a property + // of the backend "CollectionToSingleton" step. + Coder outputValueCoder = IterableCoder.of(inputValueCoder); + addOutput(name, outputValue, outputValueCoder); + } + + /** + * Adds an output with the given name to the previously added + * Dataflow step, producing the specified output {@code PValue} + * with the given {@code Coder} (if not {@code null}). + */ + private void addOutput(String name, PValue value, Coder valueCoder) { + registerOutputName(value, name); + + Map properties = getProperties(); + @Nullable List> outputInfoList = null; + try { + // TODO: This should be done via a Structs accessor. + outputInfoList = (List>) properties.get(PropertyNames.OUTPUT_INFO); + } catch (Exception e) { + throw new RuntimeException("Inconsistent dataflow pipeline translation", e); + } + if (outputInfoList == null) { + outputInfoList = new ArrayList<>(); + // TODO: This should be done via a Structs accessor. + properties.put(PropertyNames.OUTPUT_INFO, outputInfoList); + } + + Map outputInfo = new HashMap<>(); + addString(outputInfo, PropertyNames.OUTPUT_NAME, name); + addString(outputInfo, PropertyNames.USER_NAME, value.getName()); + if (value instanceof PCollection + && runner.doesPCollectionRequireIndexedFormat((PCollection) value)) { + addBoolean(outputInfo, PropertyNames.USE_INDEXED_FORMAT, true); + } + if (valueCoder != null) { + // Verify that encoding can be decoded, in order to catch serialization + // failures as early as possible. + CloudObject encoding = SerializableUtils.ensureSerializable(valueCoder); + addObject(outputInfo, PropertyNames.ENCODING, encoding); + outputCoders.put(value, valueCoder); + } + + outputInfoList.add(outputInfo); + } + + @Override + public OutputReference asOutputReference(PValue value) { + AppliedPTransform transform = + value.getProducingTransformInternal(); + String stepName = stepNames.get(transform); + if (stepName == null) { + throw new IllegalArgumentException(transform + " doesn't have a name specified"); + } + + String outputName = outputNames.get(value); + if (outputName == null) { + throw new IllegalArgumentException( + "output " + value + " doesn't have a name specified"); + } + + return new OutputReference(stepName, outputName); + } + + private Map getProperties() { + Map properties = currentStep.getProperties(); + if (properties == null) { + properties = new HashMap<>(); + currentStep.setProperties(properties); + } + return properties; + } + + /** + * Returns a fresh Dataflow step name. + */ + private String genStepName() { + return "s" + (stepNames.size() + 1); + } + + /** + * Records the name of the given output PValue, + * within its producing transform. + */ + private void registerOutputName(POutput value, String name) { + if (outputNames.put(value, name) != null) { + throw new IllegalArgumentException( + "output " + value + " already has a name specified"); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public String toString() { + return "DataflowPipelineTranslator#" + hashCode(); + } + + + /////////////////////////////////////////////////////////////////////////// + + static { + registerTransformTranslator( + View.CreatePCollectionView.class, + new TransformTranslator() { + @Override + public void translate( + View.CreatePCollectionView transform, + TranslationContext context) { + translateTyped(transform, context); + } + + private void translateTyped( + View.CreatePCollectionView transform, + TranslationContext context) { + context.addStep(transform, "CollectionToSingleton"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + context.addCollectionToSingletonOutput( + PropertyNames.OUTPUT, + context.getInput(transform), + context.getOutput(transform)); + } + }); + + DataflowPipelineTranslator.registerTransformTranslator( + Combine.GroupedValues.class, + new DataflowPipelineTranslator.TransformTranslator() { + @Override + public void translate( + Combine.GroupedValues transform, + DataflowPipelineTranslator.TranslationContext context) { + translateHelper(transform, context); + } + + private void translateHelper( + final Combine.GroupedValues transform, + DataflowPipelineTranslator.TranslationContext context) { + context.addStep(transform, "CombineValues"); + translateInputs(context.getInput(transform), transform.getSideInputs(), context); + + AppliedCombineFn fn = + transform.getAppliedFn( + context.getInput(transform).getPipeline().getCoderRegistry(), + context.getInput(transform).getCoder(), + context.getInput(transform).getWindowingStrategy()); + + context.addEncodingInput(fn.getAccumulatorCoder()); + context.addInput( + PropertyNames.SERIALIZED_FN, + byteArrayToJsonString(serializeToByteArray(fn))); + context.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + }); + + registerTransformTranslator( + Create.Values.class, + new TransformTranslator() { + @Override + public void translate( + Create.Values transform, + TranslationContext context) { + createHelper(transform, context); + } + + private void createHelper( + Create.Values transform, + TranslationContext context) { + context.addStep(transform, "CreateCollection"); + + Coder coder = context.getOutput(transform).getCoder(); + List elements = new LinkedList<>(); + for (T elem : transform.getElements()) { + byte[] encodedBytes; + try { + encodedBytes = encodeToByteArray(coder, elem); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + "Unable to encode element '" + elem + "' of transform '" + transform + + "' using coder '" + coder + "'.", + exn); + } + String encodedJson = byteArrayToJsonString(encodedBytes); + assert Arrays.equals(encodedBytes, + jsonStringToByteArray(encodedJson)); + elements.add(CloudObject.forString(encodedJson)); + } + context.addInput(PropertyNames.ELEMENT, elements); + context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + }); + + registerTransformTranslator( + Flatten.FlattenPCollectionList.class, + new TransformTranslator() { + @Override + public void translate( + Flatten.FlattenPCollectionList transform, + TranslationContext context) { + flattenHelper(transform, context); + } + + private void flattenHelper( + Flatten.FlattenPCollectionList transform, + TranslationContext context) { + context.addStep(transform, "Flatten"); + + List inputs = new LinkedList<>(); + for (PCollection input : context.getInput(transform).getAll()) { + inputs.add(context.asOutputReference(input)); + } + context.addInput(PropertyNames.INPUTS, inputs); + context.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + }); + + registerTransformTranslator( + GroupByKeyAndSortValuesOnly.class, + new TransformTranslator() { + @Override + public void translate( + GroupByKeyAndSortValuesOnly transform, + TranslationContext context) { + groupByKeyAndSortValuesHelper(transform, context); + } + + private void groupByKeyAndSortValuesHelper( + GroupByKeyAndSortValuesOnly transform, + TranslationContext context) { + context.addStep(transform, "GroupByKey"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + context.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + context.addInput(PropertyNames.SORT_VALUES, true); + + // TODO: Add support for combiner lifting once the need arises. + context.addInput( + PropertyNames.DISALLOW_COMBINER_LIFTING, true); + } + }); + + registerTransformTranslator( + GroupByKey.class, + new TransformTranslator() { + @Override + public void translate( + GroupByKey transform, + TranslationContext context) { + groupByKeyHelper(transform, context); + } + + private void groupByKeyHelper( + GroupByKey transform, + TranslationContext context) { + context.addStep(transform, "GroupByKey"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + context.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + + WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + boolean isStreaming = + context.getPipelineOptions().as(StreamingOptions.class).isStreaming(); + boolean disallowCombinerLifting = + !windowingStrategy.getWindowFn().isNonMerging() + || (isStreaming && !transform.fewKeys()) + // TODO: Allow combiner lifting on the non-default trigger, as appropriate. + || !(windowingStrategy.getTrigger().getSpec() instanceof DefaultTrigger); + context.addInput( + PropertyNames.DISALLOW_COMBINER_LIFTING, disallowCombinerLifting); + context.addInput( + PropertyNames.SERIALIZED_FN, + byteArrayToJsonString(serializeToByteArray(windowingStrategy))); + } + }); + + registerTransformTranslator( + ParDo.BoundMulti.class, + new TransformTranslator() { + @Override + public void translate( + ParDo.BoundMulti transform, + TranslationContext context) { + translateMultiHelper(transform, context); + } + + private void translateMultiHelper( + ParDo.BoundMulti transform, + TranslationContext context) { + context.addStep(transform, "ParallelDo"); + translateInputs(context.getInput(transform), transform.getSideInputs(), context); + translateFn(transform.getFn(), context.getInput(transform).getWindowingStrategy(), + transform.getSideInputs(), context.getInput(transform).getCoder(), context); + translateOutputs(context.getOutput(transform), context); + } + }); + + registerTransformTranslator( + ParDo.Bound.class, + new TransformTranslator() { + @Override + public void translate( + ParDo.Bound transform, + TranslationContext context) { + translateSingleHelper(transform, context); + } + + private void translateSingleHelper( + ParDo.Bound transform, + TranslationContext context) { + context.addStep(transform, "ParallelDo"); + translateInputs(context.getInput(transform), transform.getSideInputs(), context); + translateFn( + transform.getFn(), + context.getInput(transform).getWindowingStrategy(), + transform.getSideInputs(), context.getInput(transform).getCoder(), context); + context.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + }); + + registerTransformTranslator( + Window.Bound.class, + new DataflowPipelineTranslator.TransformTranslator() { + @Override + public void translate( + Window.Bound transform, TranslationContext context) { + translateHelper(transform, context); + } + + private void translateHelper( + Window.Bound transform, TranslationContext context) { + context.addStep(transform, "Bucket"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + context.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + + WindowingStrategy strategy = context.getOutput(transform).getWindowingStrategy(); + byte[] serializedBytes = serializeToByteArray(strategy); + String serializedJson = byteArrayToJsonString(serializedBytes); + assert Arrays.equals(serializedBytes, + jsonStringToByteArray(serializedJson)); + context.addInput(PropertyNames.SERIALIZED_FN, serializedJson); + } + }); + + /////////////////////////////////////////////////////////////////////////// + // IO Translation. + + registerTransformTranslator( + BigQueryIO.Read.Bound.class, new BigQueryIOTranslator.ReadTranslator()); + registerTransformTranslator( + BigQueryIO.Write.Bound.class, new BigQueryIOTranslator.WriteTranslator()); + + registerTransformTranslator( + PubsubIO.Read.Bound.class, new PubsubIOTranslator.ReadTranslator()); + registerTransformTranslator( + DataflowPipelineRunner.StreamingPubsubIOWrite.class, + new PubsubIOTranslator.WriteTranslator()); + + registerTransformTranslator(Read.Bounded.class, new ReadTranslator()); + } + + private static void translateInputs( + PCollection input, + List> sideInputs, + TranslationContext context) { + context.addInput(PropertyNames.PARALLEL_INPUT, input); + translateSideInputs(sideInputs, context); + } + + // Used for ParDo + private static void translateSideInputs( + List> sideInputs, + TranslationContext context) { + Map nonParInputs = new HashMap<>(); + + for (PCollectionView view : sideInputs) { + nonParInputs.put( + view.getTagInternal().getId(), + context.asOutputReference(view)); + } + + context.addInput(PropertyNames.NON_PARALLEL_INPUTS, nonParInputs); + } + + private static void translateFn( + DoFn fn, + WindowingStrategy windowingStrategy, + Iterable> sideInputs, + Coder inputCoder, + TranslationContext context) { + context.addInput(PropertyNames.USER_FN, fn.getClass().getName()); + context.addInput( + PropertyNames.SERIALIZED_FN, + byteArrayToJsonString(serializeToByteArray( + new DoFnInfo(fn, windowingStrategy, sideInputs, inputCoder)))); + } + + private static void translateOutputs( + PCollectionTuple outputs, + TranslationContext context) { + for (Map.Entry, PCollection> entry + : outputs.getAll().entrySet()) { + TupleTag tag = entry.getKey(); + PCollection output = entry.getValue(); + context.addOutput(tag.getId(), output); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowServiceException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowServiceException.java new file mode 100644 index 000000000000..6e8301b13af7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowServiceException.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import javax.annotation.Nullable; + +/** + * Signals there was an error retrieving information about a job from the Cloud Dataflow Service. + */ +public class DataflowServiceException extends DataflowJobException { + DataflowServiceException(DataflowPipelineJob job, String message) { + this(job, message, null); + } + + DataflowServiceException(DataflowPipelineJob job, String message, @Nullable Throwable cause) { + super(job, message, cause); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipeline.java new file mode 100644 index 000000000000..5217a908138e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipeline.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; + +/** + * A {@link DirectPipeline} is a {@link Pipeline} that returns + * {@link DirectPipelineRunner.EvaluationResults} when it is + * {@link com.google.cloud.dataflow.sdk.Pipeline#run()}. + */ +public class DirectPipeline extends Pipeline { + + /** + * Creates and returns a new DirectPipeline instance for tests. + */ + public static DirectPipeline createForTest() { + DirectPipelineRunner runner = DirectPipelineRunner.createForTest(); + return new DirectPipeline(runner, runner.getPipelineOptions()); + } + + private DirectPipeline(DirectPipelineRunner runner, DirectPipelineOptions options) { + super(runner, options); + } + + @Override + public DirectPipelineRunner.EvaluationResults run() { + return (DirectPipelineRunner.EvaluationResults) super.run(); + } + + @Override + public DirectPipelineRunner getRunner() { + return (DirectPipelineRunner) super.getRunner(); + } + + @Override + public String toString() { + return "DirectPipeline#" + hashCode(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRegistrar.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRegistrar.java new file mode 100644 index 000000000000..f2dd40cbdb70 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRegistrar.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.auto.service.AutoService; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar; +import com.google.common.collect.ImmutableList; + +/** + * Contains the {@link PipelineOptionsRegistrar} and {@link PipelineRunnerRegistrar} for + * the {@link DirectPipeline}. + */ +public class DirectPipelineRegistrar { + private DirectPipelineRegistrar() { } + + /** + * Register the {@link DirectPipelineRunner}. + */ + @AutoService(PipelineRunnerRegistrar.class) + public static class Runner implements PipelineRunnerRegistrar { + @Override + public Iterable>> getPipelineRunners() { + return ImmutableList.>>of(DirectPipelineRunner.class); + } + } + + /** + * Register the {@link DirectPipelineOptions}. + */ + @AutoService(PipelineOptionsRegistrar.class) + public static class Options implements PipelineOptionsRegistrar { + @Override + public Iterable> getPipelineOptions() { + return ImmutableList.>of(DirectPipelineOptions.class); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java new file mode 100644 index 000000000000..872cfef7fb9a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java @@ -0,0 +1,1156 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.FileBasedSink; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions.CheckEnabled; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Partition; +import com.google.cloud.dataflow.sdk.transforms.Partition.PartitionFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.AppliedCombineFn; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.MapAggregatorValues; +import com.google.cloud.dataflow.sdk.util.PerKeyCombineFnRunner; +import com.google.cloud.dataflow.sdk.util.PerKeyCombineFnRunners; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import com.google.common.base.Function; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * Executes the operations in the pipeline directly, in this process, without + * any optimization. Useful for small local execution and tests. + * + *

    Throws an exception from {@link #run} if execution fails. + * + *

    Permissions

    + * When reading from a Dataflow source or writing to a Dataflow sink using + * {@code DirectPipelineRunner}, the Cloud Platform account that you configured with the + * gcloud executable will need access to the + * corresponding source/sink. + * + *

    Please see Google Cloud + * Dataflow Security and Permissions for more details. + */ +@SuppressWarnings({"rawtypes", "unchecked"}) +public class DirectPipelineRunner + extends PipelineRunner { + private static final Logger LOG = LoggerFactory.getLogger(DirectPipelineRunner.class); + + /** + * A source of random data, which can be seeded if determinism is desired. + */ + private Random rand; + + /** + * A map from PTransform class to the corresponding + * TransformEvaluator to use to evaluate that transform. + * + *

    A static map that contains system-wide defaults. + */ + private static Map defaultTransformEvaluators = + new HashMap<>(); + + /** + * A map from PTransform class to the corresponding + * TransformEvaluator to use to evaluate that transform. + * + *

    An instance map that contains bindings for this DirectPipelineRunner. + * Bindings in this map override those in the default map. + */ + private Map localTransformEvaluators = + new HashMap<>(); + + /** + * Records that instances of the specified PTransform class + * should be evaluated by default by the corresponding + * TransformEvaluator. + */ + public static > + void registerDefaultTransformEvaluator( + Class transformClass, + TransformEvaluator transformEvaluator) { + if (defaultTransformEvaluators.put(transformClass, transformEvaluator) + != null) { + throw new IllegalArgumentException( + "defining multiple evaluators for " + transformClass); + } + } + + /** + * Records that instances of the specified PTransform class + * should be evaluated by the corresponding TransformEvaluator. + * Overrides any bindings specified by + * {@link #registerDefaultTransformEvaluator}. + */ + public > + void registerTransformEvaluator( + Class transformClass, + TransformEvaluator transformEvaluator) { + if (localTransformEvaluators.put(transformClass, transformEvaluator) + != null) { + throw new IllegalArgumentException( + "defining multiple evaluators for " + transformClass); + } + } + + /** + * Returns the TransformEvaluator to use for instances of the + * specified PTransform class, or null if none registered. + */ + public > + TransformEvaluator getTransformEvaluator(Class transformClass) { + TransformEvaluator transformEvaluator = + localTransformEvaluators.get(transformClass); + if (transformEvaluator == null) { + transformEvaluator = defaultTransformEvaluators.get(transformClass); + } + return transformEvaluator; + } + + /** + * Constructs a DirectPipelineRunner from the given options. + */ + public static DirectPipelineRunner fromOptions(PipelineOptions options) { + DirectPipelineOptions directOptions = + PipelineOptionsValidator.validate(DirectPipelineOptions.class, options); + LOG.debug("Creating DirectPipelineRunner"); + return new DirectPipelineRunner(directOptions); + } + + /** + * Constructs a runner with default properties for testing. + * + * @return The newly created runner. + */ + public static DirectPipelineRunner createForTest() { + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setStableUniqueNames(CheckEnabled.ERROR); + options.setGcpCredential(new TestCredential()); + return new DirectPipelineRunner(options); + } + + /** + * Enable runtime testing to verify that all functions and {@link Coder} + * instances can be serialized. + * + *

    Enabled by default. + * + *

    This method modifies the {@code DirectPipelineRunner} instance and + * returns itself. + */ + public DirectPipelineRunner withSerializabilityTesting(boolean enable) { + this.testSerializability = enable; + return this; + } + + /** + * Enable runtime testing to verify that all values can be encoded. + * + *

    Enabled by default. + * + *

    This method modifies the {@code DirectPipelineRunner} instance and + * returns itself. + */ + public DirectPipelineRunner withEncodabilityTesting(boolean enable) { + this.testEncodability = enable; + return this; + } + + /** + * Enable runtime testing to verify that functions do not depend on order + * of the elements. + * + *

    This is accomplished by randomizing the order of elements. + * + *

    Enabled by default. + * + *

    This method modifies the {@code DirectPipelineRunner} instance and + * returns itself. + */ + public DirectPipelineRunner withUnorderednessTesting(boolean enable) { + this.testUnorderedness = enable; + return this; + } + + @Override + public OutputT apply( + PTransform transform, InputT input) { + if (transform instanceof Combine.GroupedValues) { + return (OutputT) applyTestCombine((Combine.GroupedValues) transform, (PCollection) input); + } else if (transform instanceof TextIO.Write.Bound) { + return (OutputT) applyTextIOWrite((TextIO.Write.Bound) transform, (PCollection) input); + } else if (transform instanceof AvroIO.Write.Bound) { + return (OutputT) applyAvroIOWrite((AvroIO.Write.Bound) transform, (PCollection) input); + } else { + return super.apply(transform, input); + } + } + + private PCollection> applyTestCombine( + Combine.GroupedValues transform, + PCollection>> input) { + + PCollection> output = input + .apply(ParDo.of(TestCombineDoFn.create(transform, input, testSerializability, rand)) + .withSideInputs(transform.getSideInputs())); + + try { + output.setCoder(transform.getDefaultOutputCoder(input)); + } catch (CannotProvideCoderException exc) { + // let coder inference occur later, if it can + } + return output; + } + + private static class ElementProcessingOrderPartitionFn implements PartitionFn { + private int elementNumber; + @Override + public int partitionFor(T elem, int numPartitions) { + return elementNumber++ % numPartitions; + } + } + + /** + * Applies TextIO.Write honoring user requested sharding controls (i.e. withNumShards) + * by applying a partition function based upon the number of shards the user requested. + */ + private static class DirectTextIOWrite extends PTransform, PDone> { + private final TextIO.Write.Bound transform; + + private DirectTextIOWrite(TextIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + checkState(transform.getNumShards() > 1, + "DirectTextIOWrite is expected to only be used when sharding controls are required."); + + // Evenly distribute all the elements across the partitions. + PCollectionList partitionedElements = + input.apply(Partition.of(transform.getNumShards(), + new ElementProcessingOrderPartitionFn())); + + // For each input PCollection partition, create a write transform that represents + // one of the specific shards. + for (int i = 0; i < transform.getNumShards(); ++i) { + /* + * This logic mirrors the file naming strategy within + * {@link FileBasedSink#generateDestinationFilenames()} + */ + String outputFilename = IOChannelUtils.constructName( + transform.getFilenamePrefix(), + transform.getShardNameTemplate(), + getFileExtension(transform.getFilenameSuffix()), + i, + transform.getNumShards()); + + String transformName = String.format("%s(Shard:%s)", transform.getName(), i); + partitionedElements.get(i).apply(transformName, + transform.withNumShards(1).withShardNameTemplate("").withSuffix("").to(outputFilename)); + } + return PDone.in(input.getPipeline()); + } + } + + /** + * Returns the file extension to be used. If the user did not request a file + * extension then this method returns the empty string. Otherwise this method + * adds a {@code "."} to the beginning of the users extension if one is not present. + * + *

    This is copied from {@link FileBasedSink} to not expose it. + */ + private static String getFileExtension(String usersExtension) { + if (usersExtension == null || usersExtension.isEmpty()) { + return ""; + } + if (usersExtension.startsWith(".")) { + return usersExtension; + } + return "." + usersExtension; + } + + /** + * Apply the override for TextIO.Write.Bound if the user requested sharding controls + * greater than one. + */ + private PDone applyTextIOWrite(TextIO.Write.Bound transform, PCollection input) { + if (transform.getNumShards() <= 1) { + // By default, the DirectPipelineRunner outputs to only 1 shard. Since the user never + // requested sharding controls greater than 1, we default to outputting to 1 file. + return super.apply(transform.withNumShards(1), input); + } + return input.apply(new DirectTextIOWrite<>(transform)); + } + + /** + * Applies AvroIO.Write honoring user requested sharding controls (i.e. withNumShards) + * by applying a partition function based upon the number of shards the user requested. + */ + private static class DirectAvroIOWrite extends PTransform, PDone> { + private final AvroIO.Write.Bound transform; + + private DirectAvroIOWrite(AvroIO.Write.Bound transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection input) { + checkState(transform.getNumShards() > 1, + "DirectAvroIOWrite is expected to only be used when sharding controls are required."); + + // Evenly distribute all the elements across the partitions. + PCollectionList partitionedElements = + input.apply(Partition.of(transform.getNumShards(), + new ElementProcessingOrderPartitionFn())); + + // For each input PCollection partition, create a write transform that represents + // one of the specific shards. + for (int i = 0; i < transform.getNumShards(); ++i) { + /* + * This logic mirrors the file naming strategy within + * {@link FileBasedSink#generateDestinationFilenames()} + */ + String outputFilename = IOChannelUtils.constructName( + transform.getFilenamePrefix(), + transform.getShardNameTemplate(), + getFileExtension(transform.getFilenameSuffix()), + i, + transform.getNumShards()); + + String transformName = String.format("%s(Shard:%s)", transform.getName(), i); + partitionedElements.get(i).apply(transformName, + transform.withNumShards(1).withShardNameTemplate("").withSuffix("").to(outputFilename)); + } + return PDone.in(input.getPipeline()); + } + } + + /** + * Apply the override for AvroIO.Write.Bound if the user requested sharding controls + * greater than one. + */ + private PDone applyAvroIOWrite(AvroIO.Write.Bound transform, PCollection input) { + if (transform.getNumShards() <= 1) { + // By default, the DirectPipelineRunner outputs to only 1 shard. Since the user never + // requested sharding controls greater than 1, we default to outputting to 1 file. + return super.apply(transform.withNumShards(1), input); + } + return input.apply(new DirectAvroIOWrite<>(transform)); + } + + /** + * The implementation may split the {@link KeyedCombineFn} into ADD, MERGE and EXTRACT phases ( + * see {@code com.google.cloud.dataflow.sdk.runners.worker.CombineValuesFn}). In order to emulate + * this for the {@link DirectPipelineRunner} and provide an experience closer to the service, go + * through heavy serializability checks for the equivalent of the results of the ADD phase, but + * after the {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} shuffle, and the MERGE + * phase. Doing these checks ensure that not only is the accumulator coder serializable, but + * the accumulator coder can actually serialize the data in question. + */ + public static class TestCombineDoFn + extends DoFn>, KV> { + private final PerKeyCombineFnRunner fnRunner; + private final Coder accumCoder; + private final boolean testSerializability; + private final Random rand; + + public static TestCombineDoFn create( + Combine.GroupedValues transform, + PCollection>> input, + boolean testSerializability, + Random rand) { + + AppliedCombineFn fn = transform.getAppliedFn( + input.getPipeline().getCoderRegistry(), input.getCoder(), input.getWindowingStrategy()); + + return new TestCombineDoFn( + PerKeyCombineFnRunners.create(fn.getFn()), + fn.getAccumulatorCoder(), + testSerializability, + rand); + } + + public TestCombineDoFn( + PerKeyCombineFnRunner fnRunner, + Coder accumCoder, + boolean testSerializability, + Random rand) { + this.fnRunner = fnRunner; + this.accumCoder = accumCoder; + this.testSerializability = testSerializability; + this.rand = rand; + + // Check that this does not crash, specifically to catch anonymous CustomCoder subclasses. + this.accumCoder.getEncodingId(); + } + + @Override + public void processElement(ProcessContext c) throws Exception { + K key = c.element().getKey(); + Iterable values = c.element().getValue(); + List groupedPostShuffle = + ensureSerializableByCoder(ListCoder.of(accumCoder), + addInputsRandomly(fnRunner, key, values, rand, c), + "After addInputs of KeyedCombineFn " + fnRunner.fn().toString()); + AccumT merged = + ensureSerializableByCoder(accumCoder, + fnRunner.mergeAccumulators(key, groupedPostShuffle, c), + "After mergeAccumulators of KeyedCombineFn " + fnRunner.fn().toString()); + // Note: The serializability of KV is ensured by the + // runner itself, since it's a transform output. + c.output(KV.of(key, fnRunner.extractOutput(key, merged, c))); + } + + /** + * Create a random list of accumulators from the given list of values. + * + *

    Visible for testing purposes only. + */ + public static List addInputsRandomly( + PerKeyCombineFnRunner fnRunner, + K key, + Iterable values, + Random random, + DoFn.ProcessContext c) { + List out = new ArrayList(); + int i = 0; + AccumT accumulator = fnRunner.createAccumulator(key, c); + boolean hasInput = false; + + for (InputT value : values) { + accumulator = fnRunner.addInput(key, accumulator, value, c); + hasInput = true; + + // For each index i, flip a 1/2^i weighted coin for whether to + // create a new accumulator after index i is added, i.e. [0] + // is guaranteed, [1] is an even 1/2, [2] is 1/4, etc. The + // goal is to partition the inputs into accumulators, and make + // the accumulators potentially lumpy. Also compact about half + // of the accumulators. + if (i == 0 || random.nextInt(1 << Math.min(i, 30)) == 0) { + if (i % 2 == 0) { + accumulator = fnRunner.compact(key, accumulator, c); + } + out.add(accumulator); + accumulator = fnRunner.createAccumulator(key, c); + hasInput = false; + } + i++; + } + if (hasInput) { + out.add(accumulator); + } + + Collections.shuffle(out, random); + return out; + } + + public T ensureSerializableByCoder( + Coder coder, T value, String errorContext) { + if (testSerializability) { + return SerializableUtils.ensureSerializableByCoder( + coder, value, errorContext); + } + return value; + } + } + + @Override + public EvaluationResults run(Pipeline pipeline) { + LOG.info("Executing pipeline using the DirectPipelineRunner."); + + Evaluator evaluator = new Evaluator(rand); + evaluator.run(pipeline); + + // Log all counter values for debugging purposes. + for (Counter counter : evaluator.getCounters()) { + LOG.info("Final aggregator value: {}", counter); + } + + LOG.info("Pipeline execution complete."); + + return evaluator; + } + + /** + * An evaluator of a PTransform. + */ + public interface TransformEvaluator { + public void evaluate(TransformT transform, + EvaluationContext context); + } + + /** + * The interface provided to registered callbacks for interacting + * with the {@code DirectPipelineRunner}, including reading and writing the + * values of {@link PCollection}s and {@link PCollectionView}s. + */ + public interface EvaluationResults extends PipelineResult { + /** + * Retrieves the value of the given PCollection. + * Throws an exception if the PCollection's value hasn't already been set. + */ + List getPCollection(PCollection pc); + + /** + * Retrieves the windowed value of the given PCollection. + * Throws an exception if the PCollection's value hasn't already been set. + */ + List> getPCollectionWindowedValues(PCollection pc); + + /** + * Retrieves the values of each PCollection in the given + * PCollectionList. Throws an exception if the PCollectionList's + * value hasn't already been set. + */ + List> getPCollectionList(PCollectionList pcs); + + /** + * Retrieves the values indicated by the given {@link PCollectionView}. + * Note that within the {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context} + * implementation a {@link PCollectionView} should convert from this representation to a + * suitable side input value. + */ + Iterable> getPCollectionView(PCollectionView view); + } + + /** + * An immutable (value, timestamp) pair, along with other metadata necessary + * for the implementation of {@code DirectPipelineRunner}. + */ + public static class ValueWithMetadata { + /** + * Returns a new {@code ValueWithMetadata} with the {@code WindowedValue}. + * Key is null. + */ + public static ValueWithMetadata of(WindowedValue windowedValue) { + return new ValueWithMetadata<>(windowedValue, null); + } + + /** + * Returns a new {@code ValueWithMetadata} with the implicit key associated + * with this value set. The key is the last key grouped by in the chain of + * productions that produced this element. + * These keys are used internally by {@link DirectPipelineRunner} for keeping + * persisted state separate across keys. + */ + public ValueWithMetadata withKey(Object key) { + return new ValueWithMetadata<>(windowedValue, key); + } + + /** + * Returns a new {@code ValueWithMetadata} that is a copy of this one, but with + * a different value. + */ + public ValueWithMetadata withValue(T value) { + return new ValueWithMetadata(windowedValue.withValue(value), getKey()); + } + + /** + * Returns the {@code WindowedValue} associated with this element. + */ + public WindowedValue getWindowedValue() { + return windowedValue; + } + + /** + * Returns the value associated with this element. + * + * @see #withValue + */ + public V getValue() { + return windowedValue.getValue(); + } + + /** + * Returns the timestamp associated with this element. + */ + public Instant getTimestamp() { + return windowedValue.getTimestamp(); + } + + /** + * Returns the collection of windows this element has been placed into. May + * be null if the {@code PCollection} this element is in has not yet been + * windowed. + * + * @see #getWindows() + */ + public Collection getWindows() { + return windowedValue.getWindows(); + } + + + /** + * Returns the key associated with this element. May be null if the + * {@code PCollection} this element is in is not keyed. + * + * @see #withKey + */ + public Object getKey() { + return key; + } + + //////////////////////////////////////////////////////////////////////////// + + private final Object key; + private final WindowedValue windowedValue; + + private ValueWithMetadata(WindowedValue windowedValue, + Object key) { + this.windowedValue = windowedValue; + this.key = key; + } + } + + /** + * The interface provided to registered callbacks for interacting + * with the {@code DirectPipelineRunner}, including reading and writing the + * values of {@link PCollection}s and {@link PCollectionView}s. + */ + public interface EvaluationContext extends EvaluationResults { + /** + * Returns the configured pipeline options. + */ + DirectPipelineOptions getPipelineOptions(); + + /** + * Returns the input of the currently being processed transform. + */ + InputT getInput(PTransform transform); + + /** + * Returns the output of the currently being processed transform. + */ + OutputT getOutput(PTransform transform); + + /** + * Sets the value of the given PCollection, where each element also has a timestamp + * and collection of windows. + * Throws an exception if the PCollection's value has already been set. + */ + void setPCollectionValuesWithMetadata( + PCollection pc, List> elements); + + /** + * Sets the value of the given PCollection, where each element also has a timestamp + * and collection of windows. + * Throws an exception if the PCollection's value has already been set. + */ + void setPCollectionWindowedValue(PCollection pc, List> elements); + + /** + * Shorthand for setting the value of a PCollection where the elements do not have + * timestamps or windows. + * Throws an exception if the PCollection's value has already been set. + */ + void setPCollection(PCollection pc, List elements); + + /** + * Retrieves the value of the given PCollection, along with element metadata + * such as timestamps and windows. + * Throws an exception if the PCollection's value hasn't already been set. + */ + List> getPCollectionValuesWithMetadata(PCollection pc); + + /** + * Sets the value associated with the given {@link PCollectionView}. + * Throws an exception if the {@link PCollectionView}'s value has already been set. + */ + void setPCollectionView( + PCollectionView pc, + Iterable> value); + + /** + * Ensures that the element is encodable and decodable using the + * TypePValue's coder, by encoding it and decoding it, and + * returning the result. + */ + T ensureElementEncodable(TypedPValue pvalue, T element); + + /** + * If the evaluation context is testing unorderedness, + * randomly permutes the order of the elements, in a + * copy if !inPlaceAllowed, and returns the permuted list, + * otherwise returns the argument unchanged. + */ + List randomizeIfUnordered(List elements, + boolean inPlaceAllowed); + + /** + * If the evaluation context is testing serializability, ensures + * that the argument function is serializable and deserializable + * by encoding it and then decoding it, and returning the result. + * Otherwise returns the argument unchanged. + */ + FunctionT ensureSerializable(FunctionT fn); + + /** + * If the evaluation context is testing serializability, ensures + * that the argument Coder is serializable and deserializable + * by encoding it and then decoding it, and returning the result. + * Otherwise returns the argument unchanged. + */ + Coder ensureCoderSerializable(Coder coder); + + /** + * If the evaluation context is testing serializability, ensures + * that the given data is serializable and deserializable with the + * given Coder by encoding it and then decoding it, and returning + * the result. Otherwise returns the argument unchanged. + * + *

    Error context is prefixed to any thrown exceptions. + */ + T ensureSerializableByCoder(Coder coder, + T data, String errorContext); + + /** + * Returns a mutator, which can be used to add additional counters to + * this EvaluationContext. + */ + CounterSet.AddCounterMutator getAddCounterMutator(); + + /** + * Gets the step name for this transform. + */ + public String getStepName(PTransform transform); + } + + + ///////////////////////////////////////////////////////////////////////////// + + class Evaluator implements PipelineVisitor, EvaluationContext { + /** + * A map from PTransform to the step name of that transform. This is the internal name for the + * transform (e.g. "s2"). + */ + private final Map, String> stepNames = new HashMap<>(); + private final Map store = new HashMap<>(); + private final CounterSet counters = new CounterSet(); + private AppliedPTransform currentTransform; + + private Map, Collection>> aggregatorSteps = null; + + /** + * A map from PTransform to the full name of that transform. This is the user name of the + * transform (e.g. "RemoveDuplicates/Combine/GroupByKey"). + */ + private final Map, String> fullNames = new HashMap<>(); + + private Random rand; + + public Evaluator() { + this(new Random()); + } + + public Evaluator(Random rand) { + this.rand = rand; + } + + public void run(Pipeline pipeline) { + pipeline.traverseTopologically(this); + aggregatorSteps = new AggregatorPipelineExtractor(pipeline).getAggregatorSteps(); + } + + @Override + public DirectPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public InputT getInput(PTransform transform) { + checkArgument(currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + return (InputT) currentTransform.getInput(); + } + + @Override + public OutputT getOutput(PTransform transform) { + checkArgument(currentTransform != null && currentTransform.getTransform() == transform, + "can only be called with current transform"); + return (OutputT) currentTransform.getOutput(); + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + fullNames.put(transform, node.getFullName()); + TransformEvaluator evaluator = + getTransformEvaluator(transform.getClass()); + if (evaluator == null) { + throw new IllegalStateException( + "no evaluator registered for " + transform); + } + LOG.debug("Evaluating {}", transform); + currentTransform = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform); + evaluator.evaluate(transform, this); + currentTransform = null; + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + LOG.debug("Checking evaluation of {}", value); + if (value.getProducingTransformInternal() == null) { + throw new RuntimeException( + "internal error: expecting a PValue " + + "to have a producingTransform"); + } + if (!producer.isCompositeNode()) { + // Verify that primitive transform outputs are already computed. + getPValue(value); + } + } + + /** + * Sets the value of the given PValue. + * Throws an exception if the PValue's value has already been set. + */ + void setPValue(PValue pvalue, Object contents) { + if (store.containsKey(pvalue)) { + throw new IllegalStateException( + "internal error: setting the value of " + pvalue + + " more than once"); + } + store.put(pvalue, contents); + } + + /** + * Retrieves the value of the given PValue. + * Throws an exception if the PValue's value hasn't already been set. + */ + Object getPValue(PValue pvalue) { + if (!store.containsKey(pvalue)) { + throw new IllegalStateException( + "internal error: getting the value of " + pvalue + + " before it has been computed"); + } + return store.get(pvalue); + } + + /** + * Convert a list of T to a list of {@code ValueWithMetadata}, with a timestamp of 0 + * and null windows. + */ + List> toValueWithMetadata(List values) { + List> result = new ArrayList<>(values.size()); + for (T value : values) { + result.add(ValueWithMetadata.of(WindowedValue.valueInGlobalWindow(value))); + } + return result; + } + + /** + * Convert a list of {@code WindowedValue} to a list of {@code ValueWithMetadata}. + */ + List> toValueWithMetadataFromWindowedValue( + List> values) { + List> result = new ArrayList<>(values.size()); + for (WindowedValue value : values) { + result.add(ValueWithMetadata.of(value)); + } + return result; + } + + @Override + public void setPCollection(PCollection pc, List elements) { + setPCollectionValuesWithMetadata(pc, toValueWithMetadata(elements)); + } + + @Override + public void setPCollectionWindowedValue( + PCollection pc, List> elements) { + setPCollectionValuesWithMetadata(pc, toValueWithMetadataFromWindowedValue(elements)); + } + + @Override + public void setPCollectionValuesWithMetadata( + PCollection pc, List> elements) { + LOG.debug("Setting {} = {}", pc, elements); + ensurePCollectionEncodable(pc, elements); + setPValue(pc, elements); + } + + @Override + public void setPCollectionView( + PCollectionView view, + Iterable> value) { + LOG.debug("Setting {} = {}", view, value); + setPValue(view, value); + } + + /** + * Retrieves the value of the given {@link PCollection}. + * Throws an exception if the {@link PCollection}'s value hasn't already been set. + */ + @Override + public List getPCollection(PCollection pc) { + List result = new ArrayList<>(); + for (ValueWithMetadata elem : getPCollectionValuesWithMetadata(pc)) { + result.add(elem.getValue()); + } + return result; + } + + @Override + public List> getPCollectionWindowedValues(PCollection pc) { + return Lists.transform( + getPCollectionValuesWithMetadata(pc), + new Function, WindowedValue>() { + @Override + public WindowedValue apply(ValueWithMetadata input) { + return input.getWindowedValue(); + }}); + } + + @Override + public List> getPCollectionValuesWithMetadata(PCollection pc) { + List> elements = (List>) getPValue(pc); + elements = randomizeIfUnordered(elements, false /* not inPlaceAllowed */); + LOG.debug("Getting {} = {}", pc, elements); + return elements; + } + + @Override + public List> getPCollectionList(PCollectionList pcs) { + List> elementsList = new ArrayList<>(); + for (PCollection pc : pcs.getAll()) { + elementsList.add(getPCollection(pc)); + } + return elementsList; + } + + /** + * Retrieves the value indicated by the given {@link PCollectionView}. + * Note that within the {@link DoFnContext} a {@link PCollectionView} + * converts from this representation to a suitable side input value. + */ + @Override + public Iterable> getPCollectionView(PCollectionView view) { + Iterable> value = (Iterable>) getPValue(view); + LOG.debug("Getting {} = {}", view, value); + return value; + } + + /** + * If {@code testEncodability}, ensures that the {@link PCollection}'s coder and elements are + * encodable and decodable by encoding them and decoding them, and returning the result. + * Otherwise returns the argument elements. + */ + List> ensurePCollectionEncodable( + PCollection pc, List> elements) { + ensureCoderSerializable(pc.getCoder()); + if (!testEncodability) { + return elements; + } + List> elementsCopy = new ArrayList<>(elements.size()); + for (ValueWithMetadata element : elements) { + elementsCopy.add( + element.withValue(ensureElementEncodable(pc, element.getValue()))); + } + return elementsCopy; + } + + @Override + public T ensureElementEncodable(TypedPValue pvalue, T element) { + return ensureSerializableByCoder( + pvalue.getCoder(), element, "Within " + pvalue.toString()); + } + + @Override + public List randomizeIfUnordered(List elements, + boolean inPlaceAllowed) { + if (!testUnorderedness) { + return elements; + } + List elementsCopy = new ArrayList<>(elements); + Collections.shuffle(elementsCopy, rand); + return elementsCopy; + } + + @Override + public FunctionT ensureSerializable(FunctionT fn) { + if (!testSerializability) { + return fn; + } + return SerializableUtils.ensureSerializable(fn); + } + + @Override + public Coder ensureCoderSerializable(Coder coder) { + if (testSerializability) { + SerializableUtils.ensureSerializable(coder); + } + return coder; + } + + @Override + public T ensureSerializableByCoder( + Coder coder, T value, String errorContext) { + if (testSerializability) { + return SerializableUtils.ensureSerializableByCoder( + coder, value, errorContext); + } + return value; + } + + @Override + public CounterSet.AddCounterMutator getAddCounterMutator() { + return counters.getAddCounterMutator(); + } + + @Override + public String getStepName(PTransform transform) { + String stepName = stepNames.get(transform); + if (stepName == null) { + stepName = "s" + (stepNames.size() + 1); + stepNames.put(transform, stepName); + } + return stepName; + } + + /** + * Returns the CounterSet generated during evaluation, which includes + * user-defined Aggregators and may include system-defined counters. + */ + public CounterSet getCounters() { + return counters; + } + + /** + * Returns JobState.DONE in all situations. The Evaluator is not returned + * until the pipeline has been traversed, so it will either be returned + * after a successful run or the run call will terminate abnormally. + */ + @Override + public State getState() { + return State.DONE; + } + + @Override + public AggregatorValues getAggregatorValues(Aggregator aggregator) { + Map stepValues = new HashMap<>(); + for (PTransform step : aggregatorSteps.get(aggregator)) { + String stepName = String.format("user-%s-%s", stepNames.get(step), aggregator.getName()); + String fullName = fullNames.get(step); + Counter counter = counters.getExistingCounter(stepName); + if (counter == null) { + throw new IllegalArgumentException( + "Aggregator " + aggregator + " is not used in this pipeline"); + } + stepValues.put(fullName, (T) counter.getAggregate()); + } + return new MapAggregatorValues<>(stepValues); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private final DirectPipelineOptions options; + private boolean testSerializability; + private boolean testEncodability; + private boolean testUnorderedness; + + /** Returns a new DirectPipelineRunner. */ + private DirectPipelineRunner(DirectPipelineOptions options) { + this.options = options; + // (Re-)register standard IO factories. Clobbers any prior credentials. + IOChannelUtils.registerStandardIOFactories(options); + long randomSeed; + if (options.getDirectPipelineRunnerRandomSeed() != null) { + randomSeed = options.getDirectPipelineRunnerRandomSeed(); + } else { + randomSeed = new Random().nextLong(); + } + + LOG.debug("DirectPipelineRunner using random seed {}.", randomSeed); + rand = new Random(randomSeed); + + testSerializability = options.isTestSerializability(); + testEncodability = options.isTestEncodability(); + testUnorderedness = options.isTestUnorderedness(); + } + + /** + * Get the options used in this {@link Pipeline}. + */ + public DirectPipelineOptions getPipelineOptions() { + return options; + } + + @Override + public String toString() { + return "DirectPipelineRunner#" + hashCode(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunner.java new file mode 100644 index 000000000000..26d8e1e66662 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunner.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.base.Preconditions; + +/** + * A {@link PipelineRunner} can execute, translate, or otherwise process a + * {@link Pipeline}. + * + * @param the type of the result of {@link #run}. + */ +public abstract class PipelineRunner { + + /** + * Constructs a runner from the provided options. + * + * @return The newly created runner. + */ + public static PipelineRunner fromOptions(PipelineOptions options) { + GcsOptions gcsOptions = PipelineOptionsValidator.validate(GcsOptions.class, options); + Preconditions.checkNotNull(options); + + // (Re-)register standard IO factories. Clobbers any prior credentials. + IOChannelUtils.registerStandardIOFactories(gcsOptions); + + @SuppressWarnings("unchecked") + PipelineRunner result = + InstanceBuilder.ofType(PipelineRunner.class) + .fromClass(options.getRunner()) + .fromFactoryMethod("fromOptions") + .withArg(PipelineOptions.class, options) + .build(); + return result; + } + + /** + * Processes the given Pipeline, returning the results. + */ + public abstract ResultT run(Pipeline pipeline); + + /** + * Applies a transform to the given input, returning the output. + * + *

    The default implementation calls PTransform.apply(input), but can be overridden + * to customize behavior for a particular runner. + */ + public OutputT apply( + PTransform transform, InputT input) { + return transform.apply(input); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerRegistrar.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerRegistrar.java new file mode 100644 index 000000000000..1ca33466a129 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerRegistrar.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.auto.service.AutoService; +import java.util.ServiceLoader; + +/** + * {@link PipelineRunner} creators have the ability to automatically have their + * {@link PipelineRunner} registered with this SDK by creating a {@link ServiceLoader} entry + * and a concrete implementation of this interface. + * + *

    Note that automatic registration of any + * {@link com.google.cloud.dataflow.sdk.options.PipelineOptions} requires users + * conform to the limit that each {@link PipelineRunner}'s + * {@link Class#getSimpleName() simple name} must be unique. + * + *

    It is optional but recommended to use one of the many build time tools such as + * {@link AutoService} to generate the necessary META-INF files automatically. + */ +public interface PipelineRunnerRegistrar { + /** + * Get the set of {@link PipelineRunner PipelineRunners} to register. + */ + public Iterable>> getPipelineRunners(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/RecordingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/RecordingPipelineVisitor.java new file mode 100644 index 000000000000..ca02b39d1307 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/RecordingPipelineVisitor.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.ArrayList; +import java.util.List; + +/** + * Provides a simple {@link com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor} + * that records the transformation tree. + * + *

    Provided for internal unit tests. + */ +public class RecordingPipelineVisitor implements Pipeline.PipelineVisitor { + + public final List> transforms = new ArrayList<>(); + public final List values = new ArrayList<>(); + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @Override + public void visitTransform(TransformTreeNode node) { + transforms.add(node.getTransform()); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + values.add(value); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformHierarchy.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformHierarchy.java new file mode 100644 index 000000000000..d62192d1e788 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformHierarchy.java @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; + +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.Set; + +/** + * Captures information about a collection of transformations and their + * associated {@link PValue}s. + */ +public class TransformHierarchy { + private final Deque transformStack = new LinkedList<>(); + private final Map producingTransformNode = new HashMap<>(); + + /** + * Create a {@code TransformHierarchy} containing a root node. + */ + public TransformHierarchy() { + // First element in the stack is the root node, holding all child nodes. + transformStack.add(new TransformTreeNode(null, null, "", null)); + } + + /** + * Returns the last TransformTreeNode on the stack. + */ + public TransformTreeNode getCurrent() { + return transformStack.peek(); + } + + /** + * Add a TransformTreeNode to the stack. + */ + public void pushNode(TransformTreeNode current) { + transformStack.push(current); + } + + /** + * Removes the last TransformTreeNode from the stack. + */ + public void popNode() { + transformStack.pop(); + Preconditions.checkState(!transformStack.isEmpty()); + } + + /** + * Adds an input to the given node. + * + *

    This forces the producing node to be finished. + */ + public void addInput(TransformTreeNode node, PInput input) { + for (PValue i : input.expand()) { + TransformTreeNode producer = producingTransformNode.get(i); + if (producer == null) { + throw new IllegalStateException("Producer unknown for input: " + i); + } + + producer.finishSpecifying(); + node.addInputProducer(i, producer); + } + } + + /** + * Sets the output of a transform node. + */ + public void setOutput(TransformTreeNode producer, POutput output) { + producer.setOutput(output); + + for (PValue o : output.expand()) { + producingTransformNode.put(o, producer); + } + } + + /** + * Visits all nodes in the transform hierarchy, in transitive order. + */ + public void visit(Pipeline.PipelineVisitor visitor, + Set visitedNodes) { + transformStack.peekFirst().visit(visitor, visitedNodes); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformTreeNode.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformTreeNode.java new file mode 100644 index 000000000000..2649458e347f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/TransformTreeNode.java @@ -0,0 +1,252 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Provides internal tracking of transform relationships with helper methods + * for initialization and ordered visitation. + */ +public class TransformTreeNode { + private final TransformTreeNode enclosingNode; + + // The PTransform for this node, which may be a composite PTransform. + // The root of a TransformHierarchy is represented as a TransformTreeNode + // with a null transform field. + private final PTransform transform; + + private final String fullName; + + // Nodes for sub-transforms of a composite transform. + private final Collection parts = new ArrayList<>(); + + // Inputs to the transform, in expanded form and mapped to the producer + // of the input. + private final Map inputs = new HashMap<>(); + + // Input to the transform, in unexpanded form. + private final PInput input; + + // TODO: track which outputs need to be exported to parent. + // Output of the transform, in unexpanded form. + private POutput output; + + private boolean finishedSpecifying = false; + + /** + * Creates a new TransformTreeNode with the given parent and transform. + * + *

    EnclosingNode and transform may both be null for + * a root-level node, which holds all other nodes. + * + * @param enclosingNode the composite node containing this node + * @param transform the PTransform tracked by this node + * @param fullName the fully qualified name of the transform + * @param input the unexpanded input to the transform + */ + public TransformTreeNode(@Nullable TransformTreeNode enclosingNode, + @Nullable PTransform transform, + String fullName, + @Nullable PInput input) { + this.enclosingNode = enclosingNode; + this.transform = transform; + Preconditions.checkArgument((enclosingNode == null && transform == null) + || (enclosingNode != null && transform != null), + "EnclosingNode and transform must both be specified, or both be null"); + this.fullName = fullName; + this.input = input; + } + + /** + * Returns the transform associated with this transform node. + */ + public PTransform getTransform() { + return transform; + } + + /** + * Returns the enclosing composite transform node, or null if there is none. + */ + public TransformTreeNode getEnclosingNode() { + return enclosingNode; + } + + /** + * Adds a composite operation to the transform node. + * + *

    As soon as a node is added, the transform node is considered a + * composite operation instead of a primitive transform. + */ + public void addComposite(TransformTreeNode node) { + parts.add(node); + } + + /** + * Returns true if this node represents a composite transform that does not perform + * processing of its own, but merely encapsulates a sub-pipeline (which may be empty). + * + *

    Note that a node may be composite with no sub-transforms if it returns its input directly + * extracts a component of a tuple, or other operations that occur at pipeline assembly time. + */ + public boolean isCompositeNode() { + return !parts.isEmpty() || returnsOthersOutput() || isRootNode(); + } + + private boolean returnsOthersOutput() { + PTransform transform = getTransform(); + for (PValue output : getExpandedOutputs()) { + if (!output.getProducingTransformInternal().getTransform().equals(transform)) { + return true; + } + } + return false; + } + + public boolean isRootNode() { + return transform == null; + } + + public String getFullName() { + return fullName; + } + + /** + * Adds an input to the transform node. + */ + public void addInputProducer(PValue expandedInput, TransformTreeNode producer) { + Preconditions.checkState(!finishedSpecifying); + inputs.put(expandedInput, producer); + } + + /** + * Returns the transform input, in unexpanded form. + */ + public PInput getInput() { + return input; + } + + /** + * Returns a mapping of inputs to the producing nodes for all inputs to + * the transform. + */ + public Map getInputs() { + return Collections.unmodifiableMap(inputs); + } + + /** + * Adds an output to the transform node. + */ + public void setOutput(POutput output) { + Preconditions.checkState(!finishedSpecifying); + Preconditions.checkState(this.output == null); + this.output = output; + } + + /** + * Returns the transform output, in unexpanded form. + */ + public POutput getOutput() { + return output; + } + + /** + * Returns the transform outputs, in expanded form. + */ + public Collection getExpandedOutputs() { + if (output != null) { + return output.expand(); + } else { + return Collections.emptyList(); + } + } + + /** + * Visit the transform node. + * + *

    Provides an ordered visit of the input values, the primitive + * transform (or child nodes for composite transforms), then the + * output values. + */ + public void visit(Pipeline.PipelineVisitor visitor, + Set visitedValues) { + if (!finishedSpecifying) { + finishSpecifying(); + } + + // Visit inputs. + for (Map.Entry entry : inputs.entrySet()) { + if (visitedValues.add(entry.getKey())) { + visitor.visitValue(entry.getKey(), entry.getValue()); + } + } + + if (isCompositeNode()) { + visitor.enterCompositeTransform(this); + for (TransformTreeNode child : parts) { + child.visit(visitor, visitedValues); + } + visitor.leaveCompositeTransform(this); + } else { + visitor.visitTransform(this); + } + + // Visit outputs. + for (PValue pValue : getExpandedOutputs()) { + if (visitedValues.add(pValue)) { + visitor.visitValue(pValue, this); + } + } + } + + /** + * Finish specifying a transform. + * + *

    All inputs are finished first, then the transform, then + * all outputs. + */ + public void finishSpecifying() { + if (finishedSpecifying) { + return; + } + finishedSpecifying = true; + + for (TransformTreeNode input : inputs.values()) { + if (input != null) { + input.finishSpecifying(); + } + } + + if (output != null) { + output.finishSpecifyingOutput(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/AssignWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/AssignWindows.java new file mode 100644 index 000000000000..093783de8583 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/AssignWindows.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * A primitive {@link PTransform} that implements the {@link Window#into(WindowFn)} + * {@link PTransform}. + * + * For an application of {@link Window#into(WindowFn)} that changes the {@link WindowFn}, applies + * a primitive {@link PTransform} in the Dataflow service. + * + * For an application of {@link Window#into(WindowFn)} that does not change the {@link WindowFn}, + * applies an identity {@link ParDo} and sets the windowing strategy of the output + * {@link PCollection}. + * + * For internal use only. + * + * @param the type of input element + */ +public class AssignWindows extends PTransform, PCollection> { + private final Window.Bound transform; + + /** + * Builds an instance of this class from the overriden transform. + */ + @SuppressWarnings("unused") // Used via reflection + public AssignWindows(Window.Bound transform) { + this.transform = transform; + } + + @Override + public PCollection apply(PCollection input) { + WindowingStrategy outputStrategy = + transform.getOutputStrategyInternal(input.getWindowingStrategy()); + if (transform.getWindowFn() != null) { + // If the windowFn changed, we create a primitive, and run the AssignWindows operation here. + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), outputStrategy, input.isBounded()); + } else { + // If the windowFn didn't change, we just run a pass-through transform and then set the + // new windowing strategy. + return input.apply(ParDo.named("Identity").of(new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + c.output(c.element()); + } + })).setWindowingStrategyInternal(outputStrategy); + } + } + + @Override + public void validate(PCollection input) { + transform.validate(input); + } + + @Override + protected Coder getDefaultOutputCoder(PCollection input) { + return input.getCoder(); + } + + @Override + protected String getKindString() { + return "Window.Into()"; + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/BigQueryIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/BigQueryIOTranslator.java new file mode 100644 index 000000000000..538901c722c0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/BigQueryIOTranslator.java @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.api.client.json.JsonFactory; +import com.google.api.services.bigquery.model.TableReference; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * BigQuery transform support code for the Dataflow backend. + */ +public class BigQueryIOTranslator { + private static final JsonFactory JSON_FACTORY = Transport.getJsonFactory(); + private static final Logger LOG = LoggerFactory.getLogger(BigQueryIOTranslator.class); + + /** + * Implements BigQueryIO Read translation for the Dataflow backend. + */ + public static class ReadTranslator + implements DataflowPipelineTranslator.TransformTranslator { + + @Override + public void translate( + BigQueryIO.Read.Bound transform, DataflowPipelineTranslator.TranslationContext context) { + // Actual translation. + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, "bigquery"); + context.addInput(PropertyNames.BIGQUERY_EXPORT_FORMAT, "FORMAT_AVRO"); + + if (transform.getQuery() != null) { + context.addInput(PropertyNames.BIGQUERY_QUERY, transform.getQuery()); + context.addInput(PropertyNames.BIGQUERY_FLATTEN_RESULTS, transform.getFlattenResults()); + } else { + TableReference table = transform.getTable(); + if (table.getProjectId() == null) { + // If user does not specify a project we assume the table to be located in the project + // that owns the Dataflow job. + String projectIdFromOptions = context.getPipelineOptions().getProject(); + LOG.warn(String.format(BigQueryIO.SET_PROJECT_FROM_OPTIONS_WARNING, table.getDatasetId(), + table.getDatasetId(), table.getTableId(), projectIdFromOptions)); + table.setProjectId(projectIdFromOptions); + } + + context.addInput(PropertyNames.BIGQUERY_TABLE, table.getTableId()); + context.addInput(PropertyNames.BIGQUERY_DATASET, table.getDatasetId()); + if (table.getProjectId() != null) { + context.addInput(PropertyNames.BIGQUERY_PROJECT, table.getProjectId()); + } + } + context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + } + + /** + * Implements BigQueryIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator + implements DataflowPipelineTranslator.TransformTranslator { + + @Override + public void translate(BigQueryIO.Write.Bound transform, + DataflowPipelineTranslator.TranslationContext context) { + if (context.getPipelineOptions().isStreaming()) { + // Streaming is handled by the streaming runner. + throw new AssertionError( + "BigQueryIO is specified to use streaming write in batch mode."); + } + + TableReference table = transform.getTable(); + + // Actual translation. + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.FORMAT, "bigquery"); + context.addInput(PropertyNames.BIGQUERY_TABLE, + table.getTableId()); + context.addInput(PropertyNames.BIGQUERY_DATASET, + table.getDatasetId()); + if (table.getProjectId() != null) { + context.addInput(PropertyNames.BIGQUERY_PROJECT, table.getProjectId()); + } + if (transform.getSchema() != null) { + try { + context.addInput(PropertyNames.BIGQUERY_SCHEMA, + JSON_FACTORY.toString(transform.getSchema())); + } catch (IOException exn) { + throw new IllegalArgumentException("Invalid table schema.", exn); + } + } + context.addInput( + PropertyNames.BIGQUERY_CREATE_DISPOSITION, + transform.getCreateDisposition().name()); + context.addInput( + PropertyNames.BIGQUERY_WRITE_DISPOSITION, + transform.getWriteDisposition().name()); + // Set sink encoding to TableRowJsonCoder. + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(TableRowJsonCoder.of())); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/CustomSources.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/CustomSources.java new file mode 100644 index 000000000000..81606931954b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/CustomSources.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import static com.google.api.client.util.Base64.encodeBase64String; +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.addStringList; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.api.services.dataflow.model.SourceMetadata; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.ByteString; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + + +/** + * A helper class for supporting sources defined as {@code Source}. + * + *

    Provides a bridge between the high-level {@code Source} API and the + * low-level {@code CloudSource} class. + */ +public class CustomSources { + private static final String SERIALIZED_SOURCE = "serialized_source"; + @VisibleForTesting static final String SERIALIZED_SOURCE_SPLITS = "serialized_source_splits"; + /** + * The current limit on the size of a ReportWorkItemStatus RPC to Google Cloud Dataflow, which + * includes the initial splits, is 20 MB. + */ + public static final long DATAFLOW_SPLIT_RESPONSE_API_SIZE_BYTES = 20 * (1 << 20); + + private static final Logger LOG = LoggerFactory.getLogger(CustomSources.class); + + private static final ByteString firstSplitKey = ByteString.copyFromUtf8("0000000000000001"); + + public static boolean isFirstUnboundedSourceSplit(ByteString splitKey) { + return splitKey.equals(firstSplitKey); + } + + private static int getDesiredNumUnboundedSourceSplits(DataflowPipelineOptions options) { + if (options.getMaxNumWorkers() > 0) { + return options.getMaxNumWorkers(); + } else if (options.getNumWorkers() > 0) { + return options.getNumWorkers() * 3; + } else { + return 20; + } + } + + public static com.google.api.services.dataflow.model.Source serializeToCloudSource( + Source source, PipelineOptions options) throws Exception { + com.google.api.services.dataflow.model.Source cloudSource = + new com.google.api.services.dataflow.model.Source(); + // We ourselves act as the SourceFormat. + cloudSource.setSpec(CloudObject.forClass(CustomSources.class)); + addString( + cloudSource.getSpec(), SERIALIZED_SOURCE, encodeBase64String(serializeToByteArray(source))); + + SourceMetadata metadata = new SourceMetadata(); + if (source instanceof BoundedSource) { + BoundedSource boundedSource = (BoundedSource) source; + try { + metadata.setProducesSortedKeys(boundedSource.producesSortedKeys(options)); + } catch (Exception e) { + LOG.warn("Failed to check if the source produces sorted keys: " + source, e); + } + + // Size estimation is best effort so we continue even if it fails here. + try { + metadata.setEstimatedSizeBytes(boundedSource.getEstimatedSizeBytes(options)); + } catch (Exception e) { + LOG.warn("Size estimation of the source failed: " + source, e); + } + } else if (source instanceof UnboundedSource) { + UnboundedSource unboundedSource = (UnboundedSource) source; + metadata.setInfinite(true); + List encodedSplits = new ArrayList<>(); + int desiredNumSplits = + getDesiredNumUnboundedSourceSplits(options.as(DataflowPipelineOptions.class)); + for (UnboundedSource split : + unboundedSource.generateInitialSplits(desiredNumSplits, options)) { + encodedSplits.add(encodeBase64String(serializeToByteArray(split))); + } + checkArgument(!encodedSplits.isEmpty(), "UnboundedSources must have at least one split"); + addStringList(cloudSource.getSpec(), SERIALIZED_SOURCE_SPLITS, encodedSplits); + } else { + throw new IllegalArgumentException("Unexpected source kind: " + source.getClass()); + } + + cloudSource.setMetadata(metadata); + return cloudSource; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DataflowAggregatorTransforms.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DataflowAggregatorTransforms.java new file mode 100644 index 000000000000..e1d73019fec7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DataflowAggregatorTransforms.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; + +/** + * A mapping relating {@link Aggregator}s and the {@link PTransform} in which they are used. + */ +public class DataflowAggregatorTransforms { + private final Map, Collection>> aggregatorTransforms; + private final Multimap, AppliedPTransform> transformAppliedTransforms; + private final BiMap, String> appliedStepNames; + + public DataflowAggregatorTransforms( + Map, Collection>> aggregatorTransforms, + Map, String> transformStepNames) { + this.aggregatorTransforms = aggregatorTransforms; + appliedStepNames = HashBiMap.create(transformStepNames); + + transformAppliedTransforms = HashMultimap.create(); + for (AppliedPTransform appliedTransform : transformStepNames.keySet()) { + transformAppliedTransforms.put(appliedTransform.getTransform(), appliedTransform); + } + } + + /** + * Returns true if the provided {@link Aggregator} is used in the constructing {@link Pipeline}. + */ + public boolean contains(Aggregator aggregator) { + return aggregatorTransforms.containsKey(aggregator); + } + + /** + * Gets the step names in which the {@link Aggregator} is used. + */ + public Collection getAggregatorStepNames(Aggregator aggregator) { + Collection names = new HashSet<>(); + Collection> transforms = aggregatorTransforms.get(aggregator); + for (PTransform transform : transforms) { + for (AppliedPTransform applied : transformAppliedTransforms.get(transform)) { + names.add(appliedStepNames.get(applied)); + } + } + return names; + } + + /** + * Gets the {@link PTransform} that was assigned the provided step name. + */ + public AppliedPTransform getAppliedTransformForStepName(String stepName) { + return appliedStepNames.inverse().get(stepName); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DataflowMetricUpdateExtractor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DataflowMetricUpdateExtractor.java new file mode 100644 index 000000000000..13016dd4938b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/DataflowMetricUpdateExtractor.java @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.api.services.dataflow.model.MetricStructuredName; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Methods for extracting the values of an {@link Aggregator} from a collection of {@link + * MetricUpdate MetricUpdates}. + */ +public final class DataflowMetricUpdateExtractor { + private static final String STEP_NAME_CONTEXT_KEY = "step"; + private static final String IS_TENTATIVE_KEY = "tentative"; + + private DataflowMetricUpdateExtractor() { + // Do not instantiate. + } + + /** + * Extract the values of the provided {@link Aggregator} at each {@link PTransform} it was used in + * according to the provided {@link DataflowAggregatorTransforms} from the given list of {@link + * MetricUpdate MetricUpdates}. + */ + public static Map fromMetricUpdates(Aggregator aggregator, + DataflowAggregatorTransforms aggregatorTransforms, List metricUpdates) { + Map results = new HashMap<>(); + if (metricUpdates == null) { + return results; + } + + String aggregatorName = aggregator.getName(); + Collection aggregatorSteps = aggregatorTransforms.getAggregatorStepNames(aggregator); + + for (MetricUpdate metricUpdate : metricUpdates) { + MetricStructuredName metricStructuredName = metricUpdate.getName(); + Map context = metricStructuredName.getContext(); + if (metricStructuredName.getName().equals(aggregatorName) && context != null + && aggregatorSteps.contains(context.get(STEP_NAME_CONTEXT_KEY))) { + AppliedPTransform transform = + aggregatorTransforms.getAppliedTransformForStepName( + context.get(STEP_NAME_CONTEXT_KEY)); + String fullName = transform.getFullName(); + // Prefer the tentative (fresher) value if it exists. + if (Boolean.parseBoolean(context.get(IS_TENTATIVE_KEY)) || !results.containsKey(fullName)) { + results.put(fullName, toValue(aggregator, metricUpdate)); + } + } + } + + return results; + + } + + private static OutputT toValue( + Aggregator aggregator, MetricUpdate metricUpdate) { + CombineFn combineFn = aggregator.getCombineFn(); + Class outputType = combineFn.getOutputType().getRawType(); + + if (outputType.equals(Long.class)) { + @SuppressWarnings("unchecked") + OutputT asLong = (OutputT) Long.valueOf(toNumber(metricUpdate).longValue()); + return asLong; + } + if (outputType.equals(Integer.class)) { + @SuppressWarnings("unchecked") + OutputT asInt = (OutputT) Integer.valueOf(toNumber(metricUpdate).intValue()); + return asInt; + } + if (outputType.equals(Double.class)) { + @SuppressWarnings("unchecked") + OutputT asDouble = (OutputT) Double.valueOf(toNumber(metricUpdate).doubleValue()); + return asDouble; + } + throw new UnsupportedOperationException( + "Unsupported Output Type " + outputType + " in aggregator " + aggregator); + } + + private static Number toNumber(MetricUpdate update) { + if (update.getScalar() instanceof Number) { + return (Number) update.getScalar(); + } + throw new IllegalArgumentException( + "Metric Update " + update + " does not have a numeric scalar"); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java new file mode 100644 index 000000000000..8b066ab065dd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +/** + * Pubsub transform support code for the Dataflow backend. + */ +public class PubsubIOTranslator { + + /** + * Implements PubsubIO Read translation for the Dataflow backend. + */ + public static class ReadTranslator implements TransformTranslator> { + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public void translate( + PubsubIO.Read.Bound transform, + TranslationContext context) { + translateReadHelper(transform, context); + } + + private void translateReadHelper( + PubsubIO.Read.Bound transform, + TranslationContext context) { + if (!context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException( + "PubsubIO.Read can only be used with the Dataflow streaming runner."); + } + + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, "pubsub"); + if (transform.getTopic() != null) { + context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic().asV1Beta1Path()); + } + if (transform.getSubscription() != null) { + context.addInput( + PropertyNames.PUBSUB_SUBSCRIPTION, transform.getSubscription().asV1Beta1Path()); + } + if (transform.getTimestampLabel() != null) { + context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, transform.getTimestampLabel()); + } + if (transform.getIdLabel() != null) { + context.addInput(PropertyNames.PUBSUB_ID_LABEL, transform.getIdLabel()); + } + context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + } + + /** + * Implements PubsubIO Write translation for the Dataflow backend. + */ + public static class WriteTranslator + implements TransformTranslator> { + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public void translate( + DataflowPipelineRunner.StreamingPubsubIOWrite transform, + TranslationContext context) { + translateWriteHelper(transform, context); + } + + private void translateWriteHelper( + DataflowPipelineRunner.StreamingPubsubIOWrite customTransform, + TranslationContext context) { + if (!context.getPipelineOptions().isStreaming()) { + throw new IllegalArgumentException( + "PubsubIO.Write is non-primitive for the Dataflow batch runner."); + } + + PubsubIO.Write.Bound transform = customTransform.getOverriddenTransform(); + + context.addStep(customTransform, "ParallelWrite"); + context.addInput(PropertyNames.FORMAT, "pubsub"); + context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic().asV1Beta1Path()); + if (transform.getTimestampLabel() != null) { + context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, transform.getTimestampLabel()); + } + if (transform.getIdLabel() != null) { + context.addInput(PropertyNames.PUBSUB_ID_LABEL, transform.getIdLabel()); + } + context.addEncodingInput(WindowedValue.getValueOnlyCoder(transform.getCoder())); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(customTransform)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/ReadTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/ReadTranslator.java new file mode 100644 index 000000000000..f110e84adc49 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/ReadTranslator.java @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; + +import com.google.api.services.dataflow.model.SourceMetadata; +import com.google.cloud.dataflow.sdk.io.FileBasedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.HashMap; +import java.util.Map; + +/** + * Translator for the {@code Read} {@code PTransform} for the Dataflow back-end. + */ +public class ReadTranslator implements TransformTranslator> { + @Override + public void translate(Read.Bounded transform, TranslationContext context) { + translateReadHelper(transform.getSource(), transform, context); + } + + public static void translateReadHelper(Source source, + PTransform transform, + DataflowPipelineTranslator.TranslationContext context) { + try { + // TODO: Move this validation out of translation once IOChannelUtils is portable + // and can be reconstructed on the worker. + if (source instanceof FileBasedSource) { + String filePatternOrSpec = ((FileBasedSource) source).getFileOrPatternSpec(); + context.getPipelineOptions() + .getPathValidator() + .validateInputFilePatternSupported(filePatternOrSpec); + } + + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, PropertyNames.CUSTOM_SOURCE_FORMAT); + context.addInput( + PropertyNames.SOURCE_STEP_INPUT, + cloudSourceToDictionary( + CustomSources.serializeToCloudSource(source, context.getPipelineOptions()))); + context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + // Represents a cloud Source as a dictionary for encoding inside the {@code SOURCE_STEP_INPUT} + // property of CloudWorkflowStep.input. + private static Map cloudSourceToDictionary( + com.google.api.services.dataflow.model.Source source) { + // Do not translate encoding - the source's encoding is translated elsewhere + // to the step's output info. + Map res = new HashMap<>(); + addDictionary(res, PropertyNames.SOURCE_SPEC, source.getSpec()); + if (source.getMetadata() != null) { + addDictionary(res, PropertyNames.SOURCE_METADATA, + cloudSourceMetadataToDictionary(source.getMetadata())); + } + if (source.getDoesNotNeedSplitting() != null) { + addBoolean( + res, PropertyNames.SOURCE_DOES_NOT_NEED_SPLITTING, source.getDoesNotNeedSplitting()); + } + return res; + } + + private static Map cloudSourceMetadataToDictionary(SourceMetadata metadata) { + Map res = new HashMap<>(); + if (metadata.getProducesSortedKeys() != null) { + addBoolean(res, PropertyNames.SOURCE_PRODUCES_SORTED_KEYS, metadata.getProducesSortedKeys()); + } + if (metadata.getEstimatedSizeBytes() != null) { + addLong(res, PropertyNames.SOURCE_ESTIMATED_SIZE_BYTES, metadata.getEstimatedSizeBytes()); + } + if (metadata.getInfinite() != null) { + addBoolean(res, PropertyNames.SOURCE_IS_INFINITE, metadata.getInfinite()); + } + return res; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/package-info.java new file mode 100644 index 000000000000..b6b2ce690165 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Implementation of the {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner}. + */ +package com.google.cloud.dataflow.sdk.runners.dataflow; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java new file mode 100644 index 000000000000..1c0279897aac --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.io.Read.Bounded; +import com.google.cloud.dataflow.sdk.io.Source.Reader; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.io.IOException; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nullable; + +/** + * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} + * for the {@link Bounded Read.Bounded} primitive {@link PTransform}. + */ +final class BoundedReadEvaluatorFactory implements TransformEvaluatorFactory { + /* + * An evaluator for a Source is stateful, to ensure data is not read multiple times. + * Evaluators are cached here to ensure that the reader is not restarted if the evaluator is + * retriggered. + */ + private final ConcurrentMap>> + sourceEvaluators = new ConcurrentHashMap<>(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + @Nullable CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) + throws IOException { + return getTransformEvaluator((AppliedPTransform) application, evaluationContext); + } + + private TransformEvaluator getTransformEvaluator( + final AppliedPTransform, Bounded> transform, + final InProcessEvaluationContext evaluationContext) + throws IOException { + BoundedReadEvaluator evaluator = + getTransformEvaluatorQueue(transform, evaluationContext).poll(); + if (evaluator == null) { + return EmptyTransformEvaluator.create(transform); + } + return evaluator; + } + + /** + * Get the queue of {@link TransformEvaluator TransformEvaluators} that produce elements for the + * provided application of {@link Bounded Read.Bounded}, initializing it if required. + * + *

    This method is thread-safe, and will only produce new evaluators if no other invocation has + * already done so. + */ + @SuppressWarnings("unchecked") + private Queue> getTransformEvaluatorQueue( + final AppliedPTransform, Bounded> transform, + final InProcessEvaluationContext evaluationContext) + throws IOException { + // Key by the application and the context the evaluation is occurring in (which call to + // Pipeline#run). + EvaluatorKey key = new EvaluatorKey(transform, evaluationContext); + Queue> evaluatorQueue = + (Queue>) sourceEvaluators.get(key); + if (evaluatorQueue == null) { + evaluatorQueue = new ConcurrentLinkedQueue<>(); + if (sourceEvaluators.putIfAbsent(key, evaluatorQueue) == null) { + // If no queue existed in the evaluators, add an evaluator to initialize the evaluator + // factory for this transform + BoundedReadEvaluator evaluator = + new BoundedReadEvaluator(transform, evaluationContext); + evaluatorQueue.offer(evaluator); + } else { + // otherwise return the existing Queue that arrived before us + evaluatorQueue = (Queue>) sourceEvaluators.get(key); + } + } + return evaluatorQueue; + } + + private static class BoundedReadEvaluator implements TransformEvaluator { + private final AppliedPTransform, Bounded> transform; + private final InProcessEvaluationContext evaluationContext; + private final Reader reader; + private boolean contentsRemaining; + + public BoundedReadEvaluator( + AppliedPTransform, Bounded> transform, + InProcessEvaluationContext evaluationContext) + throws IOException { + this.transform = transform; + this.evaluationContext = evaluationContext; + reader = + transform.getTransform().getSource().createReader(evaluationContext.getPipelineOptions()); + contentsRemaining = reader.start(); + } + + @Override + public void processElement(WindowedValue element) {} + + @Override + public InProcessTransformResult finishBundle() throws IOException { + UncommittedBundle output = evaluationContext.createRootBundle(transform.getOutput()); + while (contentsRemaining) { + output.add( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp())); + contentsRemaining = reader.advance(); + } + return StepTransformResult + .withHold(transform, BoundedWindow.TIMESTAMP_MAX_VALUE) + .addOutput(output) + .build(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/Clock.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/Clock.java new file mode 100644 index 000000000000..11e6ec1686d9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/Clock.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import org.joda.time.Instant; + +/** + * Access to the current time. + */ +public interface Clock { + /** + * Returns the current time as an {@link Instant}. + */ + Instant now(); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EmptyTransformEvaluator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EmptyTransformEvaluator.java new file mode 100644 index 000000000000..fc092377942d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EmptyTransformEvaluator.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +/** + * A {@link TransformEvaluator} that ignores all input and produces no output. The result of + * invoking {@link #finishBundle()} on this evaluator is to return an + * {@link InProcessTransformResult} with no elements and a timestamp hold equal to + * {@link BoundedWindow#TIMESTAMP_MIN_VALUE}. Because the result contains no elements, this hold + * will not affect the watermark. + */ +final class EmptyTransformEvaluator implements TransformEvaluator { + public static TransformEvaluator create(AppliedPTransform transform) { + return new EmptyTransformEvaluator(transform); + } + + private final AppliedPTransform transform; + + private EmptyTransformEvaluator(AppliedPTransform transform) { + this.transform = transform; + } + + @Override + public void processElement(WindowedValue element) throws Exception {} + + @Override + public InProcessTransformResult finishBundle() throws Exception { + return StepTransformResult.withHold(transform, BoundedWindow.TIMESTAMP_MIN_VALUE) + .build(); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java new file mode 100644 index 000000000000..745f8f2718a3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; + +import java.util.Objects; + +/** + * A (Transform, Pipeline Execution) key for stateful evaluators. + * + * Source evaluators are stateful to ensure data is not read multiple times. Evaluators are cached + * to ensure that the reader is not restarted if the evaluator is retriggered. An + * {@link EvaluatorKey} is used to ensure that multiple Pipelines can be executed without sharing + * the same evaluators. + */ +final class EvaluatorKey { + private final AppliedPTransform transform; + private final InProcessEvaluationContext context; + + public EvaluatorKey(AppliedPTransform transform, InProcessEvaluationContext context) { + this.transform = transform; + this.context = context; + } + + @Override + public int hashCode() { + return Objects.hash(transform, context); + } + + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof EvaluatorKey)) { + return false; + } + EvaluatorKey that = (EvaluatorKey) other; + return Objects.equals(this.transform, that.transform) + && Objects.equals(this.context, that.context); + } +} + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java new file mode 100644 index 000000000000..14428888e2b5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.Flatten.FlattenPCollectionList; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +/** + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the {@link Flatten} + * {@link PTransform}. + */ +class FlattenEvaluatorFactory implements TransformEvaluatorFactory { + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + return createInMemoryEvaluator((AppliedPTransform) application, inputBundle, evaluationContext); + } + + private TransformEvaluator createInMemoryEvaluator( + final AppliedPTransform< + PCollectionList, PCollection, FlattenPCollectionList> + application, + final CommittedBundle inputBundle, + final InProcessEvaluationContext evaluationContext) { + if (inputBundle == null) { + // it is impossible to call processElement on a flatten with no input bundle. A Flatten with + // no input bundle occurs as an output of Flatten.pcollections(PCollectionList.empty()) + return new FlattenEvaluator<>( + null, StepTransformResult.withoutHold(application).build()); + } + final UncommittedBundle outputBundle = + evaluationContext.createBundle(inputBundle, application.getOutput()); + final InProcessTransformResult result = + StepTransformResult.withoutHold(application).addOutput(outputBundle).build(); + return new FlattenEvaluator<>(outputBundle, result); + } + + private static class FlattenEvaluator implements TransformEvaluator { + private final UncommittedBundle outputBundle; + private final InProcessTransformResult result; + + public FlattenEvaluator( + UncommittedBundle outputBundle, InProcessTransformResult result) { + this.outputBundle = outputBundle; + this.result = result; + } + + @Override + public void processElement(WindowedValue element) { + outputBundle.add(element); + } + + @Override + public InProcessTransformResult finishBundle() { + return result; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ForwardingPTransform.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ForwardingPTransform.java new file mode 100644 index 000000000000..b736e35d3128 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ForwardingPTransform.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.TypedPValue; + +/** + * A base class for implementing {@link PTransform} overrides, which behave identically to the + * delegate transform but with overridden methods. Implementors are required to implement + * {@link #delegate()}, which returns the object to forward calls to, and {@link #apply(PInput)}. + */ +public abstract class ForwardingPTransform + extends PTransform { + protected abstract PTransform delegate(); + + @Override + public OutputT apply(InputT input) { + return delegate().apply(input); + } + + @Override + public void validate(InputT input) { + delegate().validate(input); + } + + @Override + public String getName() { + return delegate().getName(); + } + + @Override + public Coder getDefaultOutputCoder(InputT input, @SuppressWarnings("unused") + TypedPValue output) throws CannotProvideCoderException { + return delegate().getDefaultOutputCoder(input, output); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java new file mode 100644 index 000000000000..0347281749cb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java @@ -0,0 +1,252 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.StepTransformResult.Builder; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey.ReifyTimestampsAndWindows; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowViaWindowSetDoFn; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItem; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItemCoder; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItems; +import com.google.cloud.dataflow.sdk.util.SystemReduceFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.annotations.VisibleForTesting; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the {@link GroupByKey} + * {@link PTransform}. + */ +class GroupByKeyEvaluatorFactory implements TransformEvaluatorFactory { + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public TransformEvaluator forApplication( + AppliedPTransform application, + CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + return createEvaluator( + (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); + } + + private TransformEvaluator>> createEvaluator( + final AppliedPTransform< + PCollection>>, PCollection>, + InProcessGroupByKeyOnly> + application, + final CommittedBundle> inputBundle, + final InProcessEvaluationContext evaluationContext) { + return new GroupByKeyEvaluator(evaluationContext, inputBundle, application); + } + + private static class GroupByKeyEvaluator + implements TransformEvaluator>> { + private final InProcessEvaluationContext evaluationContext; + + private final CommittedBundle> inputBundle; + private final AppliedPTransform< + PCollection>>, PCollection>, + InProcessGroupByKeyOnly> + application; + private final Coder keyCoder; + private Map, List>> groupingMap; + + public GroupByKeyEvaluator( + InProcessEvaluationContext evaluationContext, + CommittedBundle> inputBundle, + AppliedPTransform< + PCollection>>, PCollection>, + InProcessGroupByKeyOnly> + application) { + this.evaluationContext = evaluationContext; + this.inputBundle = inputBundle; + this.application = application; + + PCollection>> input = application.getInput(); + keyCoder = getKeyCoder(input.getCoder()); + groupingMap = new HashMap<>(); + } + + private Coder getKeyCoder(Coder>> coder) { + if (!(coder instanceof KvCoder)) { + throw new IllegalStateException(); + } + @SuppressWarnings("unchecked") + Coder keyCoder = ((KvCoder>) coder).getKeyCoder(); + return keyCoder; + } + + @Override + public void processElement(WindowedValue>> element) { + KV> kv = element.getValue(); + K key = kv.getKey(); + byte[] encodedKey; + try { + encodedKey = encodeToByteArray(keyCoder, key); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + String.format("unable to encode key %s of input to %s using %s", key, this, keyCoder), + exn); + } + GroupingKey groupingKey = new GroupingKey<>(key, encodedKey); + List> values = groupingMap.get(groupingKey); + if (values == null) { + values = new ArrayList>(); + groupingMap.put(groupingKey, values); + } + values.add(kv.getValue()); + } + + @Override + public InProcessTransformResult finishBundle() { + Builder resultBuilder = StepTransformResult.withoutHold(application); + for (Map.Entry, List>> groupedEntry : + groupingMap.entrySet()) { + K key = groupedEntry.getKey().key; + KeyedWorkItem groupedKv = + KeyedWorkItems.elementsWorkItem(key, groupedEntry.getValue()); + UncommittedBundle> bundle = + evaluationContext.createKeyedBundle(inputBundle, key, application.getOutput()); + bundle.add(WindowedValue.valueInEmptyWindows(groupedKv)); + resultBuilder.addOutput(bundle); + } + return resultBuilder.build(); + } + + private static class GroupingKey { + private K key; + private byte[] encodedKey; + + public GroupingKey(K key, byte[] encodedKey) { + this.key = key; + this.encodedKey = encodedKey; + } + + @Override + public boolean equals(Object o) { + if (o instanceof GroupingKey) { + GroupingKey that = (GroupingKey) o; + return Arrays.equals(this.encodedKey, that.encodedKey); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(encodedKey); + } + } + } + + /** + * An in-memory implementation of the {@link GroupByKey} primitive as a composite + * {@link PTransform}. + */ + public static final class InProcessGroupByKey + extends ForwardingPTransform>, PCollection>>> { + private final GroupByKey original; + + public InProcessGroupByKey(GroupByKey from) { + this.original = from; + } + + @Override + public PTransform>, PCollection>>> delegate() { + return original; + } + + @Override + public PCollection>> apply(PCollection> input) { + KvCoder inputCoder = (KvCoder) input.getCoder(); + + // This operation groups by the combination of key and window, + // merging windows as needed, using the windows assigned to the + // key/value input elements and the window merge operation of the + // window function associated with the input PCollection. + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + + // Use the default GroupAlsoByWindow implementation + DoFn, KV>> groupAlsoByWindow = + groupAlsoByWindow(windowingStrategy, inputCoder.getValueCoder()); + + // By default, implement GroupByKey via a series of lower-level operations. + return input + // Make each input element's timestamp and assigned windows + // explicit, in the value part. + .apply(new ReifyTimestampsAndWindows()) + + .apply(new InProcessGroupByKeyOnly()) + .setCoder(KeyedWorkItemCoder.of(inputCoder.getKeyCoder(), + inputCoder.getValueCoder(), input.getWindowingStrategy().getWindowFn().windowCoder())) + + // Group each key's values by window, merging windows as needed. + .apply("GroupAlsoByWindow", ParDo.of(groupAlsoByWindow)) + + // And update the windowing strategy as appropriate. + .setWindowingStrategyInternal(original.updateWindowingStrategy(windowingStrategy)) + .setCoder( + KvCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getValueCoder()))); + } + + private + DoFn, KV>> groupAlsoByWindow( + final WindowingStrategy windowingStrategy, final Coder inputCoder) { + return GroupAlsoByWindowViaWindowSetDoFn.create( + windowingStrategy, SystemReduceFn.buffering(inputCoder)); + } + } + + /** + * An implementation primitive to use in the evaluation of a {@link GroupByKey} + * {@link PTransform}. + */ + public static final class InProcessGroupByKeyOnly + extends PTransform>>, PCollection>> { + @Override + public PCollection> apply(PCollection>> input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()); + } + + @VisibleForTesting + InProcessGroupByKeyOnly() {} + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java new file mode 100644 index 000000000000..e280e22d2bb9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java @@ -0,0 +1,1316 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ComparisonChain; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Ordering; +import com.google.common.collect.SortedMultiset; +import com.google.common.collect.TreeMultiset; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NavigableSet; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicReference; + +import javax.annotation.Nullable; + +/** + * Manages watermarks of {@link PCollection PCollections} and input and output watermarks of + * {@link AppliedPTransform AppliedPTransforms} to provide event-time and completion tracking for + * in-memory execution. {@link InMemoryWatermarkManager} is designed to update and return a + * consistent view of watermarks in the presence of concurrent updates. + * + *

    An {@link InMemoryWatermarkManager} is provided with the collection of root + * {@link AppliedPTransform AppliedPTransforms} and a map of {@link PCollection PCollections} to + * all the {@link AppliedPTransform AppliedPTransforms} that consume them at construction time. + * + *

    Whenever a root {@link AppliedPTransform transform} produces elements, the + * {@link InMemoryWatermarkManager} is provided with the produced elements and the output watermark + * of the producing {@link AppliedPTransform transform}. The + * {@link InMemoryWatermarkManager watermark manager} is responsible for computing the watermarks + * of all {@link AppliedPTransform transforms} that consume one or more + * {@link PCollection PCollections}. + * + *

    Whenever a non-root {@link AppliedPTransform} finishes processing one or more in-flight + * elements (referred to as the input {@link CommittedBundle bundle}), the following occurs + * atomically: + *

      + *
    • All of the in-flight elements are removed from the collection of pending elements for the + * {@link AppliedPTransform}.
    • + *
    • All of the elements produced by the {@link AppliedPTransform} are added to the collection + * of pending elements for each {@link AppliedPTransform} that consumes them.
    • + *
    • The input watermark for the {@link AppliedPTransform} becomes the maximum value of + *
        + *
      • the previous input watermark
      • + *
      • the minimum of + *
          + *
        • the timestamps of all currently pending elements
        • + *
        • all input {@link PCollection} watermarks
        • + *
        + *
      • + *
      + *
    • + *
    • The output watermark for the {@link AppliedPTransform} becomes the maximum of + *
        + *
      • the previous output watermark
      • + *
      • the minimum of + *
          + *
        • the current input watermark
        • + *
        • the current watermark holds
        • + *
        + *
      • + *
      + *
    • + *
    • The watermark of the output {@link PCollection} can be advanced to the output watermark of + * the {@link AppliedPTransform}
    • + *
    • The watermark of all downstream {@link AppliedPTransform AppliedPTransforms} can be + * advanced.
    • + *
    + * + *

    The watermark of a {@link PCollection} is equal to the output watermark of the + * {@link AppliedPTransform} that produces it. + * + *

    The watermarks for a {@link PTransform} are updated as follows when output is committed:

    + * Watermark_In'  = MAX(Watermark_In, MIN(U(TS_Pending), U(Watermark_InputPCollection)))
    + * Watermark_Out' = MAX(Watermark_Out, MIN(Watermark_In', U(StateHold)))
    + * Watermark_PCollection = Watermark_Out_ProducingPTransform
    + * 
    + */ +public class InMemoryWatermarkManager { + /** + * The watermark of some {@link Pipeline} element, usually a {@link PTransform} or a + * {@link PCollection}. + * + *

    A watermark is a monotonically increasing value, which represents the point up to which the + * system believes it has received all of the data. Data that arrives with a timestamp that is + * before the watermark is considered late. {@link BoundedWindow#TIMESTAMP_MAX_VALUE} is a special + * timestamp which indicates we have received all of the data and there will be no more on-time or + * late data. This value is represented by {@link InMemoryWatermarkManager#THE_END_OF_TIME}. + */ + private static interface Watermark { + /** + * Returns the current value of this watermark. + */ + Instant get(); + + /** + * Refreshes the value of this watermark from its input watermarks and watermark holds. + * + * @return true if the value of the watermark has changed (and thus dependent watermark must + * also be updated + */ + WatermarkUpdate refresh(); + } + + /** + * The result of computing a {@link Watermark}. + */ + private static enum WatermarkUpdate { + /** The watermark is later than the value at the previous time it was computed. */ + ADVANCED(true), + /** The watermark is equal to the value at the previous time it was computed. */ + NO_CHANGE(false); + + private final boolean advanced; + + private WatermarkUpdate(boolean advanced) { + this.advanced = advanced; + } + + public boolean isAdvanced() { + return advanced; + } + + /** + * Returns the {@link WatermarkUpdate} that is a result of combining the two watermark updates. + * + * If either of the input {@link WatermarkUpdate WatermarkUpdates} were advanced, the result + * {@link WatermarkUpdate} has been advanced. + */ + public WatermarkUpdate union(WatermarkUpdate that) { + if (this.advanced) { + return this; + } + return that; + } + + /** + * Returns the {@link WatermarkUpdate} based on the former and current + * {@link Instant timestamps}. + */ + public static WatermarkUpdate fromTimestamps(Instant oldTime, Instant currentTime) { + if (currentTime.isAfter(oldTime)) { + return ADVANCED; + } + return NO_CHANGE; + } + } + + /** + * The input {@link Watermark} of an {@link AppliedPTransform}. + * + *

    At any point, the value of an {@link AppliedPTransformInputWatermark} is equal to the + * minimum watermark across all of its input {@link Watermark Watermarks}, and the minimum + * timestamp of all of the pending elements, restricted to be monotonically increasing. + * + *

    See {@link #refresh()} for more information. + */ + private static class AppliedPTransformInputWatermark implements Watermark { + private final Collection inputWatermarks; + private final SortedMultiset> pendingElements; + private final Map> objectTimers; + + private AtomicReference currentWatermark; + + public AppliedPTransformInputWatermark(Collection inputWatermarks) { + this.inputWatermarks = inputWatermarks; + this.pendingElements = TreeMultiset.create(PENDING_ELEMENT_COMPARATOR); + this.objectTimers = new HashMap<>(); + currentWatermark = new AtomicReference<>(BoundedWindow.TIMESTAMP_MIN_VALUE); + } + + @Override + public Instant get() { + return currentWatermark.get(); + } + + /** + * {@inheritDoc}. + * + *

    When refresh is called, the value of the {@link AppliedPTransformInputWatermark} becomes + * equal to the maximum value of + *

      + *
    • the previous input watermark
    • + *
    • the minimum of + *
        + *
      • the timestamps of all currently pending elements
      • + *
      • all input {@link PCollection} watermarks
      • + *
      + *
    • + *
    + */ + @Override + public synchronized WatermarkUpdate refresh() { + Instant oldWatermark = currentWatermark.get(); + Instant minInputWatermark = BoundedWindow.TIMESTAMP_MAX_VALUE; + for (Watermark inputWatermark : inputWatermarks) { + minInputWatermark = INSTANT_ORDERING.min(minInputWatermark, inputWatermark.get()); + } + if (!pendingElements.isEmpty()) { + minInputWatermark = INSTANT_ORDERING.min( + minInputWatermark, pendingElements.firstEntry().getElement().getTimestamp()); + } + Instant newWatermark = INSTANT_ORDERING.max(oldWatermark, minInputWatermark); + currentWatermark.set(newWatermark); + return WatermarkUpdate.fromTimestamps(oldWatermark, newWatermark); + } + + private synchronized void addPendingElements(Iterable> newPending) { + for (WindowedValue pendingElement : newPending) { + pendingElements.add(pendingElement); + } + } + + private synchronized void removePendingElements( + Iterable> finishedElements) { + for (WindowedValue finishedElement : finishedElements) { + pendingElements.remove(finishedElement); + } + } + + private synchronized void updateTimers(TimerUpdate update) { + NavigableSet keyTimers = objectTimers.get(update.key); + if (keyTimers == null) { + keyTimers = new TreeSet<>(); + objectTimers.put(update.key, keyTimers); + } + for (TimerData timer : update.setTimers) { + if (TimeDomain.EVENT_TIME.equals(timer.getDomain())) { + keyTimers.add(timer); + } + } + for (TimerData timer : update.deletedTimers) { + if (TimeDomain.EVENT_TIME.equals(timer.getDomain())) { + keyTimers.remove(timer); + } + } + // We don't keep references to timers that have been fired and delivered via #getFiredTimers() + } + + private synchronized Map> extractFiredEventTimeTimers() { + return extractFiredTimers(currentWatermark.get(), objectTimers); + } + + @Override + public synchronized String toString() { + return MoreObjects.toStringHelper(AppliedPTransformInputWatermark.class) + .add("pendingElements", pendingElements) + .add("currentWatermark", currentWatermark) + .toString(); + } + } + + /** + * The output {@link Watermark} of an {@link AppliedPTransform}. + * + *

    The value of an {@link AppliedPTransformOutputWatermark} is equal to the minimum of the + * current watermark hold and the {@link AppliedPTransformInputWatermark} for the same + * {@link AppliedPTransform}, restricted to be monotonically increasing. See + * {@link #refresh()} for more information. + */ + private static class AppliedPTransformOutputWatermark implements Watermark { + private final Watermark inputWatermark; + private final PerKeyHolds holds; + private AtomicReference currentWatermark; + + public AppliedPTransformOutputWatermark(AppliedPTransformInputWatermark inputWatermark) { + this.inputWatermark = inputWatermark; + holds = new PerKeyHolds(); + currentWatermark = new AtomicReference<>(BoundedWindow.TIMESTAMP_MIN_VALUE); + } + + public synchronized void updateHold(Object key, Instant newHold) { + if (newHold == null) { + holds.removeHold(key); + } else { + holds.updateHold(key, newHold); + } + } + + @Override + public Instant get() { + return currentWatermark.get(); + } + + /** + * {@inheritDoc}. + * + *

    When refresh is called, the value of the {@link AppliedPTransformOutputWatermark} becomes + * equal to the maximum value of: + *

      + *
    • the previous output watermark
    • + *
    • the minimum of + *
        + *
      • the current input watermark
      • + *
      • the current watermark holds
      • + *
      + *
    • + *
    + */ + @Override + public synchronized WatermarkUpdate refresh() { + Instant oldWatermark = currentWatermark.get(); + Instant newWatermark = INSTANT_ORDERING.min(inputWatermark.get(), holds.getMinHold()); + newWatermark = INSTANT_ORDERING.max(oldWatermark, newWatermark); + currentWatermark.set(newWatermark); + return WatermarkUpdate.fromTimestamps(oldWatermark, newWatermark); + } + + @Override + public synchronized String toString() { + return MoreObjects.toStringHelper(AppliedPTransformOutputWatermark.class) + .add("holds", holds) + .add("currentWatermark", currentWatermark) + .toString(); + } + } + + /** + * The input {@link TimeDomain#SYNCHRONIZED_PROCESSING_TIME} hold for an + * {@link AppliedPTransform}. + * + *

    At any point, the hold value of an {@link SynchronizedProcessingTimeInputWatermark} is equal + * to the minimum across all pending bundles at the {@link AppliedPTransform} and all upstream + * {@link TimeDomain#SYNCHRONIZED_PROCESSING_TIME} watermarks. The value of the input + * synchronized processing time at any step is equal to the maximum of: + *

      + *
    • The most recently returned synchronized processing input time + *
    • The minimum of + *
        + *
      • The current processing time + *
      • The current synchronized processing time input hold + *
      + *
    + */ + private static class SynchronizedProcessingTimeInputWatermark implements Watermark { + private final Collection inputWms; + private final Collection> pendingBundles; + private final Map> processingTimers; + private final Map> synchronizedProcessingTimers; + + private final PriorityQueue pendingTimers; + + private AtomicReference earliestHold; + + public SynchronizedProcessingTimeInputWatermark(Collection inputWms) { + this.inputWms = inputWms; + this.pendingBundles = new HashSet<>(); + this.processingTimers = new HashMap<>(); + this.synchronizedProcessingTimers = new HashMap<>(); + this.pendingTimers = new PriorityQueue<>(); + Instant initialHold = BoundedWindow.TIMESTAMP_MAX_VALUE; + for (Watermark wm : inputWms) { + initialHold = INSTANT_ORDERING.min(initialHold, wm.get()); + } + earliestHold = new AtomicReference<>(initialHold); + } + + @Override + public Instant get() { + return earliestHold.get(); + } + + /** + * {@inheritDoc}. + * + *

    When refresh is called, the value of the {@link SynchronizedProcessingTimeInputWatermark} + * becomes equal to the minimum value of + *

      + *
    • the timestamps of all currently pending bundles
    • + *
    • all input {@link PCollection} synchronized processing time watermarks
    • + *
    + * + *

    Note that this value is not monotonic, but the returned value for the synchronized + * processing time must be. + */ + @Override + public synchronized WatermarkUpdate refresh() { + Instant oldHold = earliestHold.get(); + Instant minTime = THE_END_OF_TIME.get(); + for (Watermark input : inputWms) { + minTime = INSTANT_ORDERING.min(minTime, input.get()); + } + for (CommittedBundle bundle : pendingBundles) { + // TODO: Track elements in the bundle by the processing time they were output instead of + // entire bundles. Requried to support arbitrarily splitting and merging bundles between + // steps + minTime = INSTANT_ORDERING.min(minTime, bundle.getSynchronizedProcessingOutputWatermark()); + } + earliestHold.set(minTime); + return WatermarkUpdate.fromTimestamps(oldHold, minTime); + } + + public synchronized void addPending(CommittedBundle bundle) { + pendingBundles.add(bundle); + } + + public synchronized void removePending(CommittedBundle bundle) { + pendingBundles.remove(bundle); + } + + /** + * Return the earliest timestamp of the earliest timer that has not been completed. This is + * either the earliest timestamp across timers that have not been completed, or the earliest + * timestamp across timers that have been delivered but have not been completed. + */ + public synchronized Instant getEarliestTimerTimestamp() { + Instant earliest = THE_END_OF_TIME.get(); + for (NavigableSet timers : processingTimers.values()) { + if (!timers.isEmpty()) { + earliest = INSTANT_ORDERING.min(timers.first().getTimestamp(), earliest); + } + } + for (NavigableSet timers : synchronizedProcessingTimers.values()) { + if (!timers.isEmpty()) { + earliest = INSTANT_ORDERING.min(timers.first().getTimestamp(), earliest); + } + } + if (!pendingTimers.isEmpty()) { + earliest = INSTANT_ORDERING.min(pendingTimers.peek().getTimestamp(), earliest); + } + return earliest; + } + + private synchronized void updateTimers(TimerUpdate update) { + for (TimerData completedTimer : update.completedTimers) { + pendingTimers.remove(completedTimer); + } + Map> timerMap = timerMap(update.key); + for (TimerData addedTimer : update.setTimers) { + NavigableSet timerQueue = timerMap.get(addedTimer.getDomain()); + if (timerQueue != null) { + timerQueue.add(addedTimer); + } + } + for (TimerData deletedTimer : update.deletedTimers) { + NavigableSet timerQueue = timerMap.get(deletedTimer.getDomain()); + if (timerQueue != null) { + timerQueue.remove(deletedTimer); + } + } + } + + private synchronized Map> extractFiredDomainTimers( + TimeDomain domain, Instant firingTime) { + Map> firedTimers; + switch (domain) { + case PROCESSING_TIME: + firedTimers = extractFiredTimers(firingTime, processingTimers); + break; + case SYNCHRONIZED_PROCESSING_TIME: + firedTimers = + extractFiredTimers( + INSTANT_ORDERING.min(firingTime, earliestHold.get()), + synchronizedProcessingTimers); + break; + default: + throw new IllegalArgumentException( + "Called getFiredTimers on a Synchronized Processing Time watermark" + + " and gave a non-processing time domain " + + domain); + } + for (Map.Entry> firedTimer : firedTimers.entrySet()) { + pendingTimers.addAll(firedTimer.getValue()); + } + return firedTimers; + } + + private Map> timerMap(Object key) { + NavigableSet processingQueue = processingTimers.get(key); + if (processingQueue == null) { + processingQueue = new TreeSet<>(); + processingTimers.put(key, processingQueue); + } + NavigableSet synchronizedProcessingQueue = + synchronizedProcessingTimers.get(key); + if (synchronizedProcessingQueue == null) { + synchronizedProcessingQueue = new TreeSet<>(); + synchronizedProcessingTimers.put(key, synchronizedProcessingQueue); + } + EnumMap> result = new EnumMap<>(TimeDomain.class); + result.put(TimeDomain.PROCESSING_TIME, processingQueue); + result.put(TimeDomain.SYNCHRONIZED_PROCESSING_TIME, synchronizedProcessingQueue); + return result; + } + + @Override + public synchronized String toString() { + return MoreObjects.toStringHelper(SynchronizedProcessingTimeInputWatermark.class) + .add("earliestHold", earliestHold) + .toString(); + } + } + + /** + * The output {@link TimeDomain#SYNCHRONIZED_PROCESSING_TIME} hold for an + * {@link AppliedPTransform}. + * + *

    At any point, the hold value of an {@link SynchronizedProcessingTimeOutputWatermark} is + * equal to the minimum across all incomplete timers at the {@link AppliedPTransform} and all + * upstream {@link TimeDomain#SYNCHRONIZED_PROCESSING_TIME} watermarks. The value of the output + * synchronized processing time at any step is equal to the maximum of: + *

      + *
    • The most recently returned synchronized processing output time + *
    • The minimum of + *
        + *
      • The current processing time + *
      • The current synchronized processing time output hold + *
      + *
    + */ + private static class SynchronizedProcessingTimeOutputWatermark implements Watermark { + private final SynchronizedProcessingTimeInputWatermark inputWm; + private AtomicReference latestRefresh; + + public SynchronizedProcessingTimeOutputWatermark( + SynchronizedProcessingTimeInputWatermark inputWm) { + this.inputWm = inputWm; + this.latestRefresh = new AtomicReference<>(BoundedWindow.TIMESTAMP_MIN_VALUE); + } + + @Override + public Instant get() { + return latestRefresh.get(); + } + + /** + * {@inheritDoc}. + * + *

    When refresh is called, the value of the {@link SynchronizedProcessingTimeOutputWatermark} + * becomes equal to the minimum value of: + *

      + *
    • the current input watermark. + *
    • all {@link TimeDomain#SYNCHRONIZED_PROCESSING_TIME} timers that are based on the input + * watermark. + *
    • all {@link TimeDomain#PROCESSING_TIME} timers that are based on the input watermark. + *
    + * + *

    Note that this value is not monotonic, but the returned value for the synchronized + * processing time must be. + */ + @Override + public synchronized WatermarkUpdate refresh() { + // Hold the output synchronized processing time to the input watermark, which takes into + // account buffered bundles, and the earliest pending timer, which determines what to hold + // downstream timers to. + Instant oldRefresh = latestRefresh.get(); + Instant newTimestamp = + INSTANT_ORDERING.min(inputWm.get(), inputWm.getEarliestTimerTimestamp()); + latestRefresh.set(newTimestamp); + return WatermarkUpdate.fromTimestamps(oldRefresh, newTimestamp); + } + + @Override + public synchronized String toString() { + return MoreObjects.toStringHelper(SynchronizedProcessingTimeOutputWatermark.class) + .add("latestRefresh", latestRefresh) + .toString(); + } + } + + /** + * The {@code Watermark} that is after the latest time it is possible to represent in the global + * window. This is a distinguished value representing a complete {@link PTransform}. + */ + private static final Watermark THE_END_OF_TIME = new Watermark() { + @Override + public WatermarkUpdate refresh() { + // THE_END_OF_TIME is a distinguished value that cannot be advanced. + return WatermarkUpdate.NO_CHANGE; + } + + @Override + public Instant get() { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + }; + + private static final Ordering INSTANT_ORDERING = Ordering.natural(); + + /** + * An ordering that compares windowed values by timestamp, then arbitrarily. This ensures that + * {@link WindowedValue WindowedValues} will be sorted by timestamp, while two different + * {@link WindowedValue WindowedValues} with the same timestamp are not considered equal. + */ + private static final Ordering> PENDING_ELEMENT_COMPARATOR = + (new WindowedValueByTimestampComparator()).compound(Ordering.arbitrary()); + + /** + * For each (Object, PriorityQueue) pair in the provided map, remove each Timer that is before the + * latestTime argument and put in in the result with the same key, then remove all of the keys + * which have no more pending timers. + * + * The result collection retains ordering of timers (from earliest to latest). + */ + private static Map> extractFiredTimers( + Instant latestTime, Map> objectTimers) { + Map> result = new HashMap<>(); + Set emptyKeys = new HashSet<>(); + for (Map.Entry> pendingTimers : objectTimers.entrySet()) { + NavigableSet timers = pendingTimers.getValue(); + if (!timers.isEmpty() && timers.first().getTimestamp().isBefore(latestTime)) { + ArrayList keyFiredTimers = new ArrayList<>(); + result.put(pendingTimers.getKey(), keyFiredTimers); + while (!timers.isEmpty() && timers.first().getTimestamp().isBefore(latestTime)) { + keyFiredTimers.add(timers.first()); + timers.remove(timers.first()); + } + } + if (timers.isEmpty()) { + emptyKeys.add(pendingTimers.getKey()); + } + } + objectTimers.keySet().removeAll(emptyKeys); + return result; + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * The {@link Clock} providing the current time in the {@link TimeDomain#PROCESSING_TIME} domain. + */ + private final Clock clock; + + /** + * A map from each {@link PCollection} to all {@link AppliedPTransform PTransform applications} + * that consume that {@link PCollection}. + */ + private final Map>> consumers; + + /** + * The input and output watermark of each {@link AppliedPTransform}. + */ + private final Map, TransformWatermarks> transformToWatermarks; + + /** + * Creates a new {@link InMemoryWatermarkManager}. All watermarks within the newly created + * {@link InMemoryWatermarkManager} start at {@link BoundedWindow#TIMESTAMP_MIN_VALUE}, the + * minimum watermark, with no watermark holds or pending elements. + * + * @param rootTransforms the root-level transforms of the {@link Pipeline} + * @param consumers a mapping between each {@link PCollection} in the {@link Pipeline} to the + * transforms that consume it as a part of their input + */ + public static InMemoryWatermarkManager create( + Clock clock, + Collection> rootTransforms, + Map>> consumers) { + return new InMemoryWatermarkManager(clock, rootTransforms, consumers); + } + + private InMemoryWatermarkManager( + Clock clock, + Collection> rootTransforms, + Map>> consumers) { + this.clock = clock; + this.consumers = consumers; + + transformToWatermarks = new HashMap<>(); + + for (AppliedPTransform rootTransform : rootTransforms) { + getTransformWatermark(rootTransform); + } + for (Collection> intermediateTransforms : consumers.values()) { + for (AppliedPTransform transform : intermediateTransforms) { + getTransformWatermark(transform); + } + } + } + + private TransformWatermarks getTransformWatermark(AppliedPTransform transform) { + TransformWatermarks wms = transformToWatermarks.get(transform); + if (wms == null) { + List inputCollectionWatermarks = getInputWatermarks(transform); + AppliedPTransformInputWatermark inputWatermark = + new AppliedPTransformInputWatermark(inputCollectionWatermarks); + AppliedPTransformOutputWatermark outputWatermark = + new AppliedPTransformOutputWatermark(inputWatermark); + + SynchronizedProcessingTimeInputWatermark inputProcessingWatermark = + new SynchronizedProcessingTimeInputWatermark(getInputProcessingWatermarks(transform)); + SynchronizedProcessingTimeOutputWatermark outputProcessingWatermark = + new SynchronizedProcessingTimeOutputWatermark(inputProcessingWatermark); + + wms = + new TransformWatermarks( + inputWatermark, outputWatermark, inputProcessingWatermark, outputProcessingWatermark); + transformToWatermarks.put(transform, wms); + } + return wms; + } + + private Collection getInputProcessingWatermarks( + AppliedPTransform transform) { + ImmutableList.Builder inputWmsBuilder = ImmutableList.builder(); + Collection inputs = transform.getInput().expand(); + if (inputs.isEmpty()) { + inputWmsBuilder.add(THE_END_OF_TIME); + } + for (PValue pvalue : inputs) { + Watermark producerOutputWatermark = + getTransformWatermark(pvalue.getProducingTransformInternal()) + .synchronizedProcessingOutputWatermark; + inputWmsBuilder.add(producerOutputWatermark); + } + return inputWmsBuilder.build(); + } + + private List getInputWatermarks(AppliedPTransform transform) { + ImmutableList.Builder inputWatermarksBuilder = ImmutableList.builder(); + Collection inputs = transform.getInput().expand(); + if (inputs.isEmpty()) { + inputWatermarksBuilder.add(THE_END_OF_TIME); + } + for (PValue pvalue : inputs) { + Watermark producerOutputWatermark = + getTransformWatermark(pvalue.getProducingTransformInternal()).outputWatermark; + inputWatermarksBuilder.add(producerOutputWatermark); + } + List inputCollectionWatermarks = inputWatermarksBuilder.build(); + return inputCollectionWatermarks; + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Gets the input and output watermarks for an {@link AppliedPTransform}. If the + * {@link AppliedPTransform PTransform} has not processed any elements, return a watermark of + * {@link BoundedWindow#TIMESTAMP_MIN_VALUE}. + * + * @return a snapshot of the input watermark and output watermark for the provided transform + */ + public TransformWatermarks getWatermarks(AppliedPTransform transform) { + return transformToWatermarks.get(transform); + } + + /** + * Updates the watermarks of a transform with one or more inputs. + * + *

    Each transform has two monotonically increasing watermarks: the input watermark, which can, + * at any time, be updated to equal: + *

    +   * MAX(CurrentInputWatermark, MIN(PendingElements, InputPCollectionWatermarks))
    +   * 
    + * and the output watermark, which can, at any time, be updated to equal: + *
    +   * MAX(CurrentOutputWatermark, MIN(InputWatermark, WatermarkHolds))
    +   * 
    . + * + * @param completed the input that has completed + * @param transform the transform that has completed processing the input + * @param outputs the bundles the transform has output + * @param earliestHold the earliest watermark hold in the transform's state. {@code null} if there + * is no hold + */ + public void updateWatermarks( + @Nullable CommittedBundle completed, + AppliedPTransform transform, + TimerUpdate timerUpdate, + Iterable> outputs, + @Nullable Instant earliestHold) { + updatePending(completed, transform, timerUpdate, outputs); + TransformWatermarks transformWms = transformToWatermarks.get(transform); + transformWms.setEventTimeHold(completed == null ? null : completed.getKey(), earliestHold); + refreshWatermarks(transform); + } + + private void refreshWatermarks(AppliedPTransform transform) { + TransformWatermarks myWatermarks = transformToWatermarks.get(transform); + WatermarkUpdate updateResult = myWatermarks.refresh(); + if (updateResult.isAdvanced()) { + for (PValue outputPValue : transform.getOutput().expand()) { + Collection> downstreamTransforms = consumers.get(outputPValue); + if (downstreamTransforms != null) { + for (AppliedPTransform downstreamTransform : downstreamTransforms) { + refreshWatermarks(downstreamTransform); + } + } + } + } + } + + /** + * Removes all of the completed Timers from the collection of pending timers, adds all new timers, + * and removes all deleted timers. Removes all elements consumed by the input bundle from the + * {@link PTransform PTransforms} collection of pending elements, and adds all elements produced + * by the {@link PTransform} to the pending queue of each consumer. + */ + private void updatePending( + CommittedBundle input, + AppliedPTransform transform, + TimerUpdate timerUpdate, + Iterable> outputs) { + TransformWatermarks completedTransform = transformToWatermarks.get(transform); + completedTransform.updateTimers(timerUpdate); + if (input != null) { + completedTransform.removePending(input); + } + + for (CommittedBundle bundle : outputs) { + for (AppliedPTransform consumer : consumers.get(bundle.getPCollection())) { + TransformWatermarks watermarks = transformToWatermarks.get(consumer); + watermarks.addPending(bundle); + } + } + } + + /** + * Returns a map of each {@link PTransform} that has pending timers to those timers. All of the + * pending timers will be removed from this {@link InMemoryWatermarkManager}. + */ + public Map, Map> extractFiredTimers() { + Map, Map> allTimers = new HashMap<>(); + for (Map.Entry, TransformWatermarks> watermarksEntry : + transformToWatermarks.entrySet()) { + Map keyFiredTimers = watermarksEntry.getValue().extractFiredTimers(); + if (!keyFiredTimers.isEmpty()) { + allTimers.put(watermarksEntry.getKey(), keyFiredTimers); + } + } + return allTimers; + } + + /** + * Returns true if, for any {@link TransformWatermarks} returned by + * {@link #getWatermarks(AppliedPTransform)}, the output watermark will be equal to + * {@link BoundedWindow#TIMESTAMP_MAX_VALUE}. + */ + public boolean isDone() { + for (Map.Entry, TransformWatermarks> watermarksEntry : + transformToWatermarks.entrySet()) { + Instant endOfTime = THE_END_OF_TIME.get(); + if (watermarksEntry.getValue().getOutputWatermark().isBefore(endOfTime)) { + return false; + } + } + return true; + } + + /** + * A (key, Instant) pair that holds the watermark. Holds are per-key, but the watermark is global, + * and as such the watermark manager must track holds and the release of holds on a per-key basis. + * + *

    The {@link #compareTo(KeyedHold)} method of {@link KeyedHold} is not consistent with equals, + * as the key is arbitrarily ordered via identity, rather than object equality. + */ + private static final class KeyedHold implements Comparable { + private static final Ordering KEY_ORDERING = Ordering.arbitrary().nullsLast(); + + private final Object key; + private final Instant timestamp; + + /** + * Create a new KeyedHold with the specified key and timestamp. + */ + public static KeyedHold of(Object key, Instant timestamp) { + return new KeyedHold(key, MoreObjects.firstNonNull(timestamp, THE_END_OF_TIME.get())); + } + + private KeyedHold(Object key, Instant timestamp) { + this.key = key; + this.timestamp = timestamp; + } + + @Override + public int compareTo(KeyedHold that) { + return ComparisonChain.start() + .compare(this.timestamp, that.timestamp) + .compare(this.key, that.key, KEY_ORDERING) + .result(); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, key); + } + + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof KeyedHold)) { + return false; + } + KeyedHold that = (KeyedHold) other; + return Objects.equals(this.timestamp, that.timestamp) && Objects.equals(this.key, that.key); + } + + /** + * Get the value of this {@link KeyedHold}. + */ + public Instant getTimestamp() { + return timestamp; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(KeyedHold.class) + .add("key", key) + .add("hold", timestamp) + .toString(); + } + } + + private static class PerKeyHolds { + private final Map keyedHolds; + private final PriorityQueue allHolds; + + private PerKeyHolds() { + this.keyedHolds = new HashMap<>(); + this.allHolds = new PriorityQueue<>(); + } + + /** + * Gets the minimum hold across all keys in this {@link PerKeyHolds}, or THE_END_OF_TIME if + * there are no holds within this {@link PerKeyHolds}. + */ + public Instant getMinHold() { + return allHolds.isEmpty() ? THE_END_OF_TIME.get() : allHolds.peek().getTimestamp(); + } + + /** + * Updates the hold of the provided key to the provided value, removing any other holds for + * the same key. + */ + public void updateHold(@Nullable Object key, Instant newHold) { + removeHold(key); + KeyedHold newKeyedHold = KeyedHold.of(key, newHold); + keyedHolds.put(key, newKeyedHold); + allHolds.offer(newKeyedHold); + } + + /** + * Removes the hold of the provided key. + */ + public void removeHold(Object key) { + KeyedHold oldHold = keyedHolds.get(key); + if (oldHold != null) { + allHolds.remove(oldHold); + } + } + } + + /** + * A reference to the input and output watermarks of an {@link AppliedPTransform}. + */ + public class TransformWatermarks { + private final AppliedPTransformInputWatermark inputWatermark; + private final AppliedPTransformOutputWatermark outputWatermark; + + private final SynchronizedProcessingTimeInputWatermark synchronizedProcessingInputWatermark; + private final SynchronizedProcessingTimeOutputWatermark synchronizedProcessingOutputWatermark; + + private Instant latestSynchronizedInputWm; + private Instant latestSynchronizedOutputWm; + + private TransformWatermarks( + AppliedPTransformInputWatermark inputWatermark, + AppliedPTransformOutputWatermark outputWatermark, + SynchronizedProcessingTimeInputWatermark inputSynchProcessingWatermark, + SynchronizedProcessingTimeOutputWatermark outputSynchProcessingWatermark) { + this.inputWatermark = inputWatermark; + this.outputWatermark = outputWatermark; + + this.synchronizedProcessingInputWatermark = inputSynchProcessingWatermark; + this.synchronizedProcessingOutputWatermark = outputSynchProcessingWatermark; + this.latestSynchronizedInputWm = BoundedWindow.TIMESTAMP_MIN_VALUE; + this.latestSynchronizedOutputWm = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + /** + * Returns the input watermark of the {@link AppliedPTransform}. + */ + public Instant getInputWatermark() { + return inputWatermark.get(); + } + + /** + * Returns the output watermark of the {@link AppliedPTransform}. + */ + public Instant getOutputWatermark() { + return outputWatermark.get(); + } + + /** + * Returns the synchronized processing input time of the {@link AppliedPTransform}. + * + *

    The returned value is guaranteed to be monotonically increasing, and outside of the + * presence of holds, will increase as the system time progresses. + */ + public synchronized Instant getSynchronizedProcessingInputTime() { + latestSynchronizedInputWm = INSTANT_ORDERING.max( + latestSynchronizedInputWm, + INSTANT_ORDERING.min(clock.now(), synchronizedProcessingInputWatermark.get())); + return latestSynchronizedInputWm; + } + + /** + * Returns the synchronized processing output time of the {@link AppliedPTransform}. + * + *

    The returned value is guaranteed to be monotonically increasing, and outside of the + * presence of holds, will increase as the system time progresses. + */ + public synchronized Instant getSynchronizedProcessingOutputTime() { + latestSynchronizedOutputWm = INSTANT_ORDERING.max( + latestSynchronizedOutputWm, + INSTANT_ORDERING.min(clock.now(), synchronizedProcessingOutputWatermark.get())); + return latestSynchronizedOutputWm; + } + + private WatermarkUpdate refresh() { + inputWatermark.refresh(); + synchronizedProcessingInputWatermark.refresh(); + WatermarkUpdate eventOutputUpdate = outputWatermark.refresh(); + WatermarkUpdate syncOutputUpdate = synchronizedProcessingOutputWatermark.refresh(); + return eventOutputUpdate.union(syncOutputUpdate); + } + + private void setEventTimeHold(Object key, Instant newHold) { + outputWatermark.updateHold(key, newHold); + } + + private void removePending(CommittedBundle bundle) { + inputWatermark.removePendingElements(bundle.getElements()); + synchronizedProcessingInputWatermark.removePending(bundle); + } + + private void addPending(CommittedBundle bundle) { + inputWatermark.addPendingElements(bundle.getElements()); + synchronizedProcessingInputWatermark.addPending(bundle); + } + + private Map extractFiredTimers() { + Map> eventTimeTimers = inputWatermark.extractFiredEventTimeTimers(); + Map> processingTimers; + Map> synchronizedTimers; + if (inputWatermark.get().equals(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + processingTimers = synchronizedProcessingInputWatermark.extractFiredDomainTimers( + TimeDomain.PROCESSING_TIME, BoundedWindow.TIMESTAMP_MAX_VALUE); + synchronizedTimers = synchronizedProcessingInputWatermark.extractFiredDomainTimers( + TimeDomain.PROCESSING_TIME, BoundedWindow.TIMESTAMP_MAX_VALUE); + } else { + processingTimers = synchronizedProcessingInputWatermark.extractFiredDomainTimers( + TimeDomain.PROCESSING_TIME, clock.now()); + synchronizedTimers = synchronizedProcessingInputWatermark.extractFiredDomainTimers( + TimeDomain.SYNCHRONIZED_PROCESSING_TIME, getSynchronizedProcessingInputTime()); + } + Map>> groupedTimers = new HashMap<>(); + groupFiredTimers(groupedTimers, eventTimeTimers, processingTimers, synchronizedTimers); + + Map keyFiredTimers = new HashMap<>(); + for (Map.Entry>> firedTimers : + groupedTimers.entrySet()) { + keyFiredTimers.put(firedTimers.getKey(), new FiredTimers(firedTimers.getValue())); + } + return keyFiredTimers; + } + + @SafeVarargs + private final void groupFiredTimers( + Map>> groupedToMutate, + Map>... timersToGroup) { + for (Map> subGroup : timersToGroup) { + for (Map.Entry> newTimers : subGroup.entrySet()) { + Map> grouped = groupedToMutate.get(newTimers.getKey()); + if (grouped == null) { + grouped = new HashMap<>(); + groupedToMutate.put(newTimers.getKey(), grouped); + } + grouped.put(newTimers.getValue().get(0).getDomain(), newTimers.getValue()); + } + } + } + + private void updateTimers(TimerUpdate update) { + inputWatermark.updateTimers(update); + synchronizedProcessingInputWatermark.updateTimers(update); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(TransformWatermarks.class) + .add("inputWatermark", inputWatermark) + .add("outputWatermark", outputWatermark) + .add("inputProcessingTime", synchronizedProcessingInputWatermark) + .add("outputProcessingTime", synchronizedProcessingOutputWatermark) + .toString(); + } + } + + /** + * A collection of newly set, deleted, and completed timers. + * + *

    setTimers and deletedTimers are collections of {@link TimerData} that have been added to the + * {@link TimerInternals} of an executed step. completedTimers are timers that were delivered as + * the input to the executed step. + */ + public static class TimerUpdate { + private final Object key; + private final Iterable completedTimers; + + private final Iterable setTimers; + private final Iterable deletedTimers; + + /** + * Returns a TimerUpdate for a null key with no timers. + */ + public static TimerUpdate empty() { + return new TimerUpdate( + null, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList()); + } + + /** + * Creates a new {@link TimerUpdate} builder with the provided completed timers that needs the + * set and deleted timers to be added to it. + */ + public static TimerUpdateBuilder builder(Object key) { + return new TimerUpdateBuilder(key); + } + + /** + * A {@link TimerUpdate} builder that needs to be provided with set timers and deleted timers. + */ + public static final class TimerUpdateBuilder { + private final Object key; + private final Collection completedTimers; + private final Collection setTimers; + private final Collection deletedTimers; + + private TimerUpdateBuilder(Object key) { + this.key = key; + this.completedTimers = new HashSet<>(); + this.setTimers = new HashSet<>(); + this.deletedTimers = new HashSet<>(); + } + + /** + * Adds all of the provided timers to the collection of completed timers, and returns this + * {@link TimerUpdateBuilder}. + */ + public TimerUpdateBuilder withCompletedTimers(Iterable completedTimers) { + Iterables.addAll(this.completedTimers, completedTimers); + return this; + } + + /** + * Adds the provided timer to the collection of set timers, removing it from deleted timers if + * it has previously been deleted. Returns this {@link TimerUpdateBuilder}. + */ + public TimerUpdateBuilder setTimer(TimerData setTimer) { + deletedTimers.remove(setTimer); + setTimers.add(setTimer); + return this; + } + + /** + * Adds the provided timer to the collection of deleted timers, removing it from set timers if + * it has previously been set. Returns this {@link TimerUpdateBuilder}. + */ + public TimerUpdateBuilder deletedTimer(TimerData deletedTimer) { + deletedTimers.add(deletedTimer); + setTimers.remove(deletedTimer); + return this; + } + + /** + * Returns a new {@link TimerUpdate} with the most recently set completedTimers, setTimers, + * and deletedTimers. + */ + public TimerUpdate build() { + return new TimerUpdate(key, ImmutableSet.copyOf(completedTimers), + ImmutableSet.copyOf(setTimers), ImmutableSet.copyOf(deletedTimers)); + } + } + + private TimerUpdate( + Object key, + Iterable completedTimers, + Iterable setTimers, + Iterable deletedTimers) { + this.key = key; + this.completedTimers = completedTimers; + this.setTimers = setTimers; + this.deletedTimers = deletedTimers; + } + + @VisibleForTesting + Object getKey() { + return key; + } + + @VisibleForTesting + Iterable getCompletedTimers() { + return completedTimers; + } + + @VisibleForTesting + Iterable getSetTimers() { + return setTimers; + } + + @VisibleForTesting + Iterable getDeletedTimers() { + return deletedTimers; + } + + @Override + public int hashCode() { + return Objects.hash(key, completedTimers, setTimers, deletedTimers); + } + + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof TimerUpdate)) { + return false; + } + TimerUpdate that = (TimerUpdate) other; + return Objects.equals(this.key, that.key) + && Objects.equals(this.completedTimers, that.completedTimers) + && Objects.equals(this.setTimers, that.setTimers) + && Objects.equals(this.deletedTimers, that.deletedTimers); + } + } + + /** + * A pair of {@link TimerData} and key which can be delivered to the appropriate + * {@link AppliedPTransform}. A timer fires at the transform that set it with a specific key when + * the time domain in which it lives progresses past a specified time, as determined by the + * {@link InMemoryWatermarkManager}. + */ + public static class FiredTimers { + private final Map> timers; + + private FiredTimers(Map> timers) { + this.timers = timers; + } + + /** + * Gets all of the timers that have fired within the provided {@link TimeDomain}. If no timers + * fired within the provided domain, return an empty collection. + * + *

    Timers within a {@link TimeDomain} are guaranteed to be in order of increasing timestamp. + */ + public Collection getTimers(TimeDomain domain) { + Collection domainTimers = timers.get(domain); + if (domainTimers == null) { + return Collections.emptyList(); + } + return domainTimers; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(FiredTimers.class).add("timers", timers).toString(); + } + } + + private static class WindowedValueByTimestampComparator extends Ordering> { + @Override + public int compare(WindowedValue o1, WindowedValue o2) { + return o1.getTimestamp().compareTo(o2.getTimestamp()); + } + } + + public Set> getCompletedTransforms() { + Set> result = new HashSet<>(); + for (Map.Entry, TransformWatermarks> wms : + transformToWatermarks.entrySet()) { + if (wms.getValue().getOutputWatermark().equals(THE_END_OF_TIME.get())) { + result.add(wms.getKey()); + } + } + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java new file mode 100644 index 000000000000..cc20161097e8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; + +import javax.annotation.Nullable; + +/** + * A {@link UncommittedBundle} that buffers elements in memory. + */ +public final class InProcessBundle implements UncommittedBundle { + private final PCollection pcollection; + private final boolean keyed; + private final Object key; + private boolean committed = false; + private ImmutableList.Builder> elements; + + /** + * Create a new {@link InProcessBundle} for the specified {@link PCollection} without a key. + */ + public static InProcessBundle unkeyed(PCollection pcollection) { + return new InProcessBundle(pcollection, false, null); + } + + /** + * Create a new {@link InProcessBundle} for the specified {@link PCollection} with the specified + * key. + * + * See {@link CommittedBundle#getKey()} and {@link CommittedBundle#isKeyed()} for more + * information. + */ + public static InProcessBundle keyed(PCollection pcollection, Object key) { + return new InProcessBundle(pcollection, true, key); + } + + private InProcessBundle(PCollection pcollection, boolean keyed, Object key) { + this.pcollection = pcollection; + this.keyed = keyed; + this.key = key; + this.elements = ImmutableList.builder(); + } + + @Override + public InProcessBundle add(WindowedValue element) { + checkState(!committed, "Can't add element %s to committed bundle %s", element, this); + elements.add(element); + return this; + } + + @Override + public CommittedBundle commit(final Instant synchronizedCompletionTime) { + checkState(!committed, "Can't commit already committed bundle %s", this); + committed = true; + final Iterable> committedElements = elements.build(); + return new CommittedBundle() { + @Override + @Nullable + public Object getKey() { + return key; + } + + @Override + public boolean isKeyed() { + return keyed; + } + + @Override + public Iterable> getElements() { + return committedElements; + } + + @Override + public PCollection getPCollection() { + return pcollection; + } + + @Override + public Instant getSynchronizedProcessingOutputWatermark() { + return synchronizedCompletionTime; + } + + @Override + public String toString() { + ToStringHelper toStringHelper = + MoreObjects.toStringHelper(this).add("pcollection", pcollection); + if (keyed) { + toStringHelper = toStringHelper.add("key", key); + } + return toStringHelper.add("elements", elements).toString(); + } + }; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundleOutputManager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundleOutputManager.java new file mode 100644 index 000000000000..406e2d46386d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundleOutputManager.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.util.DoFnRunners.OutputManager; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.Map; + +/** + * An {@link OutputManager} that outputs to {@link CommittedBundle Bundles} used by the + * {@link InProcessPipelineRunner}. + */ +public class InProcessBundleOutputManager implements OutputManager { + private final Map, UncommittedBundle> bundles; + + public static InProcessBundleOutputManager create( + Map, UncommittedBundle> outputBundles) { + return new InProcessBundleOutputManager(outputBundles); + } + + public InProcessBundleOutputManager(Map, UncommittedBundle> bundles) { + this.bundles = bundles; + } + + @SuppressWarnings("unchecked") + @Override + public void output(TupleTag tag, WindowedValue output) { + @SuppressWarnings("rawtypes") + UncommittedBundle bundle = bundles.get(tag); + bundle.add(output); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessCreate.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessCreate.java new file mode 100644 index 000000000000..9023b7b2dc4b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessCreate.java @@ -0,0 +1,209 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Create.Values; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * An in-process implementation of the {@link Values Create.Values} {@link PTransform}, implemented + * using a {@link BoundedSource}. + * + * The coder is inferred via the {@link Values#getDefaultOutputCoder(PInput)} method on the original + * transform. + */ +class InProcessCreate extends ForwardingPTransform> { + private final Create.Values original; + + public static InProcessCreate from(Create.Values original) { + return new InProcessCreate<>(original); + } + + private InProcessCreate(Values original) { + this.original = original; + } + + @Override + public PCollection apply(PInput input) { + Coder elementCoder; + try { + elementCoder = original.getDefaultOutputCoder(input); + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException( + "Unable to infer a coder and no Coder was specified. " + + "Please set a coder by invoking Create.withCoder() explicitly.", + e); + } + InMemorySource source; + try { + source = new InMemorySource<>(original.getElements(), elementCoder); + } catch (IOException e) { + throw Throwables.propagate(e); + } + PCollection result = input.getPipeline().apply(Read.from(source)); + result.setCoder(elementCoder); + return result; + } + + @Override + public PTransform> delegate() { + return original; + } + + @VisibleForTesting + static class InMemorySource extends BoundedSource { + private final Collection allElementsBytes; + private final long totalSize; + private final Coder coder; + + public InMemorySource(Iterable elements, Coder elemCoder) + throws CoderException, IOException { + allElementsBytes = new ArrayList<>(); + long totalSize = 0L; + for (T element : elements) { + byte[] bytes = CoderUtils.encodeToByteArray(elemCoder, element); + allElementsBytes.add(bytes); + totalSize += bytes.length; + } + this.totalSize = totalSize; + this.coder = elemCoder; + } + + /** + * Create a new source with the specified bytes. The new source owns the input element bytes, + * which must not be modified after this constructor is called. + */ + private InMemorySource(Collection elementBytes, long totalSize, Coder coder) { + this.allElementsBytes = ImmutableList.copyOf(elementBytes); + this.totalSize = totalSize; + this.coder = coder; + } + + @Override + public List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + ImmutableList.Builder> resultBuilder = ImmutableList.builder(); + long currentSourceSize = 0L; + List currentElems = new ArrayList<>(); + for (byte[] elemBytes : allElementsBytes) { + currentElems.add(elemBytes); + currentSourceSize += elemBytes.length; + if (currentSourceSize >= desiredBundleSizeBytes) { + resultBuilder.add(new InMemorySource<>(currentElems, currentSourceSize, coder)); + currentElems.clear(); + currentSourceSize = 0L; + } + } + if (!currentElems.isEmpty()) { + resultBuilder.add(new InMemorySource<>(currentElems, currentSourceSize, coder)); + } + return resultBuilder.build(); + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return totalSize; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedSource.BoundedReader createReader(PipelineOptions options) throws IOException { + return new BytesReader(); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + + private class BytesReader extends BoundedReader { + private final PeekingIterator iter; + /** + * Use an optional to distinguish between null next element (as Optional.absent()) and no next + * element (next is null). + */ + @Nullable private Optional next; + + public BytesReader() { + this.iter = Iterators.peekingIterator(allElementsBytes.iterator()); + } + + @Override + public BoundedSource getCurrentSource() { + return InMemorySource.this; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + boolean hasNext = iter.hasNext(); + if (hasNext) { + next = Optional.fromNullable(CoderUtils.decodeFromByteArray(coder, iter.next())); + } else { + next = null; + } + return hasNext; + } + + @Override + @Nullable + public T getCurrent() throws NoSuchElementException { + if (next == null) { + throw new NoSuchElementException(); + } + return next.orNull(); + } + + @Override + public void close() throws IOException {} + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutionContext.java new file mode 100644 index 000000000000..43cd9eb573c6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutionContext.java @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TransformWatermarks; +import com.google.cloud.dataflow.sdk.util.BaseExecutionContext; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; + +/** + * Execution Context for the {@link InProcessPipelineRunner}. + * + * This implementation is not thread safe. A new {@link InProcessExecutionContext} must be created + * for each thread that requires it. + */ +class InProcessExecutionContext + extends BaseExecutionContext { + private final Clock clock; + private final Object key; + private final CopyOnAccessInMemoryStateInternals existingState; + private final TransformWatermarks watermarks; + + public InProcessExecutionContext(Clock clock, Object key, + CopyOnAccessInMemoryStateInternals existingState, TransformWatermarks watermarks) { + this.clock = clock; + this.key = key; + this.existingState = existingState; + this.watermarks = watermarks; + } + + @Override + protected InProcessStepContext createStepContext( + String stepName, String transformName, StateSampler stateSampler) { + return new InProcessStepContext(this, stepName, transformName); + } + + /** + * Step Context for the {@link InProcessPipelineRunner}. + */ + public class InProcessStepContext + extends com.google.cloud.dataflow.sdk.util.BaseExecutionContext.StepContext { + private CopyOnAccessInMemoryStateInternals stateInternals; + private InProcessTimerInternals timerInternals; + + public InProcessStepContext( + ExecutionContext executionContext, String stepName, String transformName) { + super(executionContext, stepName, transformName); + } + + @Override + public CopyOnAccessInMemoryStateInternals stateInternals() { + if (stateInternals == null) { + stateInternals = CopyOnAccessInMemoryStateInternals.withUnderlying(key, existingState); + } + return stateInternals; + } + + @Override + public InProcessTimerInternals timerInternals() { + if (timerInternals == null) { + timerInternals = + InProcessTimerInternals.create(clock, watermarks, TimerUpdate.builder(key)); + } + return timerInternals; + } + + /** + * Commits the state of this step, and returns the committed state. If the step has not + * accessed any state, return null. + */ + public CopyOnAccessInMemoryStateInternals commitState() { + if (stateInternals != null) { + return stateInternals.commit(); + } + return null; + } + + /** + * Gets the timer update of the {@link TimerInternals} of this {@link InProcessStepContext}, + * which is empty if the {@link TimerInternals} were never accessed. + */ + public TimerUpdate getTimerUpdate() { + if (timerInternals == null) { + return TimerUpdate.empty(); + } + return timerInternals.getTimerUpdate(); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java new file mode 100644 index 000000000000..d659d962f0e5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Options that can be used to configure the {@link InProcessPipelineRunner}. + */ +public interface InProcessPipelineOptions extends PipelineOptions {} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java new file mode 100644 index 000000000000..124de46b9476 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -0,0 +1,260 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; +import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey.GroupByKeyOnly; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.collect.ImmutableMap; + +import org.joda.time.Instant; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.annotation.Nullable; + +/** + * An In-Memory implementation of the Dataflow Programming Model. Supports Unbounded + * {@link PCollection PCollections}. + */ +@Experimental +public class InProcessPipelineRunner { + @SuppressWarnings({"rawtypes", "unused"}) + private static Map, Class> + defaultTransformOverrides = + ImmutableMap., Class>builder() + .put(GroupByKey.class, InProcessGroupByKey.class) + .put(CreatePCollectionView.class, InProcessCreatePCollectionView.class) + .build(); + + private static Map, TransformEvaluatorFactory> defaultEvaluatorFactories = + new ConcurrentHashMap<>(); + + /** + * Register a default transform evaluator. + */ + public static > void registerTransformEvaluatorFactory( + Class clazz, TransformEvaluatorFactory evaluator) { + checkArgument(defaultEvaluatorFactories.put(clazz, evaluator) == null, + "Defining a default factory %s to evaluate Transforms of type %s multiple times", evaluator, + clazz); + } + + /** + * Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be + * executed by {@link PTransform PTransforms} that consume the {@link PCollection} this bundle is + * a part of at a later point. This is an uncommitted bundle and can have elements added to it. + * + * @param the type of elements that can be added to this bundle + */ + public static interface UncommittedBundle { + /** + * Outputs an element to this bundle. + * + * @param element the element to add to this bundle + * @return this bundle + */ + UncommittedBundle add(WindowedValue element); + + /** + * Commits this {@link UncommittedBundle}, returning an immutable {@link CommittedBundle} + * containing all of the elements that were added to it. The {@link #add(WindowedValue)} method + * will throw an {@link IllegalStateException} if called after a call to commit. + * @param synchronizedProcessingTime the synchronized processing time at which this bundle was + * committed + */ + CommittedBundle commit(Instant synchronizedProcessingTime); + } + + /** + * Part of a {@link PCollection}. Elements are output to an {@link UncommittedBundle}, which will + * eventually committed. Committed elements are executed by the {@link PTransform PTransforms} + * that consume the {@link PCollection} this bundle is + * a part of at a later point. + * @param the type of elements contained within this bundle + */ + public static interface CommittedBundle { + + /** + * @return the PCollection that the elements of this bundle belong to + */ + PCollection getPCollection(); + + /** + * Returns weather this bundle is keyed. A bundle that is part of a {@link PCollection} that + * occurs after a {@link GroupByKey} is keyed by the result of the last {@link GroupByKey}. + */ + boolean isKeyed(); + + /** + * Returns the (possibly null) key that was output in the most recent {@link GroupByKey} in the + * execution of this bundle. + */ + @Nullable Object getKey(); + + /** + * @return an {@link Iterable} containing all of the elements that have been added to this + * {@link CommittedBundle} + */ + Iterable> getElements(); + + /** + * Returns the processing time output watermark at the time the producing {@link PTransform} + * committed this bundle. Downstream synchronized processing time watermarks cannot progress + * past this point before consuming this bundle. + * + *

    This value is no greater than the earliest incomplete processing time or synchronized + * processing time {@link TimerData timer} at the time this bundle was committed, including any + * timers that fired to produce this bundle. + */ + Instant getSynchronizedProcessingOutputWatermark(); + } + + /** + * A {@link PCollectionViewWriter} is responsible for writing contents of a {@link PCollection} to + * a storage mechanism that can be read from while constructing a {@link PCollectionView}. + * @param the type of elements the input {@link PCollection} contains. + * @param the type of the PCollectionView this writer writes to. + */ + public static interface PCollectionViewWriter { + void add(Iterable> values); + } + + /** + * The evaluation context for the {@link InProcessPipelineRunner}. Contains state shared within + * the current evaluation. + */ + public static interface InProcessEvaluationContext { + /** + * Create a {@link UncommittedBundle} for use by a source. + */ + UncommittedBundle createRootBundle(PCollection output); + + /** + * Create a {@link UncommittedBundle} whose elements belong to the specified {@link + * PCollection}. + */ + UncommittedBundle createBundle(CommittedBundle input, PCollection output); + + /** + * Create a {@link UncommittedBundle} with the specified keys at the specified step. For use by + * {@link GroupByKeyOnly} {@link PTransform PTransforms}. + */ + UncommittedBundle createKeyedBundle( + CommittedBundle input, Object key, PCollection output); + + /** + * Create a bundle whose elements will be used in a PCollectionView. + */ + PCollectionViewWriter createPCollectionViewWriter( + PCollection> input, PCollectionView output); + + /** + * Get the options used by this {@link Pipeline}. + */ + InProcessPipelineOptions getPipelineOptions(); + + /** + * Get an {@link ExecutionContext} for the provided application. + */ + InProcessExecutionContext getExecutionContext( + AppliedPTransform application, @Nullable Object key); + + /** + * Get the Step Name for the provided application. + */ + String getStepName(AppliedPTransform application); + + /** + * @param sideInputs the {@link PCollectionView PCollectionViews} the result should be able to + * read + * @return a {@link SideInputReader} that can read all of the provided + * {@link PCollectionView PCollectionViews} + */ + SideInputReader createSideInputReader(List> sideInputs); + + /** + * Schedules a callback after the watermark for a {@link PValue} after the trigger for the + * specified window (with the specified windowing strategy) must have fired from the perspective + * of that {@link PValue}, as specified by the value of + * {@link Trigger#getWatermarkThatGuaranteesFiring(BoundedWindow)} for the trigger of the + * {@link WindowingStrategy}. + */ + void callAfterOutputMustHaveBeenProduced(PValue value, BoundedWindow window, + WindowingStrategy windowingStrategy, Runnable runnable); + + /** + * Create a {@link CounterSet} for this {@link Pipeline}. The {@link CounterSet} is independent + * of all other {@link CounterSet CounterSets} created by this call. + * + * The {@link InProcessEvaluationContext} is responsible for unifying the counters present in + * all created {@link CounterSet CounterSets} when the transforms that call this method + * complete. + */ + CounterSet createCounterSet(); + + /** + * Returns all of the counters that have been merged into this context via calls to + * {@link CounterSet#merge(CounterSet)}. + */ + CounterSet getCounters(); + } + + /** + * An executor that schedules and executes {@link AppliedPTransform AppliedPTransforms} for both + * source and intermediate {@link PTransform PTransforms}. + */ + public static interface InProcessExecutor { + /** + * @param root the root {@link AppliedPTransform} to schedule + */ + void scheduleRoot(AppliedPTransform root); + + /** + * @param consumer the {@link AppliedPTransform} to schedule + * @param bundle the input bundle to the consumer + */ + void scheduleConsumption(AppliedPTransform consumer, CommittedBundle bundle); + + /** + * Blocks until the job being executed enters a terminal state. A job is completed after all + * root {@link AppliedPTransform AppliedPTransforms} have completed, and all + * {@link CommittedBundle Bundles} have been consumed. Jobs may also terminate abnormally. + */ + void awaitCompletion(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java new file mode 100644 index 000000000000..bf9a2e1c53fe --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java @@ -0,0 +1,207 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.PCollectionViewWindow; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.base.MoreObjects; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.SettableFuture; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import javax.annotation.Nullable; + +/** + * An in-process container for {@link PCollectionView PCollectionViews}, which provides methods for + * constructing {@link SideInputReader SideInputReaders} which block until a side input is + * available and writing to a {@link PCollectionView}. + */ +class InProcessSideInputContainer { + private final InProcessEvaluationContext evaluationContext; + private final Collection> containedViews; + private final LoadingCache, + SettableFuture>>> viewByWindows; + + /** + * Create a new {@link InProcessSideInputContainer} with the provided views and the provided + * context. + */ + public static InProcessSideInputContainer create( + InProcessEvaluationContext context, Collection> containedViews) { + CacheLoader, SettableFuture>>> + loader = new CacheLoader, + SettableFuture>>>() { + @Override + public SettableFuture>> load( + PCollectionViewWindow view) { + return SettableFuture.create(); + } + }; + LoadingCache, SettableFuture>>> + viewByWindows = CacheBuilder.newBuilder().build(loader); + return new InProcessSideInputContainer(context, containedViews, viewByWindows); + } + + private InProcessSideInputContainer(InProcessEvaluationContext context, + Collection> containedViews, + LoadingCache, SettableFuture>>> + viewByWindows) { + this.evaluationContext = context; + this.containedViews = ImmutableSet.copyOf(containedViews); + this.viewByWindows = viewByWindows; + } + + /** + * Return a view of this {@link InProcessSideInputContainer} that contains only the views in + * the provided argument. The returned {@link InProcessSideInputContainer} is unmodifiable without + * casting, but will change as this {@link InProcessSideInputContainer} is modified. + */ + public SideInputReader withViews(Collection> newContainedViews) { + if (!containedViews.containsAll(newContainedViews)) { + Set> currentlyContained = ImmutableSet.copyOf(containedViews); + Set> newRequested = ImmutableSet.copyOf(newContainedViews); + throw new IllegalArgumentException("Can't create a SideInputReader with unknown views " + + Sets.difference(newRequested, currentlyContained)); + } + return new SideInputContainerSideInputReader(newContainedViews); + } + + /** + * Write the provided values to the provided view. + * + *

    The windowed values are first exploded, then for each window the pane is determined. For + * each window, if the pane is later than the current pane stored within this container, write + * all of the values to the container as the new values of the {@link PCollectionView}. + * + *

    The provided iterable is expected to contain only a single window and pane. + */ + public void write(PCollectionView view, Iterable> values) + throws ExecutionException { + Map>> valuesPerWindow = new HashMap<>(); + for (WindowedValue value : values) { + for (BoundedWindow window : value.getWindows()) { + Collection> windowValues = valuesPerWindow.get(window); + if (windowValues == null) { + windowValues = new ArrayList<>(); + valuesPerWindow.put(window, windowValues); + } + windowValues.add(value); + } + } + for (Map.Entry>> windowValues : + valuesPerWindow.entrySet()) { + PCollectionViewWindow windowedView = PCollectionViewWindow.of(view, windowValues.getKey()); + SettableFuture>> future = viewByWindows.get(windowedView); + if (future.isDone()) { + try { + Iterator> existingValues = future.get().iterator(); + PaneInfo newPane = windowValues.getValue().iterator().next().getPane(); + // The current value may have no elements, if no elements were produced for the window, + // but we are recieving late data. + if (!existingValues.hasNext() + || newPane.getIndex() > existingValues.next().getPane().getIndex()) { + viewByWindows.invalidate(windowedView); + viewByWindows.get(windowedView).set(windowValues.getValue()); + } + } catch (InterruptedException e) { + // TODO: Handle meaningfully. This should never really happen when the result remains + // useful, but the result could be available and the thread can still be interrupted. + Thread.currentThread().interrupt(); + } + } else { + future.set(windowValues.getValue()); + } + } + } + + private final class SideInputContainerSideInputReader implements SideInputReader { + private final Collection> readerViews; + + private SideInputContainerSideInputReader(Collection> readerViews) { + this.readerViews = ImmutableSet.copyOf(readerViews); + } + + @Override + @Nullable + public T get(final PCollectionView view, final BoundedWindow window) { + checkArgument( + readerViews.contains(view), "calling get(PCollectionView) with unknown view: " + view); + PCollectionViewWindow windowedView = PCollectionViewWindow.of(view, window); + try { + final SettableFuture>> future = + viewByWindows.get(windowedView); + + WindowingStrategy windowingStrategy = view.getWindowingStrategyInternal(); + evaluationContext.callAfterOutputMustHaveBeenProduced( + view, window, windowingStrategy, new Runnable() { + @Override + public void run() { + // The requested window has closed without producing elements, so reflect that in + // the PCollectionView. If set has already been called, will do nothing. + future.set(Collections.>emptyList()); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper("InProcessSideInputContainerEmptyCallback") + .add("view", view) + .add("window", window) + .toString(); + } + }); + // Safe covariant cast + @SuppressWarnings("unchecked") + Iterable> values = (Iterable>) future.get(); + return view.fromIterableInternal(values); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean contains(PCollectionView view) { + return readerViews.contains(view); + } + + @Override + public boolean isEmpty() { + return readerViews.isEmpty(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTimerInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTimerInternals.java new file mode 100644 index 000000000000..06ba7b82f432 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTimerInternals.java @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate.TimerUpdateBuilder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TransformWatermarks; +import com.google.cloud.dataflow.sdk.util.TimerInternals; + +import org.joda.time.Instant; + +import javax.annotation.Nullable; + +/** + * An implementation of {@link TimerInternals} where all relevant data exists in memory. + */ +public class InProcessTimerInternals implements TimerInternals { + private final Clock processingTimeClock; + private final TransformWatermarks watermarks; + private final TimerUpdateBuilder timerUpdateBuilder; + + public static InProcessTimerInternals create( + Clock clock, TransformWatermarks watermarks, TimerUpdateBuilder timerUpdateBuilder) { + return new InProcessTimerInternals(clock, watermarks, timerUpdateBuilder); + } + + private InProcessTimerInternals( + Clock clock, TransformWatermarks watermarks, TimerUpdateBuilder timerUpdateBuilder) { + this.processingTimeClock = clock; + this.watermarks = watermarks; + this.timerUpdateBuilder = timerUpdateBuilder; + } + + @Override + public void setTimer(TimerData timerKey) { + timerUpdateBuilder.setTimer(timerKey); + } + + @Override + public void deleteTimer(TimerData timerKey) { + timerUpdateBuilder.deletedTimer(timerKey); + } + + public TimerUpdate getTimerUpdate() { + return timerUpdateBuilder.build(); + } + + @Override + public Instant currentProcessingTime() { + return processingTimeClock.now(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return watermarks.getSynchronizedProcessingInputTime(); + } + + @Override + @Nullable + public Instant currentInputWatermarkTime() { + return watermarks.getInputWatermark(); + } + + @Override + @Nullable + public Instant currentOutputWatermarkTime() { + return watermarks.getOutputWatermark(); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTransformResult.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTransformResult.java new file mode 100644 index 000000000000..3f9e94ad9f04 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTransformResult.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; + +import org.joda.time.Instant; + +import javax.annotation.Nullable; + +/** + * The result of evaluating an {@link AppliedPTransform} with a {@link TransformEvaluator}. + */ +public interface InProcessTransformResult { + /** + * Returns the {@link AppliedPTransform} that produced this result. + */ + AppliedPTransform getTransform(); + + /** + * Returns the {@link UncommittedBundle (uncommitted) Bundles} output by this transform. These + * will be committed by the evaluation context as part of completing this result. + */ + Iterable> getOutputBundles(); + + /** + * Returns the {@link CounterSet} used by this {@link PTransform}, or null if this transform did + * not use a {@link CounterSet}. + */ + @Nullable CounterSet getCounters(); + + /** + * Returns the Watermark Hold for the transform at the time this result was produced. + * + * If the transform does not set any watermark hold, returns + * {@link BoundedWindow#TIMESTAMP_MAX_VALUE}. + */ + Instant getWatermarkHold(); + + /** + * Returns the State used by the transform. + * + * If this evaluation did not access state, this may return null. + */ + CopyOnAccessInMemoryStateInternals getState(); + + /** + * Returns a TimerUpdateBuilder that was produced as a result of this evaluation. If the + * evaluation was triggered due to the delivery of one or more timers, those timers must be added + * to the builder before it is complete. + * + *

    If this evaluation did not add or remove any timers, returns an empty TimerUpdate. + */ + TimerUpdate getTimerUpdate(); + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/NanosOffsetClock.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/NanosOffsetClock.java new file mode 100644 index 000000000000..958e26d6ee41 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/NanosOffsetClock.java @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import org.joda.time.Instant; + +import java.util.concurrent.TimeUnit; + +/** + * A {@link Clock} that uses {@link System#nanoTime()} to track the progress of time. + */ +public class NanosOffsetClock implements Clock { + private final long baseMillis; + private final long nanosAtBaseMillis; + + public static NanosOffsetClock create() { + return new NanosOffsetClock(); + } + + private NanosOffsetClock() { + baseMillis = System.currentTimeMillis(); + nanosAtBaseMillis = System.nanoTime(); + } + + @Override + public Instant now() { + return new Instant( + baseMillis + (TimeUnit.MILLISECONDS.convert( + System.nanoTime() - nanosAtBaseMillis, TimeUnit.NANOSECONDS))); + } + + /** + * Creates instances of {@link NanosOffsetClock}. + */ + public static class Factory implements DefaultValueFactory { + @Override + public Clock create(PipelineOptions options) { + return new NanosOffsetClock(); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoInProcessEvaluator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoInProcessEvaluator.java new file mode 100644 index 000000000000..2a21e8cbf5ed --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoInProcessEvaluator.java @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.DoFnRunners.OutputManager; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class ParDoInProcessEvaluator implements TransformEvaluator { + private final DoFnRunner fnRunner; + private final AppliedPTransform, ?, ?> transform; + private final CounterSet counters; + private final Collection> outputBundles; + private final InProcessStepContext stepContext; + + public ParDoInProcessEvaluator( + DoFnRunner fnRunner, + AppliedPTransform, ?, ?> transform, + CounterSet counters, + Collection> outputBundles, + InProcessStepContext stepContext) { + this.fnRunner = fnRunner; + this.transform = transform; + this.counters = counters; + this.outputBundles = outputBundles; + this.stepContext = stepContext; + } + + @Override + public void processElement(WindowedValue element) { + fnRunner.processElement(element); + } + + @Override + public InProcessTransformResult finishBundle() { + fnRunner.finishBundle(); + StepTransformResult.Builder resultBuilder; + CopyOnAccessInMemoryStateInternals state = stepContext.commitState(); + if (state != null) { + resultBuilder = + StepTransformResult.withHold(transform, state.getEarliestWatermarkHold()) + .withState(state); + } else { + resultBuilder = StepTransformResult.withoutHold(transform); + } + return resultBuilder + .addOutput(outputBundles) + .withTimerUpdate(stepContext.getTimerUpdate()) + .withCounters(counters) + .build(); + } + + static class BundleOutputManager implements OutputManager { + private final Map, UncommittedBundle> bundles; + private final Map, List> undeclaredOutputs; + + public static BundleOutputManager create(Map, UncommittedBundle> outputBundles) { + return new BundleOutputManager(outputBundles); + } + + private BundleOutputManager(Map, UncommittedBundle> bundles) { + this.bundles = bundles; + undeclaredOutputs = new HashMap<>(); + } + + @SuppressWarnings("unchecked") + @Override + public void output(TupleTag tag, WindowedValue output) { + @SuppressWarnings("rawtypes") + UncommittedBundle bundle = bundles.get(tag); + if (bundle == null) { + List undeclaredContents = undeclaredOutputs.get(tag); + if (undeclaredContents == null) { + undeclaredContents = new ArrayList(); + undeclaredOutputs.put(tag, undeclaredContents); + } + undeclaredContents.add(output); + } else { + bundle.add(output); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java new file mode 100644 index 000000000000..e3ae1a028c16 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.ParDoInProcessEvaluator.BundleOutputManager; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo.BoundMulti; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.DoFnRunners; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.HashMap; +import java.util.Map; + +/** + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the + * {@link BoundMulti} primitive {@link PTransform}. + */ +class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory { + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + return createMultiEvaluator((AppliedPTransform) application, inputBundle, evaluationContext); + } + + private static ParDoInProcessEvaluator createMultiEvaluator( + AppliedPTransform, PCollectionTuple, BoundMulti> application, + CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + PCollectionTuple output = application.getOutput(); + Map, PCollection> outputs = output.getAll(); + Map, UncommittedBundle> outputBundles = new HashMap<>(); + for (Map.Entry, PCollection> outputEntry : outputs.entrySet()) { + outputBundles.put( + outputEntry.getKey(), + evaluationContext.createBundle(inputBundle, outputEntry.getValue())); + } + InProcessExecutionContext executionContext = + evaluationContext.getExecutionContext(application, inputBundle.getKey()); + String stepName = evaluationContext.getStepName(application); + InProcessStepContext stepContext = + executionContext.getOrCreateStepContext(stepName, stepName, null); + + CounterSet counters = evaluationContext.createCounterSet(); + + DoFn fn = application.getTransform().getFn(); + DoFnRunner runner = + DoFnRunners.createDefault( + evaluationContext.getPipelineOptions(), + fn, + evaluationContext.createSideInputReader(application.getTransform().getSideInputs()), + BundleOutputManager.create(outputBundles), + application.getTransform().getMainOutputTag(), + application.getTransform().getSideOutputTags().getAll(), + stepContext, + counters.getAddCounterMutator(), + application.getInput().getWindowingStrategy()); + + runner.startBundle(); + + return new ParDoInProcessEvaluator<>( + runner, application, counters, outputBundles.values(), stepContext); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java new file mode 100644 index 000000000000..cd79c219bd67 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.ParDoInProcessEvaluator.BundleOutputManager; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo.Bound; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.DoFnRunners; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.Collections; + +/** + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the + * {@link Bound ParDo.Bound} primitive {@link PTransform}. + */ +class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory { + @Override + public TransformEvaluator forApplication( + final AppliedPTransform application, + CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + return createSingleEvaluator((AppliedPTransform) application, inputBundle, evaluationContext); + } + + private static ParDoInProcessEvaluator createSingleEvaluator( + @SuppressWarnings("rawtypes") AppliedPTransform, PCollection, + Bound> application, + CommittedBundle inputBundle, InProcessEvaluationContext evaluationContext) { + TupleTag mainOutputTag = new TupleTag<>("out"); + UncommittedBundle outputBundle = + evaluationContext.createBundle(inputBundle, application.getOutput()); + + InProcessExecutionContext executionContext = + evaluationContext.getExecutionContext(application, inputBundle.getKey()); + String stepName = evaluationContext.getStepName(application); + InProcessStepContext stepContext = + executionContext.getOrCreateStepContext(stepName, stepName, null); + + CounterSet counters = evaluationContext.createCounterSet(); + + DoFnRunner runner = + DoFnRunners.createDefault( + evaluationContext.getPipelineOptions(), + application.getTransform().getFn(), + evaluationContext.createSideInputReader(application.getTransform().getSideInputs()), + BundleOutputManager.create( + Collections., UncommittedBundle>singletonMap( + mainOutputTag, outputBundle)), + mainOutputTag, + Collections.>emptyList(), + stepContext, + counters.getAddCounterMutator(), + application.getInput().getWindowingStrategy()); + + runner.startBundle(); + return new ParDoInProcessEvaluator( + runner, + application, + counters, + Collections.>singleton(outputBundle), + stepContext); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepTransformResult.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepTransformResult.java new file mode 100644 index 000000000000..3c4ee29d96f4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepTransformResult.java @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * An immutable {@link InProcessTransformResult}. + */ +public class StepTransformResult implements InProcessTransformResult { + private final AppliedPTransform transform; + private final Iterable> bundles; + private final CopyOnAccessInMemoryStateInternals state; + private final TimerUpdate timerUpdate; + private final CounterSet counters; + private final Instant watermarkHold; + + private StepTransformResult( + AppliedPTransform transform, + Iterable> outputBundles, + CopyOnAccessInMemoryStateInternals state, + TimerUpdate timerUpdate, + CounterSet counters, + Instant watermarkHold) { + this.transform = transform; + this.bundles = outputBundles; + this.state = state; + this.timerUpdate = timerUpdate; + this.counters = counters; + this.watermarkHold = watermarkHold; + } + + @Override + public Iterable> getOutputBundles() { + return bundles; + } + + @Override + public CounterSet getCounters() { + return counters; + } + + @Override + public AppliedPTransform getTransform() { + return transform; + } + + @Override + public Instant getWatermarkHold() { + return watermarkHold; + } + + @Override + public CopyOnAccessInMemoryStateInternals getState() { + return state; + } + + @Override + public TimerUpdate getTimerUpdate() { + return timerUpdate; + } + + public static Builder withHold(AppliedPTransform transform, Instant watermarkHold) { + return new Builder(transform, watermarkHold); + } + + public static Builder withoutHold(AppliedPTransform transform) { + return new Builder(transform, BoundedWindow.TIMESTAMP_MAX_VALUE); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(StepTransformResult.class) + .add("transform", transform) + .toString(); + } + + /** + * A builder for creating instances of {@link StepTransformResult}. + */ + public static class Builder { + private final AppliedPTransform transform; + private final ImmutableList.Builder> bundlesBuilder; + private CopyOnAccessInMemoryStateInternals state; + private TimerUpdate timerUpdate; + private CounterSet counters; + private final Instant watermarkHold; + + private Builder(AppliedPTransform transform, Instant watermarkHold) { + this.transform = transform; + this.watermarkHold = watermarkHold; + this.bundlesBuilder = ImmutableList.builder(); + this.timerUpdate = TimerUpdate.builder(null).build(); + } + + public StepTransformResult build() { + return new StepTransformResult( + transform, + bundlesBuilder.build(), + state, + timerUpdate, + counters, + watermarkHold); + } + + public Builder withCounters(CounterSet counters) { + this.counters = counters; + return this; + } + + public Builder withState(CopyOnAccessInMemoryStateInternals state) { + this.state = state; + return this; + } + + public Builder withTimerUpdate(TimerUpdate timerUpdate) { + this.timerUpdate = timerUpdate; + return this; + } + + public Builder addOutput( + UncommittedBundle outputBundle, UncommittedBundle... outputBundles) { + bundlesBuilder.add(outputBundle); + bundlesBuilder.add(outputBundles); + return this; + } + + public Builder addOutput(Collection> outputBundles) { + bundlesBuilder.addAll(outputBundles); + return this; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluator.java new file mode 100644 index 000000000000..270557d55c11 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluator.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +/** + * An evaluator of a specific application of a transform. Will be used for at least one + * {@link CommittedBundle}. + * + * @param the type of elements that will be passed to {@link #processElement} + */ +public interface TransformEvaluator { + /** + * Process an element in the input {@link CommittedBundle}. + * + * @param element the element to process + */ + void processElement(WindowedValue element) throws Exception; + + /** + * Finish processing the bundle of this {@link TransformEvaluator}. + * + * After {@link #finishBundle()} is called, the {@link TransformEvaluator} will not be reused, + * and no more elements will be processed. + * + * @return an {@link InProcessTransformResult} containing the results of this bundle evaluation. + */ + InProcessTransformResult finishBundle() throws Exception; +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java new file mode 100644 index 000000000000..3b672e0def5e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import javax.annotation.Nullable; + +/** + * A factory for creating instances of {@link TransformEvaluator} for the application of a + * {@link PTransform}. + */ +public interface TransformEvaluatorFactory { + /** + * Create a new {@link TransformEvaluator} for the application of the {@link PTransform}. + * + * Any work that must be done before input elements are processed (such as calling + * {@link DoFn#startBundle(DoFn.Context)}) must be done before the {@link TransformEvaluator} is + * made available to the caller. + * + * @throws Exception whenever constructing the underlying evaluator throws an exception + */ + TransformEvaluator forApplication( + AppliedPTransform application, @Nullable CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) throws Exception; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java new file mode 100644 index 000000000000..4beac337d604 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.io.Read.Unbounded; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.CheckpointMark; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import java.io.IOException; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nullable; + +/** + * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} + * for the {@link Unbounded Read.Unbounded} primitive {@link PTransform}. + */ +class UnboundedReadEvaluatorFactory implements TransformEvaluatorFactory { + /* + * An evaluator for a Source is stateful, to ensure the CheckpointMark is properly persisted. + * Evaluators are cached here to ensure that the checkpoint mark is appropriately reused + * and any splits are honored. + */ + private final ConcurrentMap>> + sourceEvaluators = new ConcurrentHashMap<>(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public TransformEvaluator forApplication(AppliedPTransform application, + @Nullable CommittedBundle inputBundle, InProcessEvaluationContext evaluationContext) { + return getTransformEvaluator((AppliedPTransform) application, evaluationContext); + } + + private TransformEvaluator getTransformEvaluator( + final AppliedPTransform, Unbounded> transform, + final InProcessEvaluationContext evaluationContext) { + UnboundedReadEvaluator currentEvaluator = + getTransformEvaluatorQueue(transform, evaluationContext).poll(); + if (currentEvaluator == null) { + return EmptyTransformEvaluator.create(transform); + } + return currentEvaluator; + } + + /** + * Get the queue of {@link TransformEvaluator TransformEvaluators} that produce elements for the + * provided application of {@link Unbounded Read.Unbounded}, initializing it if required. + * + *

    This method is thread-safe, and will only produce new evaluators if no other invocation has + * already done so. + */ + @SuppressWarnings("unchecked") + private Queue> getTransformEvaluatorQueue( + final AppliedPTransform, Unbounded> transform, + final InProcessEvaluationContext evaluationContext) { + // Key by the application and the context the evaluation is occurring in (which call to + // Pipeline#run). + EvaluatorKey key = new EvaluatorKey(transform, evaluationContext); + @SuppressWarnings("unchecked") + Queue> evaluatorQueue = + (Queue>) sourceEvaluators.get(key); + if (evaluatorQueue == null) { + evaluatorQueue = new ConcurrentLinkedQueue<>(); + if (sourceEvaluators.putIfAbsent(key, evaluatorQueue) == null) { + // If no queue existed in the evaluators, add an evaluator to initialize the evaluator + // factory for this transform + UnboundedReadEvaluator evaluator = + new UnboundedReadEvaluator(transform, evaluationContext, evaluatorQueue); + evaluatorQueue.offer(evaluator); + } else { + // otherwise return the existing Queue that arrived before us + evaluatorQueue = (Queue>) sourceEvaluators.get(key); + } + } + return evaluatorQueue; + } + + private static class UnboundedReadEvaluator implements TransformEvaluator { + private static final int ARBITRARY_MAX_ELEMENTS = 10; + private final AppliedPTransform, Unbounded> transform; + private final InProcessEvaluationContext evaluationContext; + private final Queue> evaluatorQueue; + private CheckpointMark checkpointMark; + + public UnboundedReadEvaluator( + AppliedPTransform, Unbounded> transform, + InProcessEvaluationContext evaluationContext, + Queue> evaluatorQueue) { + this.transform = transform; + this.evaluationContext = evaluationContext; + this.evaluatorQueue = evaluatorQueue; + this.checkpointMark = null; + } + + @Override + public void processElement(WindowedValue element) {} + + @Override + public InProcessTransformResult finishBundle() throws IOException { + UncommittedBundle output = evaluationContext.createRootBundle(transform.getOutput()); + UnboundedReader reader = + createReader( + transform.getTransform().getSource(), evaluationContext.getPipelineOptions()); + int numElements = 0; + if (reader.start()) { + do { + output.add( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp())); + numElements++; + } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance()); + } + checkpointMark = reader.getCheckpointMark(); + checkpointMark.finalizeCheckpoint(); + // TODO: When exercising create initial splits, make this the minimum watermark across all + // existing readers + StepTransformResult result = + StepTransformResult.withHold(transform, reader.getWatermark()) + .addOutput(output) + .build(); + evaluatorQueue.offer(this); + return result; + } + + private UnboundedReader createReader( + UnboundedSource source, PipelineOptions options) { + @SuppressWarnings("unchecked") + CheckpointMarkT mark = (CheckpointMarkT) checkpointMark; + return source.createReader(options, mark); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java new file mode 100644 index 000000000000..f47cd1de986b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import java.util.ArrayList; +import java.util.List; + +/** + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the + * {@link CreatePCollectionView} primitive {@link PTransform}. + * + *

    The {@link ViewEvaluatorFactory} produces {@link TransformEvaluator TransformEvaluators} for + * the {@link WriteView} {@link PTransform}, which is part of the + * {@link InProcessCreatePCollectionView} composite transform. This transform is an override for the + * {@link CreatePCollectionView} transform that applies windowing and triggers before the view is + * written. + */ +class ViewEvaluatorFactory implements TransformEvaluatorFactory { + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + InProcessPipelineRunner.CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + return createEvaluator( + (AppliedPTransform) application, evaluationContext); + } + + private TransformEvaluator> createEvaluator( + final AppliedPTransform>, PCollectionView, WriteView> + application, + InProcessEvaluationContext context) { + PCollection> input = application.getInput(); + final PCollectionViewWriter writer = + context.createPCollectionViewWriter(input, application.getOutput()); + return new TransformEvaluator>() { + private final List> elements = new ArrayList<>(); + + @Override + public void processElement(WindowedValue> element) { + for (InT input : element.getValue()) { + elements.add(element.withValue(input)); + } + } + + @Override + public InProcessTransformResult finishBundle() { + writer.add(elements); + return StepTransformResult.withoutHold(application).build(); + } + }; + } + + /** + * An in-process override for {@link CreatePCollectionView}. + */ + public static class InProcessCreatePCollectionView + extends PTransform, PCollectionView> { + private final CreatePCollectionView og; + + private InProcessCreatePCollectionView(CreatePCollectionView og) { + this.og = og; + } + + @Override + public PCollectionView apply(PCollection input) { + return input.apply(WithKeys.of((Void) null)) + .setCoder(KvCoder.of(VoidCoder.of(), input.getCoder())) + .apply(GroupByKey.create()) + .apply(Values.>create()) + .apply(new WriteView(og)); + } + } + + /** + * An in-process implementation of the {@link CreatePCollectionView} primitive. + * + * This implementation requires the input {@link PCollection} to be an iterable, which is provided + * to {@link PCollectionView#fromIterableInternal(Iterable)}. + */ + public static final class WriteView + extends PTransform>, PCollectionView> { + private final CreatePCollectionView og; + + WriteView(CreatePCollectionView og) { + this.og = og; + } + + @Override + public PCollectionView apply(PCollection> input) { + return og.getView(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/package-info.java new file mode 100644 index 000000000000..d1aa6af192d7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/package-info.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines runners for executing Pipelines in different modes, including + * {@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner} and + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner}. + * + *

    {@link com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner} executes a {@code Pipeline} + * locally, without contacting the Dataflow service. + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner} submits a + * {@code Pipeline} to the Dataflow service, which executes it on Dataflow-managed Compute Engine + * instances. {@code DataflowPipelineRunner} returns + * as soon as the {@code Pipeline} has been submitted. Use + * {@link com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner} to have execution + * updates printed to the console. + * + *

    The runner is specified as part {@link com.google.cloud.dataflow.sdk.options.PipelineOptions}. + */ +package com.google.cloud.dataflow.sdk.runners; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/IsmFormat.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/IsmFormat.java new file mode 100644 index 000000000000..318de9b5b894 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/IsmFormat.java @@ -0,0 +1,946 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.worker; + +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.RandomAccessData; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * An Ism file is a prefix encoded composite key value file broken into shards. Each composite + * key is composed of a fixed number of component keys. A fixed number of those sub keys represent + * the shard key portion; see {@link IsmRecord} and {@link IsmRecordCoder} for further details + * around the data format. In addition to the data, there is a bloom filter, + * and multiple indices to allow for efficient retrieval. + * + *

    An Ism file is composed of these high level sections (in order): + *

      + *
    • shard block
    • + *
    • bloom filter (See {@code ScalableBloomFilter} for details on encoding format)
    • + *
    • shard index
    • + *
    • footer (See {@link Footer} for details on encoding format)
    • + *
    + * + *

    The shard block is composed of multiple copies of the following: + *

      + *
    • data block
    • + *
    • data index
    • + *
    + * + *

    The data block is composed of multiple copies of the following: + *

      + *
    • key prefix (See {@link KeyPrefix} for details on encoding format)
    • + *
    • unshared key bytes
    • + *
    • value bytes
    • + *
    • optional 0x00 0x00 bytes followed by metadata bytes + * (if the following 0x00 0x00 bytes are not present, then there are no metadata bytes)
    • + *
    + * Each key written into the data block must be in unsigned lexicographically increasing order + * and also its shard portion of the key must hash to the same shard id as all other keys + * within the same data block. The hashing function used is the + * + * 32-bit murmur3 algorithm, x86 variant (little-endian variant), + * using {@code 1225801234} as the seed value. + * + *

    The data index is composed of {@code N} copies of the following: + *

      + *
    • key prefix (See {@link KeyPrefix} for details on encoding format)
    • + *
    • unshared key bytes
    • + *
    • byte offset to key prefix in data block (variable length long coding)
    • + *
    + * + *

    The shard index is composed of a {@link VarInt variable length integer} encoding representing + * the number of shard index records followed by that many shard index records. + * See {@link IsmShardCoder} for further details as to its encoding scheme. + */ +public class IsmFormat { + private static final int HASH_SEED = 1225801234; + private static final HashFunction HASH_FUNCTION = Hashing.murmur3_32(HASH_SEED); + static final int SHARD_BITS = 0x7F; // [0-127] shards + [128-255] metadata shards + + /** + * A record containing a composite key and either a value or metadata. The composite key + * must not contain the metadata key component place holder if producing a value record, and must + * contain the metadata component key place holder if producing a metadata record. + * + *

    The composite key is a fixed number of component keys where the first {@code N} component + * keys are used to create a shard id via hashing. See {@link IsmRecordCoder#hash(List)} for + * further details. + */ + public static class IsmRecord { + /** Returns an IsmRecord with the specified key components and value. */ + public static IsmRecord of(List keyComponents, V value) { + checkNotNull(keyComponents); + checkArgument(!keyComponents.isEmpty(), "Expected non-empty list of key components."); + checkArgument(!isMetadataKey(keyComponents), + "Expected key components to not contain metadata key."); + return new IsmRecord<>(keyComponents, value, null); + } + + public static IsmRecord meta(List keyComponents, byte[] metadata) { + checkNotNull(keyComponents); + checkNotNull(metadata); + checkArgument(!keyComponents.isEmpty(), "Expected non-empty list of key components."); + checkArgument(isMetadataKey(keyComponents), + "Expected key components to contain metadata key."); + return new IsmRecord(keyComponents, null, metadata); + } + + private final List keyComponents; + @Nullable + private final V value; + @Nullable + private final byte[] metadata; + private IsmRecord(List keyComponents, V value, byte[] metadata) { + this.keyComponents = keyComponents; + this.value = value; + this.metadata = metadata; + } + + /** Returns the list of key components. */ + public List getKeyComponents() { + return keyComponents; + } + + /** Returns the key component at the specified index. */ + public Object getKeyComponent(int index) { + return keyComponents.get(index); + } + + /** + * Returns the value. Throws {@link IllegalStateException} if this is not a + * value record. + */ + public V getValue() { + checkState(!isMetadataKey(keyComponents), + "This is a metadata record and not a value record."); + return value; + } + + /** + * Returns the metadata. Throws {@link IllegalStateException} if this is not a + * metadata record. + */ + public byte[] getMetadata() { + checkState(isMetadataKey(keyComponents), + "This is a value record and not a metadata record."); + return metadata; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof IsmRecord)) { + return false; + } + IsmRecord other = (IsmRecord) obj; + return Objects.equal(keyComponents, other.keyComponents) + && Objects.equal(value, other.value) + && Arrays.equals(metadata, other.metadata); + } + + @Override + public int hashCode() { + return Objects.hashCode(keyComponents, value, Arrays.hashCode(metadata)); + } + + @Override + public String toString() { + ToStringHelper builder = MoreObjects.toStringHelper(IsmRecord.class) + .add("keyComponents", keyComponents); + if (isMetadataKey(keyComponents)) { + builder.add("metadata", metadata); + } else { + builder.add("value", value); + } + return builder.toString(); + } + } + + /** A {@link Coder} for {@link IsmRecord}s. + * + *

    Note that this coder standalone will not produce an Ism file. This coder can be used + * to materialize a {@link PCollection} of {@link IsmRecord}s. Only when this coder + * is combined with an {@link IsmSink} will one produce an Ism file. + * + *

    The {@link IsmRecord} encoded format is: + *

      + *
    • encoded key component 1 using key component coder 1
    • + *
    • ...
    • + *
    • encoded key component N using key component coder N
    • + *
    • encoded value using value coder
    • + *
    + */ + public static class IsmRecordCoder + extends StandardCoder> { + /** Returns an IsmRecordCoder with the specified key component coders, value coder. */ + public static IsmRecordCoder of( + int numberOfShardKeyCoders, + int numberOfMetadataShardKeyCoders, + List> keyComponentCoders, + Coder valueCoder) { + checkNotNull(keyComponentCoders); + checkArgument(keyComponentCoders.size() > 0); + checkArgument(numberOfShardKeyCoders > 0); + checkArgument(numberOfShardKeyCoders <= keyComponentCoders.size()); + checkArgument(numberOfMetadataShardKeyCoders <= keyComponentCoders.size()); + return new IsmRecordCoder<>( + numberOfShardKeyCoders, + numberOfMetadataShardKeyCoders, + keyComponentCoders, + valueCoder); + } + + /** + * Returns an IsmRecordCoder with the specified coders. Note that this method is not meant + * to be called by users but used by Jackson when decoding this coder. + */ + @JsonCreator + public static IsmRecordCoder of( + @JsonProperty(PropertyNames.NUM_SHARD_CODERS) int numberOfShardCoders, + @JsonProperty(PropertyNames.NUM_METADATA_SHARD_CODERS) int numberOfMetadataShardCoders, + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) List> components) { + Preconditions.checkArgument(components.size() >= 2, + "Expecting at least 2 components, got " + components.size()); + return of( + numberOfShardCoders, + numberOfMetadataShardCoders, + components.subList(0, components.size() - 1), + components.get(components.size() - 1)); + } + + private final int numberOfShardKeyCoders; + private final int numberOfMetadataShardKeyCoders; + private final List> keyComponentCoders; + private final Coder valueCoder; + + private IsmRecordCoder( + int numberOfShardKeyCoders, + int numberOfMetadataShardKeyCoders, + List> keyComponentCoders, Coder valueCoder) { + this.numberOfShardKeyCoders = numberOfShardKeyCoders; + this.numberOfMetadataShardKeyCoders = numberOfMetadataShardKeyCoders; + this.keyComponentCoders = keyComponentCoders; + this.valueCoder = valueCoder; + } + + /** Returns the list of key component coders. */ + public List> getKeyComponentCoders() { + return keyComponentCoders; + } + + /** Returns the key coder at the specified index. */ + public Coder getKeyComponentCoder(int index) { + return keyComponentCoders.get(index); + } + + /** Returns the value coder. */ + public Coder getValueCoder() { + return valueCoder; + } + + @Override + public void encode(IsmRecord value, OutputStream outStream, + Coder.Context context) throws CoderException, IOException { + if (value.getKeyComponents().size() != keyComponentCoders.size()) { + throw new CoderException(String.format( + "Expected %s key component(s) but received key component(s) %s.", + keyComponentCoders.size(), value.getKeyComponents())); + } + for (int i = 0; i < keyComponentCoders.size(); ++i) { + getKeyComponentCoder(i).encode(value.getKeyComponent(i), outStream, context.nested()); + } + if (isMetadataKey(value.getKeyComponents())) { + ByteArrayCoder.of().encode(value.getMetadata(), outStream, context.nested()); + } else { + valueCoder.encode(value.getValue(), outStream, context.nested()); + } + } + + @Override + public IsmRecord decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + List keyComponents = new ArrayList<>(keyComponentCoders.size()); + for (Coder keyCoder : keyComponentCoders) { + keyComponents.add(keyCoder.decode(inStream, context.nested())); + } + if (isMetadataKey(keyComponents)) { + return IsmRecord.meta( + keyComponents, ByteArrayCoder.of().decode(inStream, context.nested())); + } else { + return IsmRecord.of(keyComponents, valueCoder.decode(inStream, context.nested())); + } + } + + int getNumberOfShardKeyCoders(List keyComponents) { + if (isMetadataKey(keyComponents)) { + return numberOfMetadataShardKeyCoders; + } else { + return numberOfShardKeyCoders; + } + } + + /** + * Computes the shard id for the given key component(s). + * + * The shard keys are encoded into their byte representations and hashed using the + * + * 32-bit murmur3 algorithm, x86 variant (little-endian variant), + * using {@code 1225801234} as the seed value. We ensure that shard ids for + * metadata keys and normal keys do not overlap. + */ + public int hash(List keyComponents) { + return encodeAndHash(keyComponents, new RandomAccessData(), new ArrayList()); + } + + /** + * Computes the shard id for the given key component(s). + * + * Mutates {@code keyBytes} such that when returned, contains the encoded + * version of the key components. + */ + int encodeAndHash(List keyComponents, RandomAccessData keyBytesToMutate) { + return encodeAndHash(keyComponents, keyBytesToMutate, new ArrayList()); + } + + /** + * Computes the shard id for the given key component(s). + * + * Mutates {@code keyBytes} such that when returned, contains the encoded + * version of the key components. Also, mutates {@code keyComponentByteOffsetsToMutate} to + * store the location where each key component's encoded byte representation ends within + * {@code keyBytes}. + */ + int encodeAndHash( + List keyComponents, + RandomAccessData keyBytesToMutate, + List keyComponentByteOffsetsToMutate) { + checkNotNull(keyComponents); + checkArgument(keyComponents.size() <= keyComponentCoders.size(), + "Expected at most %s key component(s) but received %s.", + keyComponentCoders.size(), keyComponents); + + final int numberOfKeyCodersToUse; + final int shardOffset; + if (isMetadataKey(keyComponents)) { + numberOfKeyCodersToUse = numberOfMetadataShardKeyCoders; + shardOffset = SHARD_BITS + 1; + } else { + numberOfKeyCodersToUse = numberOfShardKeyCoders; + shardOffset = 0; + } + + checkArgument(numberOfKeyCodersToUse <= keyComponents.size(), + "Expected at least %s key component(s) but received %s.", + numberOfShardKeyCoders, keyComponents); + + try { + // Encode the shard portion + for (int i = 0; i < numberOfKeyCodersToUse; ++i) { + getKeyComponentCoder(i).encode( + keyComponents.get(i), keyBytesToMutate.asOutputStream(), Context.NESTED); + keyComponentByteOffsetsToMutate.add(keyBytesToMutate.size()); + } + int rval = HASH_FUNCTION.hashBytes( + keyBytesToMutate.array(), 0, keyBytesToMutate.size()).asInt() & SHARD_BITS; + rval += shardOffset; + + // Encode the remainder + for (int i = numberOfKeyCodersToUse; i < keyComponents.size(); ++i) { + getKeyComponentCoder(i).encode( + keyComponents.get(i), keyBytesToMutate.asOutputStream(), Context.NESTED); + keyComponentByteOffsetsToMutate.add(keyBytesToMutate.size()); + } + return rval; + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to hash %s with coder %s", keyComponents, this), e); + } + } + + @Override + public List> getCoderArguments() { + return ImmutableList.>builder() + .addAll(keyComponentCoders) + .add(valueCoder) + .build(); + } + + @Override + public CloudObject asCloudObject() { + CloudObject cloudObject = super.asCloudObject(); + addLong(cloudObject, PropertyNames.NUM_SHARD_CODERS, numberOfShardKeyCoders); + addLong(cloudObject, PropertyNames.NUM_METADATA_SHARD_CODERS, numberOfMetadataShardKeyCoders); + return cloudObject; + } + + @Override + public void verifyDeterministic() throws Coder.NonDeterministicException { + verifyDeterministic("Key component coders expected to be deterministic.", keyComponentCoders); + verifyDeterministic("Value coder expected to be deterministic.", valueCoder); + } + + @Override + public boolean consistentWithEquals() { + for (Coder keyComponentCoder : keyComponentCoders) { + if (!keyComponentCoder.consistentWithEquals()) { + return false; + } + } + return valueCoder.consistentWithEquals(); + } + + @Override + public Object structuralValue(IsmRecord record) throws Exception { + checkState(record.getKeyComponents().size() == keyComponentCoders.size(), + "Expected the number of key component coders %s " + + "to match the number of key components %s.", + keyComponentCoders.size(), record.getKeyComponents()); + + if (record != null && consistentWithEquals()) { + ArrayList keyComponentStructuralValues = new ArrayList<>(); + for (int i = 0; i < keyComponentCoders.size(); ++i) { + keyComponentStructuralValues.add( + getKeyComponentCoder(i).structuralValue(record.getKeyComponent(i))); + } + if (isMetadataKey(record.getKeyComponents())) { + return IsmRecord.meta(keyComponentStructuralValues, record.getMetadata()); + } else { + return IsmRecord.of(keyComponentStructuralValues, + valueCoder.structuralValue(record.getValue())); + } + } + return super.structuralValue(record); + } + } + + /** + * Validates that the key portion of the given coder is deterministic. + */ + static void validateCoderIsCompatible(IsmRecordCoder coder) { + for (Coder keyComponentCoder : coder.getKeyComponentCoders()) { + try { + keyComponentCoder.verifyDeterministic(); + } catch (NonDeterministicException e) { + throw new IllegalArgumentException( + String.format("Key component coder %s is expected to be deterministic.", + keyComponentCoder), e); + } + } + } + + /** Returns true if and only if any of the passed in key components represent a metadata key. */ + public static boolean isMetadataKey(List keyComponents) { + for (Object keyComponent : keyComponents) { + if (keyComponent == METADATA_KEY) { + return true; + } + } + return false; + } + + /** A marker object representing the wildcard metadata key component. */ + private static final Object METADATA_KEY = new Object() { + @Override + public String toString() { + return "META"; + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + @Override + public int hashCode() { + return -1248902349; + } + }; + + /** + * An object representing a wild card for a key component. + * Encoded using {@link MetadataKeyCoder}. + */ + public static Object getMetadataKey() { + return METADATA_KEY; + } + + /** + * A coder for metadata key component. Can be used to wrap key component coder allowing for + * the metadata key component to be used as a place holder instead of an actual key. + */ + public static class MetadataKeyCoder extends StandardCoder { + public static MetadataKeyCoder of(Coder keyCoder) { + checkNotNull(keyCoder); + return new MetadataKeyCoder<>(keyCoder); + } + + /** + * Returns an IsmRecordCoder with the specified coders. Note that this method is not meant + * to be called by users but used by Jackson when decoding this coder. + */ + @JsonCreator + public static MetadataKeyCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting one component, got " + components.size()); + return of(components.get(0)); + } + + private final Coder keyCoder; + + private MetadataKeyCoder(Coder keyCoder) { + this.keyCoder = keyCoder; + } + + public Coder getKeyCoder() { + return keyCoder; + } + + @Override + public void encode(K value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + if (value == METADATA_KEY) { + outStream.write(0); + } else { + outStream.write(1); + keyCoder.encode(value, outStream, context.nested()); + } + } + + @Override + public K decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + int marker = inStream.read(); + if (marker == 0) { + return (K) getMetadataKey(); + } else if (marker == 1) { + return keyCoder.decode(inStream, context.nested()); + } else { + throw new CoderException(String.format("Expected marker but got %s.", marker)); + } + } + + @Override + public List> getCoderArguments() { + return ImmutableList.>of(keyCoder); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic("Expected key coder to be deterministic", keyCoder); + } + } + + /** + * A shard descriptor containing shard id, the data block offset, and the index offset for the + * given shard. + */ + public static class IsmShard { + private final int id; + private final long blockOffset; + private final long indexOffset; + + /** Returns an IsmShard with the given id, block offset and no index offset. */ + public static IsmShard of(int id, long blockOffset) { + IsmShard ismShard = new IsmShard(id, blockOffset, -1); + checkState(id >= 0, + "%s attempting to be written with negative shard id.", + ismShard); + checkState(blockOffset >= 0, + "%s attempting to be written with negative block offset.", + ismShard); + return ismShard; + } + + /** Returns an IsmShard with the given id, block offset, and index offset. */ + public static IsmShard of(int id, long blockOffset, long indexOffset) { + IsmShard ismShard = new IsmShard(id, blockOffset, indexOffset); + checkState(id >= 0, + "%s attempting to be written with negative shard id.", + ismShard); + checkState(blockOffset >= 0, + "%s attempting to be written with negative block offset.", + ismShard); + checkState(indexOffset >= 0, + "%s attempting to be written with negative index offset.", + ismShard); + return ismShard; + } + + private IsmShard(int id, long blockOffset, long indexOffset) { + this.id = id; + this.blockOffset = blockOffset; + this.indexOffset = indexOffset; + } + + /** Return the shard id. */ + public int getId() { + return id; + } + + /** Return the absolute position within the Ism file where the data block begins. */ + public long getBlockOffset() { + return blockOffset; + } + + /** + * Return the absolute position within the Ism file where the index block begins. + * Throws {@link IllegalStateException} if the index offset was never specified. + */ + public long getIndexOffset() { + checkState(indexOffset >= 0, + "Unable to fetch index offset because it was never specified."); + return indexOffset; + } + + /** Returns a new IsmShard like this one with the specified index offset. */ + public IsmShard withIndexOffset(long indexOffset) { + return of(id, blockOffset, indexOffset); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(IsmShard.class) + .add("id", id) + .add("blockOffset", blockOffset) + .add("indexOffset", indexOffset) + .toString(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof IsmShard)) { + return false; + } + IsmShard other = (IsmShard) obj; + return Objects.equal(id, other.id) + && Objects.equal(blockOffset, other.blockOffset) + && Objects.equal(indexOffset, other.indexOffset); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, blockOffset, indexOffset); + } + } + + /** + * A {@link ListCoder} wrapping a {@link IsmShardCoder} used to encode the shard index. + * See {@link ListCoder} for its encoding specification and {@link IsmShardCoder} for its + * encoding specification. + */ + public static final Coder> ISM_SHARD_INDEX_CODER = + ListCoder.of(IsmShardCoder.of()); + + /** + * A coder for {@link IsmShard}s. + * + * The shard descriptor is encoded as: + *
      + *
    • id (variable length integer encoding)
    • + *
    • blockOffset (variable length long encoding)
    • + *
    • indexOffset (variable length long encoding)
    • + *
    + */ + public static class IsmShardCoder extends AtomicCoder { + private static final IsmShardCoder INSTANCE = new IsmShardCoder(); + + /** Returns an IsmShardCoder. */ + @JsonCreator + public static IsmShardCoder of() { + return INSTANCE; + } + + private IsmShardCoder() { + } + + @Override + public void encode(IsmShard value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + checkState(value.getIndexOffset() >= 0, + "%s attempting to be written without index offset.", + value); + VarIntCoder.of().encode(value.getId(), outStream, context.nested()); + VarLongCoder.of().encode(value.getBlockOffset(), outStream, context.nested()); + VarLongCoder.of().encode(value.getIndexOffset(), outStream, context.nested()); + } + + @Override + public IsmShard decode( + InputStream inStream, Coder.Context context) throws CoderException, IOException { + return IsmShard.of( + VarIntCoder.of().decode(inStream, context), + VarLongCoder.of().decode(inStream, context), + VarLongCoder.of().decode(inStream, context)); + } + + @Override + public boolean consistentWithEquals() { + return true; + } + } + + /** + * The prefix used before each key which contains the number of shared and unshared + * bytes from the previous key that was read. The key prefix along with the previous key + * and the unshared key bytes allows one to construct the current key by doing the following + * {@code currentKey = previousKey[0 : sharedBytes] + read(unsharedBytes)}. + * + *

    The key prefix is encoded as: + *

      + *
    • number of shared key bytes (variable length integer coding)
    • + *
    • number of unshared key bytes (variable length integer coding)
    • + *
    + */ + static class KeyPrefix { + private final int sharedKeySize; + private final int unsharedKeySize; + + KeyPrefix(int sharedBytes, int unsharedBytes) { + this.sharedKeySize = sharedBytes; + this.unsharedKeySize = unsharedBytes; + } + + public int getSharedKeySize() { + return sharedKeySize; + } + + public int getUnsharedKeySize() { + return unsharedKeySize; + } + + @Override + public int hashCode() { + return Objects.hashCode(sharedKeySize, unsharedKeySize); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof KeyPrefix)) { + return false; + } + KeyPrefix keyPrefix = (KeyPrefix) other; + return sharedKeySize == keyPrefix.sharedKeySize + && unsharedKeySize == keyPrefix.unsharedKeySize; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("sharedKeySize", sharedKeySize) + .add("unsharedKeySize", unsharedKeySize) + .toString(); + } + } + + /** A {@link Coder} for {@link KeyPrefix}. */ + static final class KeyPrefixCoder extends AtomicCoder { + private static final KeyPrefixCoder INSTANCE = new KeyPrefixCoder(); + + @JsonCreator + public static KeyPrefixCoder of() { + return INSTANCE; + } + + @Override + public void encode(KeyPrefix value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + VarInt.encode(value.sharedKeySize, outStream); + VarInt.encode(value.unsharedKeySize, outStream); + } + + @Override + public KeyPrefix decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + return new KeyPrefix(VarInt.decodeInt(inStream), VarInt.decodeInt(inStream)); + } + + @Override + public boolean consistentWithEquals() { + return true; + } + + @Override + public boolean isRegisterByteSizeObserverCheap(KeyPrefix value, Coder.Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(KeyPrefix value, Coder.Context context) + throws Exception { + Preconditions.checkNotNull(value); + return VarInt.getLength(value.sharedKeySize) + VarInt.getLength(value.unsharedKeySize); + } + } + + /** + * The footer stores the relevant information required to locate the index and bloom filter. + * It also stores a version byte and the number of keys stored. + * + *

    The footer is encoded as the value containing: + *

      + *
    • start of bloom filter offset (big endian long coding)
    • + *
    • start of shard index position offset (big endian long coding)
    • + *
    • number of keys in file (big endian long coding)
    • + *
    • 0x01 (version key as a single byte)
    • + *
    + */ + static class Footer { + static final int LONG_BYTES = 8; + static final int FIXED_LENGTH = 3 * LONG_BYTES + 1; + static final byte VERSION = 2; + + private final long indexPosition; + private final long bloomFilterPosition; + private final long numberOfKeys; + + Footer(long indexPosition, long bloomFilterPosition, long numberOfKeys) { + this.indexPosition = indexPosition; + this.bloomFilterPosition = bloomFilterPosition; + this.numberOfKeys = numberOfKeys; + } + + public long getIndexPosition() { + return indexPosition; + } + + public long getBloomFilterPosition() { + return bloomFilterPosition; + } + + public long getNumberOfKeys() { + return numberOfKeys; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof Footer)) { + return false; + } + Footer footer = (Footer) other; + return indexPosition == footer.indexPosition + && bloomFilterPosition == footer.bloomFilterPosition + && numberOfKeys == footer.numberOfKeys; + } + + @Override + public int hashCode() { + return Objects.hashCode(indexPosition, bloomFilterPosition, numberOfKeys); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("version", Footer.VERSION) + .add("indexPosition", indexPosition) + .add("bloomFilterPosition", bloomFilterPosition) + .add("numberOfKeys", numberOfKeys) + .toString(); + } + } + + /** A {@link Coder} for {@link Footer}. */ + static final class FooterCoder extends AtomicCoder
    { + private static final FooterCoder INSTANCE = new FooterCoder(); + + @JsonCreator + public static FooterCoder of() { + return INSTANCE; + } + + @Override + public void encode(Footer value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + DataOutputStream dataOut = new DataOutputStream(outStream); + dataOut.writeLong(value.indexPosition); + dataOut.writeLong(value.bloomFilterPosition); + dataOut.writeLong(value.numberOfKeys); + dataOut.write(Footer.VERSION); + } + + @Override + public Footer decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + DataInputStream dataIn = new DataInputStream(inStream); + Footer footer = new Footer(dataIn.readLong(), dataIn.readLong(), dataIn.readLong()); + int version = dataIn.read(); + if (version != Footer.VERSION) { + throw new IOException("Unknown version " + version + ". " + + "Only version 2 is currently supported."); + } + return footer; + } + + @Override + public boolean consistentWithEquals() { + return true; + } + + @Override + public boolean isRegisterByteSizeObserverCheap(Footer value, Coder.Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(Footer value, Coder.Context context) + throws Exception { + return Footer.FIXED_LENGTH; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/package-info.java new file mode 100644 index 000000000000..af0a345d6d95 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/worker/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Implementation of the harness that runs on each Google Compute Engine instance to coordinate + * execution of Pipeline code. + */ +@ParametersAreNonnullByDefault +package com.google.cloud.dataflow.sdk.runners.worker; + +import javax.annotation.ParametersAreNonnullByDefault; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/CoderProperties.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/CoderProperties.java new file mode 100644 index 000000000000..5705dc4c78b6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/CoderProperties.java @@ -0,0 +1,349 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.util.Structs; +import com.google.cloud.dataflow.sdk.util.UnownedInputStream; +import com.google.cloud.dataflow.sdk.util.UnownedOutputStream; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Properties for use in {@link Coder} tests. These are implemented with junit assertions + * rather than as predicates for the sake of error messages. + * + *

    We serialize and deserialize the coder to make sure that any state information required by + * the coder is preserved. This causes tests written such that coders that lose information during + * serialization or change state during encoding/decoding will fail. + */ +public class CoderProperties { + + /** + * All the contexts, for use in test cases. + */ + public static final List ALL_CONTEXTS = Arrays.asList( + Coder.Context.OUTER, Coder.Context.NESTED); + + /** + * Verifies that for the given {@code Coder}, and values of + * type {@code T}, if the values are equal then the encoded bytes are equal, in any + * {@code Coder.Context}. + */ + public static void coderDeterministic( + Coder coder, T value1, T value2) + throws Exception { + for (Coder.Context context : ALL_CONTEXTS) { + coderDeterministicInContext(coder, context, value1, value2); + } + } + + /** + * Verifies that for the given {@code Coder}, {@code Coder.Context}, and values of + * type {@code T}, if the values are equal then the encoded bytes are equal. + */ + public static void coderDeterministicInContext( + Coder coder, Coder.Context context, T value1, T value2) + throws Exception { + + try { + coder.verifyDeterministic(); + } catch (NonDeterministicException e) { + fail("Expected that the coder is deterministic"); + } + assertThat("Expected that the passed in values are equal()", value1, equalTo(value2)); + assertThat( + encode(coder, context, value1), + equalTo(encode(coder, context, value2))); + } + + /** + * Verifies that for the given {@code Coder}, + * and value of type {@code T}, encoding followed by decoding yields an + * equal value of type {@code T}, in any {@code Coder.Context}. + */ + public static void coderDecodeEncodeEqual( + Coder coder, T value) + throws Exception { + for (Coder.Context context : ALL_CONTEXTS) { + coderDecodeEncodeEqualInContext(coder, context, value); + } + } + + /** + * Verifies that for the given {@code Coder}, {@code Coder.Context}, + * and value of type {@code T}, encoding followed by decoding yields an + * equal value of type {@code T}. + */ + public static void coderDecodeEncodeEqualInContext( + Coder coder, Coder.Context context, T value) + throws Exception { + assertThat(decodeEncode(coder, context, value), equalTo(value)); + } + + /** + * Verifies that for the given {@code Coder>}, + * and value of type {@code Collection}, encoding followed by decoding yields an + * equal value of type {@code Collection}, in any {@code Coder.Context}. + */ + public static > void coderDecodeEncodeContentsEqual( + Coder coder, CollectionT value) + throws Exception { + for (Coder.Context context : ALL_CONTEXTS) { + coderDecodeEncodeContentsEqualInContext(coder, context, value); + } + } + + /** + * Verifies that for the given {@code Coder>}, + * and value of type {@code Collection}, encoding followed by decoding yields an + * equal value of type {@code Collection}, in the given {@code Coder.Context}. + */ + @SuppressWarnings("unchecked") + public static > void coderDecodeEncodeContentsEqualInContext( + Coder coder, Coder.Context context, CollectionT value) + throws Exception { + // Matchers.containsInAnyOrder() requires at least one element + Collection result = decodeEncode(coder, context, value); + if (value.isEmpty()) { + assertThat(result, emptyIterable()); + } else { + // This is the only Matchers.containInAnyOrder() overload that takes literal values + assertThat(result, containsInAnyOrder((T[]) value.toArray())); + } + } + + /** + * Verifies that for the given {@code Coder>}, + * and value of type {@code Collection}, encoding followed by decoding yields an + * equal value of type {@code Collection}, in any {@code Coder.Context}. + */ + public static > void coderDecodeEncodeContentsInSameOrder( + Coder coder, IterableT value) + throws Exception { + for (Coder.Context context : ALL_CONTEXTS) { + CoderProperties.coderDecodeEncodeContentsInSameOrderInContext( + coder, context, value); + } + } + + /** + * Verifies that for the given {@code Coder>}, + * and value of type {@code Iterable}, encoding followed by decoding yields an + * equal value of type {@code Collection}, in the given {@code Coder.Context}. + */ + @SuppressWarnings("unchecked") + public static > void + coderDecodeEncodeContentsInSameOrderInContext( + Coder coder, Coder.Context context, IterableT value) + throws Exception { + Iterable result = decodeEncode(coder, context, value); + // Matchers.contains() requires at least one element + if (Iterables.isEmpty(value)) { + assertThat(result, emptyIterable()); + } else { + // This is the only Matchers.contains() overload that takes literal values + assertThat(result, contains((T[]) Iterables.toArray(value, Object.class))); + } + } + + public static void coderSerializable(Coder coder) { + SerializableUtils.ensureSerializable(coder); + } + + public static void coderConsistentWithEquals( + Coder coder, T value1, T value2) + throws Exception { + + for (Coder.Context context : ALL_CONTEXTS) { + CoderProperties.coderConsistentWithEqualsInContext(coder, context, value1, value2); + } + } + + public static void coderConsistentWithEqualsInContext( + Coder coder, Coder.Context context, T value1, T value2) throws Exception { + + assertEquals( + value1.equals(value2), + Arrays.equals( + encode(coder, context, value1), + encode(coder, context, value2))); + } + + public static void coderHasEncodingId(Coder coder, String encodingId) throws Exception { + assertThat(coder.getEncodingId(), equalTo(encodingId)); + assertThat(Structs.getString(coder.asCloudObject(), PropertyNames.ENCODING_ID, ""), + equalTo(encodingId)); + } + + public static void coderAllowsEncoding(Coder coder, String encodingId) throws Exception { + assertThat(coder.getAllowedEncodings(), hasItem(encodingId)); + assertThat( + String.format("Expected to find \"%s\" in property \"%s\" of %s", + encodingId, PropertyNames.ALLOWED_ENCODINGS, coder.asCloudObject()), + Structs.getStrings( + coder.asCloudObject(), + PropertyNames.ALLOWED_ENCODINGS, + Collections.emptyList()), + hasItem(encodingId)); + } + + public static void structuralValueConsistentWithEquals( + Coder coder, T value1, T value2) + throws Exception { + + for (Coder.Context context : ALL_CONTEXTS) { + CoderProperties.structuralValueConsistentWithEqualsInContext( + coder, context, value1, value2); + } + } + + public static void structuralValueConsistentWithEqualsInContext( + Coder coder, Coder.Context context, T value1, T value2) throws Exception { + + assertEquals( + coder.structuralValue(value1).equals(coder.structuralValue(value2)), + Arrays.equals( + encode(coder, context, value1), + encode(coder, context, value2))); + } + + + private static final String DECODING_WIRE_FORMAT_MESSAGE = + "Decoded value from known wire format does not match expected value." + + " This probably means that this Coder no longer correctly decodes" + + " a prior wire format. Changing the wire formats this Coder can read" + + " should be avoided, as it is likely to cause breakage." + + " If you truly intend to change the backwards compatibility for this Coder " + + " then you must remove any now-unsupported encodings from getAllowedEncodings()."; + + public static void coderDecodesBase64(Coder coder, String base64Encoding, T value) + throws Exception { + assertThat(DECODING_WIRE_FORMAT_MESSAGE, CoderUtils.decodeFromBase64(coder, base64Encoding), + equalTo(value)); + } + + public static void coderDecodesBase64( + Coder coder, List base64Encodings, List values) throws Exception { + assertThat("List of base64 encodings has different size than List of values", + base64Encodings.size(), equalTo(values.size())); + + for (int i = 0; i < base64Encodings.size(); i++) { + coderDecodesBase64(coder, base64Encodings.get(i), values.get(i)); + } + } + + private static final String ENCODING_WIRE_FORMAT_MESSAGE = + "Encoded value does not match expected wire format." + + " Changing the wire format should be avoided, as it is likely to cause breakage." + + " If you truly intend to change the wire format for this Coder " + + " then you must update getEncodingId() to a new value and add any supported" + + " prior formats to getAllowedEncodings()." + + " See com.google.cloud.dataflow.sdk.coders.PrintBase64Encoding for how to generate" + + " new test data."; + + public static void coderEncodesBase64(Coder coder, T value, String base64Encoding) + throws Exception { + assertThat(ENCODING_WIRE_FORMAT_MESSAGE, CoderUtils.encodeToBase64(coder, value), + equalTo(base64Encoding)); + } + + public static void coderEncodesBase64( + Coder coder, List values, List base64Encodings) throws Exception { + assertThat("List of base64 encodings has different size than List of values", + base64Encodings.size(), equalTo(values.size())); + + for (int i = 0; i < base64Encodings.size(); i++) { + coderEncodesBase64(coder, values.get(i), base64Encodings.get(i)); + } + } + + @SuppressWarnings("unchecked") + public static > void coderDecodesBase64ContentsEqual( + Coder coder, String base64Encoding, IterableT expected) throws Exception { + + IterableT result = CoderUtils.decodeFromBase64(coder, base64Encoding); + if (Iterables.isEmpty(expected)) { + assertThat(ENCODING_WIRE_FORMAT_MESSAGE, result, emptyIterable()); + } else { + assertThat(ENCODING_WIRE_FORMAT_MESSAGE, result, + containsInAnyOrder((T[]) Iterables.toArray(expected, Object.class))); + } + } + + public static > void coderDecodesBase64ContentsEqual( + Coder coder, List base64Encodings, List expected) + throws Exception { + assertThat("List of base64 encodings has different size than List of values", + base64Encodings.size(), equalTo(expected.size())); + + for (int i = 0; i < base64Encodings.size(); i++) { + coderDecodesBase64ContentsEqual(coder, base64Encodings.get(i), expected.get(i)); + } + } + + ////////////////////////////////////////////////////////////////////////// + + @VisibleForTesting + static byte[] encode( + Coder coder, Coder.Context context, T value) throws CoderException, IOException { + @SuppressWarnings("unchecked") + Coder deserializedCoder = Serializer.deserialize(coder.asCloudObject(), Coder.class); + + ByteArrayOutputStream os = new ByteArrayOutputStream(); + deserializedCoder.encode(value, new UnownedOutputStream(os), context); + return os.toByteArray(); + } + + @VisibleForTesting + static T decode( + Coder coder, Coder.Context context, byte[] bytes) throws CoderException, IOException { + @SuppressWarnings("unchecked") + Coder deserializedCoder = Serializer.deserialize(coder.asCloudObject(), Coder.class); + + ByteArrayInputStream is = new ByteArrayInputStream(bytes); + return deserializedCoder.decode(new UnownedInputStream(is), context); + } + + private static T decodeEncode(Coder coder, Coder.Context context, T value) + throws CoderException, IOException { + return decode(coder, context, encode(coder, context, value)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/DataflowAssert.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/DataflowAssert.java new file mode 100644 index 000000000000..6c9643cc3c0e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/DataflowAssert.java @@ -0,0 +1,825 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.StreamingOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.common.base.Optional; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * An assertion on the contents of a {@link PCollection} + * incorporated into the pipeline. Such an assertion + * can be checked no matter what kind of {@link PipelineRunner} is + * used. + * + *

    Note that the {@code DataflowAssert} call must precede the call + * to {@link Pipeline#run}. + * + *

    Examples of use: + *

    {@code
    + * Pipeline p = TestPipeline.create();
    + * ...
    + * PCollection output =
    + *      input
    + *      .apply(ParDo.of(new TestDoFn()));
    + * DataflowAssert.that(output)
    + *     .containsInAnyOrder("out1", "out2", "out3");
    + * ...
    + * PCollection ints = ...
    + * PCollection sum =
    + *     ints
    + *     .apply(Combine.globally(new SumInts()));
    + * DataflowAssert.that(sum)
    + *     .is(42);
    + * ...
    + * p.run();
    + * }
    + * + *

    JUnit and Hamcrest must be linked in by any code that uses DataflowAssert. + */ +public class DataflowAssert { + + private static final Logger LOG = LoggerFactory.getLogger(DataflowAssert.class); + + static final String SUCCESS_COUNTER = "DataflowAssertSuccess"; + static final String FAILURE_COUNTER = "DataflowAssertFailure"; + + private static int assertCount = 0; + + // Do not instantiate. + private DataflowAssert() {} + + /** + * Constructs an {@link IterableAssert} for the elements of the provided + * {@link PCollection}. + */ + public static IterableAssert that(PCollection actual) { + return new IterableAssert<>( + new CreateActual>(actual, View.asIterable()), + actual.getPipeline()) + .setCoder(actual.getCoder()); + } + + /** + * Constructs an {@link IterableAssert} for the value of the provided + * {@link PCollection} which must contain a single {@code Iterable} + * value. + */ + public static IterableAssert + thatSingletonIterable(PCollection> actual) { + + List> maybeElementCoder = actual.getCoder().getCoderArguments(); + Coder tCoder; + try { + @SuppressWarnings("unchecked") + Coder tCoderTmp = (Coder) Iterables.getOnlyElement(maybeElementCoder); + tCoder = tCoderTmp; + } catch (NoSuchElementException | IllegalArgumentException exc) { + throw new IllegalArgumentException( + "DataflowAssert.thatSingletonIterable requires a PCollection>" + + " with a Coder> where getCoderArguments() yields a" + + " single Coder to apply to the elements."); + } + + @SuppressWarnings("unchecked") // Safe covariant cast + PCollection> actualIterables = (PCollection>) actual; + + return new IterableAssert<>( + new CreateActual, Iterable>( + actualIterables, View.>asSingleton()), + actual.getPipeline()) + .setCoder(tCoder); + } + + /** + * Constructs an {@link IterableAssert} for the value of the provided + * {@code PCollectionView PCollectionView>}. + */ + public static IterableAssert thatIterable(PCollectionView> actual) { + return new IterableAssert<>(new PreExisting>(actual), actual.getPipeline()); + } + + /** + * Constructs a {@link SingletonAssert} for the value of the provided + * {@code PCollection PCollection}, which must be a singleton. + */ + public static SingletonAssert thatSingleton(PCollection actual) { + return new SingletonAssert<>( + new CreateActual(actual, View.asSingleton()), actual.getPipeline()) + .setCoder(actual.getCoder()); + } + + /** + * Constructs a {@link SingletonAssert} for the value of the provided {@link PCollection}. + * + *

    Note that the actual value must be coded by a {@link KvCoder}, + * not just any {@code Coder}. + */ + public static SingletonAssert>> + thatMultimap(PCollection> actual) { + @SuppressWarnings("unchecked") + KvCoder kvCoder = (KvCoder) actual.getCoder(); + + return new SingletonAssert<>( + new CreateActual<>(actual, View.asMultimap()), actual.getPipeline()) + .setCoder(MapCoder.of(kvCoder.getKeyCoder(), IterableCoder.of(kvCoder.getValueCoder()))); + } + + /** + * Constructs a {@link SingletonAssert} for the value of the provided {@link PCollection}, + * which must have at most one value per key. + * + *

    Note that the actual value must be coded by a {@link KvCoder}, + * not just any {@code Coder}. + */ + public static SingletonAssert> thatMap(PCollection> actual) { + @SuppressWarnings("unchecked") + KvCoder kvCoder = (KvCoder) actual.getCoder(); + + return new SingletonAssert<>( + new CreateActual<>(actual, View.asMap()), actual.getPipeline()) + .setCoder(MapCoder.of(kvCoder.getKeyCoder(), kvCoder.getValueCoder())); + } + + //////////////////////////////////////////////////////////// + + /** + * An assertion about the contents of a {@link PCollectionView} yielding an {@code Iterable}. + */ + public static class IterableAssert implements Serializable { + private final Pipeline pipeline; + private final PTransform>> createActual; + private Optional> coder; + + protected IterableAssert( + PTransform>> createActual, Pipeline pipeline) { + this.createActual = createActual; + this.pipeline = pipeline; + this.coder = Optional.absent(); + } + + /** + * Sets the coder to use for elements of type {@code T}, as needed for internal purposes. + * + *

    Returns this {@code IterableAssert}. + */ + public IterableAssert setCoder(Coder coderOrNull) { + this.coder = Optional.fromNullable(coderOrNull); + return this; + } + + /** + * Gets the coder, which may yet be absent. + */ + public Coder getCoder() { + if (coder.isPresent()) { + return coder.get(); + } else { + throw new IllegalStateException( + "Attempting to access the coder of an IterableAssert" + + " that has not been set yet."); + } + } + + /** + * Applies a {@link SerializableFunction} to check the elements of the {@code Iterable}. + * + *

    Returns this {@code IterableAssert}. + */ + public IterableAssert satisfies(SerializableFunction, Void> checkerFn) { + pipeline.apply( + "DataflowAssert$" + (assertCount++), + new OneSideInputAssert>(createActual, checkerFn)); + return this; + } + + /** + * Applies a {@link SerializableFunction} to check the elements of the {@code Iterable}. + * + *

    Returns this {@code IterableAssert}. + */ + public IterableAssert satisfies( + AssertRelation, Iterable> relation, + final Iterable expectedElements) { + pipeline.apply( + "DataflowAssert$" + (assertCount++), + new TwoSideInputAssert, Iterable>(createActual, + new CreateExpected>(expectedElements, coder, View.asIterable()), + relation)); + + return this; + } + + /** + * Applies a {@link SerializableMatcher} to check the elements of the {@code Iterable}. + * + *

    Returns this {@code IterableAssert}. + */ + IterableAssert satisfies(final SerializableMatcher> matcher) { + // Safe covariant cast. Could be elided by changing a lot of this file to use + // more flexible bounds. + @SuppressWarnings({"rawtypes", "unchecked"}) + SerializableFunction, Void> checkerFn = + (SerializableFunction) new MatcherCheckerFn<>(matcher); + pipeline.apply( + "DataflowAssert$" + (assertCount++), + new OneSideInputAssert>( + createActual, + checkerFn)); + return this; + } + + private static class MatcherCheckerFn implements SerializableFunction { + private SerializableMatcher matcher; + + public MatcherCheckerFn(SerializableMatcher matcher) { + this.matcher = matcher; + } + + @Override + public Void apply(T actual) { + assertThat(actual, matcher); + return null; + } + } + + /** + * Checks that the {@code Iterable} is empty. + * + *

    Returns this {@code IterableAssert}. + */ + public IterableAssert empty() { + return satisfies(new AssertContainsInAnyOrderRelation(), Collections.emptyList()); + } + + /** + * @throws UnsupportedOperationException always + * @deprecated {@link Object#equals(Object)} is not supported on DataflowAssert objects. + * If you meant to test object equality, use a variant of {@link #containsInAnyOrder} + * instead. + */ + @Deprecated + @Override + public boolean equals(Object o) { + throw new UnsupportedOperationException( + "If you meant to test object equality, use .containsInAnyOrder instead."); + } + + /** + * @throws UnsupportedOperationException always. + * @deprecated {@link Object#hashCode()} is not supported on DataflowAssert objects. + */ + @Deprecated + @Override + public int hashCode() { + throw new UnsupportedOperationException( + String.format("%s.hashCode() is not supported.", IterableAssert.class.getSimpleName())); + } + + /** + * Checks that the {@code Iterable} contains the expected elements, in any + * order. + * + *

    Returns this {@code IterableAssert}. + */ + public IterableAssert containsInAnyOrder(Iterable expectedElements) { + return satisfies(new AssertContainsInAnyOrderRelation(), expectedElements); + } + + /** + * Checks that the {@code Iterable} contains the expected elements, in any + * order. + * + *

    Returns this {@code IterableAssert}. + */ + @SafeVarargs + public final IterableAssert containsInAnyOrder(T... expectedElements) { + return satisfies( + new AssertContainsInAnyOrderRelation(), + Arrays.asList(expectedElements)); + } + + /** + * Checks that the {@code Iterable} contains elements that match the provided matchers, + * in any order. + * + *

    Returns this {@code IterableAssert}. + */ + @SafeVarargs + final IterableAssert containsInAnyOrder( + SerializableMatcher... elementMatchers) { + return satisfies(SerializableMatchers.containsInAnyOrder(elementMatchers)); + } + } + + /** + * An assertion about the single value of type {@code T} + * associated with a {@link PCollectionView}. + */ + public static class SingletonAssert implements Serializable { + private final Pipeline pipeline; + private final CreateActual createActual; + private Optional> coder; + + protected SingletonAssert( + CreateActual createActual, Pipeline pipeline) { + this.pipeline = pipeline; + this.createActual = createActual; + this.coder = Optional.absent(); + } + + /** + * Always throws an {@link UnsupportedOperationException}: users are probably looking for + * {@link #isEqualTo}. + */ + @Deprecated + @Override + public boolean equals(Object o) { + throw new UnsupportedOperationException( + String.format( + "tests for Java equality of the %s object, not the PCollection in question. " + + "Call a test method, such as isEqualTo.", + getClass().getSimpleName())); + } + + /** + * @throws UnsupportedOperationException always. + * @deprecated {@link Object#hashCode()} is not supported on DataflowAssert objects. + */ + @Deprecated + @Override + public int hashCode() { + throw new UnsupportedOperationException( + String.format("%s.hashCode() is not supported.", SingletonAssert.class.getSimpleName())); + } + + /** + * Sets the coder to use for elements of type {@code T}, as needed + * for internal purposes. + * + *

    Returns this {@code IterableAssert}. + */ + public SingletonAssert setCoder(Coder coderOrNull) { + this.coder = Optional.fromNullable(coderOrNull); + return this; + } + + /** + * Gets the coder, which may yet be absent. + */ + public Coder getCoder() { + if (coder.isPresent()) { + return coder.get(); + } else { + throw new IllegalStateException( + "Attempting to access the coder of an IterableAssert that has not been set yet."); + } + } + + /** + * Applies a {@link SerializableFunction} to check the value of this + * {@code SingletonAssert}'s view. + * + *

    Returns this {@code SingletonAssert}. + */ + public SingletonAssert satisfies(SerializableFunction checkerFn) { + pipeline.apply( + "DataflowAssert$" + (assertCount++), + new OneSideInputAssert(createActual, checkerFn)); + return this; + } + + /** + * Applies an {@link AssertRelation} to check the provided relation against the + * value of this assert and the provided expected value. + * + *

    Returns this {@code SingletonAssert}. + */ + public SingletonAssert satisfies( + AssertRelation relation, + final T expectedValue) { + pipeline.apply( + "DataflowAssert$" + (assertCount++), + new TwoSideInputAssert(createActual, + new CreateExpected(Arrays.asList(expectedValue), coder, View.asSingleton()), + relation)); + + return this; + } + + /** + * Checks that the value of this {@code SingletonAssert}'s view is equal + * to the expected value. + * + *

    Returns this {@code SingletonAssert}. + */ + public SingletonAssert isEqualTo(T expectedValue) { + return satisfies(new AssertIsEqualToRelation(), expectedValue); + } + + /** + * Checks that the value of this {@code SingletonAssert}'s view is not equal + * to the expected value. + * + *

    Returns this {@code SingletonAssert}. + */ + public SingletonAssert notEqualTo(T expectedValue) { + return satisfies(new AssertNotEqualToRelation(), expectedValue); + } + + /** + * Checks that the value of this {@code SingletonAssert}'s view is equal to + * the expected value. + * + * @deprecated replaced by {@link #isEqualTo} + */ + @Deprecated + public SingletonAssert is(T expectedValue) { + return isEqualTo(expectedValue); + } + + } + + //////////////////////////////////////////////////////////////////////// + + private static class CreateActual + extends PTransform> { + + private final transient PCollection actual; + private final transient PTransform, PCollectionView> actualView; + + private CreateActual(PCollection actual, + PTransform, PCollectionView> actualView) { + this.actual = actual; + this.actualView = actualView; + } + + @Override + public PCollectionView apply(PBegin input) { + final Coder coder = actual.getCoder(); + return actual + .apply(Window.into(new GlobalWindows())) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext context) throws CoderException { + context.output(CoderUtils.clone(coder, context.element())); + } + })) + .apply(actualView); + } + } + + private static class CreateExpected + extends PTransform> { + + private final Iterable elements; + private final Optional> coder; + private final transient PTransform, PCollectionView> view; + + private CreateExpected(Iterable elements, Optional> coder, + PTransform, PCollectionView> view) { + this.elements = elements; + this.coder = coder; + this.view = view; + } + + @Override + public PCollectionView apply(PBegin input) { + Create.Values createTransform = Create.of(elements); + if (coder.isPresent()) { + createTransform = createTransform.withCoder(coder.get()); + } + return input.apply(createTransform).apply(view); + } + } + + private static class PreExisting extends PTransform> { + + private final PCollectionView view; + + private PreExisting(PCollectionView view) { + this.view = view; + } + + @Override + public PCollectionView apply(PBegin input) { + return view; + } + } + + /** + * An assertion checker that takes a single + * {@link PCollectionView PCollectionView<ActualT>} + * and an assertion over {@code ActualT}, and checks it within a dataflow + * pipeline. + * + *

    Note that the entire assertion must be serializable. If + * you need to make assertions involving multiple inputs + * that are each not serializable, use TwoSideInputAssert. + * + *

    This is generally useful for assertion functions that + * are serializable but whose underlying data may not have a coder. + */ + static class OneSideInputAssert + extends PTransform implements Serializable { + private final transient PTransform> createActual; + private final SerializableFunction checkerFn; + + public OneSideInputAssert( + PTransform> createActual, + SerializableFunction checkerFn) { + this.createActual = createActual; + this.checkerFn = checkerFn; + } + + @Override + public PDone apply(PBegin input) { + final PCollectionView actual = input.apply("CreateActual", createActual); + + input + .apply(Create.of((Void) null).withCoder(VoidCoder.of())) + .apply(ParDo.named("RunChecks").withSideInputs(actual) + .of(new CheckerDoFn<>(checkerFn, actual))); + + return PDone.in(input.getPipeline()); + } + } + + /** + * A {@link DoFn} that runs a checking {@link SerializableFunction} on the contents of + * a {@link PCollectionView}, and adjusts counters and thrown exceptions for use in testing. + */ + private static class CheckerDoFn extends DoFn { + private final SerializableFunction checkerFn; + private final Aggregator success = + createAggregator(SUCCESS_COUNTER, new Sum.SumIntegerFn()); + private final Aggregator failure = + createAggregator(FAILURE_COUNTER, new Sum.SumIntegerFn()); + private final PCollectionView actual; + + private CheckerDoFn( + SerializableFunction checkerFn, + PCollectionView actual) { + this.checkerFn = checkerFn; + this.actual = actual; + } + + @Override + public void processElement(ProcessContext c) { + try { + ActualT actualContents = c.sideInput(actual); + checkerFn.apply(actualContents); + success.addValue(1); + } catch (Throwable t) { + LOG.error("DataflowAssert failed expectations.", t); + failure.addValue(1); + // TODO: allow for metrics to propagate on failure when running a streaming pipeline + if (!c.getPipelineOptions().as(StreamingOptions.class).isStreaming()) { + throw t; + } + } + } + } + + /** + * An assertion checker that takes a {@link PCollectionView PCollectionView<ActualT>}, + * a {@link PCollectionView PCollectionView<ExpectedT>}, a relation + * over {@code A} and {@code B}, and checks that the relation holds + * within a dataflow pipeline. + * + *

    This is useful when either/both of {@code A} and {@code B} + * are not serializable, but have coders (provided + * by the underlying {@link PCollection}s). + */ + static class TwoSideInputAssert + extends PTransform implements Serializable { + + private final transient PTransform> createActual; + private final transient PTransform> createExpected; + private final AssertRelation relation; + + protected TwoSideInputAssert( + PTransform> createActual, + PTransform> createExpected, + AssertRelation relation) { + this.createActual = createActual; + this.createExpected = createExpected; + this.relation = relation; + } + + @Override + public PDone apply(PBegin input) { + final PCollectionView actual = input.apply("CreateActual", createActual); + final PCollectionView expected = input.apply("CreateExpected", createExpected); + + input + .apply(Create.of((Void) null).withCoder(VoidCoder.of())) + .apply(ParDo.named("RunChecks").withSideInputs(actual, expected) + .of(new CheckerDoFn<>(relation, actual, expected))); + + return PDone.in(input.getPipeline()); + } + + private static class CheckerDoFn extends DoFn { + private final Aggregator success = + createAggregator(SUCCESS_COUNTER, new Sum.SumIntegerFn()); + private final Aggregator failure = + createAggregator(FAILURE_COUNTER, new Sum.SumIntegerFn()); + private final AssertRelation relation; + private final PCollectionView actual; + private final PCollectionView expected; + + private CheckerDoFn(AssertRelation relation, + PCollectionView actual, PCollectionView expected) { + this.relation = relation; + this.actual = actual; + this.expected = expected; + } + + @Override + public void processElement(ProcessContext c) { + try { + ActualT actualContents = c.sideInput(actual); + ExpectedT expectedContents = c.sideInput(expected); + relation.assertFor(expectedContents).apply(actualContents); + success.addValue(1); + } catch (Throwable t) { + LOG.error("DataflowAssert failed expectations.", t); + failure.addValue(1); + // TODO: allow for metrics to propagate on failure when running a streaming pipeline + if (!c.getPipelineOptions().as(StreamingOptions.class).isStreaming()) { + throw t; + } + } + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@link SerializableFunction} that verifies that an actual value is equal to an + * expected value. + */ + private static class AssertIsEqualTo implements SerializableFunction { + private T expected; + + public AssertIsEqualTo(T expected) { + this.expected = expected; + } + + @Override + public Void apply(T actual) { + assertThat(actual, equalTo(expected)); + return null; + } + } + + /** + * A {@link SerializableFunction} that verifies that an actual value is not equal to an + * expected value. + */ + private static class AssertNotEqualTo implements SerializableFunction { + private T expected; + + public AssertNotEqualTo(T expected) { + this.expected = expected; + } + + @Override + public Void apply(T actual) { + assertThat(actual, not(equalTo(expected))); + return null; + } + } + + /** + * A {@link SerializableFunction} that verifies that an {@code Iterable} contains + * expected items in any order. + */ + private static class AssertContainsInAnyOrder + implements SerializableFunction, Void> { + private T[] expected; + + @SafeVarargs + public AssertContainsInAnyOrder(T... expected) { + this.expected = expected; + } + + @SuppressWarnings("unchecked") + public AssertContainsInAnyOrder(Collection expected) { + this((T[]) expected.toArray()); + } + + public AssertContainsInAnyOrder(Iterable expected) { + this(Lists.newArrayList(expected)); + } + + @Override + public Void apply(Iterable actual) { + assertThat(actual, containsInAnyOrder(expected)); + return null; + } + } + + //////////////////////////////////////////////////////////// + + /** + * A binary predicate between types {@code Actual} and {@code Expected}. + * Implemented as a method {@code assertFor(Expected)} which returns + * a {@code SerializableFunction} + * that should verify the assertion.. + */ + private static interface AssertRelation extends Serializable { + public SerializableFunction assertFor(ExpectedT input); + } + + /** + * An {@link AssertRelation} implementing the binary predicate that two objects are equal. + */ + private static class AssertIsEqualToRelation + implements AssertRelation { + @Override + public SerializableFunction assertFor(T expected) { + return new AssertIsEqualTo(expected); + } + } + + /** + * An {@link AssertRelation} implementing the binary predicate that two objects are not equal. + */ + private static class AssertNotEqualToRelation + implements AssertRelation { + @Override + public SerializableFunction assertFor(T expected) { + return new AssertNotEqualTo(expected); + } + } + + /** + * An {@code AssertRelation} implementing the binary predicate that two collections are equal + * modulo reordering. + */ + private static class AssertContainsInAnyOrderRelation + implements AssertRelation, Iterable> { + @Override + public SerializableFunction, Void> assertFor(Iterable expectedElements) { + return new AssertContainsInAnyOrder(expectedElements); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/RunnableOnService.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/RunnableOnService.java new file mode 100644 index 000000000000..60ab2e51b667 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/RunnableOnService.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +/** + * Category tag for tests that can be run on the + * {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner} if the + * {@code runIntegrationTestOnService} System property is set to true. + * Example usage: + *

    
    + *     {@literal @}Test
    + *     {@literal @}Category(RunnableOnService.class)
    + *     public void testParDo() {...
    + * 
    + */ +public interface RunnableOnService {} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SerializableMatcher.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SerializableMatcher.java new file mode 100644 index 000000000000..10f221e347e3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SerializableMatcher.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import org.hamcrest.Matcher; + +import java.io.Serializable; + +/** + * A {@link Matcher} that is also {@link Serializable}. + * + *

    Such matchers can be used with {@link DataflowAssert}, which builds Dataflow pipelines + * such that these matchers may be serialized and executed remotely. + * + *

    To create a {@code SerializableMatcher}, extend {@link org.hamcrest.BaseMatcher} + * and also implement this interface. + * + * @param The type of value matched. + */ +interface SerializableMatcher extends Matcher, Serializable { +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SerializableMatchers.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SerializableMatchers.java new file mode 100644 index 000000000000..da5171e21f50 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SerializableMatchers.java @@ -0,0 +1,1180 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.base.MoreObjects; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Static class for building and using {@link SerializableMatcher} instances. + * + *

    Most matchers are wrappers for hamcrest's {@link Matchers}. Please be familiar with the + * documentation there. Values retained by a {@link SerializableMatcher} are required to be + * serializable, either via Java serialization or via a provided {@link Coder}. + * + *

    The following matchers are novel to Dataflow: + *

      + *
    • {@link #kvWithKey} for matching just the key of a {@link KV}. + *
    • {@link #kvWithValue} for matching just the value of a {@link KV}. + *
    • {@link #kv} for matching the key and value of a {@link KV}. + *
    + * + *

    For example, to match a group from + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}, which has type + * {@code KV>} for some {@code K} and {@code V} and where the order of the iterable + * is undefined, use a matcher like + * {@code kv(equalTo("some key"), containsInAnyOrder(1, 2, 3))}. + */ +class SerializableMatchers implements Serializable { + + // Serializable only because of capture by anonymous inner classes + private SerializableMatchers() { } // not instantiable + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#allOf(Iterable)}. + */ + public static SerializableMatcher + allOf(Iterable> serializableMatchers) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe covariant cast + final Iterable> matchers = (Iterable) serializableMatchers; + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.allOf(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#allOf(Matcher[])}. + */ + @SafeVarargs + public static SerializableMatcher allOf(final SerializableMatcher... matchers) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.allOf(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#anyOf(Iterable)}. + */ + public static SerializableMatcher + anyOf(Iterable> serializableMatchers) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe covariant cast + final Iterable> matchers = (Iterable) serializableMatchers; + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.anyOf(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#anyOf(Matcher[])}. + */ + @SafeVarargs + public static SerializableMatcher anyOf(final SerializableMatcher... matchers) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.anyOf(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#anything()}. + */ + public static SerializableMatcher anything() { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.anything(); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContaining(Object[])}. + */ + @SafeVarargs + public static SerializableMatcher + arrayContaining(final T... items) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContaining(items); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContaining(Object[])}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. They are + * explicitly not required or expected to be serializable via Java serialization. + */ + @SafeVarargs + public static SerializableMatcher arrayContaining(Coder coder, T... items) { + + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContaining(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContaining(Matcher[])}. + */ + @SafeVarargs + public static SerializableMatcher + arrayContaining(final SerializableMatcher... matchers) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContaining(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContaining(List)}. + */ + public static SerializableMatcher + arrayContaining(List> serializableMatchers) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe covariant cast + final List> matchers = (List) serializableMatchers; + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContaining(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContainingInAnyOrder(Object[])}. + */ + @SafeVarargs + public static SerializableMatcher + arrayContainingInAnyOrder(final T... items) { + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContainingInAnyOrder(items); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContainingInAnyOrder(Object[])}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. They are + * explicitly not required or expected to be serializable via Java serialization. + */ + @SafeVarargs + public static SerializableMatcher arrayContainingInAnyOrder(Coder coder, T... items) { + + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContaining(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContainingInAnyOrder(Matcher[])}. + */ + @SafeVarargs + public static SerializableMatcher arrayContainingInAnyOrder( + final SerializableMatcher... matchers) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContainingInAnyOrder(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayContainingInAnyOrder(Collection)}. + */ + public static SerializableMatcher arrayContainingInAnyOrder( + Collection> serializableMatchers) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe covariant cast + final Collection> matchers = (Collection) serializableMatchers; + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayContainingInAnyOrder(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayWithSize(int)}. + */ + public static SerializableMatcher arrayWithSize(final int size) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayWithSize(size); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#arrayWithSize(Matcher)}. + */ + public static SerializableMatcher arrayWithSize( + final SerializableMatcher sizeMatcher) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.arrayWithSize(sizeMatcher); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#closeTo(double,double)}. + */ + public static SerializableMatcher closeTo(final double target, final double error) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.closeTo(target, error); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#contains(Object[])}. + */ + @SafeVarargs + public static SerializableMatcher> contains( + final T... items) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.contains(items); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#contains(Object[])}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. They are + * explicitly not required or expected to be serializable via Java serialization. + */ + @SafeVarargs + public static SerializableMatcher> + contains(Coder coder, T... items) { + + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.containsInAnyOrder(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#contains(Matcher[])}. + */ + @SafeVarargs + public static SerializableMatcher> contains( + final SerializableMatcher... matchers) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.contains(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#contains(List)}. + */ + public static SerializableMatcher> contains( + List> serializableMatchers) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe covariant cast + final List> matchers = (List) serializableMatchers; + + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.contains(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#containsInAnyOrder(Object[])}. + */ + @SafeVarargs + public static SerializableMatcher> + containsInAnyOrder(final T... items) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.containsInAnyOrder(items); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#containsInAnyOrder(Object[])}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + @SafeVarargs + public static SerializableMatcher> + containsInAnyOrder(Coder coder, T... items) { + + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.containsInAnyOrder(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#containsInAnyOrder(Matcher[])}. + */ + @SafeVarargs + public static SerializableMatcher> containsInAnyOrder( + final SerializableMatcher... matchers) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.containsInAnyOrder(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#containsInAnyOrder(Collection)}. + */ + public static SerializableMatcher> containsInAnyOrder( + Collection> serializableMatchers) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe covariant cast + final Collection> matchers = (Collection) serializableMatchers; + + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.containsInAnyOrder(matchers); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#containsString}. + */ + public static SerializableMatcher containsString(final String substring) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.containsString(substring); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#empty()}. + */ + public static SerializableMatcher> empty() { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.empty(); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#emptyArray()}. + */ + public static SerializableMatcher emptyArray() { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.emptyArray(); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#emptyIterable()}. + */ + public static SerializableMatcher> emptyIterable() { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.emptyIterable(); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#endsWith}. + */ + public static SerializableMatcher endsWith(final String substring) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.endsWith(substring); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#equalTo()}. + */ + public static SerializableMatcher equalTo(final T expected) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.equalTo(expected); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#equalTo()}. + * + *

    The expected value of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static SerializableMatcher equalTo(Coder coder, T expected) { + + final SerializableSupplier expectedSupplier = new SerializableViaCoder<>(coder, expected); + + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.equalTo(expectedSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#greaterThan()}. + */ + public static & Serializable> SerializableMatcher + greaterThan(final T target) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.greaterThan(target); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#greaterThan()}. + * + *

    The target value of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static & Serializable> SerializableMatcher + greaterThan(final Coder coder, T target) { + final SerializableSupplier targetSupplier = new SerializableViaCoder<>(coder, target); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.greaterThan(targetSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#greaterThanOrEqualTo()}. + */ + public static > SerializableMatcher greaterThanOrEqualTo( + final T target) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.greaterThanOrEqualTo(target); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#greaterThanOrEqualTo()}. + * + *

    The target value of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static & Serializable> SerializableMatcher + greaterThanOrEqualTo(final Coder coder, T target) { + final SerializableSupplier targetSupplier = new SerializableViaCoder<>(coder, target); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.greaterThanOrEqualTo(targetSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#hasItem(Object)}. + */ + public static SerializableMatcher> hasItem( + final T target) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.hasItem(target); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#hasItem(Object)}. + * + *

    The item of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static SerializableMatcher> hasItem(Coder coder, T target) { + final SerializableSupplier targetSupplier = new SerializableViaCoder<>(coder, target); + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.hasItem(targetSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#hasItem(Matcher)}. + */ + public static SerializableMatcher> hasItem( + final SerializableMatcher matcher) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.hasItem(matcher); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#hasSize(int)}. + */ + public static SerializableMatcher> hasSize(final int size) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.hasSize(size); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#hasSize(Matcher)}. + */ + public static SerializableMatcher> hasSize( + final SerializableMatcher sizeMatcher) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.hasSize(sizeMatcher); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#iterableWithSize(int)}. + */ + public static SerializableMatcher> iterableWithSize(final int size) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.iterableWithSize(size); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#iterableWithSize(Matcher)}. + */ + public static SerializableMatcher> iterableWithSize( + final SerializableMatcher sizeMatcher) { + return fromSupplier(new SerializableSupplier>>() { + @Override + public Matcher> get() { + return Matchers.iterableWithSize(sizeMatcher); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#isIn(Collection)}. + */ + public static SerializableMatcher + isIn(final Collection collection) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.isIn(collection); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#isIn(Collection)}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. + * They are explicitly not required or expected to be serializable via Java serialization. + */ + public static SerializableMatcher isIn(Coder coder, Collection collection) { + @SuppressWarnings("unchecked") + T[] items = (T[]) collection.toArray(); + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.isIn(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#isIn(Object[])}. + */ + public static SerializableMatcher isIn(final T[] items) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.isIn(items); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#isIn(Object[])}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. + * They are explicitly not required or expected to be serializable via Java serialization. + */ + public static SerializableMatcher isIn(Coder coder, T[] items) { + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.isIn(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#isOneOf}. + */ + @SafeVarargs + public static SerializableMatcher isOneOf(final T... elems) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.isOneOf(elems); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#isOneOf}. + * + *

    The items of type {@code T} will be serialized using the provided {@link Coder}. + * They are explicitly not required or expected to be serializable via Java serialization. + */ + @SafeVarargs + public static SerializableMatcher isOneOf(Coder coder, T... items) { + final SerializableSupplier itemsSupplier = + new SerializableArrayViaCoder<>(coder, items); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.isOneOf(itemsSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with the specified key. + */ + public static SerializableMatcher> + kvWithKey(K key) { + return new KvKeyMatcher(equalTo(key)); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with the specified key. + * + *

    The key of type {@code K} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static SerializableMatcher> + kvWithKey(Coder coder, K key) { + return new KvKeyMatcher(equalTo(coder, key)); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with matching key. + */ + public static SerializableMatcher> kvWithKey( + final SerializableMatcher keyMatcher) { + return new KvKeyMatcher(keyMatcher); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with the specified value. + */ + public static SerializableMatcher> + kvWithValue(V value) { + return new KvValueMatcher(equalTo(value)); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with the specified value. + * + *

    The value of type {@code V} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static SerializableMatcher> + kvWithValue(Coder coder, V value) { + return new KvValueMatcher(equalTo(coder, value)); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with matching value. + */ + public static SerializableMatcher> kvWithValue( + final SerializableMatcher valueMatcher) { + return new KvValueMatcher<>(valueMatcher); + } + + /** + * A {@link SerializableMatcher} that matches any {@link KV} with matching key and value. + */ + public static SerializableMatcher> kv( + final SerializableMatcher keyMatcher, + final SerializableMatcher valueMatcher) { + + return SerializableMatchers.>allOf( + SerializableMatchers.kvWithKey(keyMatcher), + SerializableMatchers.kvWithValue(valueMatcher)); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#lessThan()}. + */ + public static & Serializable> SerializableMatcher lessThan( + final T target) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.lessThan(target); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#lessThan()}. + * + *

    The target value of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static > SerializableMatcher + lessThan(Coder coder, T target) { + final SerializableSupplier targetSupplier = new SerializableViaCoder<>(coder, target); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.lessThan(targetSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#lessThanOrEqualTo()}. + */ + public static & Serializable> SerializableMatcher lessThanOrEqualTo( + final T target) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.lessThanOrEqualTo(target); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#lessThanOrEqualTo()}. + * + *

    The target value of type {@code T} will be serialized using the provided {@link Coder}. + * It is explicitly not required or expected to be serializable via Java serialization. + */ + public static > SerializableMatcher lessThanOrEqualTo( + Coder coder, T target) { + final SerializableSupplier targetSupplier = new SerializableViaCoder<>(coder, target); + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.lessThanOrEqualTo(targetSupplier.get()); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#not}. + */ + public static SerializableMatcher not(final SerializableMatcher matcher) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.not(matcher); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to + * {@link Matchers#nullValue}. + */ + public static SerializableMatcher nullValue() { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.nullValue(); + } + }); + } + + /** + * A {@link SerializableMatcher} with identical criteria to {@link Matchers#startsWith}. + */ + public static SerializableMatcher startsWith(final String substring) { + return fromSupplier(new SerializableSupplier>() { + @Override + public Matcher get() { + return Matchers.startsWith(substring); + } + }); + } + + private static class KvKeyMatcher + extends BaseMatcher> + implements SerializableMatcher> { + private final SerializableMatcher keyMatcher; + + public KvKeyMatcher(SerializableMatcher keyMatcher) { + this.keyMatcher = keyMatcher; + } + + @Override + public boolean matches(Object item) { + @SuppressWarnings("unchecked") + KV kvItem = (KV) item; + return keyMatcher.matches(kvItem.getKey()); + } + + @Override + public void describeMismatch(Object item, Description mismatchDescription) { + @SuppressWarnings("unchecked") + KV kvItem = (KV) item; + if (!keyMatcher.matches(kvItem.getKey())) { + mismatchDescription.appendText("key did not match: "); + keyMatcher.describeMismatch(kvItem.getKey(), mismatchDescription); + } + } + + @Override + public void describeTo(Description description) { + description.appendText("KV with key matching "); + keyMatcher.describeTo(description); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .addValue(keyMatcher) + .toString(); + } + } + + private static class KvValueMatcher + extends BaseMatcher> + implements SerializableMatcher> { + private final SerializableMatcher valueMatcher; + + public KvValueMatcher(SerializableMatcher valueMatcher) { + this.valueMatcher = valueMatcher; + } + + @Override + public boolean matches(Object item) { + @SuppressWarnings("unchecked") + KV kvItem = (KV) item; + return valueMatcher.matches(kvItem.getValue()); + } + + @Override + public void describeMismatch(Object item, Description mismatchDescription) { + @SuppressWarnings("unchecked") + KV kvItem = (KV) item; + if (!valueMatcher.matches(kvItem.getValue())) { + mismatchDescription.appendText("value did not match: "); + valueMatcher.describeMismatch(kvItem.getValue(), mismatchDescription); + } + } + + @Override + public void describeTo(Description description) { + description.appendText("KV with value matching "); + valueMatcher.describeTo(description); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .addValue(valueMatcher) + .toString(); + } + } + + /** + * Constructs a {@link SerializableMatcher} from a non-serializable {@link Matcher} via + * indirection through {@link SerializableSupplier}. + * + *

    To wrap a {@link Matcher} which is not serializable, provide a {@link SerializableSupplier} + * with a {@link SerializableSupplier#get()} method that returns a fresh instance of the + * {@link Matcher} desired. The resulting {@link SerializableMatcher} will behave according to + * the {@link Matcher} returned by {@link SerializableSupplier#get() get()} when it is invoked + * during matching (which may occur on another machine, such as a Dataflow worker). + * + * + * return fromSupplier(new SerializableSupplier>() { + * * @Override + * public Matcher get() { + * return new MyMatcherForT(); + * } + * }); + * + */ + public static SerializableMatcher fromSupplier( + SerializableSupplier> supplier) { + return new SerializableMatcherFromSupplier<>(supplier); + } + + /** + * Supplies values of type {@code T}, and is serializable. Thus, even if {@code T} is not + * serializable, the supplier can be serialized and provide a {@code T} wherever it is + * deserialized. + * + * @param the type of value supplied. + */ + public interface SerializableSupplier extends Serializable { + T get(); + } + + /** + * Since the delegate {@link Matcher} is not generally serializable, instead this takes a nullary + * SerializableFunction to return such a matcher. + */ + private static class SerializableMatcherFromSupplier extends BaseMatcher + implements SerializableMatcher { + + private SerializableSupplier> supplier; + + public SerializableMatcherFromSupplier(SerializableSupplier> supplier) { + this.supplier = supplier; + } + + @Override + public void describeTo(Description description) { + supplier.get().describeTo(description); + } + + @Override + public boolean matches(Object item) { + return supplier.get().matches(item); + } + + @Override + public void describeMismatch(Object item, Description mismatchDescription) { + supplier.get().describeMismatch(item, mismatchDescription); + } + } + + /** + * Wraps any value that can be encoded via a {@link Coder} to make it {@link Serializable}. + * This is not likely to be a good encoding, so should be used only for tests, where data + * volume is small and minor costs are not critical. + */ + private static class SerializableViaCoder implements SerializableSupplier { + /** Cached value that is not serialized. */ + @Nullable + private transient T value; + + /** The bytes of {@link #value} when encoded via {@link #coder}. */ + private byte[] encodedValue; + + private Coder coder; + + public SerializableViaCoder(Coder coder, T value) { + this.coder = coder; + this.value = value; + try { + this.encodedValue = CoderUtils.encodeToByteArray(coder, value); + } catch (CoderException exc) { + throw new RuntimeException("Error serializing via Coder", exc); + } + } + + @Override + public T get() { + if (value == null) { + try { + value = CoderUtils.decodeFromByteArray(coder, encodedValue); + } catch (CoderException exc) { + throw new RuntimeException("Error deserializing via Coder", exc); + } + } + return value; + } + } + + /** + * Wraps any array with values that can be encoded via a {@link Coder} to make it + * {@link Serializable}. This is not likely to be a good encoding, so should be used only for + * tests, where data volume is small and minor costs are not critical. + */ + private static class SerializableArrayViaCoder implements SerializableSupplier { + /** Cached value that is not serialized. */ + @Nullable + private transient T[] value; + + /** The bytes of {@link #value} when encoded via {@link #coder}. */ + private byte[] encodedValue; + + private Coder> coder; + + public SerializableArrayViaCoder(Coder elementCoder, T[] value) { + this.coder = ListCoder.of(elementCoder); + this.value = value; + try { + this.encodedValue = CoderUtils.encodeToByteArray(coder, Arrays.asList(value)); + } catch (CoderException exc) { + throw UserCodeException.wrap(exc); + } + } + + @Override + public T[] get() { + if (value == null) { + try { + @SuppressWarnings("unchecked") + T[] decoded = (T[]) CoderUtils.decodeFromByteArray(coder, encodedValue).toArray(); + value = decoded; + } catch (CoderException exc) { + throw new RuntimeException("Error deserializing via Coder", exc); + } + } + return value; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SourceTestUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SourceTestUtils.java new file mode 100644 index 000000000000..b8f9b0b4233e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/SourceTestUtils.java @@ -0,0 +1,642 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.junit.Assert; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * Helper functions and test harnesses for checking correctness of {@link Source} + * implementations. + * + *

    Contains a few lightweight utilities (e.g. reading items from a source or a reader, + * such as {@link #readFromSource} and {@link #readFromUnstartedReader}), as well as + * heavyweight property testing and stress testing harnesses that help getting a large + * amount of test coverage with few code. Most notable ones are: + *

      + *
    • {@link #assertSourcesEqualReferenceSource} helps testing that the data read + * by the union of sources produced by {@link BoundedSource#splitIntoBundles} + * is the same as data read by the original source. + *
    • If your source implements dynamic work rebalancing, use the + * {@code assertSplitAtFraction} family of functions - they test behavior of + * {@link BoundedSource.BoundedReader#splitAtFraction}, in particular, that + * various consistency properties are respected and the total set of data read + * by the source is preserved when splits happen. + * Use {@link #assertSplitAtFractionBehavior} to test individual cases + * of {@code splitAtFraction} and use {@link #assertSplitAtFractionExhaustive} + * as a heavy-weight stress test including concurrency. We strongly recommend to + * use both. + *
    + * For example usages, see the unit tests of classes such as + * {@link com.google.cloud.dataflow.sdk.io.AvroSource} or + * {@link com.google.cloud.dataflow.sdk.io.XmlSource}. + * + *

    Like {@link DataflowAssert}, requires JUnit and Hamcrest to be present in the classpath. + */ +public class SourceTestUtils { + // A wrapper around a value of type T that compares according to the structural + // value provided by a Coder, but prints both the original and structural value, + // to help get good error messages from JUnit equality assertion failures and such. + private static class ReadableStructuralValue { + private T originalValue; + private Object structuralValue; + + public ReadableStructuralValue(T originalValue, Object structuralValue) { + this.originalValue = originalValue; + this.structuralValue = structuralValue; + } + + @Override + public int hashCode() { + return Objects.hashCode(structuralValue); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof ReadableStructuralValue)) { + return false; + } + return Objects.equals(structuralValue, ((ReadableStructuralValue) obj).structuralValue); + } + + @Override + public String toString() { + return String.format("[%s (structural %s)]", originalValue, structuralValue); + } + } + + /** + * Testing utilities below depend on standard assertions and matchers to compare elements read by + * sources. In general the elements may not implement {@code equals}/{@code hashCode} properly, + * however every source has a {@link Coder} and every {@code Coder} can + * produce a {@link Coder#structuralValue} whose {@code equals}/{@code hashCode} is + * consistent with equality of encoded format. + * So we use this {@link Coder#structuralValue} to compare elements read by sources. + */ + public static List> createStructuralValues( + Coder coder, List list) + throws Exception { + List> result = new ArrayList<>(); + for (T elem : list) { + result.add(new ReadableStructuralValue<>(elem, coder.structuralValue(elem))); + } + return result; + } + + /** + * Reads all elements from the given {@link BoundedSource}. + */ + public static List readFromSource(BoundedSource source, PipelineOptions options) + throws IOException { + try (BoundedSource.BoundedReader reader = source.createReader(options)) { + return readFromUnstartedReader(reader); + } + } + + /** + * Reads all elements from the given unstarted {@link Source.Reader}. + */ + public static List readFromUnstartedReader(Source.Reader reader) throws IOException { + return readRemainingFromReader(reader, false); + } + + /** + * Reads all elements from the given started {@link Source.Reader}. + */ + public static List readFromStartedReader(Source.Reader reader) throws IOException { + return readRemainingFromReader(reader, true); + } + + /** + * Read elements from a {@link Source.Reader} until n elements are read. + */ + public static List readNItemsFromUnstartedReader(Source.Reader reader, int n) + throws IOException { + return readNItemsFromReader(reader, n, false); + } + + /** + * Read elements from a {@link Source.Reader} that has already had {@link Source.Reader#start} + * called on it, until n elements are read. + */ + public static List readNItemsFromStartedReader(Source.Reader reader, int n) + throws IOException { + return readNItemsFromReader(reader, n, true); + } + + /** + * Read elements from a {@link Source.Reader} until n elements are read. + * + *

    There must be at least n elements remaining in the reader, except for + * the case when n is {@code Integer.MAX_VALUE}, which means "read all + * remaining elements". + */ + private static List readNItemsFromReader(Source.Reader reader, int n, boolean started) + throws IOException { + List res = new ArrayList<>(); + for (int i = 0; i < n; i++) { + boolean shouldStart = (i == 0 && !started); + boolean more = shouldStart ? reader.start() : reader.advance(); + if (n != Integer.MAX_VALUE) { + assertTrue(more); + } + if (!more) { + break; + } + res.add(reader.getCurrent()); + } + return res; + } + + /** + * Read all remaining elements from a {@link Source.Reader}. + */ + public static List readRemainingFromReader(Source.Reader reader, boolean started) + throws IOException { + return readNItemsFromReader(reader, Integer.MAX_VALUE, started); + } + + /** + * Given a reference {@code Source} and a list of {@code Source}s, assert that the union of + * the records read from the list of sources is equal to the records read from the reference + * source. + */ + public static void assertSourcesEqualReferenceSource( + BoundedSource referenceSource, + List> sources, + PipelineOptions options) + throws Exception { + Coder coder = referenceSource.getDefaultOutputCoder(); + List referenceRecords = readFromSource(referenceSource, options); + List bundleRecords = new ArrayList<>(); + for (BoundedSource source : sources) { + assertThat( + "Coder type for source " + + source + + " is not compatible with Coder type for referenceSource " + + referenceSource, + source.getDefaultOutputCoder(), + equalTo(coder)); + List elems = readFromSource(source, options); + bundleRecords.addAll(elems); + } + List> bundleValues = + createStructuralValues(coder, bundleRecords); + List> referenceValues = + createStructuralValues(coder, referenceRecords); + assertThat(bundleValues, containsInAnyOrder(referenceValues.toArray())); + } + + /** + * Assert that a {@code Reader} returns a {@code Source} that, when read from, produces the same + * records as the reader. + */ + public static void assertUnstartedReaderReadsSameAsItsSource( + BoundedSource.BoundedReader reader, PipelineOptions options) throws Exception { + Coder coder = reader.getCurrentSource().getDefaultOutputCoder(); + List expected = readFromUnstartedReader(reader); + List actual = readFromSource(reader.getCurrentSource(), options); + List> expectedStructural = createStructuralValues(coder, expected); + List> actualStructural = createStructuralValues(coder, actual); + assertThat(actualStructural, containsInAnyOrder(expectedStructural.toArray())); + } + + /** + * Expected outcome of + * {@link com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader#splitAtFraction}. + */ + public enum ExpectedSplitOutcome { + /** + * The operation must succeed and the results must be consistent. + */ + MUST_SUCCEED_AND_BE_CONSISTENT, + /** + * The operation must fail (return {@code null}). + */ + MUST_FAIL, + /** + * The operation must either fail, or succeed and the results be consistent. + */ + MUST_BE_CONSISTENT_IF_SUCCEEDS + } + + /** + * Contains two values: the number of items in the primary source, and the number of items in + * the residual source, -1 if split failed. + */ + private static class SplitAtFractionResult { + public int numPrimaryItems; + public int numResidualItems; + + public SplitAtFractionResult(int numPrimaryItems, int numResidualItems) { + this.numPrimaryItems = numPrimaryItems; + this.numResidualItems = numResidualItems; + } + } + + /** + * Asserts that the {@code source}'s reader either fails to {@code splitAtFraction(fraction)} + * after reading {@code numItemsToReadBeforeSplit} items, or succeeds in a way that is + * consistent according to {@link #assertSplitAtFractionSucceedsAndConsistent}. + *

    Returns SplitAtFractionResult. + */ + + public static SplitAtFractionResult assertSplitAtFractionBehavior( + BoundedSource source, + int numItemsToReadBeforeSplit, + double splitFraction, + ExpectedSplitOutcome expectedOutcome, + PipelineOptions options) + throws Exception { + return assertSplitAtFractionBehaviorImpl( + source, readFromSource(source, options), numItemsToReadBeforeSplit, splitFraction, + expectedOutcome, options); + } + + /** + * Compares two lists elementwise and throws a detailed assertion failure optimized for + * human reading in case they are unequal. + */ + private static void assertListsEqualInOrder( + String message, String expectedLabel, List expected, String actualLabel, List actual) { + int i = 0; + for (; i < expected.size() && i < actual.size(); ++i) { + if (!Objects.equals(expected.get(i), actual.get(i))) { + Assert.fail(String.format( + "%s: %s and %s have %d items in common and then differ. " + + "Item in %s (%d more): %s, item in %s (%d more): %s", + message, expectedLabel, actualLabel, i, + expectedLabel, expected.size() - i - 1, expected.get(i), + actualLabel, actual.size() - i - 1, actual.get(i))); + } + } + if (i < expected.size() /* but i == actual.size() */) { + Assert.fail(String.format( + "%s: %s has %d more items after matching all %d from %s. First 5: %s", + message, expectedLabel, expected.size() - actual.size(), actual.size(), actualLabel, + expected.subList(actual.size(), Math.min(expected.size(), actual.size() + 5)))); + } else if (i < actual.size() /* but i == expected.size() */) { + Assert.fail(String.format( + "%s: %s has %d more items after matching all %d from %s. First 5: %s", + message, actualLabel, actual.size() - expected.size(), expected.size(), expectedLabel, + actual.subList(expected.size(), Math.min(actual.size(), expected.size() + 5)))); + } else { + // All is well. + } + } + + private static SourceTestUtils.SplitAtFractionResult assertSplitAtFractionBehaviorImpl( + BoundedSource source, List expectedItems, int numItemsToReadBeforeSplit, + double splitFraction, ExpectedSplitOutcome expectedOutcome, PipelineOptions options) + throws Exception { + try (BoundedSource.BoundedReader reader = source.createReader(options)) { + BoundedSource originalSource = reader.getCurrentSource(); + List currentItems = readNItemsFromUnstartedReader(reader, numItemsToReadBeforeSplit); + BoundedSource residual = reader.splitAtFraction(splitFraction); + if (residual != null) { + assertFalse( + String.format( + "Primary source didn't change after a successful split of %s at %f " + + "after reading %d items. " + + "Was the source object mutated instead of creating a new one? " + + "Source objects MUST be immutable.", + source, splitFraction, numItemsToReadBeforeSplit), + reader.getCurrentSource() == originalSource); + assertFalse( + String.format( + "Residual source equal to original source after a successful split of %s at %f " + + "after reading %d items. " + + "Was the source object mutated instead of creating a new one? " + + "Source objects MUST be immutable.", + source, splitFraction, numItemsToReadBeforeSplit), + reader.getCurrentSource() == residual); + } + // Failure cases are: must succeed but fails; must fail but succeeds. + switch (expectedOutcome) { + case MUST_SUCCEED_AND_BE_CONSISTENT: + assertNotNull( + "Failed to split reader of source: " + + source + + " at " + + splitFraction + + " after reading " + + numItemsToReadBeforeSplit + + " items", + residual); + break; + case MUST_FAIL: + assertEquals(null, residual); + break; + case MUST_BE_CONSISTENT_IF_SUCCEEDS: + // Nothing. + break; + } + currentItems.addAll(readRemainingFromReader(reader, numItemsToReadBeforeSplit > 0)); + BoundedSource primary = reader.getCurrentSource(); + return verifySingleSplitAtFractionResult( + source, expectedItems, currentItems, primary, residual, + numItemsToReadBeforeSplit, splitFraction, options); + } + } + + private static SourceTestUtils.SplitAtFractionResult verifySingleSplitAtFractionResult( + BoundedSource source, List expectedItems, List currentItems, + BoundedSource primary, BoundedSource residual, + int numItemsToReadBeforeSplit, double splitFraction, PipelineOptions options) + throws Exception { + List primaryItems = readFromSource(primary, options); + if (residual != null) { + List residualItems = readFromSource(residual, options); + List totalItems = new ArrayList<>(); + totalItems.addAll(primaryItems); + totalItems.addAll(residualItems); + String errorMsgForPrimarySourceComp = + String.format( + "Continued reading after split yielded different items than primary source: " + + "split at %s after reading %s items, original source: %s, primary source: %s", + splitFraction, + numItemsToReadBeforeSplit, + source, + primary); + String errorMsgForTotalSourceComp = + String.format( + "Items in primary and residual sources after split do not add up to items " + + "in the original source. Split at %s after reading %s items; " + + "original source: %s, primary: %s, residual: %s", + splitFraction, + numItemsToReadBeforeSplit, + source, + primary, + residual); + Coder coder = primary.getDefaultOutputCoder(); + List> primaryValues = + createStructuralValues(coder, primaryItems); + List> currentValues = + createStructuralValues(coder, currentItems); + List> expectedValues = + createStructuralValues(coder, expectedItems); + List> totalValues = + createStructuralValues(coder, totalItems); + assertListsEqualInOrder( + errorMsgForPrimarySourceComp, "current", currentValues, "primary", primaryValues); + assertListsEqualInOrder( + errorMsgForTotalSourceComp, "total", expectedValues, "primary+residual", totalValues); + return new SplitAtFractionResult(primaryItems.size(), residualItems.size()); + } + return new SplitAtFractionResult(primaryItems.size(), -1); + } + + /** + * Verifies some consistency properties of + * {@link BoundedSource.BoundedReader#splitAtFraction} on the given source. Equivalent to + * the following pseudocode: + *

    +   *   Reader reader = source.createReader();
    +   *   read N items from reader;
    +   *   Source residual = reader.splitAtFraction(splitFraction);
    +   *   Source primary = reader.getCurrentSource();
    +   *   assert: items in primary == items we read so far
    +   *                               + items we'll get by continuing to read from reader;
    +   *   assert: items in original source == items in primary + items in residual
    +   * 
    + */ + public static void assertSplitAtFractionSucceedsAndConsistent( + BoundedSource source, + int numItemsToReadBeforeSplit, + double splitFraction, + PipelineOptions options) + throws Exception { + assertSplitAtFractionBehavior( + source, + numItemsToReadBeforeSplit, + splitFraction, + ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT, + options); + } + + /** + * Asserts that the {@code source}'s reader fails to {@code splitAtFraction(fraction)} + * after reading {@code numItemsToReadBeforeSplit} items. + */ + public static void assertSplitAtFractionFails( + BoundedSource source, + int numItemsToReadBeforeSplit, + double splitFraction, + PipelineOptions options) + throws Exception { + assertSplitAtFractionBehavior( + source, numItemsToReadBeforeSplit, splitFraction, ExpectedSplitOutcome.MUST_FAIL, options); + } + + private static class SplitFractionStatistics { + List successfulFractions = new ArrayList<>(); + List nonTrivialFractions = new ArrayList<>(); + } + + /** + * Asserts that given a start position, + * {@link BoundedSource.BoundedReader#splitAtFraction} at every interesting fraction (halfway + * between two fractions that differ by at least one item) can be called successfully and the + * results are consistent if a split succeeds. + */ + private static void assertSplitAtFractionBinary( + BoundedSource source, + List expectedItems, + int numItemsToBeReadBeforeSplit, + double leftFraction, + SplitAtFractionResult leftResult, + double rightFraction, + SplitAtFractionResult rightResult, + PipelineOptions options, + SplitFractionStatistics stats) + throws Exception { + if (rightFraction - leftFraction < 0.001) { + // Do not recurse too deeply. Otherwise we will end up in infinite + // recursion, e.g., while trying to find the exact minimal fraction s.t. + // split succeeds. A precision of 0.001 when looking for such a fraction + // ought to be enough for everybody. + return; + } + double middleFraction = (rightFraction + leftFraction) / 2; + if (leftResult == null) { + leftResult = assertSplitAtFractionBehaviorImpl( + source, expectedItems, numItemsToBeReadBeforeSplit, leftFraction, + ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS, options); + } + if (rightResult == null) { + rightResult = assertSplitAtFractionBehaviorImpl( + source, expectedItems, numItemsToBeReadBeforeSplit, rightFraction, + ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS, options); + } + SplitAtFractionResult middleResult = assertSplitAtFractionBehaviorImpl( + source, expectedItems, numItemsToBeReadBeforeSplit, middleFraction, + ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS, options); + if (middleResult.numResidualItems != -1) { + stats.successfulFractions.add(middleFraction); + } + if (middleResult.numResidualItems > 0) { + stats.nonTrivialFractions.add(middleFraction); + } + // Two split fractions are equivalent if they yield the same number of + // items in primary vs. residual source. Left and right are already not + // equivalent. Recurse into [left, middle) and [right, middle) respectively + // if middle is not equivalent to left or right. + if (leftResult.numPrimaryItems != middleResult.numPrimaryItems) { + assertSplitAtFractionBinary( + source, expectedItems, numItemsToBeReadBeforeSplit, + leftFraction, leftResult, middleFraction, middleResult, options, stats); + } + if (rightResult.numPrimaryItems != middleResult.numPrimaryItems) { + assertSplitAtFractionBinary( + source, expectedItems, numItemsToBeReadBeforeSplit, + middleFraction, middleResult, rightFraction, rightResult, options, stats); + } + } + + /** + * Asserts that for each possible start position, + * {@link BoundedSource.BoundedReader#splitAtFraction} at every interesting fraction (halfway + * between two fractions that differ by at least one item) can be called successfully and the + * results are consistent if a split succeeds. Verifies multithreaded splitting as well. + */ + public static void assertSplitAtFractionExhaustive( + BoundedSource source, PipelineOptions options) throws Exception { + List expectedItems = readFromSource(source, options); + assertFalse("Empty source", expectedItems.isEmpty()); + assertFalse("Source reads a single item", expectedItems.size() == 1); + List> allNonTrivialFractions = new ArrayList<>(); + { + boolean anySuccessfulFractions = false; + boolean anyNonTrivialFractions = false; + for (int i = 0; i < expectedItems.size(); i++) { + SplitFractionStatistics stats = new SplitFractionStatistics(); + assertSplitAtFractionBinary(source, expectedItems, i, + 0.0, null, 1.0, null, options, stats); + if (!stats.successfulFractions.isEmpty()) { + anySuccessfulFractions = true; + } + if (!stats.nonTrivialFractions.isEmpty()) { + anyNonTrivialFractions = true; + } + allNonTrivialFractions.add(stats.nonTrivialFractions); + } + assertTrue( + "splitAtFraction test completed vacuously: no successful split fractions found", + anySuccessfulFractions); + assertTrue( + "splitAtFraction test completed vacuously: no non-trivial split fractions found", + anyNonTrivialFractions); + } + { + // Perform a stress test of "racy" concurrent splitting: + // for every position (number of items read), try to split at the minimum nontrivial + // split fraction for that position concurrently with reading the record at that position. + // To ensure that the test is non-vacuous, make sure that the splitting succeeds + // at least once and fails at least once. + ExecutorService executor = Executors.newFixedThreadPool(2); + for (int i = 0; i < expectedItems.size(); i++) { + double minNonTrivialFraction = 2.0; // Greater than any possible fraction. + for (double fraction : allNonTrivialFractions.get(i)) { + minNonTrivialFraction = Math.min(minNonTrivialFraction, fraction); + } + if (minNonTrivialFraction == 2.0) { + // This will not happen all the time because otherwise the test above would + // detect vacuousness. + continue; + } + boolean haveSuccess = false, haveFailure = false; + while (!haveSuccess || !haveFailure) { + if (assertSplitAtFractionConcurrent( + executor, source, expectedItems, i, minNonTrivialFraction, options)) { + haveSuccess = true; + } else { + haveFailure = true; + } + } + } + } + } + + private static boolean assertSplitAtFractionConcurrent( + ExecutorService executor, BoundedSource source, List expectedItems, + final int numItemsToReadBeforeSplitting, final double fraction, PipelineOptions options) + throws Exception { + @SuppressWarnings("resource") // Closed in readerThread + final BoundedSource.BoundedReader reader = source.createReader(options); + final CountDownLatch unblockSplitter = new CountDownLatch(1); + Future> readerThread = + executor.submit( + new Callable>() { + @Override + public List call() throws Exception { + try { + List items = + readNItemsFromUnstartedReader(reader, numItemsToReadBeforeSplitting); + unblockSplitter.countDown(); + items.addAll(readRemainingFromReader(reader, numItemsToReadBeforeSplitting > 0)); + return items; + } finally { + reader.close(); + } + } + }); + Future, BoundedSource>> splitterThread = executor.submit( + new Callable, BoundedSource>>() { + @Override + public KV, BoundedSource> call() throws Exception { + unblockSplitter.await(); + BoundedSource residual = reader.splitAtFraction(fraction); + if (residual == null) { + return null; + } + return KV.of(reader.getCurrentSource(), residual); + } + }); + List currentItems = readerThread.get(); + KV, BoundedSource> splitSources = splitterThread.get(); + if (splitSources == null) { + return false; + } + SplitAtFractionResult res = verifySingleSplitAtFractionResult( + source, expectedItems, currentItems, splitSources.getKey(), splitSources.getValue(), + numItemsToReadBeforeSplitting, fraction, options); + return (res.numResidualItems > 0); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineOptions.java new file mode 100644 index 000000000000..1afb6910446c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineOptions.java @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.options.BlockingDataflowPipelineOptions; + +/** + * A set of options used to configure the {@link TestPipeline}. + */ +public interface TestDataflowPipelineOptions extends BlockingDataflowPipelineOptions { + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunner.java new file mode 100644 index 000000000000..9fff070f884c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunner.java @@ -0,0 +1,220 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.JobMetrics; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowJobExecutionException; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineJob; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil.JobMessagesHandler; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.base.Optional; +import com.google.common.base.Throwables; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +/** + * {@link TestDataflowPipelineRunner} is a pipeline runner that wraps a + * {@link DataflowPipelineRunner} when running tests against the {@link TestPipeline}. + * + * @see TestPipeline + */ +public class TestDataflowPipelineRunner extends PipelineRunner { + private static final String TENTATIVE_COUNTER = "tentative"; + private static final Logger LOG = LoggerFactory.getLogger(TestDataflowPipelineRunner.class); + + private final TestDataflowPipelineOptions options; + private final DataflowPipelineRunner runner; + private int expectedNumberOfAssertions = 0; + + TestDataflowPipelineRunner(TestDataflowPipelineOptions options) { + this.options = options; + this.runner = DataflowPipelineRunner.fromOptions(options); + } + + /** + * Constructs a runner from the provided options. + */ + public static TestDataflowPipelineRunner fromOptions( + PipelineOptions options) { + TestDataflowPipelineOptions dataflowOptions = options.as(TestDataflowPipelineOptions.class); + + return new TestDataflowPipelineRunner(dataflowOptions); + } + + @Override + public DataflowPipelineJob run(Pipeline pipeline) { + return run(pipeline, runner); + } + + DataflowPipelineJob run(Pipeline pipeline, DataflowPipelineRunner runner) { + + final JobMessagesHandler messageHandler = + new MonitoringUtil.PrintHandler(options.getJobMessageOutput()); + final DataflowPipelineJob job; + try { + job = runner.run(pipeline); + } catch (DataflowJobExecutionException ex) { + throw new IllegalStateException("The dataflow failed."); + } + + LOG.info("Running Dataflow job {} with {} expected assertions.", + job.getJobId(), expectedNumberOfAssertions); + + try { + final Optional result; + if (options.isStreaming()) { + Future> resultFuture = options.getExecutorService().submit( + new Callable>() { + @Override + public Optional call() throws Exception { + try { + for (;;) { + Optional result = checkForSuccess(job); + if (result.isPresent()) { + return result; + } + Thread.sleep(10000L); + } + } finally { + LOG.info("Cancelling Dataflow job {}", job.getJobId()); + job.cancel(); + } + } + }); + State finalState = job.waitToFinish(10L, TimeUnit.MINUTES, new JobMessagesHandler() { + @Override + public void process(List messages) { + messageHandler.process(messages); + for (JobMessage message : messages) { + if (message.getMessageImportance() != null + && message.getMessageImportance().equals("JOB_MESSAGE_ERROR")) { + LOG.info("Dataflow job {} threw exception, cancelling. Exception was: {}", + job.getJobId(), message.getMessageText()); + try { + job.cancel(); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + } + } + }); + if (finalState == null || finalState == State.RUNNING) { + LOG.info("Dataflow job {} took longer than 10 minutes to complete, cancelling.", + job.getJobId()); + job.cancel(); + } + result = resultFuture.get(); + } else { + job.waitToFinish(-1, TimeUnit.SECONDS, messageHandler); + result = checkForSuccess(job); + } + if (!result.isPresent()) { + throw new IllegalStateException( + "The dataflow did not output a success or failure metric."); + } else if (!result.get()) { + throw new IllegalStateException("The dataflow failed."); + } + } catch (Exception e) { + Throwables.propagateIfPossible(e); + throw Throwables.propagate(e); + } + return job; + } + + @Override + public OutputT apply( + PTransform transform, InputT input) { + if (transform instanceof DataflowAssert.OneSideInputAssert + || transform instanceof DataflowAssert.TwoSideInputAssert) { + expectedNumberOfAssertions += 1; + } + + return runner.apply(transform, input); + } + + Optional checkForSuccess(DataflowPipelineJob job) + throws IOException { + State state = job.getState(); + if (state == State.FAILED || state == State.CANCELLED) { + LOG.info("The pipeline failed"); + return Optional.of(false); + } + + JobMetrics metrics = job.getDataflowClient().projects().jobs() + .getMetrics(job.getProjectId(), job.getJobId()).execute(); + + if (metrics == null || metrics.getMetrics() == null) { + LOG.warn("Metrics not present for Dataflow job {}.", job.getJobId()); + } else { + int successes = 0; + int failures = 0; + for (MetricUpdate metric : metrics.getMetrics()) { + if (metric.getName() == null || metric.getName().getContext() == null + || !metric.getName().getContext().containsKey(TENTATIVE_COUNTER)) { + // Don't double count using the non-tentative version of the metric. + continue; + } + if (DataflowAssert.SUCCESS_COUNTER.equals(metric.getName().getName())) { + successes += ((BigDecimal) metric.getScalar()).intValue(); + } else if (DataflowAssert.FAILURE_COUNTER.equals(metric.getName().getName())) { + failures += ((BigDecimal) metric.getScalar()).intValue(); + } + } + + if (failures > 0) { + LOG.info("Found result while running Dataflow job {}. Found {} success, {} failures out of " + + "{} expected assertions.", job.getJobId(), successes, failures, + expectedNumberOfAssertions); + return Optional.of(false); + } else if (successes >= expectedNumberOfAssertions) { + LOG.info("Found result while running Dataflow job {}. Found {} success, {} failures out of " + + "{} expected assertions.", job.getJobId(), successes, failures, + expectedNumberOfAssertions); + return Optional.of(true); + } + + LOG.info("Running Dataflow job {}. Found {} success, {} failures out of {} expected " + + "assertions.", job.getJobId(), successes, failures, expectedNumberOfAssertions); + } + + return Optional.absent(); + } + + @Override + public String toString() { + return "TestDataflowPipelineRunner#" + options.getAppName(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestPipeline.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestPipeline.java new file mode 100644 index 000000000000..a05a7785d9e1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/TestPipeline.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions.CheckEnabled; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.common.base.Optional; +import com.google.common.collect.Iterators; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; +import java.util.Iterator; + +import javax.annotation.Nullable; + +/** + * A creator of test pipelines that can be used inside of tests that can be + * configured to run locally or against the live service. + * + *

    It is recommended to tag hand-selected tests for this purpose using the + * RunnableOnService Category annotation, as each test run against the service + * will spin up and tear down a single VM. + * + *

    In order to run tests on the dataflow pipeline service, the following + * conditions must be met: + *

      + *
    • runIntegrationTestOnService System property must be set to true. + *
    • System property "projectName" must be set to your Cloud project. + *
    • System property "temp_gcs_directory" must be set to a valid GCS bucket. + *
    • Jars containing the SDK and test classes must be added to the test classpath. + *
    + * + *

    Use {@link DataflowAssert} for tests, as it integrates with this test + * harness in both direct and remote execution modes. For example: + * + *

    {@code
    + * Pipeline p = TestPipeline.create();
    + * PCollection output = ...
    + *
    + * DataflowAssert.that(output)
    + *     .containsInAnyOrder(1, 2, 3, 4);
    + * p.run();
    + * }
    + * + */ +public class TestPipeline extends Pipeline { + private static final String PROPERTY_DATAFLOW_OPTIONS = "dataflowOptions"; + private static final ObjectMapper MAPPER = new ObjectMapper(); + + /** + * Creates and returns a new test pipeline. + * + *

    Use {@link DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + */ + public static TestPipeline create() { + return fromOptions(testingPipelineOptions()); + } + + public static TestPipeline fromOptions(PipelineOptions options) { + return new TestPipeline(PipelineRunner.fromOptions(options), options); + } + + /** + * Returns whether a {@link TestPipeline} supports dynamic work rebalancing, and thus tests + * of dynamic work rebalancing are expected to pass. + */ + public boolean supportsDynamicWorkRebalancing() { + return getRunner() instanceof DataflowPipelineRunner; + } + + private TestPipeline(PipelineRunner runner, PipelineOptions options) { + super(runner, options); + } + + /** + * Runs this {@link TestPipeline}, unwrapping any {@code AssertionError} + * that is raised during testing. + */ + @Override + public PipelineResult run() { + try { + return super.run(); + } catch (RuntimeException exc) { + Throwable cause = exc.getCause(); + if (cause instanceof AssertionError) { + throw (AssertionError) cause; + } else { + throw exc; + } + } + } + + @Override + public String toString() { + return "TestPipeline#" + getOptions().as(ApplicationNameOptions.class).getAppName(); + } + + /** + * Creates {@link PipelineOptions} for testing. + */ + public static PipelineOptions testingPipelineOptions() { + try { + @Nullable String systemDataflowOptions = System.getProperty(PROPERTY_DATAFLOW_OPTIONS); + PipelineOptions options = + systemDataflowOptions == null + ? PipelineOptionsFactory.create() + : PipelineOptionsFactory.fromArgs( + MAPPER.readValue( + System.getProperty(PROPERTY_DATAFLOW_OPTIONS), String[].class)) + .as(PipelineOptions.class); + + options.as(ApplicationNameOptions.class).setAppName(getAppName()); + if (isIntegrationTest()) { + // TODO: adjust everyone's integration test frameworks to set the runner class via the + // pipeline options via PROPERTY_DATAFLOW_OPTIONS + options.setRunner(TestDataflowPipelineRunner.class); + } else { + options.as(GcpOptions.class).setGcpCredential(new TestCredential()); + } + options.setStableUniqueNames(CheckEnabled.ERROR); + return options; + } catch (IOException e) { + throw new RuntimeException("Unable to instantiate test options from system property " + + PROPERTY_DATAFLOW_OPTIONS + ":" + System.getProperty(PROPERTY_DATAFLOW_OPTIONS), e); + } + } + + /** + * Returns whether a {@link TestPipeline} should be treated as an integration test. + */ + private static boolean isIntegrationTest() { + return Boolean.parseBoolean(System.getProperty("runIntegrationTestOnService")); + } + + /** Returns the class + method name of the test, or a default name. */ + private static String getAppName() { + Optional stackTraceElement = findCallersStackTrace(); + if (stackTraceElement.isPresent()) { + String methodName = stackTraceElement.get().getMethodName(); + String className = stackTraceElement.get().getClassName(); + if (className.contains(".")) { + className = className.substring(className.lastIndexOf(".") + 1); + } + return className + "-" + methodName; + } + return "UnitTest"; + } + + /** Returns the {@link StackTraceElement} of the calling class. */ + private static Optional findCallersStackTrace() { + Iterator elements = + Iterators.forArray(Thread.currentThread().getStackTrace()); + // First find the TestPipeline class in the stack trace. + while (elements.hasNext()) { + StackTraceElement next = elements.next(); + if (TestPipeline.class.getName().equals(next.getClassName())) { + break; + } + } + // Then find the first instance after that is not the TestPipeline + while (elements.hasNext()) { + StackTraceElement next = elements.next(); + if (!TestPipeline.class.getName().equals(next.getClassName())) { + return Optional.of(next); + } + } + return Optional.absent(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/WindowFnTestUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/WindowFnTestUtils.java new file mode 100644 index 000000000000..dc0baf52b81d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/WindowFnTestUtils.java @@ -0,0 +1,325 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; + +import org.joda.time.Instant; +import org.joda.time.ReadableInstant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * A utility class for testing {@link WindowFn}s. + */ +public class WindowFnTestUtils { + + /** + * Creates a Set of elements to be used as expected output in + * {@link #runWindowFn}. + */ + public static Set set(long... timestamps) { + Set result = new HashSet<>(); + for (long timestamp : timestamps) { + result.add(timestampValue(timestamp)); + } + return result; + } + + /** + * Runs the {@link WindowFn} over the provided input, returning a map + * of windows to the timestamps in those windows. + */ + public static Map> runWindowFn( + WindowFn windowFn, + List timestamps) throws Exception { + + final TestWindowSet windowSet = new TestWindowSet(); + for (final Long timestamp : timestamps) { + for (W window : windowFn.assignWindows( + new TestAssignContext(new Instant(timestamp), windowFn))) { + windowSet.put(window, timestampValue(timestamp)); + } + windowFn.mergeWindows(new TestMergeContext(windowSet, windowFn)); + } + Map> actual = new HashMap<>(); + for (W window : windowSet.windows()) { + actual.put(window, windowSet.get(window)); + } + return actual; + } + + public static Collection assignedWindows( + WindowFn windowFn, long timestamp) throws Exception { + return windowFn.assignWindows(new TestAssignContext(new Instant(timestamp), windowFn)); + } + + private static String timestampValue(long timestamp) { + return "T" + new Instant(timestamp); + } + + /** + * Test implementation of AssignContext. + */ + private static class TestAssignContext + extends WindowFn.AssignContext { + private Instant timestamp; + + public TestAssignContext(Instant timestamp, WindowFn windowFn) { + windowFn.super(); + this.timestamp = timestamp; + } + + @Override + public T element() { + return null; + } + + @Override + public Instant timestamp() { + return timestamp; + } + + @Override + public Collection windows() { + return null; + } + } + + /** + * Test implementation of MergeContext. + */ + private static class TestMergeContext + extends WindowFn.MergeContext { + private TestWindowSet windowSet; + + public TestMergeContext( + TestWindowSet windowSet, WindowFn windowFn) { + windowFn.super(); + this.windowSet = windowSet; + } + + @Override + public Collection windows() { + return windowSet.windows(); + } + + @Override + public void merge(Collection toBeMerged, W mergeResult) { + windowSet.merge(toBeMerged, mergeResult); + } + } + + /** + * A WindowSet useful for testing WindowFns that simply + * collects the placed elements into multisets. + */ + private static class TestWindowSet { + + private Map> elements = new HashMap<>(); + + public void put(W window, V value) { + Set all = elements.get(window); + if (all == null) { + all = new HashSet<>(); + elements.put(window, all); + } + all.add(value); + } + + public void merge(Collection otherWindows, W window) { + if (otherWindows.isEmpty()) { + return; + } + Set merged = new HashSet<>(); + if (elements.containsKey(window) && !otherWindows.contains(window)) { + merged.addAll(elements.get(window)); + } + for (W w : otherWindows) { + if (!elements.containsKey(w)) { + throw new IllegalArgumentException("Tried to merge a non-existent window:" + w); + } + merged.addAll(elements.get(w)); + elements.remove(w); + } + elements.put(window, merged); + } + + public Collection windows() { + return elements.keySet(); + } + + // For testing. + + public Set get(W window) { + return elements.get(window); + } + } + + /** + * Assigns the given {@code timestamp} to windows using the specified {@code windowFn}, and + * verifies that result of {@code windowFn.getOutputTimestamp} for each window is within the + * proper bound. + */ + public static void validateNonInterferingOutputTimes( + WindowFn windowFn, long timestamp) throws Exception { + Collection windows = WindowFnTestUtils.assignedWindows(windowFn, timestamp); + + Instant instant = new Instant(timestamp); + for (W window : windows) { + Instant outputTimestamp = windowFn.getOutputTimeFn().assignOutputTime(instant, window); + assertFalse("getOutputTime must be greater than or equal to input timestamp", + outputTimestamp.isBefore(instant)); + assertFalse("getOutputTime must be less than or equal to the max timestamp", + outputTimestamp.isAfter(window.maxTimestamp())); + } + } + + /** + * Assigns the given {@code timestamp} to windows using the specified {@code windowFn}, and + * verifies that result of {@link WindowFn#getOutputTime windowFn.getOutputTime} for later windows + * (as defined by {@code maxTimestamp} won't prevent the watermark from passing the end of earlier + * windows. + * + *

    This verifies that overlapping windows don't interfere at all. Depending on the + * {@code windowFn} this may be stricter than desired. + */ + public static void validateGetOutputTimestamp( + WindowFn windowFn, long timestamp) throws Exception { + Collection windows = WindowFnTestUtils.assignedWindows(windowFn, timestamp); + List sortedWindows = new ArrayList<>(windows); + Collections.sort(sortedWindows, new Comparator() { + @Override + public int compare(BoundedWindow o1, BoundedWindow o2) { + return o1.maxTimestamp().compareTo(o2.maxTimestamp()); + } + }); + + Instant instant = new Instant(timestamp); + Instant endOfPrevious = null; + for (W window : sortedWindows) { + Instant outputTimestamp = windowFn.getOutputTimeFn().assignOutputTime(instant, window); + if (endOfPrevious == null) { + // If this is the first window, the output timestamp can be anything, as long as it is in + // the valid range. + assertFalse("getOutputTime must be greater than or equal to input timestamp", + outputTimestamp.isBefore(instant)); + assertFalse("getOutputTime must be less than or equal to the max timestamp", + outputTimestamp.isAfter(window.maxTimestamp())); + } else { + // If this is a later window, the output timestamp must be after the end of the previous + // window + assertTrue("getOutputTime must be greater than the end of the previous window", + outputTimestamp.isAfter(endOfPrevious)); + assertFalse("getOutputTime must be less than or equal to the max timestamp", + outputTimestamp.isAfter(window.maxTimestamp())); + } + endOfPrevious = window.maxTimestamp(); + } + } + + /** + * Verifies that later-ending merged windows from any of the timestamps hold up output of + * earlier-ending windows, using the provided {@link WindowFn} and {@link OutputTimeFn}. + * + *

    Given a list of lists of timestamps, where each list is expected to merge into a single + * window with end times in ascending order, assigns and merges windows for each list (as though + * each were a separate key/user session). Then maps each timestamp in the list according to + * {@link OutputTimeFn#assignOutputTime outputTimeFn.assignOutputTime()} and + * {@link OutputTimeFn#combine outputTimeFn.combine()}. + * + *

    Verifies that a overlapping windows do not hold each other up via the watermark. + */ + public static + void validateGetOutputTimestamps( + WindowFn windowFn, + OutputTimeFn outputTimeFn, + List> timestampsPerWindow) throws Exception { + + // Assign windows to each timestamp, then merge them, storing the merged windows in + // a list in corresponding order to timestampsPerWindow + final List windows = new ArrayList<>(); + for (List timestampsForWindow : timestampsPerWindow) { + final Set windowsToMerge = new HashSet<>(); + + for (long timestamp : timestampsForWindow) { + windowsToMerge.addAll( + WindowFnTestUtils.assignedWindows(windowFn, timestamp)); + } + + windowFn.mergeWindows(windowFn.new MergeContext() { + @Override + public Collection windows() { + return windowsToMerge; + } + + @Override + public void merge(Collection toBeMerged, W mergeResult) throws Exception { + windows.add(mergeResult); + } + }); + } + + // Map every list of input timestamps to an output timestamp + final List combinedOutputTimestamps = new ArrayList<>(); + for (int i = 0; i < timestampsPerWindow.size(); ++i) { + List timestampsForWindow = timestampsPerWindow.get(i); + W window = windows.get(i); + + List outputInstants = new ArrayList<>(); + for (long inputTimestamp : timestampsForWindow) { + outputInstants.add(outputTimeFn.assignOutputTime(new Instant(inputTimestamp), window)); + } + + combinedOutputTimestamps.add(OutputTimeFns.combineOutputTimes(outputTimeFn, outputInstants)); + } + + // Consider windows in increasing order of max timestamp; ensure the output timestamp is after + // the max timestamp of the previous + @Nullable W earlierEndingWindow = null; + for (int i = 0; i < windows.size(); ++i) { + W window = windows.get(i); + ReadableInstant outputTimestamp = combinedOutputTimestamps.get(i); + + if (earlierEndingWindow != null) { + assertThat(outputTimestamp, + greaterThan((ReadableInstant) earlierEndingWindow.maxTimestamp())); + } + + earlierEndingWindow = window; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/package-info.java new file mode 100644 index 000000000000..d6f075d097cc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/testing/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines utilities for unit testing Dataflow pipelines. The tests for the {@code PTransform}s and + * examples included the Dataflow SDK provide examples of using these utilities. + */ +package com.google.cloud.dataflow.sdk.testing; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Aggregator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Aggregator.java new file mode 100644 index 000000000000..7e56ddac0dc7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Aggregator.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; + +/** + * An {@code Aggregator} enables monitoring of values of type {@code InputT}, + * to be combined across all bundles. + * + *

    Aggregators are created by calling {@link DoFn#createAggregator}, + * typically from the {@link DoFn} constructor. Elements can be added to the + * {@code Aggregator} by calling {@link Aggregator#addValue}. + * + *

    Aggregators are visible in the monitoring UI, when the pipeline is run + * using DataflowPipelineRunner or BlockingDataflowPipelineRunner, along with + * their current value. Aggregators may not become visible until the system + * begins executing the ParDo transform that created them and/or their initial + * value is changed. + * + *

    Example: + *

     {@code
    + * class MyDoFn extends DoFn {
    + *   private Aggregator myAggregator;
    + *
    + *   public MyDoFn() {
    + *     myAggregator = createAggregator("myAggregator", new Sum.SumIntegerFn());
    + *   }
    + *
    + *   @Override
    + *   public void processElement(ProcessContext c) {
    + *     myAggregator.addValue(1);
    + *   }
    + * }
    + * } 
    + * + * @param the type of input values + * @param the type of output values + */ +public interface Aggregator { + + /** + * Adds a new value into the Aggregator. + */ + void addValue(InputT value); + + /** + * Returns the name of the Aggregator. + */ + String getName(); + + /** + * Returns the {@link CombineFn}, which combines input elements in the + * aggregator. + */ + CombineFn getCombineFn(); + + // TODO: Consider the following additional API conveniences: + // - In addition to createAggregator(), consider adding getAggregator() to + // avoid the need to store the aggregator locally in a DoFn, i.e., create + // if not already present. + // - Add a shortcut for the most common aggregator: + // c.createAggregator("name", new Sum.SumIntegerFn()). +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/AggregatorRetriever.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/AggregatorRetriever.java new file mode 100644 index 000000000000..4bbea85f52a0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/AggregatorRetriever.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import java.util.Collection; + +/** + * An internal class for extracting {@link Aggregator Aggregators} from {@link DoFn DoFns}. + */ +public final class AggregatorRetriever { + private AggregatorRetriever() { + // do not instantiate + } + + /** + * Returns the {@link Aggregator Aggregators} created by the provided {@link DoFn}. + */ + public static Collection> getAggregators(DoFn fn) { + return fn.getAggregators(); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/AppliedPTransform.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/AppliedPTransform.java new file mode 100644 index 000000000000..7b3d87dfcf8b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/AppliedPTransform.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; + +/** + * Represents the application of a {@link PTransform} to a specific input to produce + * a specific output. + * + *

    For internal use. + * + * @param transform input type + * @param transform output type + * @param transform type + */ +public class AppliedPTransform + > { + + private final String fullName; + private final InputT input; + private final OutputT output; + private final TransformT transform; + + private AppliedPTransform(String fullName, InputT input, OutputT output, TransformT transform) { + this.input = input; + this.output = output; + this.transform = transform; + this.fullName = fullName; + } + + public static > + AppliedPTransform of( + String fullName, InputT input, OutputT output, TransformT transform) { + return new AppliedPTransform(fullName, input, output, transform); + } + + public String getFullName() { + return fullName; + } + + public InputT getInput() { + return input; + } + + public OutputT getOutput() { + return output; + } + + public TransformT getTransform() { + return transform; + } + + @Override + public int hashCode() { + return Objects.hashCode(getFullName(), getInput(), getOutput(), getTransform()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof AppliedPTransform) { + AppliedPTransform that = (AppliedPTransform) other; + return Objects.equal(this.getFullName(), that.getFullName()) + && Objects.equal(this.getInput(), that.getInput()) + && Objects.equal(this.getOutput(), that.getOutput()) + && Objects.equal(this.getTransform(), that.getTransform()); + } else { + return false; + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("fullName", getFullName()) + .add("input", getInput()) + .add("output", getOutput()) + .add("transform", getTransform()) + .toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantiles.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantiles.java new file mode 100644 index 000000000000..57dd51009b8f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantiles.java @@ -0,0 +1,766 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn.Accumulator; +import com.google.cloud.dataflow.sdk.util.WeightedValue; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.UnmodifiableIterator; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; + +import javax.annotation.Nullable; + +/** + * {@code PTransform}s for getting an idea of a {@code PCollection}'s + * data distribution using approximate {@code N}-tiles (e.g. quartiles, + * percentiles, etc.), either globally or per-key. + */ +public class ApproximateQuantiles { + private ApproximateQuantiles() { + // do not instantiate + } + + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection>} whose single value is a + * {@code List} of the approximate {@code N}-tiles of the elements + * of the input {@code PCollection}. This gives an idea of the + * distribution of the input elements. + * + *

    The computed {@code List} is of size {@code numQuantiles}, + * and contains the input elements' minimum value, + * {@code numQuantiles-2} intermediate values, and maximum value, in + * sorted order, using the given {@code Comparator} to order values. + * To compute traditional {@code N}-tiles, one should use + * {@code ApproximateQuantiles.globally(compareFn, N+1)}. + * + *

    If there are fewer input elements than {@code numQuantiles}, + * then the result {@code List} will contain all the input elements, + * in sorted order. + * + *

    The argument {@code Comparator} must be {@code Serializable}. + * + *

    Example of use: + *

     {@code
    +   * PCollection pc = ...;
    +   * PCollection> quantiles =
    +   *     pc.apply(ApproximateQuantiles.globally(stringCompareFn, 11));
    +   * } 
    + * + * @param the type of the elements in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + * @param compareFn the function to use to order the elements + */ + public static & Serializable> + PTransform, PCollection>> globally( + int numQuantiles, ComparatorT compareFn) { + return Combine.globally( + ApproximateQuantilesCombineFn.create(numQuantiles, compareFn)); + } + + /** + * Like {@link #globally(int, Comparator)}, but sorts using the + * elements' natural ordering. + * + * @param the type of the elements in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + */ + public static > + PTransform, PCollection>> globally(int numQuantiles) { + return Combine.globally( + ApproximateQuantilesCombineFn.create(numQuantiles)); + } + + /** + * Returns a {@code PTransform} that takes a + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to a {@code List} of the approximate + * {@code N}-tiles of the values associated with that key in the + * input {@code PCollection}. This gives an idea of the + * distribution of the input values for each key. + * + *

    Each of the computed {@code List}s is of size {@code numQuantiles}, + * and contains the input values' minimum value, + * {@code numQuantiles-2} intermediate values, and maximum value, in + * sorted order, using the given {@code Comparator} to order values. + * To compute traditional {@code N}-tiles, one should use + * {@code ApproximateQuantiles.perKey(compareFn, N+1)}. + * + *

    If a key has fewer than {@code numQuantiles} values + * associated with it, then that key's output {@code List} will + * contain all the key's input values, in sorted order. + * + *

    The argument {@code Comparator} must be {@code Serializable}. + * + *

    Example of use: + *

     {@code
    +   * PCollection> pc = ...;
    +   * PCollection>> quantilesPerKey =
    +   *     pc.apply(ApproximateQuantiles.perKey(stringCompareFn, 11));
    +   * } 
    + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + * @param compareFn the function to use to order the elements + */ + public static & Serializable> + PTransform>, PCollection>>> + perKey(int numQuantiles, ComparatorT compareFn) { + return Combine.perKey( + ApproximateQuantilesCombineFn.create(numQuantiles, compareFn) + .asKeyedFn()); + } + + /** + * Like {@link #perKey(int, Comparator)}, but sorts + * values using the their natural ordering. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param numQuantiles the number of elements in the resulting + * quantile values {@code List} + */ + public static > + PTransform>, PCollection>>> + perKey(int numQuantiles) { + return Combine.perKey( + ApproximateQuantilesCombineFn.create(numQuantiles) + .asKeyedFn()); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * The {@code ApproximateQuantilesCombineFn} combiner gives an idea + * of the distribution of a collection of values using approximate + * {@code N}-tiles. The output of this combiner is a {@code List} + * of size {@code numQuantiles}, containing the input values' + * minimum value, {@code numQuantiles-2} intermediate values, and + * maximum value, in sorted order, so for traditional + * {@code N}-tiles, one should use + * {@code ApproximateQuantilesCombineFn#create(N+1)}. + * + *

    If there are fewer values to combine than + * {@code numQuantiles}, then the result {@code List} will contain all the + * values being combined, in sorted order. + * + *

    Values are ordered using either a specified + * {@code Comparator} or the values' natural ordering. + * + *

    To evaluate the quantiles we use the "New Algorithm" described here: + *

    +   *   [MRL98] Manku, Rajagopalan & Lindsay, "Approximate Medians and other
    +   *   Quantiles in One Pass and with Limited Memory", Proc. 1998 ACM
    +   *   SIGMOD, Vol 27, No 2, p 426-435, June 1998.
    +   *   http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1&type=pdf
    +   * 
    + * + *

    The default error bound is {@code 1 / N}, though in practice + * the accuracy tends to be much better.

    See + * {@link #create(int, Comparator, long, double)} for + * more information about the meaning of {@code epsilon}, and + * {@link #withEpsilon} for a convenient way to adjust it. + * + * @param the type of the values being combined + */ + public static class ApproximateQuantilesCombineFn + & Serializable> + extends AccumulatingCombineFn, List> { + + /** + * The cost (in time and space) to compute quantiles to a given + * accuracy is a function of the total number of elements in the + * data set. If an estimate is not known or specified, we use + * this as an upper bound. If this is too low, errors may exceed + * the requested tolerance; if too high, efficiency may be + * non-optimal. The impact is logarithmic with respect to this + * value, so this default should be fine for most uses. + */ + public static final long DEFAULT_MAX_NUM_ELEMENTS = (long) 1e9; + + /** The comparison function to use. */ + private final ComparatorT compareFn; + + /** + * Number of quantiles to produce. The size of the final output + * list, including the minimum and maximum, is numQuantiles. + */ + private final int numQuantiles; + + /** The size of the buffers, corresponding to k in the referenced paper. */ + private final int bufferSize; + + /** The number of buffers, corresponding to b in the referenced paper. */ + private final int numBuffers; + + private final long maxNumElements; + + private ApproximateQuantilesCombineFn( + int numQuantiles, + ComparatorT compareFn, + int bufferSize, + int numBuffers, + long maxNumElements) { + Preconditions.checkArgument(numQuantiles >= 2); + Preconditions.checkArgument(bufferSize >= 2); + Preconditions.checkArgument(numBuffers >= 2); + this.numQuantiles = numQuantiles; + this.compareFn = compareFn; + this.bufferSize = bufferSize; + this.numBuffers = numBuffers; + this.maxNumElements = maxNumElements; + } + + /** + * Returns an approximate quantiles combiner with the given + * {@code compareFn} and desired number of quantiles. A total of + * {@code numQuantiles} elements will appear in the output list, + * including the minimum and maximum. + * + *

    The {@code Comparator} must be {@code Serializable}. + * + *

    The default error bound is {@code 1 / numQuantiles}, which + * holds as long as the number of elements is less than + * {@link #DEFAULT_MAX_NUM_ELEMENTS}. + */ + public static & Serializable> + ApproximateQuantilesCombineFn create( + int numQuantiles, ComparatorT compareFn) { + return create( + numQuantiles, compareFn, DEFAULT_MAX_NUM_ELEMENTS, 1.0 / numQuantiles); + } + + /** + * Like {@link #create(int, Comparator)}, but sorts values using their natural ordering. + */ + public static > + ApproximateQuantilesCombineFn> create(int numQuantiles) { + return create(numQuantiles, new Top.Largest()); + } + + /** + * Returns an {@code ApproximateQuantilesCombineFn} that's like + * this one except that it uses the specified {@code epsilon} + * value. Does not modify this combiner. + * + *

    See {@link #create(int, Comparator, long, + * double)} for more information about the meaning of + * {@code epsilon}. + */ + public ApproximateQuantilesCombineFn withEpsilon(double epsilon) { + return create(numQuantiles, compareFn, maxNumElements, epsilon); + } + + /** + * Returns an {@code ApproximateQuantilesCombineFn} that's like + * this one except that it uses the specified {@code maxNumElements} + * value. Does not modify this combiner. + * + *

    See {@link #create(int, Comparator, long, double)} for more + * information about the meaning of {@code maxNumElements}. + */ + public ApproximateQuantilesCombineFn withMaxInputSize( + long maxNumElements) { + return create(numQuantiles, compareFn, maxNumElements, maxNumElements); + } + + /** + * Creates an approximate quantiles combiner with the given + * {@code compareFn} and desired number of quantiles. A total of + * {@code numQuantiles} elements will appear in the output list, + * including the minimum and maximum. + * + *

    The {@code Comparator} must be {@code Serializable}. + * + *

    The default error bound is {@code epsilon}, which holds as long + * as the number of elements is less than {@code maxNumElements}. + * Specifically, if one considers the input as a sorted list x_1, ..., x_N, + * then the distance between the each exact quantile x_c and its + * approximation x_c' is bounded by {@code |c - c'| < epsilon * N}. + * Note that these errors are worst-case scenarios; in practice the accuracy + * tends to be much better. + */ + public static & Serializable> + ApproximateQuantilesCombineFn create( + int numQuantiles, + ComparatorT compareFn, + long maxNumElements, + double epsilon) { + // Compute optimal b and k. + int b = 2; + while ((b - 2) * (1 << (b - 2)) < epsilon * maxNumElements) { + b++; + } + b--; + int k = Math.max(2, (int) Math.ceil(maxNumElements / (1 << (b - 1)))); + return new ApproximateQuantilesCombineFn( + numQuantiles, compareFn, k, b, maxNumElements); + } + + @Override + public QuantileState createAccumulator() { + return QuantileState.empty(compareFn, numQuantiles, numBuffers, bufferSize); + } + + @Override + public Coder> getAccumulatorCoder( + CoderRegistry registry, Coder elementCoder) { + return new QuantileStateCoder<>(compareFn, elementCoder); + } + } + + /** + * Compact summarization of a collection on which quantiles can be estimated. + */ + static class QuantileState & Serializable> + implements Accumulator, List> { + + private ComparatorT compareFn; + private int numQuantiles; + private int numBuffers; + private int bufferSize; + + @Nullable + private T min; + + @Nullable + private T max; + + /** + * The set of buffers, ordered by level from smallest to largest. + */ + private PriorityQueue> buffers; + + /** + * The algorithm requires that the manipulated buffers always be filled + * to capacity to perform the collapse operation. This operation can + * be extended to buffers of varying sizes by introducing the notion of + * fractional weights, but it's easier to simply combine the remainders + * from all shards into new, full buffers and then take them into account + * when computing the final output. + */ + private List unbufferedElements = Lists.newArrayList(); + + private QuantileState( + ComparatorT compareFn, + int numQuantiles, + @Nullable T min, + @Nullable T max, + int numBuffers, + int bufferSize, + Collection unbufferedElements, + Collection> buffers) { + this.compareFn = compareFn; + this.numQuantiles = numQuantiles; + this.numBuffers = numBuffers; + this.bufferSize = bufferSize; + this.buffers = new PriorityQueue<>(numBuffers + 1); + this.min = min; + this.max = max; + this.unbufferedElements.addAll(unbufferedElements); + this.buffers.addAll(buffers); + } + + public static & Serializable> + QuantileState empty( + ComparatorT compareFn, int numQuantiles, int numBuffers, int bufferSize) { + return new QuantileState( + compareFn, + numQuantiles, + null, /* min */ + null, /* max */ + numBuffers, + bufferSize, + Collections.emptyList(), + Collections.>emptyList()); + } + + public static & Serializable> + QuantileState singleton( + ComparatorT compareFn, int numQuantiles, T elem, int numBuffers, int bufferSize) { + return new QuantileState( + compareFn, + numQuantiles, + elem, /* min */ + elem, /* max */ + numBuffers, + bufferSize, + Collections.singletonList(elem), + Collections.>emptyList()); + } + + /** + * Add a new element to the collection being summarized by this state. + */ + @Override + public void addInput(T elem) { + if (isEmpty()) { + min = max = elem; + } else if (compareFn.compare(elem, min) < 0) { + min = elem; + } else if (compareFn.compare(elem, max) > 0) { + max = elem; + } + addUnbuffered(elem); + } + + /** + * Add a new buffer to the unbuffered list, creating a new buffer and + * collapsing if needed. + */ + private void addUnbuffered(T elem) { + unbufferedElements.add(elem); + if (unbufferedElements.size() == bufferSize) { + Collections.sort(unbufferedElements, compareFn); + buffers.add(new QuantileBuffer(unbufferedElements)); + unbufferedElements = Lists.newArrayListWithCapacity(bufferSize); + collapseIfNeeded(); + } + } + + /** + * Updates this as if adding all elements seen by other. + * + *

    Note that this ignores the {@code Comparator} of the other {@link QuantileState}. In + * practice, they should generally be equal, but this method tolerates a mismatch. + */ + @Override + public void mergeAccumulator(QuantileState other) { + if (other.isEmpty()) { + return; + } + if (min == null || compareFn.compare(other.min, min) < 0) { + min = other.min; + } + if (max == null || compareFn.compare(other.max, max) > 0) { + max = other.max; + } + for (T elem : other.unbufferedElements) { + addUnbuffered(elem); + } + buffers.addAll(other.buffers); + collapseIfNeeded(); + } + + public boolean isEmpty() { + return unbufferedElements.size() == 0 && buffers.size() == 0; + } + + private void collapseIfNeeded() { + while (buffers.size() > numBuffers) { + List> toCollapse = Lists.newArrayList(); + toCollapse.add(buffers.poll()); + toCollapse.add(buffers.poll()); + int minLevel = toCollapse.get(1).level; + while (!buffers.isEmpty() && buffers.peek().level == minLevel) { + toCollapse.add(buffers.poll()); + } + buffers.add(collapse(toCollapse)); + } + } + + private QuantileBuffer collapse( + Iterable> buffers) { + int newLevel = 0; + long newWeight = 0; + for (QuantileBuffer buffer : buffers) { + // As presented in the paper, there should always be at least two + // buffers of the same (minimal) level to collapse, but it is possible + // to violate this condition when combining buffers from independently + // computed shards. If they differ we take the max. + newLevel = Math.max(newLevel, buffer.level + 1); + newWeight += buffer.weight; + } + List newElements = + interpolate(buffers, bufferSize, newWeight, offset(newWeight)); + return new QuantileBuffer<>(newLevel, newWeight, newElements); + } + + /** + * If the weight is even, we must round up or down. Alternate between these two options to + * avoid a bias. + */ + private long offset(long newWeight) { + if (newWeight % 2 == 1) { + return (newWeight + 1) / 2; + } else { + offsetJitter = 2 - offsetJitter; + return (newWeight + offsetJitter) / 2; + } + } + + /** For alternating between biasing up and down in the above even weight collapse operation. */ + private int offsetJitter = 0; + + + /** + * Emulates taking the ordered union of all elements in buffers, repeated + * according to their weight, and picking out the (k * step + offset)-th + * elements of this list for {@code 0 <= k < count}. + */ + private List interpolate(Iterable> buffers, + int count, double step, double offset) { + List>> iterators = Lists.newArrayList(); + for (QuantileBuffer buffer : buffers) { + iterators.add(buffer.sizedIterator()); + } + // Each of the buffers is already sorted by element. + Iterator> sorted = Iterators.mergeSorted( + iterators, + new Comparator>() { + @Override + public int compare(WeightedValue a, WeightedValue b) { + return compareFn.compare(a.getValue(), b.getValue()); + } + }); + + List newElements = Lists.newArrayListWithCapacity(count); + WeightedValue weightedElement = sorted.next(); + double current = weightedElement.getWeight(); + for (int j = 0; j < count; j++) { + double target = j * step + offset; + while (current <= target && sorted.hasNext()) { + weightedElement = sorted.next(); + current += weightedElement.getWeight(); + } + newElements.add(weightedElement.getValue()); + } + return newElements; + } + + /** + * Outputs numQuantiles elements consisting of the minimum, maximum, and + * numQuantiles - 2 evenly spaced intermediate elements. + * + *

    Returns the empty list if no elements have been added. + */ + @Override + public List extractOutput() { + if (isEmpty()) { + return Lists.newArrayList(); + } + long totalCount = unbufferedElements.size(); + for (QuantileBuffer buffer : buffers) { + totalCount += bufferSize * buffer.weight; + } + List> all = Lists.newArrayList(buffers); + if (!unbufferedElements.isEmpty()) { + Collections.sort(unbufferedElements, compareFn); + all.add(new QuantileBuffer<>(unbufferedElements)); + } + double step = 1.0 * totalCount / (numQuantiles - 1); + double offset = (1.0 * totalCount - 1) / (numQuantiles - 1); + List quantiles = interpolate(all, numQuantiles - 2, step, offset); + quantiles.add(0, min); + quantiles.add(max); + return quantiles; + } + } + + /** + * A single buffer in the sense of the referenced algorithm. + */ + private static class QuantileBuffer implements Comparable> { + private int level; + private long weight; + private List elements; + + public QuantileBuffer(List elements) { + this(0, 1, elements); + } + + public QuantileBuffer(int level, long weight, List elements) { + this.level = level; + this.weight = weight; + this.elements = elements; + } + + @Override + public int compareTo(QuantileBuffer other) { + return this.level - other.level; + } + + @Override + public String toString() { + return "QuantileBuffer[" + + "level=" + level + + ", weight=" + + weight + ", elements=" + elements + "]"; + } + + public Iterator> sizedIterator() { + return new UnmodifiableIterator>() { + Iterator iter = elements.iterator(); + @Override + public boolean hasNext() { + return iter.hasNext(); + } + @Override public WeightedValue next() { + return WeightedValue.of(iter.next(), weight); + } + }; + } + } + + /** + * Coder for QuantileState. + */ + private static class QuantileStateCoder & Serializable> + extends CustomCoder> { + private final ComparatorT compareFn; + private final Coder elementCoder; + private final Coder> elementListCoder; + private final Coder intCoder = BigEndianIntegerCoder.of(); + + public QuantileStateCoder(ComparatorT compareFn, Coder elementCoder) { + this.compareFn = compareFn; + this.elementCoder = elementCoder; + this.elementListCoder = ListCoder.of(elementCoder); + } + + @Override + public void encode( + QuantileState state, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + intCoder.encode(state.numQuantiles, outStream, nestedContext); + intCoder.encode(state.bufferSize, outStream, nestedContext); + elementCoder.encode(state.min, outStream, nestedContext); + elementCoder.encode(state.max, outStream, nestedContext); + elementListCoder.encode( + state.unbufferedElements, outStream, nestedContext); + BigEndianIntegerCoder.of().encode( + state.buffers.size(), outStream, nestedContext); + for (QuantileBuffer buffer : state.buffers) { + encodeBuffer(buffer, outStream, nestedContext); + } + } + + @Override + public QuantileState decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + int numQuantiles = intCoder.decode(inStream, nestedContext); + int bufferSize = intCoder.decode(inStream, nestedContext); + T min = elementCoder.decode(inStream, nestedContext); + T max = elementCoder.decode(inStream, nestedContext); + List unbufferedElements = + elementListCoder.decode(inStream, nestedContext); + int numBuffers = + BigEndianIntegerCoder.of().decode(inStream, nestedContext); + List> buffers = new ArrayList<>(numBuffers); + for (int i = 0; i < numBuffers; i++) { + buffers.add(decodeBuffer(inStream, nestedContext)); + } + return new QuantileState( + compareFn, numQuantiles, min, max, numBuffers, bufferSize, unbufferedElements, buffers); + } + + private void encodeBuffer( + QuantileBuffer buffer, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + DataOutputStream outData = new DataOutputStream(outStream); + outData.writeInt(buffer.level); + outData.writeLong(buffer.weight); + elementListCoder.encode(buffer.elements, outStream, context); + } + + private QuantileBuffer decodeBuffer( + InputStream inStream, Coder.Context context) + throws IOException, CoderException { + DataInputStream inData = new DataInputStream(inStream); + return new QuantileBuffer<>( + inData.readInt(), + inData.readLong(), + elementListCoder.decode(inStream, context)); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the + * encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + QuantileState state, + ElementByteSizeObserver observer, + Coder.Context context) + throws Exception { + Coder.Context nestedContext = context.nested(); + elementCoder.registerByteSizeObserver( + state.min, observer, nestedContext); + elementCoder.registerByteSizeObserver( + state.max, observer, nestedContext); + elementListCoder.registerByteSizeObserver( + state.unbufferedElements, observer, nestedContext); + + BigEndianIntegerCoder.of().registerByteSizeObserver( + state.buffers.size(), observer, nestedContext); + for (QuantileBuffer buffer : state.buffers) { + observer.update(4L + 8); + + elementListCoder.registerByteSizeObserver( + buffer.elements, observer, nestedContext); + } + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "QuantileState.ElementCoder must be deterministic", + elementCoder); + verifyDeterministic( + "QuantileState.ElementListCoder must be deterministic", + elementListCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUnique.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUnique.java new file mode 100644 index 000000000000..3c936a2b13a0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUnique.java @@ -0,0 +1,419 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.hash.Hashing; +import com.google.common.hash.HashingOutputStream; +import com.google.common.io.ByteStreams; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; + +/** + * {@code PTransform}s for estimating the number of distinct elements + * in a {@code PCollection}, or the number of distinct values + * associated with each key in a {@code PCollection} of {@code KV}s. + */ +public class ApproximateUnique { + + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection} containing a single value + * that is an estimate of the number of distinct elements in the + * input {@code PCollection}. + * + *

    The {@code sampleSize} parameter controls the estimation + * error. The error is about {@code 2 / sqrt(sampleSize)}, so for + * {@code ApproximateUnique.globally(10000)} the estimation error is + * about 2%. Similarly, for {@code ApproximateUnique.of(16)} the + * estimation error is about 50%. If there are fewer than + * {@code sampleSize} distinct elements then the returned result + * will be exact with extremely high probability (the chance of a + * hash collision is about {@code sampleSize^2 / 2^65}). + * + *

    This transform approximates the number of elements in a set + * by computing the top {@code sampleSize} hash values, and using + * that to extrapolate the size of the entire set of hash values by + * assuming the rest of the hash values are as densely distributed + * as the top {@code sampleSize}. + * + *

    See also {@link #globally(double)}. + * + *

    Example of use: + *

     {@code
    +   * PCollection pc = ...;
    +   * PCollection approxNumDistinct =
    +   *     pc.apply(ApproximateUnique.globally(1000));
    +   * } 
    + * + * @param the type of the elements in the input {@code PCollection} + * @param sampleSize the number of entries in the statistical + * sample; the higher this number, the more accurate the + * estimate will be; should be {@code >= 16} + * @throws IllegalArgumentException if the {@code sampleSize} + * argument is too small + */ + public static Globally globally(int sampleSize) { + return new Globally<>(sampleSize); + } + + /** + * Like {@link #globally(int)}, but specifies the desired maximum + * estimation error instead of the sample size. + * + * @param the type of the elements in the input {@code PCollection} + * @param maximumEstimationError the maximum estimation error, which + * should be in the range {@code [0.01, 0.5]} + * @throws IllegalArgumentException if the + * {@code maximumEstimationError} argument is out of range + */ + public static Globally globally(double maximumEstimationError) { + return new Globally<>(maximumEstimationError); + } + + /** + * Returns a {@code PTransform} that takes a + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output element + * mapping each distinct key in the input {@code PCollection} to an + * estimate of the number of distinct values associated with that + * key in the input {@code PCollection}. + * + *

    See {@link #globally(int)} for an explanation of the + * {@code sampleSize} parameter. A separate sampling is computed + * for each distinct key of the input. + * + *

    See also {@link #perKey(double)}. + * + *

    Example of use: + *

     {@code
    +   * PCollection> pc = ...;
    +   * PCollection> approxNumDistinctPerKey =
    +   *     pc.apply(ApproximateUnique.perKey(1000));
    +   * } 
    + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param sampleSize the number of entries in the statistical + * sample; the higher this number, the more accurate the + * estimate will be; should be {@code >= 16} + * @throws IllegalArgumentException if the {@code sampleSize} + * argument is too small + */ + public static PerKey perKey(int sampleSize) { + return new PerKey<>(sampleSize); + } + + /** + * Like {@link #perKey(int)}, but specifies the desired maximum + * estimation error instead of the sample size. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + * @param maximumEstimationError the maximum estimation error, which + * should be in the range {@code [0.01, 0.5]} + * @throws IllegalArgumentException if the + * {@code maximumEstimationError} argument is out of range + */ + public static PerKey perKey(double maximumEstimationError) { + return new PerKey<>(maximumEstimationError); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code PTransform} for estimating the number of distinct elements + * in a {@code PCollection}. + * + * @param the type of the elements in the input {@code PCollection} + */ + static class Globally extends PTransform, PCollection> { + + /** + * The number of entries in the statistical sample; the higher this number, + * the more accurate the estimate will be. + */ + private final long sampleSize; + + /** + * @see ApproximateUnique#globally(int) + */ + public Globally(int sampleSize) { + if (sampleSize < 16) { + throw new IllegalArgumentException( + "ApproximateUnique needs a sampleSize " + + ">= 16 for an estimation error <= 50%. " + + "In general, the estimation " + + "error is about 2 / sqrt(sampleSize)."); + } + this.sampleSize = sampleSize; + } + + /** + * @see ApproximateUnique#globally(double) + */ + public Globally(double maximumEstimationError) { + if (maximumEstimationError < 0.01 || maximumEstimationError > 0.5) { + throw new IllegalArgumentException( + "ApproximateUnique needs an " + + "estimation error between 1% (0.01) and 50% (0.5)."); + } + this.sampleSize = sampleSizeFromEstimationError(maximumEstimationError); + } + + @Override + public PCollection apply(PCollection input) { + Coder coder = input.getCoder(); + return input.apply( + Combine.globally( + new ApproximateUniqueCombineFn<>(sampleSize, coder))); + } + } + + /** + * {@code PTransform} for estimating the number of distinct values + * associated with each key in a {@code PCollection} of {@code KV}s. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + * @param the type of the values in the input {@code PCollection} + */ + static class PerKey + extends PTransform>, PCollection>> { + + private final long sampleSize; + + /** + * @see ApproximateUnique#perKey(int) + */ + public PerKey(int sampleSize) { + if (sampleSize < 16) { + throw new IllegalArgumentException( + "ApproximateUnique needs a " + + "sampleSize >= 16 for an estimation error <= 50%. In general, " + + "the estimation error is about 2 / sqrt(sampleSize)."); + } + this.sampleSize = sampleSize; + } + + /** + * @see ApproximateUnique#perKey(double) + */ + public PerKey(double estimationError) { + if (estimationError < 0.01 || estimationError > 0.5) { + throw new IllegalArgumentException( + "ApproximateUnique.PerKey needs an " + + "estimation error between 1% (0.01) and 50% (0.5)."); + } + this.sampleSize = sampleSizeFromEstimationError(estimationError); + } + + @Override + public PCollection> apply(PCollection> input) { + Coder> inputCoder = input.getCoder(); + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "ApproximateUnique.PerKey requires its input to use KvCoder"); + } + @SuppressWarnings("unchecked") + final Coder coder = ((KvCoder) inputCoder).getValueCoder(); + + return input.apply( + Combine.perKey(new ApproximateUniqueCombineFn<>( + sampleSize, coder).asKeyedFn())); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code CombineFn} that computes an estimate of the number of + * distinct values that were combined. + * + *

    Hashes input elements, computes the top {@code sampleSize} + * hash values, and uses those to extrapolate the size of the entire + * set of hash values by assuming the rest of the hash values are as + * densely distributed as the top {@code sampleSize}. + * + *

    Used to implement + * {@link #globally(int) ApproximatUnique.globally(...)} and + * {@link #perKey(int) ApproximatUnique.perKey(...)}. + * + * @param the type of the values being combined + */ + public static class ApproximateUniqueCombineFn extends + CombineFn { + + /** + * The size of the space of hashes returned by the hash function. + */ + static final double HASH_SPACE_SIZE = + Long.MAX_VALUE - (double) Long.MIN_VALUE; + + /** + * A heap utility class to efficiently track the largest added elements. + */ + public static class LargestUnique implements Serializable { + private PriorityQueue heap = new PriorityQueue<>(); + private final long sampleSize; + + /** + * Creates a heap to track the largest {@code sampleSize} elements. + * + * @param sampleSize the size of the heap + */ + public LargestUnique(long sampleSize) { + this.sampleSize = sampleSize; + } + + /** + * Adds a value to the heap, returning whether the value is (large enough + * to be) in the heap. + */ + public boolean add(Long value) { + if (heap.contains(value)) { + return true; + } else if (heap.size() < sampleSize) { + heap.add(value); + return true; + } else if (value > heap.element()) { + heap.remove(); + heap.add(value); + return true; + } else { + return false; + } + } + + /** + * Returns the values in the heap, ordered largest to smallest. + */ + public List extractOrderedList() { + // The only way to extract the order from the heap is element-by-element + // from smallest to largest. + Long[] array = new Long[heap.size()]; + for (int i = heap.size() - 1; i >= 0; i--) { + array[i] = heap.remove(); + } + return Arrays.asList(array); + } + } + + private final long sampleSize; + private final Coder coder; + + public ApproximateUniqueCombineFn(long sampleSize, Coder coder) { + this.sampleSize = sampleSize; + this.coder = coder; + } + + @Override + public LargestUnique createAccumulator() { + return new LargestUnique(sampleSize); + } + + @Override + public LargestUnique addInput(LargestUnique heap, T input) { + try { + heap.add(hash(input, coder)); + return heap; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @Override + public LargestUnique mergeAccumulators(Iterable heaps) { + Iterator iterator = heaps.iterator(); + LargestUnique heap = iterator.next(); + while (iterator.hasNext()) { + List largestHashes = iterator.next().extractOrderedList(); + for (long hash : largestHashes) { + if (!heap.add(hash)) { + break; // The remainder of this list is all smaller. + } + } + } + return heap; + } + + @Override + public Long extractOutput(LargestUnique heap) { + List largestHashes = heap.extractOrderedList(); + if (largestHashes.size() < sampleSize) { + return (long) largestHashes.size(); + } else { + long smallestSampleHash = largestHashes.get(largestHashes.size() - 1); + double sampleSpaceSize = Long.MAX_VALUE - (double) smallestSampleHash; + // This formula takes into account the possibility of hash collisions, + // which become more likely than not for 2^32 distinct elements. + // Note that log(1+x) ~ x for small x, so for sampleSize << maxHash + // log(1 - sampleSize/sampleSpace) / log(1 - 1/sampleSpace) ~ sampleSize + // and hence estimate ~ sampleSize * HASH_SPACE_SIZE / sampleSpace + // as one would expect. + double estimate = Math.log1p(-sampleSize / sampleSpaceSize) + / Math.log1p(-1 / sampleSpaceSize) + * HASH_SPACE_SIZE / sampleSpaceSize; + return Math.round(estimate); + } + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, + Coder inputCoder) { + return SerializableCoder.of(LargestUnique.class); + } + + /** + * Encodes the given element using the given coder and hashes the encoding. + */ + static long hash(T element, Coder coder) throws CoderException, IOException { + try (HashingOutputStream stream = + new HashingOutputStream(Hashing.murmur3_128(), ByteStreams.nullOutputStream())) { + coder.encode(element, stream, Context.OUTER); + return stream.hash().asLong(); + } + } + } + + /** + * Computes the sampleSize based on the desired estimation error. + * + * @param estimationError should be bounded by [0.01, 0.5] + * @return the sample size needed for the desired estimation error + */ + static long sampleSizeFromEstimationError(double estimationError) { + return Math.round(Math.ceil(4.0 / Math.pow(estimationError, 2.0))); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java new file mode 100644 index 000000000000..cc0347a12432 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java @@ -0,0 +1,2252 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.DelegateCoder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.AbstractGlobalCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.AbstractPerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.RequiresContextInternal; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.AppliedCombineFn; +import com.google.cloud.dataflow.sdk.util.PerKeyCombineFnRunner; +import com.google.cloud.dataflow.sdk.util.PerKeyCombineFnRunners; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +/** + * {@code PTransform}s for combining {@code PCollection} elements + * globally and per-key. + * + *

    See the documentation + * for how to use the operations in this class. + */ +public class Combine { + private Combine() { + // do not instantiate + } + + /** + * Returns a {@link Globally Combine.Globally} {@code PTransform} + * that uses the given {@code SerializableFunction} to combine all + * the elements in each window of the input {@code PCollection} into a + * single value in the output {@code PCollection}. The types of the input + * elements and the output elements must be the same. + * + *

    If the input {@code PCollection} is windowed into {@link GlobalWindows}, + * a default value in the {@link GlobalWindow} will be output if the input + * {@code PCollection} is empty. To use this with inputs with other windowing, + * either {@link Globally#withoutDefaults} or {@link Globally#asSingletonView} + * must be called. + * + *

    See {@link Globally Combine.Globally} for more information. + */ + public static Globally globally( + SerializableFunction, V> combiner) { + return globally(IterableCombineFn.of(combiner)); + } + + /** + * Returns a {@link Globally Combine.Globally} {@code PTransform} + * that uses the given {@code GloballyCombineFn} to combine all + * the elements in each window of the input {@code PCollection} into a + * single value in the output {@code PCollection}. The types of the input + * elements and the output elements can differ. + * + *

    If the input {@code PCollection} is windowed into {@link GlobalWindows}, + * a default value in the {@link GlobalWindow} will be output if the input + * {@code PCollection} is empty. To use this with inputs with other windowing, + * either {@link Globally#withoutDefaults} or {@link Globally#asSingletonView} + * must be called. + * + *

    See {@link Globally Combine.Globally} for more information. + */ + public static Globally globally( + GlobalCombineFn fn) { + return new Globally<>(fn, true, 0); + } + + /** + * Returns a {@link PerKey Combine.PerKey} {@code PTransform} that + * first groups its input {@code PCollection} of {@code KV}s by keys and + * windows, then invokes the given function on each of the values lists to + * produce a combined value, and then returns a {@code PCollection} + * of {@code KV}s mapping each distinct key to its combined value for each + * window. + * + *

    Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * as the input. + * + *

    See {@link PerKey Combine.PerKey} for more information. + */ + public static PerKey perKey( + SerializableFunction, V> fn) { + return perKey(Combine.IterableCombineFn.of(fn)); + } + + /** + * Returns a {@link PerKey Combine.PerKey} {@code PTransform} that + * first groups its input {@code PCollection} of {@code KV}s by keys and + * windows, then invokes the given function on each of the values lists to + * produce a combined value, and then returns a {@code PCollection} + * of {@code KV}s mapping each distinct key to its combined value for each + * window. + * + *

    Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * as the input. + * + *

    See {@link PerKey Combine.PerKey} for more information. + */ + public static PerKey perKey( + GlobalCombineFn fn) { + return perKey(fn.asKeyedFn()); + } + + /** + * Returns a {@link PerKey Combine.PerKey} {@code PTransform} that + * first groups its input {@code PCollection} of {@code KV}s by keys and + * windows, then invokes the given function on each of the key/values-lists + * pairs to produce a combined value, and then returns a + * {@code PCollection} of {@code KV}s mapping each distinct key to + * its combined value for each window. + * + *

    Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * as the input. + * + *

    See {@link PerKey Combine.PerKey} for more information. + */ + public static PerKey perKey( + PerKeyCombineFn fn) { + return new PerKey<>(fn, false /*fewKeys*/); + } + + /** + * Returns a {@link PerKey Combine.PerKey}, and set fewKeys + * in {@link GroupByKey}. + */ + private static PerKey fewKeys( + PerKeyCombineFn fn) { + return new PerKey<>(fn, true /*fewKeys*/); + } + + /** + * Returns a {@link GroupedValues Combine.GroupedValues} + * {@code PTransform} that takes a {@code PCollection} of + * {@code KV}s where a key maps to an {@code Iterable} of values, e.g., + * the result of a {@code GroupByKey}, then uses the given + * {@code SerializableFunction} to combine all the values associated + * with a key, ignoring the key. The type of the input and + * output values must be the same. + * + *

    Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + *

    See {@link GroupedValues Combine.GroupedValues} for more information. + * + *

    Note that {@link #perKey(SerializableFunction)} is typically + * more convenient to use than {@link GroupByKey} followed by + * {@code groupedValues(...)}. + */ + public static GroupedValues groupedValues( + SerializableFunction, V> fn) { + return groupedValues(IterableCombineFn.of(fn)); + } + + /** + * Returns a {@link GroupedValues Combine.GroupedValues} + * {@code PTransform} that takes a {@code PCollection} of + * {@code KV}s where a key maps to an {@code Iterable} of values, e.g., + * the result of a {@code GroupByKey}, then uses the given + * {@code CombineFn} to combine all the values associated with a + * key, ignoring the key. The types of the input and output values + * can differ. + * + *

    Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + *

    See {@link GroupedValues Combine.GroupedValues} for more information. + * + *

    Note that {@link #perKey(CombineFnBase.GlobalCombineFn)} is typically + * more convenient to use than {@link GroupByKey} followed by + * {@code groupedValues(...)}. + */ + public static GroupedValues groupedValues( + GlobalCombineFn fn) { + return groupedValues(fn.asKeyedFn()); + } + + /** + * Returns a {@link GroupedValues Combine.GroupedValues} + * {@code PTransform} that takes a {@code PCollection} of + * {@code KV}s where a key maps to an {@code Iterable} of values, e.g., + * the result of a {@code GroupByKey}, then uses the given + * {@code KeyedCombineFn} to combine all the values associated with + * each key. The combining function is provided the key. The types + * of the input and output values can differ. + * + *

    Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + *

    See {@link GroupedValues Combine.GroupedValues} for more information. + * + *

    Note that {@link #perKey(CombineFnBase.PerKeyCombineFn)} is typically + * more convenient to use than {@link GroupByKey} followed by + * {@code groupedValues(...)}. + */ + public static GroupedValues groupedValues( + PerKeyCombineFn fn) { + return new GroupedValues<>(fn); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code CombineFn} specifies how to combine a + * collection of input values of type {@code InputT} into a single + * output value of type {@code OutputT}. It does this via one or more + * intermediate mutable accumulator values of type {@code AccumT}. + * + *

    The overall process to combine a collection of input + * {@code InputT} values into a single output {@code OutputT} value is as + * follows: + * + *

      + * + *
    1. The input {@code InputT} values are partitioned into one or more + * batches. + * + *
    2. For each batch, the {@link #createAccumulator} operation is + * invoked to create a fresh mutable accumulator value of type + * {@code AccumT}, initialized to represent the combination of zero + * values. + * + *
    3. For each input {@code InputT} value in a batch, the + * {@link #addInput} operation is invoked to add the value to that + * batch's accumulator {@code AccumT} value. The accumulator may just + * record the new value (e.g., if {@code AccumT == List}, or may do + * work to represent the combination more compactly. + * + *
    4. The {@link #mergeAccumulators} operation is invoked to + * combine a collection of accumulator {@code AccumT} values into a + * single combined output accumulator {@code AccumT} value, once the + * merging accumulators have had all all the input values in their + * batches added to them. This operation is invoked repeatedly, + * until there is only one accumulator value left. + * + *
    5. The {@link #extractOutput} operation is invoked on the final + * accumulator {@code AccumT} value to get the output {@code OutputT} value. + * + *
    + * + *

    For example: + *

     {@code
    +   * public class AverageFn extends CombineFn {
    +   *   public static class Accum {
    +   *     int sum = 0;
    +   *     int count = 0;
    +   *   }
    +   *   public Accum createAccumulator() {
    +   *     return new Accum();
    +   *   }
    +   *   public Accum addInput(Accum accum, Integer input) {
    +   *       accum.sum += input;
    +   *       accum.count++;
    +   *       return accum;
    +   *   }
    +   *   public Accum mergeAccumulators(Iterable accums) {
    +   *     Accum merged = createAccumulator();
    +   *     for (Accum accum : accums) {
    +   *       merged.sum += accum.sum;
    +   *       merged.count += accum.count;
    +   *     }
    +   *     return merged;
    +   *   }
    +   *   public Double extractOutput(Accum accum) {
    +   *     return ((double) accum.sum) / accum.count;
    +   *   }
    +   * }
    +   * PCollection pc = ...;
    +   * PCollection average = pc.apply(Combine.globally(new AverageFn()));
    +   * } 
    + * + *

    Combining functions used by {@link Combine.Globally}, + * {@link Combine.PerKey}, {@link Combine.GroupedValues}, and + * {@code PTransforms} derived from them should be + * associative and commutative. Associativity is + * required because input values are first broken up into subgroups + * before being combined, and their intermediate results further + * combined, in an arbitrary tree structure. Commutativity is + * required because any order of the input values is ignored when + * breaking up input values into groups. + * + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public abstract static class CombineFn + extends AbstractGlobalCombineFn { + + /** + * Returns a new, mutable accumulator value, representing the accumulation of zero input values. + */ + public abstract AccumT createAccumulator(); + + /** + * Adds the given input value to the given accumulator, returning the + * new accumulator value. + * + *

    For efficiency, the input accumulator may be modified and returned. + */ + public abstract AccumT addInput(AccumT accumulator, InputT input); + + /** + * Returns an accumulator representing the accumulation of all the + * input values accumulated in the merging accumulators. + * + *

    May modify any of the argument accumulators. May return a + * fresh accumulator, or may return one of the (modified) argument + * accumulators. + */ + public abstract AccumT mergeAccumulators(Iterable accumulators); + + /** + * Returns the output value that is the result of combining all + * the input values represented by the given accumulator. + */ + public abstract OutputT extractOutput(AccumT accumulator); + + /** + * Returns an accumulator that represents the same logical value as the + * input accumulator, but may have a more compact representation. + * + *

    For most CombineFns this would be a no-op, but should be overridden + * by CombineFns that (for example) buffer up elements and combine + * them in batches. + * + *

    For efficiency, the input accumulator may be modified and returned. + * + *

    By default returns the original accumulator. + */ + public AccumT compact(AccumT accumulator) { + return accumulator; + } + + /** + * Applies this {@code CombineFn} to a collection of input values + * to produce a combined output value. + * + *

    Useful when using a {@code CombineFn} separately from a + * {@code Combine} transform. Does not invoke the + * {@link mergeAccumulators} operation. + */ + public OutputT apply(Iterable inputs) { + AccumT accum = createAccumulator(); + for (InputT input : inputs) { + accum = addInput(accum, input); + } + return extractOutput(accum); + } + + /** + * {@inheritDoc} + * + *

    By default returns the extract output of an empty accumulator. + */ + @Override + public OutputT defaultValue() { + return extractOutput(createAccumulator()); + } + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the output type of this {@code CombineFn} instance's + * most-derived class. + * + *

    In the normal case of a concrete {@code CombineFn} subclass with + * no generic type parameters of its own, this will be a complete + * non-generic type. + */ + public TypeDescriptor getOutputType() { + return new TypeDescriptor(getClass()) {}; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public KeyedCombineFn asKeyedFn() { + // The key, an object, is never even looked at. + return new KeyedCombineFn() { + @Override + public AccumT createAccumulator(K key) { + return CombineFn.this.createAccumulator(); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT input) { + return CombineFn.this.addInput(accumulator, input); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators) { + return CombineFn.this.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator) { + return CombineFn.this.extractOutput(accumulator); + } + + @Override + public AccumT compact(K key, AccumT accumulator) { + return CombineFn.this.compact(accumulator); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) + throws CannotProvideCoderException { + return CombineFn.this.getAccumulatorCoder(registry, inputCoder); + } + + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) + throws CannotProvideCoderException { + return CombineFn.this.getDefaultOutputCoder(registry, inputCoder); + } + + @Override + public CombineFn forKey(K key, Coder keyCoder) { + return CombineFn.this; + } + }; + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * An abstract subclass of {@link CombineFn} for implementing combiners that are more + * easily expressed as binary operations. + */ + public abstract static class BinaryCombineFn extends + CombineFn, V> { + + /** + * Applies the binary operation to the two operands, returning the result. + */ + public abstract V apply(V left, V right); + + /** + * Returns the value that should be used for the combine of the empty set. + */ + public V identity() { + return null; + } + + @Override + public Holder createAccumulator() { + return new Holder<>(); + } + + @Override + public Holder addInput(Holder accumulator, V input) { + if (accumulator.present) { + accumulator.set(apply(accumulator.value, input)); + } else { + accumulator.set(input); + } + return accumulator; + } + + @Override + public Holder mergeAccumulators(Iterable> accumulators) { + Iterator> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + Holder running = iter.next(); + while (iter.hasNext()) { + Holder accum = iter.next(); + if (accum.present) { + if (running.present) { + running.set(apply(running.value, accum.value)); + } else { + running.set(accum.value); + } + } + } + return running; + } + } + + @Override + public V extractOutput(Holder accumulator) { + if (accumulator.present) { + return accumulator.value; + } else { + return identity(); + } + } + + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return new HolderCoder<>(inputCoder); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return inputCoder; + } + + } + + /** + * Holds a single value value of type {@code V} which may or may not be present. + * + *

    Used only as a private accumulator class. + */ + public static class Holder { + private V value; + private boolean present; + private Holder() { } + private Holder(V value) { + set(value); + } + + private void set(V value) { + this.present = true; + this.value = value; + } + } + + /** + * A {@link Coder} for a {@link Holder}. + */ + private static class HolderCoder extends CustomCoder> { + + private Coder valueCoder; + + public HolderCoder(Coder valueCoder) { + this.valueCoder = valueCoder; + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(valueCoder); + } + + @Override + public void encode(Holder accumulator, OutputStream outStream, Context context) + throws CoderException, IOException { + if (accumulator.present) { + outStream.write(1); + valueCoder.encode(accumulator.value, outStream, context); + } else { + outStream.write(0); + } + } + + @Override + public Holder decode(InputStream inStream, Context context) + throws CoderException, IOException { + if (inStream.read() == 1) { + return new Holder<>(valueCoder.decode(inStream, context)); + } else { + return new Holder<>(); + } + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + valueCoder.verifyDeterministic(); + } + } + + /** + * An abstract subclass of {@link CombineFn} for implementing combiners that are more + * easily and efficiently expressed as binary operations on ints + * + *

    It uses {@code int[0]} as the mutable accumulator. + */ + public abstract static class BinaryCombineIntegerFn extends CombineFn { + + /** + * Applies the binary operation to the two operands, returning the result. + */ + public abstract int apply(int left, int right); + + /** + * Returns the identity element of this operation, i.e. an element {@code e} + * such that {@code apply(e, x) == apply(x, e) == x} for all values of {@code x}. + */ + public abstract int identity(); + + @Override + public int[] createAccumulator() { + return wrap(identity()); + } + + @Override + public int[] addInput(int[] accumulator, Integer input) { + accumulator[0] = apply(accumulator[0], input); + return accumulator; + } + + @Override + public int[] mergeAccumulators(Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + int[] running = iter.next(); + while (iter.hasNext()) { + running[0] = apply(running[0], iter.next()[0]); + } + return running; + } + } + + @Override + public Integer extractOutput(int[] accumulator) { + return accumulator[0]; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return DelegateCoder.of( + inputCoder, + new DelegateCoder.CodingFunction() { + @Override + public Integer apply(int[] accumulator) { + return accumulator[0]; + } + }, + new DelegateCoder.CodingFunction() { + @Override + public int[] apply(Integer value) { + return wrap(value); + } + }); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, + Coder inputCoder) { + return inputCoder; + } + + private int[] wrap(int value) { + return new int[] { value }; + } + + public Counter getCounter(String name) { + throw new UnsupportedOperationException("BinaryCombineDoubleFn does not support getCounter"); + } + } + + /** + * An abstract subclass of {@link CombineFn} for implementing combiners that are more + * easily and efficiently expressed as binary operations on longs. + * + *

    It uses {@code long[0]} as the mutable accumulator. + */ + public abstract static class BinaryCombineLongFn extends CombineFn { + /** + * Applies the binary operation to the two operands, returning the result. + */ + public abstract long apply(long left, long right); + + /** + * Returns the identity element of this operation, i.e. an element {@code e} + * such that {@code apply(e, x) == apply(x, e) == x} for all values of {@code x}. + */ + public abstract long identity(); + + @Override + public long[] createAccumulator() { + return wrap(identity()); + } + + @Override + public long[] addInput(long[] accumulator, Long input) { + accumulator[0] = apply(accumulator[0], input); + return accumulator; + } + + @Override + public long[] mergeAccumulators(Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + long[] running = iter.next(); + while (iter.hasNext()) { + running[0] = apply(running[0], iter.next()[0]); + } + return running; + } + } + + @Override + public Long extractOutput(long[] accumulator) { + return accumulator[0]; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return DelegateCoder.of( + inputCoder, + new DelegateCoder.CodingFunction() { + @Override + public Long apply(long[] accumulator) { + return accumulator[0]; + } + }, + new DelegateCoder.CodingFunction() { + @Override + public long[] apply(Long value) { + return wrap(value); + } + }); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return inputCoder; + } + + private long[] wrap(long value) { + return new long[] { value }; + } + + public Counter getCounter(String name) { + throw new UnsupportedOperationException("BinaryCombineDoubleFn does not support getCounter"); + } + } + + /** + * An abstract subclass of {@link CombineFn} for implementing combiners that are more + * easily and efficiently expressed as binary operations on doubles. + * + *

    It uses {@code double[0]} as the mutable accumulator. + */ + public abstract static class BinaryCombineDoubleFn extends CombineFn { + + /** + * Applies the binary operation to the two operands, returning the result. + */ + public abstract double apply(double left, double right); + + /** + * Returns the identity element of this operation, i.e. an element {@code e} + * such that {@code apply(e, x) == apply(x, e) == x} for all values of {@code x}. + */ + public abstract double identity(); + + @Override + public double[] createAccumulator() { + return wrap(identity()); + } + + @Override + public double[] addInput(double[] accumulator, Double input) { + accumulator[0] = apply(accumulator[0], input); + return accumulator; + } + + @Override + public double[] mergeAccumulators(Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + double[] running = iter.next(); + while (iter.hasNext()) { + running[0] = apply(running[0], iter.next()[0]); + } + return running; + } + } + + @Override + public Double extractOutput(double[] accumulator) { + return accumulator[0]; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return DelegateCoder.of( + inputCoder, + new DelegateCoder.CodingFunction() { + @Override + public Double apply(double[] accumulator) { + return accumulator[0]; + } + }, + new DelegateCoder.CodingFunction() { + @Override + public double[] apply(Double value) { + return wrap(value); + } + }); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return inputCoder; + } + + private double[] wrap(double value) { + return new double[] { value }; + } + + public Counter getCounter(String name) { + throw new UnsupportedOperationException("BinaryCombineDoubleFn does not support getCounter"); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code CombineFn} that uses a subclass of + * {@link AccumulatingCombineFn.Accumulator} as its accumulator + * type. By defining the operations of the {@code Accumulator} + * helper class, the operations of the enclosing {@code CombineFn} + * are automatically provided. This can reduce the code required to + * implement a {@code CombineFn}. + * + *

    For example, the example from {@link CombineFn} above can be + * expressed using {@code AccumulatingCombineFn} more concisely as + * follows: + * + *

     {@code
    +   * public class AverageFn
    +   *     extends AccumulatingCombineFn {
    +   *   public Accum createAccumulator() {
    +   *     return new Accum();
    +   *   }
    +   *   public class Accum
    +   *       extends AccumulatingCombineFn
    +   *               .Accumulator {
    +   *     private int sum = 0;
    +   *     private int count = 0;
    +   *     public void addInput(Integer input) {
    +   *       sum += input;
    +   *       count++;
    +   *     }
    +   *     public void mergeAccumulator(Accum other) {
    +   *       sum += other.sum;
    +   *       count += other.count;
    +   *     }
    +   *     public Double extractOutput() {
    +   *       return ((double) sum) / count;
    +   *     }
    +   *   }
    +   * }
    +   * PCollection pc = ...;
    +   * PCollection average = pc.apply(Combine.globally(new AverageFn()));
    +   * } 
    + * + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public abstract static class AccumulatingCombineFn< + InputT, + AccumT extends AccumulatingCombineFn.Accumulator, + OutputT> + extends CombineFn { + + /** + * The type of mutable accumulator values used by this + * {@code AccumulatingCombineFn}. + */ + public abstract static interface Accumulator { + /** + * Adds the given input value to this accumulator, modifying + * this accumulator. + */ + public abstract void addInput(InputT input); + + /** + * Adds the input values represented by the given accumulator + * into this accumulator. + */ + public abstract void mergeAccumulator(AccumT other); + + /** + * Returns the output value that is the result of combining all + * the input values represented by this accumulator. + */ + public abstract OutputT extractOutput(); + } + + @Override + public final AccumT addInput(AccumT accumulator, InputT input) { + accumulator.addInput(input); + return accumulator; + } + + @Override + public final AccumT mergeAccumulators(Iterable accumulators) { + AccumT accumulator = createAccumulator(); + for (AccumT partial : accumulators) { + accumulator.mergeAccumulator(partial); + } + return accumulator; + } + + @Override + public final OutputT extractOutput(AccumT accumulator) { + return accumulator.extractOutput(); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + + /** + * A {@code KeyedCombineFn} specifies how to combine + * a collection of input values of type {@code InputT}, associated with + * a key of type {@code K}, into a single output value of type + * {@code OutputT}. It does this via one or more intermediate mutable + * accumulator values of type {@code AccumT}. + * + *

    The overall process to combine a collection of input + * {@code InputT} values associated with an input {@code K} key into a + * single output {@code OutputT} value is as follows: + * + *

      + * + *
    1. The input {@code InputT} values are partitioned into one or more + * batches. + * + *
    2. For each batch, the {@link #createAccumulator} operation is + * invoked to create a fresh mutable accumulator value of type + * {@code AccumT}, initialized to represent the combination of zero + * values. + * + *
    3. For each input {@code InputT} value in a batch, the + * {@link #addInput} operation is invoked to add the value to that + * batch's accumulator {@code AccumT} value. The accumulator may just + * record the new value (e.g., if {@code AccumT == List}, or may do + * work to represent the combination more compactly. + * + *
    4. The {@link #mergeAccumulators} operation is invoked to + * combine a collection of accumulator {@code AccumT} values into a + * single combined output accumulator {@code AccumT} value, once the + * merging accumulators have had all all the input values in their + * batches added to them. This operation is invoked repeatedly, + * until there is only one accumulator value left. + * + *
    5. The {@link #extractOutput} operation is invoked on the final + * accumulator {@code AccumT} value to get the output {@code OutputT} value. + * + *
    + * + *

    All of these operations are passed the {@code K} key that the + * values being combined are associated with. + * + *

    For example: + *

     {@code
    +   * public class ConcatFn
    +   *     extends KeyedCombineFn {
    +   *   public static class Accum {
    +   *     String s = "";
    +   *   }
    +   *   public Accum createAccumulator(String key) {
    +   *     return new Accum();
    +   *   }
    +   *   public Accum addInput(String key, Accum accum, Integer input) {
    +   *       accum.s += "+" + input;
    +   *       return accum;
    +   *   }
    +   *   public Accum mergeAccumulators(String key, Iterable accums) {
    +   *     Accum merged = new Accum();
    +   *     for (Accum accum : accums) {
    +   *       merged.s += accum.s;
    +   *     }
    +   *     return merged;
    +   *   }
    +   *   public String extractOutput(String key, Accum accum) {
    +   *     return key + accum.s;
    +   *   }
    +   * }
    +   * PCollection> pc = ...;
    +   * PCollection> pc2 = pc.apply(
    +   *     Combine.perKey(new ConcatFn()));
    +   * } 
    + * + *

    Keyed combining functions used by {@link Combine.PerKey}, + * {@link Combine.GroupedValues}, and {@code PTransforms} derived + * from them should be associative and commutative. + * Associativity is required because input values are first broken + * up into subgroups before being combined, and their intermediate + * results further combined, in an arbitrary tree structure. + * Commutativity is required because any order of the input values + * is ignored when breaking up input values into groups. + * + * @param type of keys + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public abstract static class KeyedCombineFn + extends AbstractPerKeyCombineFn { + /** + * Returns a new, mutable accumulator value representing the accumulation of zero input values. + * + * @param key the key that all the accumulated values using the + * accumulator are associated with + */ + public abstract AccumT createAccumulator(K key); + + /** + * Adds the given input value to the given accumulator, returning the new accumulator value. + * + *

    For efficiency, the input accumulator may be modified and returned. + * + * @param key the key that all the accumulated values using the + * accumulator are associated with + */ + public abstract AccumT addInput(K key, AccumT accumulator, InputT value); + + /** + * Returns an accumulator representing the accumulation of all the + * input values accumulated in the merging accumulators. + * + *

    May modify any of the argument accumulators. May return a + * fresh accumulator, or may return one of the (modified) argument + * accumulators. + * + * @param key the key that all the accumulators are associated + * with + */ + public abstract AccumT mergeAccumulators(K key, Iterable accumulators); + + /** + * Returns the output value that is the result of combining all + * the input values represented by the given accumulator. + * + * @param key the key that all the accumulated values using the + * accumulator are associated with + */ + public abstract OutputT extractOutput(K key, AccumT accumulator); + + /** + * Returns an accumulator that represents the same logical value as the + * input accumulator, but may have a more compact representation. + * + *

    For most CombineFns this would be a no-op, but should be overridden + * by CombineFns that (for example) buffer up elements and combine + * them in batches. + * + *

    For efficiency, the input accumulator may be modified and returned. + * + *

    By default returns the original accumulator. + */ + public AccumT compact(K key, AccumT accumulator) { + return accumulator; + } + + @Override + public CombineFn forKey(final K key, final Coder keyCoder) { + return new CombineFn() { + + @Override + public AccumT createAccumulator() { + return KeyedCombineFn.this.createAccumulator(key); + } + + @Override + public AccumT addInput(AccumT accumulator, InputT input) { + return KeyedCombineFn.this.addInput(key, accumulator, input); + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return KeyedCombineFn.this.mergeAccumulators(key, accumulators); + } + + @Override + public OutputT extractOutput(AccumT accumulator) { + return KeyedCombineFn.this.extractOutput(key, accumulator); + } + + @Override + public AccumT compact(AccumT accumulator) { + return KeyedCombineFn.this.compact(key, accumulator); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return KeyedCombineFn.this.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) throws CannotProvideCoderException { + return KeyedCombineFn.this.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + }; + } + + /** + * Applies this {@code KeyedCombineFn} to a key and a collection + * of input values to produce a combined output value. + * + *

    Useful when testing the behavior of a {@code KeyedCombineFn} + * separately from a {@code Combine} transform. + */ + public OutputT apply(K key, Iterable inputs) { + AccumT accum = createAccumulator(key); + for (InputT input : inputs) { + accum = addInput(key, accum, input); + } + return extractOutput(key, accum); + } + } + + //////////////////////////////////////////////////////////////////////////// + + /** + * {@code Combine.Globally} takes a {@code PCollection} + * and returns a {@code PCollection} whose elements are the result of + * combining all the elements in each window of the input {@code PCollection}, + * using a specified {@link CombineFn CombineFn<InputT, AccumT, OutputT>}. + * It is common for {@code InputT == OutputT}, but not required. Common combining + * functions include sums, mins, maxes, and averages of numbers, + * conjunctions and disjunctions of booleans, statistical + * aggregations, etc. + * + *

    Example of use: + *

     {@code
    +   * PCollection pc = ...;
    +   * PCollection sum = pc.apply(
    +   *     Combine.globally(new Sum.SumIntegerFn()));
    +   * } 
    + * + *

    Combining can happen in parallel, with different subsets of the + * input {@code PCollection} being combined separately, and their + * intermediate results combined further, in an arbitrary tree + * reduction pattern, until a single result value is produced. + * + *

    If the input {@code PCollection} is windowed into {@link GlobalWindows}, + * a default value in the {@link GlobalWindow} will be output if the input + * {@code PCollection} is empty. To use this with inputs with other windowing, + * either {@link #withoutDefaults} or {@link #asSingletonView} must be called, + * as the default value cannot be automatically assigned to any single window. + * + *

    By default, the {@code Coder} of the output {@code PValue} + * is inferred from the concrete type of the + * {@code CombineFn}'s output type {@code OutputT}. + * + *

    See also {@link #perKey}/{@link PerKey Combine.PerKey} and + * {@link #groupedValues}/{@link GroupedValues Combine.GroupedValues}, which + * are useful for combining values associated with each key in + * a {@code PCollection} of {@code KV}s. + * + * @param type of input values + * @param type of output values + */ + public static class Globally + extends PTransform, PCollection> { + + private final GlobalCombineFn fn; + private final boolean insertDefault; + private final int fanout; + private final List> sideInputs; + + private Globally(GlobalCombineFn fn, + boolean insertDefault, int fanout) { + this.fn = fn; + this.insertDefault = insertDefault; + this.fanout = fanout; + this.sideInputs = ImmutableList.>of(); + } + + private Globally(String name, GlobalCombineFn fn, + boolean insertDefault, int fanout) { + super(name); + this.fn = fn; + this.insertDefault = insertDefault; + this.fanout = fanout; + this.sideInputs = ImmutableList.>of(); + } + + private Globally(String name, GlobalCombineFn fn, + boolean insertDefault, int fanout, List> sideInputs) { + super(name); + this.fn = fn; + this.insertDefault = insertDefault; + this.fanout = fanout; + this.sideInputs = sideInputs; + } + + /** + * Return a new {@code Globally} transform that's like this transform but with the + * specified name. Does not modify this transform. + */ + public Globally named(String name) { + return new Globally<>(name, fn, insertDefault, fanout); + } + + /** + * Returns a {@link PTransform} that produces a {@code PCollectionView} + * whose elements are the result of combining elements per-window in + * the input {@code PCollection}. If a value is requested from the view + * for a window that is not present, the result of applying the {@code CombineFn} + * to an empty input set will be returned. + */ + public GloballyAsSingletonView asSingletonView() { + return new GloballyAsSingletonView<>(fn, insertDefault, fanout); + } + + /** + * Returns a {@link PTransform} identical to this, but that does not attempt to + * provide a default value in the case of empty input. Required when the input + * is not globally windowed and the output is not being used as a side input. + */ + public Globally withoutDefaults() { + return new Globally<>(name, fn, false, fanout); + } + + /** + * Returns a {@link PTransform} identical to this, but that uses an intermediate node + * to combine parts of the data to reduce load on the final global combine step. + * + *

    The {@code fanout} parameter determines the number of intermediate keys + * that will be used. + */ + public Globally withFanout(int fanout) { + return new Globally<>(name, fn, insertDefault, fanout); + } + + /** + * Returns a {@link PTransform} identical to this, but with the specified side inputs to use + * in {@link CombineFnWithContext}. + */ + public Globally withSideInputs( + Iterable> sideInputs) { + Preconditions.checkState(fn instanceof RequiresContextInternal); + return new Globally(name, fn, insertDefault, fanout, + ImmutableList.>copyOf(sideInputs)); + } + + @Override + public PCollection apply(PCollection input) { + PCollection> withKeys = input + .apply(WithKeys.of((Void) null)) + .setCoder(KvCoder.of(VoidCoder.of(), input.getCoder())); + + Combine.PerKey combine = + Combine.fewKeys(fn.asKeyedFn()); + if (!sideInputs.isEmpty()) { + combine = combine.withSideInputs(sideInputs); + } + + PCollection> combined; + if (fanout >= 2) { + combined = withKeys.apply(combine.withHotKeyFanout(fanout)); + } else { + combined = withKeys.apply(combine); + } + + PCollection output = combined.apply(Values.create()); + + if (insertDefault) { + if (!output.getWindowingStrategy().getWindowFn().isCompatible(new GlobalWindows())) { + throw new IllegalStateException(fn.getIncompatibleGlobalWindowErrorMessage()); + } + return insertDefaultValueIfEmpty(output); + } else { + return output; + } + } + + private PCollection insertDefaultValueIfEmpty(PCollection maybeEmpty) { + final PCollectionView> maybeEmptyView = maybeEmpty.apply( + View.asIterable()); + + + final OutputT defaultValue = fn.defaultValue(); + PCollection defaultIfEmpty = maybeEmpty.getPipeline() + .apply("CreateVoid", Create.of((Void) null).withCoder(VoidCoder.of())) + .apply(ParDo.named("ProduceDefault").withSideInputs(maybeEmptyView).of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) { + Iterator combined = c.sideInput(maybeEmptyView).iterator(); + if (!combined.hasNext()) { + c.output(defaultValue); + } + } + })) + .setCoder(maybeEmpty.getCoder()) + .setWindowingStrategyInternal(maybeEmpty.getWindowingStrategy()); + + return PCollectionList.of(maybeEmpty).and(defaultIfEmpty) + .apply(Flatten.pCollections()); + } + } + + /** + * {@code Combine.GloballyAsSingletonView} takes a {@code PCollection} + * and returns a {@code PCollectionView} whose elements are the result of + * combining all the elements in each window of the input {@code PCollection}, + * using a specified {@link CombineFn CombineFn<InputT, AccumT, OutputT>}. + * It is common for {@code InputT == OutputT}, but not required. Common combining + * functions include sums, mins, maxes, and averages of numbers, + * conjunctions and disjunctions of booleans, statistical + * aggregations, etc. + * + *

    Example of use: + *

     {@code
    +   * PCollection pc = ...;
    +   * PCollection sum = pc.apply(
    +   *     Combine.globally(new Sum.SumIntegerFn()));
    +   * } 
    + * + *

    Combining can happen in parallel, with different subsets of the + * input {@code PCollection} being combined separately, and their + * intermediate results combined further, in an arbitrary tree + * reduction pattern, until a single result value is produced. + * + *

    If a value is requested from the view for a window that is not present + * and {@code insertDefault} is true, the result of calling the {@code CombineFn} + * on empty input will returned. If {@code insertDefault} is false, an + * exception will be thrown instead. + * + *

    By default, the {@code Coder} of the output {@code PValue} + * is inferred from the concrete type of the + * {@code CombineFn}'s output type {@code OutputT}. + * + *

    See also {@link #perKey}/{@link PerKey Combine.PerKey} and + * {@link #groupedValues}/{@link GroupedValues Combine.GroupedValues}, which + * are useful for combining values associated with each key in + * a {@code PCollection} of {@code KV}s. + * + * @param type of input values + * @param type of output values + */ + public static class GloballyAsSingletonView + extends PTransform, PCollectionView> { + + private final GlobalCombineFn fn; + private final boolean insertDefault; + private final int fanout; + + private GloballyAsSingletonView( + GlobalCombineFn fn, boolean insertDefault, int fanout) { + this.fn = fn; + this.insertDefault = insertDefault; + this.fanout = fanout; + } + + @Override + public PCollectionView apply(PCollection input) { + Globally combineGlobally = + Combine.globally(fn).withoutDefaults().withFanout(fanout); + if (insertDefault) { + return input + .apply(combineGlobally) + .apply(View.asSingleton().withDefaultValue(fn.defaultValue())); + } else { + return input + .apply(combineGlobally) + .apply(View.asSingleton()); + } + } + + public int getFanout() { + return fanout; + } + + public boolean getInsertDefault() { + return insertDefault; + } + + public GlobalCombineFn getCombineFn() { + return fn; + } + } + + /** + * Converts a {@link SerializableFunction} from {@code Iterable}s + * to {@code V}s into a simple {@link CombineFn} over {@code V}s. + * + *

    Used in the implementation of convenience methods like + * {@link #globally(SerializableFunction)}, + * {@link #perKey(SerializableFunction)}, and + * {@link #groupedValues(SerializableFunction)}. + */ + public static class IterableCombineFn extends CombineFn, V> { + /** + * Returns a {@code CombineFn} that uses the given + * {@code SerializableFunction} to combine values. + */ + public static IterableCombineFn of( + SerializableFunction, V> combiner) { + return of(combiner, DEFAULT_BUFFER_SIZE); + } + + /** + * Returns a {@code CombineFn} that uses the given + * {@code SerializableFunction} to combine values, + * attempting to buffer at least {@code bufferSize} + * values between invocations. + */ + public static IterableCombineFn of( + SerializableFunction, V> combiner, int bufferSize) { + return new IterableCombineFn<>(combiner, bufferSize); + } + + private static final int DEFAULT_BUFFER_SIZE = 20; + + /** The combiner function. */ + private final SerializableFunction, V> combiner; + + /** + * The number of values to accumulate before invoking the combiner + * function to combine them. + */ + private final int bufferSize; + + private IterableCombineFn( + SerializableFunction, V> combiner, int bufferSize) { + this.combiner = combiner; + this.bufferSize = bufferSize; + } + + @Override + public List createAccumulator() { + return new ArrayList<>(); + } + + @Override + public List addInput(List accumulator, V input) { + accumulator.add(input); + if (accumulator.size() > bufferSize) { + return mergeToSingleton(accumulator); + } else { + return accumulator; + } + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + return mergeToSingleton(Iterables.concat(accumulators)); + } + + @Override + public V extractOutput(List accumulator) { + return combiner.apply(accumulator); + } + + @Override + public List compact(List accumulator) { + return accumulator.size() > 1 ? mergeToSingleton(accumulator) : accumulator; + } + + private List mergeToSingleton(Iterable values) { + List singleton = new ArrayList<>(); + singleton.add(combiner.apply(values)); + return singleton; + } + } + + /** + * Converts a {@link SerializableFunction} from {@code Iterable}s + * to {@code V}s into a simple {@link CombineFn} over {@code V}s. + * + *

    @deprecated Use {@link IterableCombineFn} or the more space efficient + * {@link BinaryCombineFn} instead (which avoids buffering values). + */ + @Deprecated + public static class SimpleCombineFn extends IterableCombineFn { + + /** + * Returns a {@code CombineFn} that uses the given + * {@code SerializableFunction} to combine values. + */ + @Deprecated + public static SimpleCombineFn of( + SerializableFunction, V> combiner) { + return new SimpleCombineFn<>(combiner); + } + + protected SimpleCombineFn(SerializableFunction, V> combiner) { + super(combiner, IterableCombineFn.DEFAULT_BUFFER_SIZE); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code PerKey} takes a + * {@code PCollection>}, groups it by key, applies a + * combining function to the {@code InputT} values associated with each + * key to produce a combined {@code OutputT} value, and returns a + * {@code PCollection>} representing a map from each + * distinct key of the input {@code PCollection} to the corresponding + * combined value. {@code InputT} and {@code OutputT} are often the same. + * + *

    This is a concise shorthand for an application of + * {@link GroupByKey} followed by an application of + * {@link GroupedValues Combine.GroupedValues}. See those + * operations for more details on how keys are compared for equality + * and on the default {@code Coder} for the output. + * + *

    Example of use: + *

     {@code
    +   * PCollection> salesRecords = ...;
    +   * PCollection> totalSalesPerPerson =
    +   *     salesRecords.apply(Combine.perKey(
    +   *         new Sum.SumDoubleFn()));
    +   * } 
    + * + *

    Each output element is in the window by which its corresponding input + * was grouped, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * as the input. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * @param the type of the values of the output {@code PCollection} + */ + public static class PerKey + extends PTransform>, PCollection>> { + + private final transient PerKeyCombineFn fn; + private final boolean fewKeys; + private final List> sideInputs; + + private PerKey( + PerKeyCombineFn fn, boolean fewKeys) { + this.fn = fn; + this.fewKeys = fewKeys; + this.sideInputs = ImmutableList.of(); + } + + private PerKey(String name, + PerKeyCombineFn fn, + boolean fewKeys, List> sideInputs) { + super(name); + this.fn = fn; + this.fewKeys = fewKeys; + this.sideInputs = sideInputs; + } + + private PerKey( + String name, PerKeyCombineFn fn, + boolean fewKeys) { + super(name); + this.fn = fn; + this.fewKeys = fewKeys; + this.sideInputs = ImmutableList.of(); + } + + /** + * Return a new {@code Globally} transform that's like this transform but with the + * specified name. Does not modify this transform. + */ + public PerKey named(String name) { + return new PerKey(name, fn, fewKeys); + } + + /** + * Returns a {@link PTransform} identical to this, but with the specified side inputs to use + * in {@link KeyedCombineFnWithContext}. + */ + public PerKey withSideInputs( + Iterable> sideInputs) { + Preconditions.checkState(fn instanceof RequiresContextInternal); + return new PerKey(name, fn, fewKeys, + ImmutableList.>copyOf(sideInputs)); + } + + /** + * If a single key has disproportionately many values, it may become a + * bottleneck, especially in streaming mode. This returns a new per-key + * combining transform that inserts an intermediate node to combine "hot" + * keys partially before performing the full combine. + * + * @param hotKeyFanout a function from keys to an integer N, where the key + * will be spread among N intermediate nodes for partial combining. + * If N is less than or equal to 1, this key will not be sent through an + * intermediate node. + */ + public PerKeyWithHotKeyFanout withHotKeyFanout( + SerializableFunction hotKeyFanout) { + return new PerKeyWithHotKeyFanout(name, fn, hotKeyFanout); + } + + /** + * Like {@link #withHotKeyFanout(SerializableFunction)}, but returning the given + * constant value for every key. + */ + public PerKeyWithHotKeyFanout withHotKeyFanout(final int hotKeyFanout) { + return new PerKeyWithHotKeyFanout(name, fn, + new SerializableFunction(){ + @Override + public Integer apply(K unused) { + return hotKeyFanout; + } + }); + } + + /** + * Returns the {@link PerKeyCombineFn} used by this Combine operation. + */ + public PerKeyCombineFn getFn() { + return fn; + } + + /** + * Returns the side inputs used by this Combine operation. + */ + public List> getSideInputs() { + return sideInputs; + } + + @Override + public PCollection> apply(PCollection> input) { + if (fn instanceof RequiresContextInternal) { + return input + .apply(GroupByKey.create(fewKeys)) + .apply(ParDo.of(new DoFn>, KV>>() { + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element()); + } + })) + .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); + } else { + return input + .apply(GroupByKey.create(fewKeys)) + .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); + } + } + } + + /** + * Like {@link PerKey}, but sharding the combining of hot keys. + */ + public static class PerKeyWithHotKeyFanout + extends PTransform>, PCollection>> { + + private final transient PerKeyCombineFn fn; + private final SerializableFunction hotKeyFanout; + + private PerKeyWithHotKeyFanout(String name, + PerKeyCombineFn fn, + SerializableFunction hotKeyFanout) { + super(name); + this.fn = fn; + this.hotKeyFanout = hotKeyFanout; + } + + @Override + public PCollection> apply(PCollection> input) { + return applyHelper(input); + } + + private PCollection> applyHelper(PCollection> input) { + + // Name the accumulator type. + @SuppressWarnings("unchecked") + final PerKeyCombineFn typedFn = + (PerKeyCombineFn) this.fn; + + if (!(input.getCoder() instanceof KvCoder)) { + throw new IllegalStateException( + "Expected input coder to be KvCoder, but was " + input.getCoder()); + } + + @SuppressWarnings("unchecked") + final KvCoder inputCoder = (KvCoder) input.getCoder(); + final Coder accumCoder; + + try { + accumCoder = typedFn.getAccumulatorCoder( + input.getPipeline().getCoderRegistry(), + inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Unable to determine accumulator coder.", e); + } + Coder> inputOrAccumCoder = + new InputOrAccum.InputOrAccumCoder( + inputCoder.getValueCoder(), accumCoder); + + // A CombineFn's mergeAccumulator can be applied in a tree-like fashon. + // Here we shard the key using an integer nonce, combine on that partial + // set of values, then drop the nonce and do a final combine of the + // aggregates. We do this by splitting the original CombineFn into two, + // on that does addInput + merge and another that does merge + extract. + PerKeyCombineFn, InputT, AccumT, AccumT> hotPreCombine; + PerKeyCombineFn, AccumT, OutputT> postCombine; + if (!(typedFn instanceof RequiresContextInternal)) { + final KeyedCombineFn keyedFn = + (KeyedCombineFn) typedFn; + hotPreCombine = + new KeyedCombineFn, InputT, AccumT, AccumT>() { + @Override + public AccumT createAccumulator(KV key) { + return keyedFn.createAccumulator(key.getKey()); + } + @Override + public AccumT addInput(KV key, AccumT accumulator, InputT value) { + return keyedFn.addInput(key.getKey(), accumulator, value); + } + @Override + public AccumT mergeAccumulators( + KV key, Iterable accumulators) { + return keyedFn.mergeAccumulators(key.getKey(), accumulators); + } + @Override + public AccumT compact(KV key, AccumT accumulator) { + return keyedFn.compact(key.getKey(), accumulator); + } + @Override + public AccumT extractOutput(KV key, AccumT accumulator) { + return accumulator; + } + @Override + @SuppressWarnings("unchecked") + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder> keyCoder, Coder inputCoder) + throws CannotProvideCoderException { + return accumCoder; + } + }; + postCombine = + new KeyedCombineFn, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key) { + return keyedFn.createAccumulator(key); + } + @Override + public AccumT addInput( + K key, AccumT accumulator, InputOrAccum value) { + if (value.accum == null) { + return keyedFn.addInput(key, accumulator, value.input); + } else { + return keyedFn.mergeAccumulators(key, ImmutableList.of(accumulator, value.accum)); + } + } + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators) { + return keyedFn.mergeAccumulators(key, accumulators); + } + @Override + public AccumT compact(K key, AccumT accumulator) { + return keyedFn.compact(key, accumulator); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator) { + return keyedFn.extractOutput(key, accumulator); + } + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, + Coder keyCoder, + Coder> accumulatorCoder) + throws CannotProvideCoderException { + return keyedFn.getDefaultOutputCoder( + registry, keyCoder, inputCoder.getValueCoder()); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder> inputCoder) + throws CannotProvideCoderException { + return accumCoder; + } + }; + } else { + final KeyedCombineFnWithContext keyedFnWithContext = + (KeyedCombineFnWithContext) typedFn; + hotPreCombine = + new KeyedCombineFnWithContext, InputT, AccumT, AccumT>() { + @Override + public AccumT createAccumulator(KV key, Context c) { + return keyedFnWithContext.createAccumulator(key.getKey(), c); + } + + @Override + public AccumT addInput( + KV key, AccumT accumulator, InputT value, Context c) { + return keyedFnWithContext.addInput(key.getKey(), accumulator, value, c); + } + + @Override + public AccumT mergeAccumulators( + KV key, Iterable accumulators, Context c) { + return keyedFnWithContext.mergeAccumulators(key.getKey(), accumulators, c); + } + + @Override + public AccumT compact(KV key, AccumT accumulator, Context c) { + return keyedFnWithContext.compact(key.getKey(), accumulator, c); + } + + @Override + public AccumT extractOutput(KV key, AccumT accumulator, Context c) { + return accumulator; + } + + @Override + @SuppressWarnings("unchecked") + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder> keyCoder, Coder inputCoder) + throws CannotProvideCoderException { + return accumCoder; + } + }; + postCombine = + new KeyedCombineFnWithContext, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, Context c) { + return keyedFnWithContext.createAccumulator(key, c); + } + @Override + public AccumT addInput( + K key, AccumT accumulator, InputOrAccum value, Context c) { + if (value.accum == null) { + return keyedFnWithContext.addInput(key, accumulator, value.input, c); + } else { + return keyedFnWithContext.mergeAccumulators( + key, ImmutableList.of(accumulator, value.accum), c); + } + } + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, Context c) { + return keyedFnWithContext.mergeAccumulators(key, accumulators, c); + } + @Override + public AccumT compact(K key, AccumT accumulator, Context c) { + return keyedFnWithContext.compact(key, accumulator, c); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator, Context c) { + return keyedFnWithContext.extractOutput(key, accumulator, c); + } + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, + Coder keyCoder, + Coder> accumulatorCoder) + throws CannotProvideCoderException { + return keyedFnWithContext.getDefaultOutputCoder( + registry, keyCoder, inputCoder.getValueCoder()); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder> inputCoder) + throws CannotProvideCoderException { + return accumCoder; + } + }; + } + + // Use the provided hotKeyFanout fn to split into "hot" and "cold" keys, + // augmenting the hot keys with a nonce. + final TupleTag, InputT>> hot = new TupleTag<>(); + final TupleTag> cold = new TupleTag<>(); + PCollectionTuple split = input.apply( + ParDo.named("AddNonce").of( + new DoFn, KV>() { + transient int counter; + @Override + public void startBundle(Context c) { + counter = ThreadLocalRandom.current().nextInt( + Integer.MAX_VALUE); + } + + @Override + public void processElement(ProcessContext c) { + KV kv = c.element(); + int spread = Math.max(1, hotKeyFanout.apply(kv.getKey())); + if (spread <= 1) { + c.output(kv); + } else { + int nonce = counter++ % spread; + c.sideOutput(hot, KV.of(KV.of(kv.getKey(), nonce), kv.getValue())); + } + } + }) + .withOutputTags(cold, TupleTagList.of(hot))); + + // The first level of combine should never use accumulating mode. + WindowingStrategy preCombineStrategy = input.getWindowingStrategy(); + if (preCombineStrategy.getMode() + == WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES) { + preCombineStrategy = preCombineStrategy.withMode( + WindowingStrategy.AccumulationMode.DISCARDING_FIRED_PANES); + } + + // Combine the hot and cold keys separately. + PCollection>> precombinedHot = split + .get(hot) + .setCoder(KvCoder.of(KvCoder.of(inputCoder.getKeyCoder(), VarIntCoder.of()), + inputCoder.getValueCoder())) + .setWindowingStrategyInternal(preCombineStrategy) + .apply("PreCombineHot", Combine.perKey(hotPreCombine)) + .apply(ParDo.named("StripNonce").of( + new DoFn, AccumT>, + KV>>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of( + c.element().getKey().getKey(), + InputOrAccum.accum(c.element().getValue()))); + } + })) + .setCoder(KvCoder.of(inputCoder.getKeyCoder(), inputOrAccumCoder)) + .apply(Window.>>remerge()) + .setWindowingStrategyInternal(input.getWindowingStrategy()); + PCollection>> preprocessedCold = split + .get(cold) + .setCoder(inputCoder) + .apply(ParDo.named("PrepareCold").of( + new DoFn, KV>>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element().getKey(), + InputOrAccum.input(c.element().getValue()))); + } + })) + .setCoder(KvCoder.of(inputCoder.getKeyCoder(), inputOrAccumCoder)); + + // Combine the union of the pre-processed hot and cold key results. + return PCollectionList.of(precombinedHot).and(preprocessedCold) + .apply(Flatten.>>pCollections()) + .apply("PostCombine", Combine.perKey(postCombine)); + } + + /** + * Used to store either an input or accumulator value, for flattening + * the hot and cold key paths. + */ + private static class InputOrAccum { + public final InputT input; + public final AccumT accum; + + private InputOrAccum(InputT input, AccumT aggr) { + this.input = input; + this.accum = aggr; + } + + public static InputOrAccum input(InputT input) { + return new InputOrAccum(input, null); + } + + public static InputOrAccum accum(AccumT aggr) { + return new InputOrAccum(null, aggr); + } + + private static class InputOrAccumCoder + extends StandardCoder> { + + private final Coder inputCoder; + private final Coder accumCoder; + + public InputOrAccumCoder(Coder inputCoder, Coder accumCoder) { + this.inputCoder = inputCoder; + this.accumCoder = accumCoder; + } + + @JsonCreator + @SuppressWarnings({"rawtypes", "unchecked"}) + public static InputOrAccumCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> elementCoders) { + return new InputOrAccumCoder(elementCoders.get(0), elementCoders.get(1)); + } + + @Override + public void encode( + InputOrAccum value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + if (value.input != null) { + outStream.write(0); + inputCoder.encode(value.input, outStream, context); + } else { + outStream.write(1); + accumCoder.encode(value.accum, outStream, context); + } + } + + @Override + public InputOrAccum decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + if (inStream.read() == 0) { + return InputOrAccum.input(inputCoder.decode(inStream, context)); + } else { + return InputOrAccum.accum(accumCoder.decode(inStream, context)); + } + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(inputCoder, accumCoder); + } + + @Override + public void verifyDeterministic() throws Coder.NonDeterministicException { + inputCoder.verifyDeterministic(); + accumCoder.verifyDeterministic(); + } + } + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * {@code GroupedValues} takes a + * {@code PCollection>>}, such as the result of + * {@link GroupByKey}, applies a specified + * {@link KeyedCombineFn KeyedCombineFn<K, InputT, AccumT, OutputT>} + * to each of the input {@code KV>} elements to + * produce a combined output {@code KV} element, and returns a + * {@code PCollection>} containing all the combined output + * elements. It is common for {@code InputT == OutputT}, but not required. + * Common combining functions include sums, mins, maxes, and averages + * of numbers, conjunctions and disjunctions of booleans, statistical + * aggregations, etc. + * + *

    Example of use: + *

     {@code
    +   * PCollection> pc = ...;
    +   * PCollection>> groupedByKey = pc.apply(
    +   *     new GroupByKey());
    +   * PCollection> sumByKey = groupedByKey.apply(
    +   *     Combine.groupedValues(
    +   *         new Sum.SumIntegerFn()));
    +   * } 
    + * + *

    See also {@link #perKey}/{@link PerKey Combine.PerKey}, which + * captures the common pattern of "combining by key" in a + * single easy-to-use {@code PTransform}. + * + *

    Combining for different keys can happen in parallel. Moreover, + * combining of the {@code Iterable} values associated a single + * key can happen in parallel, with different subsets of the values + * being combined separately, and their intermediate results combined + * further, in an arbitrary tree reduction pattern, until a single + * result value is produced for each key. + * + *

    By default, the {@code Coder} of the keys of the output + * {@code PCollection>} is that of the keys of the input + * {@code PCollection>}, and the {@code Coder} of the values + * of the output {@code PCollection>} is inferred from the + * concrete type of the {@code KeyedCombineFn}'s output + * type {@code OutputT}. + * + *

    Each output element has the same timestamp and is in the same window + * as its corresponding input element, and the output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + *

    See also {@link #globally}/{@link Globally Combine.Globally}, which + * combines all the values in a {@code PCollection} into a + * single value in a {@code PCollection}. + * + * @param type of input and output keys + * @param type of input values + * @param type of output values + */ + public static class GroupedValues + extends PTransform + >>, + PCollection>> { + + private final PerKeyCombineFn fn; + private final List> sideInputs; + + private GroupedValues(PerKeyCombineFn fn) { + this.fn = SerializableUtils.clone(fn); + this.sideInputs = ImmutableList.>of(); + } + + private GroupedValues( + PerKeyCombineFn fn, + List> sideInputs) { + this.fn = SerializableUtils.clone(fn); + this.sideInputs = sideInputs; + } + + public GroupedValues withSideInputs( + Iterable> sideInputs) { + return new GroupedValues<>(fn, ImmutableList.>copyOf(sideInputs)); + } + + /** + * Returns the KeyedCombineFn used by this Combine operation. + */ + public PerKeyCombineFn getFn() { + return fn; + } + + public List> getSideInputs() { + return sideInputs; + } + + @Override + public PCollection> apply( + PCollection>> input) { + + final PerKeyCombineFnRunner combineFnRunner = + PerKeyCombineFnRunners.create(fn); + PCollection> output = input.apply(ParDo.of( + new DoFn>, KV>() { + @Override + public void processElement(ProcessContext c) { + K key = c.element().getKey(); + + c.output(KV.of(key, combineFnRunner.apply(key, c.element().getValue(), c))); + } + }).withSideInputs(sideInputs)); + + try { + Coder> outputCoder = getDefaultOutputCoder(input); + output.setCoder(outputCoder); + } catch (CannotProvideCoderException exc) { + // let coder inference happen later, if it can + } + + return output; + } + + /** + * Returns the {@link CombineFn} bound to its coders. + * + *

    For internal use. + */ + public AppliedCombineFn getAppliedFn( + CoderRegistry registry, Coder>> inputCoder, + WindowingStrategy windowingStrategy) { + KvCoder kvCoder = getKvCoder(inputCoder); + return AppliedCombineFn.withInputCoder( + fn, registry, kvCoder, sideInputs, windowingStrategy); + } + + private KvCoder getKvCoder( + Coder>> inputCoder) { + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "Combine.GroupedValues requires its input to use KvCoder"); + } + @SuppressWarnings({"unchecked", "rawtypes"}) + KvCoder> kvCoder = (KvCoder) inputCoder; + Coder keyCoder = kvCoder.getKeyCoder(); + Coder> kvValueCoder = kvCoder.getValueCoder(); + if (!(kvValueCoder instanceof IterableCoder)) { + throw new IllegalStateException( + "Combine.GroupedValues requires its input values to use " + + "IterableCoder"); + } + @SuppressWarnings("unchecked") + IterableCoder inputValuesCoder = (IterableCoder) kvValueCoder; + Coder inputValueCoder = inputValuesCoder.getElemCoder(); + return KvCoder.of(keyCoder, inputValueCoder); + } + + @Override + public Coder> getDefaultOutputCoder( + PCollection>> input) + throws CannotProvideCoderException { + KvCoder kvCoder = getKvCoder(input.getCoder()); + @SuppressWarnings("unchecked") + Coder outputValueCoder = + ((PerKeyCombineFn) fn) + .getDefaultOutputCoder( + input.getPipeline().getCoderRegistry(), + kvCoder.getKeyCoder(), kvCoder.getValueCoder()); + return KvCoder.of(kvCoder.getKeyCoder(), outputValueCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFnBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFnBase.java new file mode 100644 index 000000000000..a0b06cf1fc29 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFnBase.java @@ -0,0 +1,283 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableMap; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; + +/** + * This class contains the shared interfaces and abstract classes for different types of combine + * functions. + * + *

    Users should not implement or extend them directly. + */ +public class CombineFnBase { + /** + * A {@code GloballyCombineFn} specifies how to combine a + * collection of input values of type {@code InputT} into a single + * output value of type {@code OutputT}. It does this via one or more + * intermediate mutable accumulator values of type {@code AccumT}. + * + *

    Do not implement this interface directly. + * Extends {@link CombineFn} and {@link CombineFnWithContext} instead. + * + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public interface GlobalCombineFn extends Serializable { + + /** + * Returns the {@code Coder} to use for accumulator {@code AccumT} + * values, or null if it is not able to be inferred. + * + *

    By default, uses the knowledge of the {@code Coder} being used + * for {@code InputT} values and the enclosing {@code Pipeline}'s + * {@code CoderRegistry} to try to infer the Coder for {@code AccumT} + * values. + * + *

    This is the Coder used to send data through a communication-intensive + * shuffle step, so a compact and efficient representation may have + * significant performance benefits. + */ + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException; + + /** + * Returns the {@code Coder} to use by default for output + * {@code OutputT} values, or null if it is not able to be inferred. + * + *

    By default, uses the knowledge of the {@code Coder} being + * used for input {@code InputT} values and the enclosing + * {@code Pipeline}'s {@code CoderRegistry} to try to infer the + * Coder for {@code OutputT} values. + */ + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException; + + /** + * Returns the error message for not supported default values in Combine.globally(). + */ + public String getIncompatibleGlobalWindowErrorMessage(); + + /** + * Returns the default value when there are no values added to the accumulator. + */ + public OutputT defaultValue(); + + /** + * Converts this {@code GloballyCombineFn} into an equivalent + * {@link PerKeyCombineFn} that ignores the keys passed to it and + * combines the values according to this {@code GloballyCombineFn}. + * + * @param the type of the (ignored) keys + */ + public PerKeyCombineFn asKeyedFn(); + } + + /** + * A {@code PerKeyCombineFn} specifies how to combine + * a collection of input values of type {@code InputT}, associated with + * a key of type {@code K}, into a single output value of type + * {@code OutputT}. It does this via one or more intermediate mutable + * accumulator values of type {@code AccumT}. + * + *

    Do not implement this interface directly. + * Extends {@link KeyedCombineFn} and {@link KeyedCombineFnWithContext} instead. + * + * @param type of keys + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + public interface PerKeyCombineFn extends Serializable { + /** + * Returns the {@code Coder} to use for accumulator {@code AccumT} + * values, or null if it is not able to be inferred. + * + *

    By default, uses the knowledge of the {@code Coder} being + * used for {@code K} keys and input {@code InputT} values and the + * enclosing {@code Pipeline}'s {@code CoderRegistry} to try to + * infer the Coder for {@code AccumT} values. + * + *

    This is the Coder used to send data through a communication-intensive + * shuffle step, so a compact and efficient representation may have + * significant performance benefits. + */ + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException; + + /** + * Returns the {@code Coder} to use by default for output + * {@code OutputT} values, or null if it is not able to be inferred. + * + *

    By default, uses the knowledge of the {@code Coder} being + * used for {@code K} keys and input {@code InputT} values and the + * enclosing {@code Pipeline}'s {@code CoderRegistry} to try to + * infer the Coder for {@code OutputT} values. + */ + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException; + + /** + * Returns the a regular {@link GlobalCombineFn} that operates on a specific key. + */ + public abstract GlobalCombineFn forKey( + final K key, final Coder keyCoder); + } + + /** + * An abstract {@link GlobalCombineFn} base class shared by + * {@link CombineFn} and {@link CombineFnWithContext}. + * + *

    Do not extend this class directly. + * Extends {@link CombineFn} and {@link CombineFnWithContext} instead. + * + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + abstract static class AbstractGlobalCombineFn + implements GlobalCombineFn, Serializable { + private static final String INCOMPATIBLE_GLOBAL_WINDOW_ERROR_MESSAGE = + "Default values are not supported in Combine.globally() if the output " + + "PCollection is not windowed by GlobalWindows. Instead, use " + + "Combine.globally().withoutDefaults() to output an empty PCollection if the input " + + "PCollection is empty, or Combine.globally().asSingletonView() to get the default " + + "output of the CombineFn if the input PCollection is empty."; + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return registry.getDefaultCoder(getClass(), AbstractGlobalCombineFn.class, + ImmutableMap.>of(getInputTVariable(), inputCoder), getAccumTVariable()); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return registry.getDefaultCoder(getClass(), AbstractGlobalCombineFn.class, + ImmutableMap.>of(getInputTVariable(), inputCoder, getAccumTVariable(), + this.getAccumulatorCoder(registry, inputCoder)), + getOutputTVariable()); + } + + @Override + public String getIncompatibleGlobalWindowErrorMessage() { + return INCOMPATIBLE_GLOBAL_WINDOW_ERROR_MESSAGE; + } + + /** + * Returns the {@link TypeVariable} of {@code InputT}. + */ + public TypeVariable getInputTVariable() { + return (TypeVariable) + new TypeDescriptor(AbstractGlobalCombineFn.class) {}.getType(); + } + + /** + * Returns the {@link TypeVariable} of {@code AccumT}. + */ + public TypeVariable getAccumTVariable() { + return (TypeVariable) + new TypeDescriptor(AbstractGlobalCombineFn.class) {}.getType(); + } + + /** + * Returns the {@link TypeVariable} of {@code OutputT}. + */ + public TypeVariable getOutputTVariable() { + return (TypeVariable) + new TypeDescriptor(AbstractGlobalCombineFn.class) {}.getType(); + } + } + + /** + * An abstract {@link PerKeyCombineFn} base class shared by + * {@link KeyedCombineFn} and {@link KeyedCombineFnWithContext}. + * + *

    Do not extends this class directly. + * Extends {@link KeyedCombineFn} and {@link KeyedCombineFnWithContext} instead. + * + * @param type of keys + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + abstract static class AbstractPerKeyCombineFn + implements PerKeyCombineFn { + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return registry.getDefaultCoder(getClass(), AbstractPerKeyCombineFn.class, + ImmutableMap.>of( + getKTypeVariable(), keyCoder, getInputTVariable(), inputCoder), + getAccumTVariable()); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return registry.getDefaultCoder(getClass(), AbstractPerKeyCombineFn.class, + ImmutableMap.>of(getKTypeVariable(), keyCoder, getInputTVariable(), + inputCoder, getAccumTVariable(), + this.getAccumulatorCoder(registry, keyCoder, inputCoder)), + getOutputTVariable()); + } + + /** + * Returns the {@link TypeVariable} of {@code K}. + */ + public TypeVariable getKTypeVariable() { + return (TypeVariable) new TypeDescriptor(AbstractPerKeyCombineFn.class) {}.getType(); + } + + /** + * Returns the {@link TypeVariable} of {@code InputT}. + */ + public TypeVariable getInputTVariable() { + return (TypeVariable) + new TypeDescriptor(AbstractPerKeyCombineFn.class) {}.getType(); + } + + /** + * Returns the {@link TypeVariable} of {@code AccumT}. + */ + public TypeVariable getAccumTVariable() { + return (TypeVariable) + new TypeDescriptor(AbstractPerKeyCombineFn.class) {}.getType(); + } + + /** + * Returns the {@link TypeVariable} of {@code OutputT}. + */ + public TypeVariable getOutputTVariable() { + return (TypeVariable) + new TypeDescriptor(AbstractPerKeyCombineFn.class) {}.getType(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineWithContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineWithContext.java new file mode 100644 index 000000000000..fdf56e33c04f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineWithContext.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +/** + * This class contains combine functions that have access to {@code PipelineOptions} and side inputs + * through {@code CombineWithContext.Context}. + * + *

    {@link CombineFnWithContext} and {@link KeyedCombineFnWithContext} are for users to extend. + */ +public class CombineWithContext { + + /** + * Information accessible to all methods in {@code CombineFnWithContext} + * and {@code KeyedCombineFnWithContext}. + */ + public abstract static class Context { + /** + * Returns the {@code PipelineOptions} specified with the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * invoking this {@code KeyedCombineFn}. + */ + public abstract PipelineOptions getPipelineOptions(); + + /** + * Returns the value of the side input for the window corresponding to the + * window of the main input element. + */ + public abstract T sideInput(PCollectionView view); + } + + /** + * An internal interface for signaling that a {@code GloballyCombineFn} + * or a {@code PerKeyCombineFn} needs to access {@code CombineWithContext.Context}. + * + *

    For internal use only. + */ + public interface RequiresContextInternal {} + + /** + * A combine function that has access to {@code PipelineOptions} and side inputs through + * {@code CombineWithContext.Context}. + * + * See the equivalent {@link CombineFn} for details about combine functions. + */ + public abstract static class CombineFnWithContext + extends CombineFnBase.AbstractGlobalCombineFn + implements RequiresContextInternal { + /** + * Returns a new, mutable accumulator value, representing the accumulation of zero input values. + * + *

    It is equivalent to {@link CombineFn#createAccumulator}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public abstract AccumT createAccumulator(Context c); + + /** + * Adds the given input value to the given accumulator, returning the + * new accumulator value. + * + *

    It is equivalent to {@link CombineFn#addInput}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public abstract AccumT addInput(AccumT accumulator, InputT input, Context c); + + /** + * Returns an accumulator representing the accumulation of all the + * input values accumulated in the merging accumulators. + * + *

    It is equivalent to {@link CombineFn#mergeAccumulators}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public abstract AccumT mergeAccumulators(Iterable accumulators, Context c); + + /** + * Returns the output value that is the result of combining all + * the input values represented by the given accumulator. + * + *

    It is equivalent to {@link CombineFn#extractOutput}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public abstract OutputT extractOutput(AccumT accumulator, Context c); + + /** + * Returns an accumulator that represents the same logical value as the + * input accumulator, but may have a more compact representation. + * + *

    It is equivalent to {@link CombineFn#compact}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public AccumT compact(AccumT accumulator, Context c) { + return accumulator; + } + + @Override + public OutputT defaultValue() { + throw new UnsupportedOperationException( + "Override this function to provide the default value."); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public KeyedCombineFnWithContext asKeyedFn() { + // The key, an object, is never even looked at. + return new KeyedCombineFnWithContext() { + @Override + public AccumT createAccumulator(K key, Context c) { + return CombineFnWithContext.this.createAccumulator(c); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT input, Context c) { + return CombineFnWithContext.this.addInput(accumulator, input, c); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, Context c) { + return CombineFnWithContext.this.mergeAccumulators(accumulators, c); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, Context c) { + return CombineFnWithContext.this.extractOutput(accumulator, c); + } + + @Override + public AccumT compact(K key, AccumT accumulator, Context c) { + return CombineFnWithContext.this.compact(accumulator, c); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return CombineFnWithContext.this.getAccumulatorCoder(registry, inputCoder); + } + + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return CombineFnWithContext.this.getDefaultOutputCoder(registry, inputCoder); + } + + @Override + public CombineFnWithContext forKey(K key, Coder keyCoder) { + return CombineFnWithContext.this; + } + }; + } + } + + /** + * A keyed combine function that has access to {@code PipelineOptions} and side inputs through + * {@code CombineWithContext.Context}. + * + * See the equivalent {@link KeyedCombineFn} for details about keyed combine functions. + */ + public abstract static class KeyedCombineFnWithContext + extends CombineFnBase.AbstractPerKeyCombineFn + implements RequiresContextInternal { + /** + * Returns a new, mutable accumulator value representing the accumulation of zero input values. + * + *

    It is equivalent to {@link KeyedCombineFn#createAccumulator}, + * but it has additional access to {@code CombineWithContext.Context}. + */ + public abstract AccumT createAccumulator(K key, Context c); + + /** + * Adds the given input value to the given accumulator, returning the new accumulator value. + * + *

    It is equivalent to {@link KeyedCombineFn#addInput}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public abstract AccumT addInput(K key, AccumT accumulator, InputT value, Context c); + + /** + * Returns an accumulator representing the accumulation of all the + * input values accumulated in the merging accumulators. + * + *

    It is equivalent to {@link KeyedCombineFn#mergeAccumulators}, + * but it has additional access to {@code CombineWithContext.Context}.. + */ + public abstract AccumT mergeAccumulators(K key, Iterable accumulators, Context c); + + /** + * Returns the output value that is the result of combining all + * the input values represented by the given accumulator. + * + *

    It is equivalent to {@link KeyedCombineFn#extractOutput}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public abstract OutputT extractOutput(K key, AccumT accumulator, Context c); + + /** + * Returns an accumulator that represents the same logical value as the + * input accumulator, but may have a more compact representation. + * + *

    It is equivalent to {@link KeyedCombineFn#compact}, but it has additional access to + * {@code CombineWithContext.Context}. + */ + public AccumT compact(K key, AccumT accumulator, Context c) { + return accumulator; + } + + /** + * Applies this {@code KeyedCombineFnWithContext} to a key and a collection + * of input values to produce a combined output value. + */ + public OutputT apply(K key, Iterable inputs, Context c) { + AccumT accum = createAccumulator(key, c); + for (InputT input : inputs) { + accum = addInput(key, accum, input, c); + } + return extractOutput(key, accum, c); + } + + @Override + public CombineFnWithContext forKey( + final K key, final Coder keyCoder) { + return new CombineFnWithContext() { + @Override + public AccumT createAccumulator(Context c) { + return KeyedCombineFnWithContext.this.createAccumulator(key, c); + } + + @Override + public AccumT addInput(AccumT accumulator, InputT input, Context c) { + return KeyedCombineFnWithContext.this.addInput(key, accumulator, input, c); + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators, Context c) { + return KeyedCombineFnWithContext.this.mergeAccumulators(key, accumulators, c); + } + + @Override + public OutputT extractOutput(AccumT accumulator, Context c) { + return KeyedCombineFnWithContext.this.extractOutput(key, accumulator, c); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return KeyedCombineFnWithContext.this.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) throws CannotProvideCoderException { + return KeyedCombineFnWithContext.this.getDefaultOutputCoder( + registry, keyCoder, inputCoder); + } + }; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java new file mode 100644 index 000000000000..ffa11d13a3c9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code PTransorm}s to count the elements in a {@link PCollection}. + * + *

    {@link Count#perElement()} can be used to count the number of occurrences of each + * distinct element in the PCollection, {@link Count#perKey()} can be used to count the + * number of values per key, and {@link Count#globally()} can be used to count the total + * number of elements in a PCollection. + */ +public class Count { + private Count() { + // do not instantiate + } + + /** + * Returns a {@link Combine.Globally} {@link PTransform} that counts the number of elements in + * its input {@link PCollection}. + */ + public static Combine.Globally globally() { + return Combine.globally(new CountFn()).named("Count.Globally"); + } + + /** + * Returns a {@link Combine.PerKey} {@link PTransform} that counts the number of elements + * associated with each key of its input {@link PCollection}. + */ + public static Combine.PerKey perKey() { + return Combine.perKey(new CountFn()).named("Count.PerKey"); + } + + /** + * Returns a {@link PerElement Count.PerElement} {@link PTransform} that counts the number of + * occurrences of each element in its input {@link PCollection}. + * + *

    See {@link PerElement Count.PerElement} for more details. + */ + public static PerElement perElement() { + return new PerElement<>(); + } + + /** + * {@code Count.PerElement} takes a {@code PCollection} and returns a + * {@code PCollection>} representing a map from each distinct element of the input + * {@code PCollection} to the number of times that element occurs in the input. Each key in the + * output {@code PCollection} is unique. + * + *

    This transform compares two values of type {@code T} by first encoding each element using + * the input {@code PCollection}'s {@code Coder}, then comparing the encoded bytes. Because of + * this, the input coder must be deterministic. + * (See {@link com.google.cloud.dataflow.sdk.coders.Coder#verifyDeterministic()} for more detail). + * Performing the comparison in this manner admits efficient parallel evaluation. + * + *

    By default, the {@code Coder} of the keys of the output {@code PCollection} is the same as + * the {@code Coder} of the elements of the input {@code PCollection}. + * + *

    Example of use: + *

     {@code
    +   * PCollection words = ...;
    +   * PCollection> wordCounts =
    +   *     words.apply(Count.perElement());
    +   * } 
    + * + * @param the type of the elements of the input {@code PCollection}, and the type of the keys + * of the output {@code PCollection} + */ + public static class PerElement + extends PTransform, PCollection>> { + + public PerElement() { } + + @Override + public PCollection> apply(PCollection input) { + return + input + .apply(ParDo.named("Init").of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), (Void) null)); + } + })) + .apply(Count.perKey()); + } + } + + /** + * A {@link CombineFn} that counts elements. + */ + private static class CountFn extends CombineFn { + + @Override + public Long createAccumulator() { + return 0L; + } + + @Override + public Long addInput(Long accumulator, T input) { + return accumulator + 1; + } + + @Override + public Long mergeAccumulators(Iterable accumulators) { + long result = 0L; + for (Long accum : accumulators) { + result += accum; + } + return result; + } + + @Override + public Long extractOutput(Long accumulator) { + return accumulator; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Create.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Create.java new file mode 100644 index 000000000000..a74e5bff7f65 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Create.java @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TimestampedValue.TimestampedValueCoder; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Function; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * {@code Create} takes a collection of elements of type {@code T} + * known when the pipeline is constructed and returns a + * {@code PCollection} containing the elements. + * + *

    Example of use: + *

     {@code
    + * Pipeline p = ...;
    + *
    + * PCollection pc = p.apply(Create.of(3, 4, 5).withCoder(BigEndianIntegerCoder.of()));
    + *
    + * Map map = ...;
    + * PCollection> pt =
    + *     p.apply(Create.of(map)
    + *      .withCoder(KvCoder.of(StringUtf8Coder.of(),
    + *                            BigEndianIntegerCoder.of())));
    + * } 
    + * + *

    {@code Create} can automatically determine the {@code Coder} to use + * if all elements have the same run-time class, and a default coder is registered for that + * class. See {@link CoderRegistry} for details on how defaults are determined. + * + *

    If a coder can not be inferred, {@link Create.Values#withCoder} must be called + * explicitly to set the encoding of the resulting + * {@code PCollection}. + * + *

    A good use for {@code Create} is when a {@code PCollection} + * needs to be created without dependencies on files or other external + * entities. This is especially useful during testing. + * + *

    Caveat: {@code Create} only supports small in-memory datasets, + * particularly when submitting jobs to the Google Cloud Dataflow + * service. + * + * @param the type of the elements of the resulting {@code PCollection} + */ +public class Create { + /** + * Returns a new {@code Create.Values} transform that produces a + * {@link PCollection} containing elements of the provided + * {@code Iterable}. + * + *

    The argument should not be modified after this is called. + * + *

    The elements of the output {@link PCollection} will have a timestamp of negative infinity, + * see {@link Create#timestamped} for a way of creating a {@code PCollection} with timestamped + * elements. + * + *

    By default, {@code Create.Values} can automatically determine the {@code Coder} to use + * if all elements have the same non-parameterized run-time class, and a default coder is + * registered for that class. See {@link CoderRegistry} for details on how defaults are + * determined. + * Otherwise, use {@link Create.Values#withCoder} to set the coder explicitly. + */ + public static Values of(Iterable elems) { + return new Values<>(elems, Optional.>absent()); + } + + /** + * Returns a new {@code Create.Values} transform that produces a + * {@link PCollection} containing the specified elements. + * + *

    The elements will have a timestamp of negative infinity, see + * {@link Create#timestamped} for a way of creating a {@code PCollection} + * with timestamped elements. + * + *

    The arguments should not be modified after this is called. + * + *

    By default, {@code Create.Values} can automatically determine the {@code Coder} to use + * if all elements have the same non-parameterized run-time class, and a default coder is + * registered for that class. See {@link CoderRegistry} for details on how defaults are + * determined. + * Otherwise, use {@link Create.Values#withCoder} to set the coder explicitly. + */ + @SafeVarargs + public static Values of(T... elems) { + return of(Arrays.asList(elems)); + } + + /** + * Returns a new {@code Create.Values} transform that produces a + * {@link PCollection} of {@link KV}s corresponding to the keys and + * values of the specified {@code Map}. + * + *

    The elements will have a timestamp of negative infinity, see + * {@link Create#timestamped} for a way of creating a {@code PCollection} + * with timestamped elements. + * + *

    By default, {@code Create.Values} can automatically determine the {@code Coder} to use + * if all elements have the same non-parameterized run-time class, and a default coder is + * registered for that class. See {@link CoderRegistry} for details on how defaults are + * determined. + * Otherwise, use {@link Create.Values#withCoder} to set the coder explicitly. + */ + public static Values> of(Map elems) { + List> kvs = new ArrayList<>(elems.size()); + for (Map.Entry entry : elems.entrySet()) { + kvs.add(KV.of(entry.getKey(), entry.getValue())); + } + return of(kvs); + } + + /** + * Returns a new {@link Create.TimestampedValues} transform that produces a + * {@link PCollection} containing the elements of the provided {@code Iterable} + * with the specified timestamps. + * + *

    The argument should not be modified after this is called. + * + *

    By default, {@code Create.TimestampedValues} can automatically determine the {@code Coder} + * to use if all elements have the same non-parameterized run-time class, and a default coder is + * registered for that class. See {@link CoderRegistry} for details on how defaults are + * determined. + * Otherwise, use {@link Create.TimestampedValues#withCoder} to set the coder explicitly. + */ + public static TimestampedValues timestamped(Iterable> elems) { + return new TimestampedValues<>(elems, Optional.>absent()); + } + + /** + * Returns a new {@link Create.TimestampedValues} transform that produces a {@link PCollection} + * containing the specified elements with the specified timestamps. + * + *

    The arguments should not be modified after this is called. + */ + @SafeVarargs + public static TimestampedValues timestamped( + @SuppressWarnings("unchecked") TimestampedValue... elems) { + return timestamped(Arrays.asList(elems)); + } + + /** + * Returns a new root transform that produces a {@link PCollection} containing + * the specified elements with the specified timestamps. + * + *

    The arguments should not be modified after this is called. + * + *

    By default, {@code Create.TimestampedValues} can automatically determine the {@code Coder} + * to use if all elements have the same non-parameterized run-time class, and a default coder + * is registered for that class. See {@link CoderRegistry} for details on how defaults are + * determined. + * Otherwise, use {@link Create.TimestampedValues#withCoder} to set the coder explicitly. + + * @throws IllegalArgumentException if there are a different number of values + * and timestamps + */ + public static TimestampedValues timestamped( + Iterable values, Iterable timestamps) { + List> elems = new ArrayList<>(); + Iterator valueIter = values.iterator(); + Iterator timestampIter = timestamps.iterator(); + while (valueIter.hasNext() && timestampIter.hasNext()) { + elems.add(TimestampedValue.of(valueIter.next(), new Instant(timestampIter.next()))); + } + Preconditions.checkArgument( + !valueIter.hasNext() && !timestampIter.hasNext(), + "Expect sizes of values and timestamps are same."); + return timestamped(elems); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code PTransform} that creates a {@code PCollection} from a set of in-memory objects. + */ + public static class Values extends PTransform> { + /** + * Returns a {@link Create.Values} PTransform like this one that uses the given + * {@code Coder} to decode each of the objects into a + * value of type {@code T}. + * + *

    By default, {@code Create.Values} can automatically determine the {@code Coder} to use + * if all elements have the same non-parameterized run-time class, and a default coder is + * registered for that class. See {@link CoderRegistry} for details on how defaults are + * determined. + * + *

    Note that for {@link Create.Values} with no elements, the {@link VoidCoder} is used. + */ + public Values withCoder(Coder coder) { + return new Values<>(elems, Optional.of(coder)); + } + + public Iterable getElements() { + return elems; + } + + @Override + public PCollection apply(PInput input) { + try { + Coder coder = getDefaultOutputCoder(input); + return PCollection + .createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + IsBounded.BOUNDED) + .setCoder(coder); + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException("Unable to infer a coder and no Coder was specified. " + + "Please set a coder by invoking Create.withCoder() explicitly.", e); + } + } + + @Override + public Coder getDefaultOutputCoder(PInput input) throws CannotProvideCoderException { + if (coder.isPresent()) { + return coder.get(); + } + // First try to deduce a coder using the types of the elements. + Class elementClazz = Void.class; + for (T elem : elems) { + if (elem == null) { + continue; + } + Class clazz = elem.getClass(); + if (elementClazz.equals(Void.class)) { + elementClazz = clazz; + } else if (!elementClazz.equals(clazz)) { + // Elements are not the same type, require a user-specified coder. + throw new CannotProvideCoderException( + "Cannot provide coder for Create: The elements are not all of the same class."); + } + } + + if (elementClazz.getTypeParameters().length == 0) { + try { + @SuppressWarnings("unchecked") // elementClazz is a wildcard type + Coder coder = (Coder) input.getPipeline().getCoderRegistry() + .getDefaultCoder(TypeDescriptor.of(elementClazz)); + return coder; + } catch (CannotProvideCoderException exc) { + // let the next stage try + } + } + + // If that fails, try to deduce a coder using the elements themselves + Optional> coder = Optional.absent(); + for (T elem : elems) { + Coder c = input.getPipeline().getCoderRegistry().getDefaultCoder(elem); + if (!coder.isPresent()) { + coder = Optional.of(c); + } else if (!Objects.equals(c, coder.get())) { + throw new CannotProvideCoderException( + "Cannot provide coder for elements of " + Create.class.getSimpleName() + ":" + + " For their common class, no coder could be provided." + + " Based on their values, they do not all default to the same Coder."); + } + } + + if (!coder.isPresent()) { + throw new CannotProvideCoderException("Unable to infer a coder. Please register " + + "a coder for "); + } + return coder.get(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** The elements of the resulting PCollection. */ + private final transient Iterable elems; + + /** The coder used to encode the values to and from a binary representation. */ + private final transient Optional> coder; + + /** + * Constructs a {@code Create.Values} transform that produces a + * {@link PCollection} containing the specified elements. + * + *

    The arguments should not be modified after this is called. + */ + private Values(Iterable elems, Optional> coder) { + this.elems = elems; + this.coder = coder; + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code PTransform} that creates a {@code PCollection} whose elements have + * associated timestamps. + */ + public static class TimestampedValues extends Values { + /** + * Returns a {@link Create.TimestampedValues} PTransform like this one that uses the given + * {@code Coder} to decode each of the objects into a + * value of type {@code T}. + * + *

    By default, {@code Create.TimestampedValues} can automatically determine the + * {@code Coder} to use if all elements have the same non-parameterized run-time class, + * and a default coder is registered for that class. See {@link CoderRegistry} for details + * on how defaults are determined. + * + *

    Note that for {@link Create.TimestampedValues with no elements}, the {@link VoidCoder} + * is used. + */ + @Override + public TimestampedValues withCoder(Coder coder) { + return new TimestampedValues<>(elems, Optional.>of(coder)); + } + + @Override + public PCollection apply(PInput input) { + try { + Coder coder = getDefaultOutputCoder(input); + PCollection> intermediate = Pipeline.applyTransform(input, + Create.of(elems).withCoder(TimestampedValueCoder.of(coder))); + + PCollection output = intermediate.apply(ParDo.of(new ConvertTimestamps())); + output.setCoder(coder); + return output; + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException("Unable to infer a coder and no Coder was specified. " + + "Please set a coder by invoking CreateTimestamped.withCoder() explicitly.", e); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** The timestamped elements of the resulting PCollection. */ + private final transient Iterable> elems; + + private TimestampedValues(Iterable> elems, + Optional> coder) { + super( + Iterables.transform(elems, new Function, T>() { + @Override + public T apply(TimestampedValue input) { + return input.getValue(); + } + }), coder); + this.elems = elems; + } + + private static class ConvertTimestamps extends DoFn, T> { + @Override + public void processElement(ProcessContext c) { + c.outputWithTimestamp(c.element().getValue(), c.element().getTimestamp()); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + static { + registerDefaultTransformEvaluator(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static void registerDefaultTransformEvaluator() { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Create.Values.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Create.Values transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + Create.Values transform, + DirectPipelineRunner.EvaluationContext context) { + // Convert the Iterable of elems into a List of elems. + List listElems; + if (transform.elems instanceof Collection) { + Collection collectionElems = (Collection) transform.elems; + listElems = new ArrayList<>(collectionElems.size()); + } else { + listElems = new ArrayList<>(); + } + for (T elem : transform.elems) { + listElems.add( + context.ensureElementEncodable(context.getOutput(transform), elem)); + } + context.setPCollection(context.getOutput(transform), listElems); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java new file mode 100644 index 000000000000..af06cc87961f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java @@ -0,0 +1,552 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.MoreObjects; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +/** + * The argument to {@link ParDo} providing the code to use to process + * elements of the input + * {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * + *

    See {@link ParDo} for more explanation, examples of use, and + * discussion of constraints on {@code DoFn}s, including their + * serializability, lack of access to global shared mutable state, + * requirements for failure tolerance, and benefits of optimization. + * + *

    {@code DoFn}s can be tested in the context of a particular + * {@code Pipeline} by running that {@code Pipeline} on sample input + * and then checking its output. Unit testing of a {@code DoFn}, + * separately from any {@code ParDo} transform or {@code Pipeline}, + * can be done via the {@link DoFnTester} harness. + * + *

    {@link DoFnWithContext} (currently experimental) offers an alternative + * mechanism for accessing {@link ProcessContext#window()} without the need + * to implement {@link RequiresWindowAccess}. + * + *

    See also {@link #processElement} for details on implementing the transformation + * from {@code InputT} to {@code OutputT}. + * + * @param the type of the (main) input elements + * @param the type of the (main) output elements + */ +public abstract class DoFn implements Serializable { + + /** + * Information accessible to all methods in this {@code DoFn}. + * Used primarily to output elements. + */ + public abstract class Context { + + /** + * Returns the {@code PipelineOptions} specified with the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * invoking this {@code DoFn}. The {@code PipelineOptions} will + * be the default running via {@link DoFnTester}. + */ + public abstract PipelineOptions getPipelineOptions(); + + /** + * Adds the given element to the main output {@code PCollection}. + * + *

    Once passed to {@code output} the element should be considered + * immutable and not be modified in any way. It may be cached or retained + * by the Dataflow runtime or later steps in the pipeline, or used in + * other unspecified ways. + * + *

    If invoked from {@link DoFn#processElement processElement}, the output + * element will have the same timestamp and be in the same windows + * as the input element passed to {@link DoFn#processElement processElement}. + * + *

    If invoked from {@link #startBundle startBundle} or {@link #finishBundle finishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element. The output element + * will have a timestamp of negative infinity. + */ + public abstract void output(OutputT output); + + /** + * Adds the given element to the main output {@code PCollection}, + * with the given timestamp. + * + *

    Once passed to {@code outputWithTimestamp} the element should not be + * modified in any way. + * + *

    If invoked from {@link DoFn#processElement processElement}, the timestamp + * must not be older than the input element's timestamp minus + * {@link DoFn#getAllowedTimestampSkew getAllowedTimestampSkew}. The output element will + * be in the same windows as the input element. + * + *

    If invoked from {@link #startBundle startBundle} or {@link #finishBundle finishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element except for the + * timestamp. + */ + public abstract void outputWithTimestamp(OutputT output, Instant timestamp); + + /** + * Adds the given element to the side output {@code PCollection} with the + * given tag. + * + *

    Once passed to {@code sideOutput} the element should not be modified + * in any way. + * + *

    The caller of {@code ParDo} uses {@link ParDo#withOutputTags withOutputTags} to + * specify the tags of side outputs that it consumes. Non-consumed side + * outputs, e.g., outputs for monitoring purposes only, don't necessarily + * need to be specified. + * + *

    The output element will have the same timestamp and be in the same + * windows as the input element passed to {@link DoFn#processElement processElement}. + * + *

    If invoked from {@link #startBundle startBundle} or {@link #finishBundle finishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element. The output element + * will have a timestamp of negative infinity. + * + * @see ParDo#withOutputTags + */ + public abstract void sideOutput(TupleTag tag, T output); + + /** + * Adds the given element to the specified side output {@code PCollection}, + * with the given timestamp. + * + *

    Once passed to {@code sideOutputWithTimestamp} the element should not be + * modified in any way. + * + *

    If invoked from {@link DoFn#processElement processElement}, the timestamp + * must not be older than the input element's timestamp minus + * {@link DoFn#getAllowedTimestampSkew getAllowedTimestampSkew}. The output element will + * be in the same windows as the input element. + * + *

    If invoked from {@link #startBundle startBundle} or {@link #finishBundle finishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element except for the + * timestamp. + * + * @see ParDo#withOutputTags + */ + public abstract void sideOutputWithTimestamp( + TupleTag tag, T output, Instant timestamp); + + /** + * Creates an {@link Aggregator} in the {@link DoFn} context with the + * specified name and aggregation logic specified by {@link CombineFn}. + * + *

    For internal use only. + * + * @param name the name of the aggregator + * @param combiner the {@link CombineFn} to use in the aggregator + * @return an aggregator for the provided name and {@link CombineFn} in this + * context + */ + @Experimental(Kind.AGGREGATOR) + protected abstract Aggregator + createAggregatorInternal(String name, CombineFn combiner); + + /** + * Sets up {@link Aggregator}s created by the {@link DoFn} so they are + * usable within this context. + * + *

    This method should be called by runners before {@link DoFn#startBundle} + * is executed. + */ + @Experimental(Kind.AGGREGATOR) + protected final void setupDelegateAggregators() { + for (DelegatingAggregator aggregator : aggregators.values()) { + setupDelegateAggregator(aggregator); + } + + aggregatorsAreFinal = true; + } + + private final void setupDelegateAggregator( + DelegatingAggregator aggregator) { + + Aggregator delegate = createAggregatorInternal( + aggregator.getName(), aggregator.getCombineFn()); + + aggregator.setDelegate(delegate); + } + } + + /** + * Information accessible when running {@link DoFn#processElement}. + */ + public abstract class ProcessContext extends Context { + + /** + * Returns the input element to be processed. + * + *

    The element should be considered immutable. The Dataflow runtime will not mutate the + * element, so it is safe to cache, etc. The element should not be mutated by any of the + * {@link DoFn} methods, because it may be cached elsewhere, retained by the Dataflow runtime, + * or used in other unspecified ways. + */ + public abstract InputT element(); + + /** + * Returns the value of the side input for the window corresponding to the + * window of the main input element. + * + *

    See + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn#getSideInputWindow} + * for how this corresponding window is determined. + * + * @throws IllegalArgumentException if this is not a side input + * @see ParDo#withSideInputs + */ + public abstract T sideInput(PCollectionView view); + + /** + * Returns the timestamp of the input element. + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + */ + public abstract Instant timestamp(); + + /** + * Returns the window into which the input element has been assigned. + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + * + * @throws UnsupportedOperationException if this {@link DoFn} does + * not implement {@link RequiresWindowAccess}. + */ + public abstract BoundedWindow window(); + + /** + * Returns information about the pane within this window into which the + * input element has been assigned. + * + *

    Generally all data is in a single, uninteresting pane unless custom + * triggering and/or late data has been explicitly requested. + * See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + */ + public abstract PaneInfo pane(); + + /** + * Returns the process context to use for implementing windowing. + */ + @Experimental + public abstract WindowingInternals windowingInternals(); + } + + /** + * Returns the allowed timestamp skew duration, which is the maximum + * duration that timestamps can be shifted backward in + * {@link DoFn.Context#outputWithTimestamp}. + * + *

    The default value is {@code Duration.ZERO}, in which case + * timestamps can only be shifted forward to future. For infinite + * skew, return {@code Duration.millis(Long.MAX_VALUE)}. + * + *

    Note that producing an element whose timestamp is less than the + * current timestamp may result in late data, i.e. returning a non-zero + * value here does not impact watermark calculations used for firing + * windows. + * + * @deprecated does not interact well with the watermark. + */ + @Deprecated + public Duration getAllowedTimestampSkew() { + return Duration.ZERO; + } + + /** + * Interface for signaling that a {@link DoFn} needs to access the window the + * element is being processed in, via {@link DoFn.ProcessContext#window}. + */ + @Experimental + public interface RequiresWindowAccess {} + + public DoFn() { + this(new HashMap>()); + } + + DoFn(Map> aggregators) { + this.aggregators = aggregators; + } + + ///////////////////////////////////////////////////////////////////////////// + + private final Map> aggregators; + + /** + * Protects aggregators from being created after initialization. + */ + private boolean aggregatorsAreFinal; + + /** + * Prepares this {@code DoFn} instance for processing a batch of elements. + * + *

    By default, does nothing. + */ + public void startBundle(Context c) throws Exception { + } + + /** + * Processes one input element. + * + *

    The current element of the input {@code PCollection} is returned by + * {@link ProcessContext#element() c.element()}. It should be considered immutable. The Dataflow + * runtime will not mutate the element, so it is safe to cache, etc. The element should not be + * mutated by any of the {@link DoFn} methods, because it may be cached elsewhere, retained by the + * Dataflow runtime, or used in other unspecified ways. + * + *

    A value is added to the main output {@code PCollection} by {@link ProcessContext#output}. + * Once passed to {@code output} the element should be considered immutable and not be modified in + * any way. It may be cached elsewhere, retained by the Dataflow runtime, or used in other + * unspecified ways. + * + * @see ProcessContext + */ + public abstract void processElement(ProcessContext c) throws Exception; + + /** + * Finishes processing this batch of elements. + * + *

    By default, does nothing. + */ + public void finishBundle(Context c) throws Exception { + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the input type of this {@code DoFn} instance's most-derived + * class. + * + *

    See {@link #getOutputTypeDescriptor} for more discussion. + */ + protected TypeDescriptor getInputTypeDescriptor() { + return new TypeDescriptor(getClass()) {}; + } + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the output type of this {@code DoFn} instance's + * most-derived class. + * + *

    In the normal case of a concrete {@code DoFn} subclass with + * no generic type parameters of its own (including anonymous inner + * classes), this will be a complete non-generic type, which is good + * for choosing a default output {@code Coder} for the output + * {@code PCollection}. + */ + protected TypeDescriptor getOutputTypeDescriptor() { + return new TypeDescriptor(getClass()) {}; + } + + /** + * Returns an {@link Aggregator} with aggregation logic specified by the + * {@link CombineFn} argument. The name provided must be unique across + * {@link Aggregator}s created within the DoFn. Aggregators can only be created + * during pipeline construction. + * + * @param name the name of the aggregator + * @param combiner the {@link CombineFn} to use in the aggregator + * @return an aggregator for the provided name and combiner in the scope of + * this DoFn + * @throws NullPointerException if the name or combiner is null + * @throws IllegalArgumentException if the given name collides with another + * aggregator in this scope + * @throws IllegalStateException if called during pipeline processing. + */ + protected final Aggregator + createAggregator(String name, CombineFn combiner) { + checkNotNull(name, "name cannot be null"); + checkNotNull(combiner, "combiner cannot be null"); + checkArgument(!aggregators.containsKey(name), + "Cannot create aggregator with name %s." + + " An Aggregator with that name already exists within this scope.", + name); + + checkState(!aggregatorsAreFinal, "Cannot create an aggregator during DoFn processing." + + " Aggregators should be registered during pipeline construction."); + + DelegatingAggregator aggregator = + new DelegatingAggregator<>(name, combiner); + aggregators.put(name, aggregator); + return aggregator; + } + + /** + * Returns an {@link Aggregator} with the aggregation logic specified by the + * {@link SerializableFunction} argument. The name provided must be unique + * across {@link Aggregator}s created within the DoFn. Aggregators can only be + * created during pipeline construction. + * + * @param name the name of the aggregator + * @param combiner the {@link SerializableFunction} to use in the aggregator + * @return an aggregator for the provided name and combiner in the scope of + * this DoFn + * @throws NullPointerException if the name or combiner is null + * @throws IllegalArgumentException if the given name collides with another + * aggregator in this scope + * @throws IllegalStateException if called during pipeline processing. + */ + protected final Aggregator createAggregator(String name, + SerializableFunction, AggInputT> combiner) { + checkNotNull(combiner, "combiner cannot be null."); + return createAggregator(name, Combine.IterableCombineFn.of(combiner)); + } + + /** + * Returns the {@link Aggregator Aggregators} created by this {@code DoFn}. + */ + Collection> getAggregators() { + return Collections.>unmodifiableCollection(aggregators.values()); + } + + /** + * An {@link Aggregator} that delegates calls to addValue to another + * aggregator. + * + * @param the type of input element + * @param the type of output element + */ + static class DelegatingAggregator implements + Aggregator, Serializable { + private final UUID id; + + private final String name; + + private final CombineFn combineFn; + + private Aggregator delegate; + + public DelegatingAggregator(String name, + CombineFn combiner) { + this.id = UUID.randomUUID(); + this.name = checkNotNull(name, "name cannot be null"); + // Safe contravariant cast + @SuppressWarnings("unchecked") + CombineFn specificCombiner = + (CombineFn) checkNotNull(combiner, "combineFn cannot be null"); + this.combineFn = specificCombiner; + } + + @Override + public void addValue(AggInputT value) { + if (delegate == null) { + throw new IllegalStateException( + "addValue cannot be called on Aggregator outside of the execution of a DoFn."); + } else { + delegate.addValue(value); + } + } + + @Override + public String getName() { + return name; + } + + @Override + public CombineFn getCombineFn() { + return combineFn; + } + + /** + * Sets the current delegate of the Aggregator. + * + * @param delegate the delegate to set in this aggregator + */ + public void setDelegate(Aggregator delegate) { + this.delegate = delegate; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("name", name) + .add("combineFn", combineFn) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(id, name, combineFn.getClass()); + } + + /** + * Indicates whether some other object is "equal to" this one. + * + *

    {@code DelegatingAggregator} instances are equal if they have the same name, their + * CombineFns are the same class, and they have identical IDs. + */ + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (o == null) { + return false; + } + if (o instanceof DelegatingAggregator) { + DelegatingAggregator that = (DelegatingAggregator) o; + return Objects.equals(this.id, that.id) + && Objects.equals(this.name, that.name) + && Objects.equals(this.combineFn.getClass(), that.combineFn.getClass()); + } + return false; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnReflector.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnReflector.java new file mode 100644 index 000000000000..1bb05fb3405b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnReflector.java @@ -0,0 +1,667 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.ExtraContextFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.FinishBundle; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.ProcessElement; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.StartBundle; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.util.common.ReflectHelpers; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Throwables; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableMap; +import com.google.common.reflect.TypeParameter; +import com.google.common.reflect.TypeToken; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Utility implementing the necessary reflection for working with {@link DoFnWithContext}s. + */ +public abstract class DoFnReflector { + + private interface ExtraContextInfo { + /** + * Create an instance of the given instance using the instance factory. + */ + Object createInstance( + DoFnWithContext.ExtraContextFactory factory); + + /** + * Create the type token for the given type, filling in the generics. + */ + TypeToken tokenFor(TypeToken in, TypeToken out); + } + + private static final Map, ExtraContextInfo> EXTRA_CONTEXTS = Collections.emptyMap(); + private static final Map, ExtraContextInfo> EXTRA_PROCESS_CONTEXTS = + ImmutableMap., ExtraContextInfo>builder() + .putAll(EXTRA_CONTEXTS) + .put(BoundedWindow.class, new ExtraContextInfo() { + @Override + public Object + createInstance(ExtraContextFactory factory) { + return factory.window(); + } + + @Override + public TypeToken + tokenFor(TypeToken in, TypeToken out) { + return TypeToken.of(BoundedWindow.class); + } + }) + .put(WindowingInternals.class, new ExtraContextInfo() { + @Override + public Object + createInstance(ExtraContextFactory factory) { + return factory.windowingInternals(); + } + + @Override + public TypeToken + tokenFor(TypeToken in, TypeToken out) { + return new TypeToken>() { + } + .where(new TypeParameter() {}, in) + .where(new TypeParameter() {}, out); + } + }) + .build(); + + /** + * @return true if the reflected {@link DoFnWithContext} uses a Single Window. + */ + public abstract boolean usesSingleWindow(); + + /** + * Invoke the reflected {@link ProcessElement} method on the given instance. + * + * @param fn an instance of the {@link DoFnWithContext} to invoke {@link ProcessElement} on. + * @param c the {@link com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.ProcessContext} + * to pass to {@link ProcessElement}. + */ + abstract void invokeProcessElement( + DoFnWithContext fn, + DoFnWithContext.ProcessContext c, + ExtraContextFactory extra); + + /** + * Invoke the reflected {@link StartBundle} method on the given instance. + * + * @param fn an instance of the {@link DoFnWithContext} to invoke {@link StartBundle} on. + * @param c the {@link com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.Context} + * to pass to {@link StartBundle}. + */ + void invokeStartBundle( + DoFnWithContext fn, + DoFnWithContext.Context c, + ExtraContextFactory extra) { + fn.prepareForProcessing(); + } + + /** + * Invoke the reflected {@link FinishBundle} method on the given instance. + * + * @param fn an instance of the {@link DoFnWithContext} to invoke {@link FinishBundle} on. + * @param c the {@link com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.Context} + * to pass to {@link FinishBundle}. + */ + abstract void invokeFinishBundle( + DoFnWithContext fn, + DoFnWithContext.Context c, + ExtraContextFactory extra); + + private static final Map, DoFnReflector> REFLECTOR_CACHE = + new LinkedHashMap, DoFnReflector>(); + + /** + * @return the {@link DoFnReflector} for the given {@link DoFnWithContext}. + */ + public static DoFnReflector of( + @SuppressWarnings("rawtypes") Class fn) { + DoFnReflector reflector = REFLECTOR_CACHE.get(fn); + if (reflector != null) { + return reflector; + } + + reflector = new GenericDoFnReflector(fn); + REFLECTOR_CACHE.put(fn, reflector); + return reflector; + } + + /** + * Create a {@link DoFn} that the {@link DoFnWithContext}. + */ + public DoFn toDoFn(DoFnWithContext fn) { + if (usesSingleWindow()) { + return new WindowDoFnAdapter(this, fn); + } else { + return new SimpleDoFnAdapter(this, fn); + } + } + + private static String formatType(TypeToken t) { + return ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(t.getType()); + } + + private static String format(Method m) { + return ReflectHelpers.CLASS_AND_METHOD_FORMATTER.apply(m); + } + + private static Collection describeSupportedTypes( + Map, ExtraContextInfo> extraProcessContexts, + final TypeToken in, final TypeToken out) { + return FluentIterable + .from(extraProcessContexts.values()) + .transform(new Function() { + @Override + @Nullable + public String apply(@Nullable ExtraContextInfo input) { + if (input == null) { + return null; + } else { + return formatType(input.tokenFor(in, out)); + } + } + }) + .toSortedSet(String.CASE_INSENSITIVE_ORDER); + } + + @VisibleForTesting + static ExtraContextInfo[] verifyProcessMethodArguments(Method m) { + return verifyMethodArguments(m, + EXTRA_PROCESS_CONTEXTS, + new TypeToken.ProcessContext>() { + }, + new TypeParameter() {}, + new TypeParameter() {}); + } + + @VisibleForTesting + static ExtraContextInfo[] verifyBundleMethodArguments(Method m) { + return verifyMethodArguments(m, + EXTRA_CONTEXTS, + new TypeToken.Context>() { + }, + new TypeParameter() {}, + new TypeParameter() {}); + } + + /** + * Verify the method arguments for a given {@link DoFnWithContext} method. + * + *

    The requirements for a method to be valid, are: + *

      + *
    1. The method has at least one argument. + *
    2. The first argument is of type firstContextArg. + *
    3. The remaining arguments have raw types that appear in {@code contexts} + *
    4. Any generics on the extra context arguments match what is expected. Eg., + * {@code WindowingInternals} either matches the + * {@code InputT} and {@code OutputT} parameters of the + * {@code DoFn.ProcessContext}, or it uses a wildcard, etc. + *
    + * + * @param m the method to verify + * @param contexts mapping from raw classes to the {@link ExtraContextInfo} used + * to create new instances. + * @param firstContextArg the expected type of the first context argument + * @param iParam TypeParameter representing the input type + * @param oParam TypeParameter representing the output type + */ + @VisibleForTesting static ExtraContextInfo[] verifyMethodArguments(Method m, + Map, ExtraContextInfo> contexts, + TypeToken firstContextArg, TypeParameter iParam, TypeParameter oParam) { + + if (!void.class.equals(m.getReturnType())) { + throw new IllegalStateException(String.format( + "%s must have a void return type", format(m))); + } + if (m.isVarArgs()) { + throw new IllegalStateException(String.format( + "%s must not have var args", format(m))); + } + + // The first parameter must be present, and must be the specified type + Type[] params = m.getGenericParameterTypes(); + TypeToken contextToken = null; + if (params.length > 0) { + contextToken = TypeToken.of(params[0]); + } + if (contextToken == null + || !contextToken.getRawType().equals(firstContextArg.getRawType())) { + throw new IllegalStateException(String.format( + "%s must take a %s as its first argument", + format(m), firstContextArg.getRawType().getSimpleName())); + } + ExtraContextInfo[] contextInfos = new ExtraContextInfo[params.length - 1]; + + // Fill in the generics in the allExtraContextArgs interface from the types in the + // Context or ProcessContext DoFn. + ParameterizedType pt = (ParameterizedType) contextToken.getType(); + // We actually want the owner, since ProcessContext and Context are owned by DoFnWithContext. + pt = (ParameterizedType) pt.getOwnerType(); + @SuppressWarnings("unchecked") + TypeToken iActual = (TypeToken) TypeToken.of(pt.getActualTypeArguments()[0]); + @SuppressWarnings("unchecked") + TypeToken oActual = (TypeToken) TypeToken.of(pt.getActualTypeArguments()[1]); + + // All of the remaining parameters must be a super-interface of allExtraContextArgs + // that is not listed in the EXCLUDED_INTERFACES set. + for (int i = 1; i < params.length; i++) { + TypeToken param = TypeToken.of(params[i]); + + ExtraContextInfo info = contexts.get(param.getRawType()); + if (info == null) { + throw new IllegalStateException(String.format( + "%s is not a valid context parameter for method %s. Should be one of %s", + formatType(param), format(m), + describeSupportedTypes(contexts, iActual, oActual))); + } + + // If we get here, the class matches, but maybe the generics don't: + TypeToken expected = info.tokenFor(iActual, oActual); + if (!expected.isSubtypeOf(param)) { + throw new IllegalStateException(String.format( + "Incompatible generics in context parameter %s for method %s. Should be %s", + formatType(param), format(m), formatType(info.tokenFor(iActual, oActual)))); + } + + // Register the (now validated) context info + contextInfos[i - 1] = info; + } + return contextInfos; + } + + /** + * Implementation of {@link DoFnReflector} for the arbitrary {@link DoFnWithContext}. + */ + private static class GenericDoFnReflector extends DoFnReflector { + + private Method startBundle; + private Method processElement; + private Method finishBundle; + private ExtraContextInfo[] processElementArgs; + private ExtraContextInfo[] startBundleArgs; + private ExtraContextInfo[] finishBundleArgs; + + private GenericDoFnReflector(Class fn) { + // Locate the annotated methods + this.processElement = findAnnotatedMethod(ProcessElement.class, fn, true); + this.startBundle = findAnnotatedMethod(StartBundle.class, fn, false); + this.finishBundle = findAnnotatedMethod(FinishBundle.class, fn, false); + + // Verify that their method arguments satisfy our conditions. + processElementArgs = verifyProcessMethodArguments(processElement); + if (startBundle != null) { + startBundleArgs = verifyBundleMethodArguments(startBundle); + } + if (finishBundle != null) { + finishBundleArgs = verifyBundleMethodArguments(finishBundle); + } + } + + private static Collection declaredMethodsWithAnnotation( + Class anno, + Class startClass, Class stopClass) { + Collection matches = new ArrayList<>(); + + Class clazz = startClass; + LinkedHashSet> interfaces = new LinkedHashSet<>(); + + // First, find all declared methods on the startClass and parents (up to stopClass) + while (clazz != null && !clazz.equals(stopClass)) { + for (Method method : clazz.getDeclaredMethods()) { + if (method.isAnnotationPresent(anno)) { + matches.add(method); + } + } + + Collections.addAll(interfaces, clazz.getInterfaces()); + + clazz = clazz.getSuperclass(); + } + + // Now, iterate over all the discovered interfaces + for (Method method : ReflectHelpers.getClosureOfMethodsOnInterfaces(interfaces)) { + if (method.isAnnotationPresent(anno)) { + matches.add(method); + } + } + return matches; + } + + private static Method findAnnotatedMethod( + Class anno, Class fnClazz, boolean required) { + Collection matches = declaredMethodsWithAnnotation( + anno, fnClazz, DoFnWithContext.class); + + if (matches.size() == 0) { + if (required == true) { + throw new IllegalStateException(String.format( + "No method annotated with @%s found in %s", + anno.getSimpleName(), fnClazz.getName())); + } else { + return null; + } + } + + // If we have at least one match, then either it should be the only match + // or it should be an extension of the other matches (which came from parent + // classes). + Method first = matches.iterator().next(); + for (Method other : matches) { + if (!first.getName().equals(other.getName()) + || !Arrays.equals(first.getParameterTypes(), other.getParameterTypes())) { + throw new IllegalStateException(String.format( + "Found multiple methods annotated with @%s. [%s] and [%s]", + anno.getSimpleName(), format(first), format(other))); + } + } + + // We need to be able to call it. We require it is public. + if ((first.getModifiers() & Modifier.PUBLIC) == 0) { + throw new IllegalStateException(format(first) + " must be public"); + } + + // And make sure its not static. + if ((first.getModifiers() & Modifier.STATIC) != 0) { + throw new IllegalStateException(format(first) + " must not be static"); + } + + first.setAccessible(true); + return first; + } + + @Override + public boolean usesSingleWindow() { + return usesContext(BoundedWindow.class); + } + + private boolean usesContext(Class context) { + for (Class clazz : processElement.getParameterTypes()) { + if (clazz.equals(context)) { + return true; + } + } + return false; + } + + @Override + void invokeProcessElement( + DoFnWithContext fn, + DoFnWithContext.ProcessContext c, + ExtraContextFactory extra) { + invoke(processElement, fn, c, extra, processElementArgs); + } + + @Override + void invokeStartBundle( + DoFnWithContext fn, + DoFnWithContext.Context c, + ExtraContextFactory extra) { + super.invokeStartBundle(fn, c, extra); + if (startBundle != null) { + invoke(startBundle, fn, c, extra, startBundleArgs); + } + } + + @Override + void invokeFinishBundle( + DoFnWithContext fn, + DoFnWithContext.Context c, + ExtraContextFactory extra) { + if (finishBundle != null) { + invoke(finishBundle, fn, c, extra, finishBundleArgs); + } + } + + private void invoke(Method m, + DoFnWithContext on, + DoFnWithContext.Context contextArg, + ExtraContextFactory extraArgFactory, + ExtraContextInfo[] extraArgs) { + + Class[] parameterTypes = m.getParameterTypes(); + Object[] args = new Object[parameterTypes.length]; + args[0] = contextArg; + for (int i = 1; i < args.length; i++) { + args[i] = extraArgs[i - 1].createInstance(extraArgFactory); + } + + try { + m.invoke(on, args); + } catch (InvocationTargetException e) { + // Exception in user code. + throw UserCodeException.wrap(e.getCause()); + } catch (IllegalAccessException | IllegalArgumentException e) { + // Exception in our code. + throw Throwables.propagate(e); + } + } + } + + private static class ContextAdapter + extends DoFnWithContext.Context + implements DoFnWithContext.ExtraContextFactory { + + private DoFn.Context context; + + private ContextAdapter( + DoFnWithContext fn, DoFn.Context context) { + fn.super(); + this.context = context; + } + + @Override + public PipelineOptions getPipelineOptions() { + return context.getPipelineOptions(); + } + + @Override + public void output(OutputT output) { + context.output(output); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + context.outputWithTimestamp(output, timestamp); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + context.sideOutput(tag, output); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + context.sideOutputWithTimestamp(tag, output, timestamp); + } + + @Override + public BoundedWindow window() { + // The DoFnWithContext doesn't allow us to ask for these outside ProcessElements, so this + // should be unreachable. + throw new UnsupportedOperationException("Can only get the window in ProcessElements"); + } + + @Override + public WindowingInternals windowingInternals() { + // The DoFnWithContext doesn't allow us to ask for these outside ProcessElements, so this + // should be unreachable. + throw new UnsupportedOperationException( + "Can only get the windowingInternals in ProcessElements"); + } + } + + private static class ProcessContextAdapter + extends DoFnWithContext.ProcessContext + implements DoFnWithContext.ExtraContextFactory { + + private DoFn.ProcessContext context; + + private ProcessContextAdapter( + DoFnWithContext fn, + DoFn.ProcessContext context) { + fn.super(); + this.context = context; + } + + @Override + public PipelineOptions getPipelineOptions() { + return context.getPipelineOptions(); + } + + @Override + public T sideInput(PCollectionView view) { + return context.sideInput(view); + } + + @Override + public void output(OutputT output) { + context.output(output); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + context.outputWithTimestamp(output, timestamp); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + context.sideOutput(tag, output); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + context.sideOutputWithTimestamp(tag, output, timestamp); + } + + @Override + public InputT element() { + return context.element(); + } + + @Override + public Instant timestamp() { + return context.timestamp(); + } + + @Override + public PaneInfo pane() { + return context.pane(); + } + + @Override + public BoundedWindow window() { + return context.window(); + } + + @Override + public WindowingInternals windowingInternals() { + return context.windowingInternals(); + } + } + + public static Class getDoFnClass(DoFn fn) { + if (fn instanceof SimpleDoFnAdapter) { + return ((SimpleDoFnAdapter) fn).fn.getClass(); + } else { + return fn.getClass(); + } + } + + private static class SimpleDoFnAdapter extends DoFn { + + private transient DoFnReflector reflector; + private DoFnWithContext fn; + + private SimpleDoFnAdapter(DoFnReflector reflector, DoFnWithContext fn) { + super(fn.aggregators); + this.reflector = reflector; + this.fn = fn; + } + + @Override + public void startBundle(DoFn.Context c) throws Exception { + ContextAdapter adapter = new ContextAdapter<>(fn, c); + reflector.invokeStartBundle(fn, adapter, adapter); + } + + @Override + public void finishBundle(DoFn.Context c) throws Exception { + ContextAdapter adapter = new ContextAdapter<>(fn, c); + reflector.invokeFinishBundle(fn, adapter, adapter); + } + + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + ProcessContextAdapter adapter = new ProcessContextAdapter<>(fn, c); + reflector.invokeProcessElement(fn, adapter, adapter); + } + + @Override + protected TypeDescriptor getInputTypeDescriptor() { + return fn.getInputTypeDescriptor(); + } + + @Override + protected TypeDescriptor getOutputTypeDescriptor() { + return fn.getOutputTypeDescriptor(); + } + + private void readObject(java.io.ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + reflector = DoFnReflector.of(fn.getClass()); + } + } + + private static class WindowDoFnAdapter + extends SimpleDoFnAdapter implements DoFn.RequiresWindowAccess { + + private WindowDoFnAdapter(DoFnReflector reflector, DoFnWithContext fn) { + super(reflector, fn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnTester.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnTester.java new file mode 100644 index 000000000000..544766433c5b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnTester.java @@ -0,0 +1,495 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.DirectSideInputReader; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.DoFnRunnerBase; +import com.google.cloud.dataflow.sdk.util.DoFnRunners; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.base.Function; +import com.google.common.base.Objects; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A harness for unit-testing a {@link DoFn}. + * + *

    For example: + * + *

     {@code
    + * DoFn fn = ...;
    + *
    + * DoFnTester fnTester = DoFnTester.of(fn);
    + *
    + * // Set arguments shared across all batches:
    + * fnTester.setSideInputs(...);      // If fn takes side inputs.
    + * fnTester.setSideOutputTags(...);  // If fn writes to side outputs.
    + *
    + * // Process a batch containing a single input element:
    + * Input testInput = ...;
    + * List testOutputs = fnTester.processBatch(testInput);
    + * Assert.assertThat(testOutputs,
    + *                   JUnitMatchers.hasItems(...));
    + *
    + * // Process a bigger batch:
    + * Assert.assertThat(fnTester.processBatch(i1, i2, ...),
    + *                   JUnitMatchers.hasItems(...));
    + * } 
    + * + * @param the type of the {@code DoFn}'s (main) input elements + * @param the type of the {@code DoFn}'s (main) output elements + */ +public class DoFnTester { + /** + * Returns a {@code DoFnTester} supporting unit-testing of the given + * {@link DoFn}. + */ + @SuppressWarnings("unchecked") + public static DoFnTester of(DoFn fn) { + return new DoFnTester(fn); + } + + /** + * Returns a {@code DoFnTester} supporting unit-testing of the given + * {@link DoFn}. + */ + @SuppressWarnings("unchecked") + public static DoFnTester + of(DoFnWithContext fn) { + return new DoFnTester(DoFnReflector.of(fn.getClass()).toDoFn(fn)); + } + + /** + * Registers the tuple of values of the side input {@link PCollectionView}s to + * pass to the {@link DoFn} under test. + * + *

    If needed, first creates a fresh instance of the {@link DoFn} + * under test. + * + *

    If this isn't called, {@code DoFnTester} assumes the + * {@link DoFn} takes no side inputs. + */ + public void setSideInputs(Map, Iterable>> sideInputs) { + this.sideInputs = sideInputs; + resetState(); + } + + /** + * Registers the values of a side input {@link PCollectionView} to + * pass to the {@link DoFn} under test. + * + *

    If needed, first creates a fresh instance of the {@code DoFn} + * under test. + * + *

    If this isn't called, {@code DoFnTester} assumes the + * {@code DoFn} takes no side inputs. + */ + public void setSideInput(PCollectionView sideInput, Iterable> value) { + sideInputs.put(sideInput, value); + } + + /** + * Registers the values for a side input {@link PCollectionView} to + * pass to the {@link DoFn} under test. All values are placed + * in the global window. + */ + public void setSideInputInGlobalWindow( + PCollectionView sideInput, + Iterable value) { + sideInputs.put( + sideInput, + Iterables.transform(value, new Function>() { + @Override + public WindowedValue apply(Object input) { + return WindowedValue.valueInGlobalWindow(input); + } + })); + } + + + /** + * Registers the list of {@code TupleTag}s that can be used by the + * {@code DoFn} under test to output to side output + * {@code PCollection}s. + * + *

    If needed, first creates a fresh instance of the DoFn under test. + * + *

    If this isn't called, {@code DoFnTester} assumes the + * {@code DoFn} doesn't emit to any side outputs. + */ + public void setSideOutputTags(TupleTagList sideOutputTags) { + this.sideOutputTags = sideOutputTags.getAll(); + resetState(); + } + + /** + * A convenience operation that first calls {@link #startBundle}, + * then calls {@link #processElement} on each of the input elements, then + * calls {@link #finishBundle}, then returns the result of + * {@link #takeOutputElements}. + */ + public List processBatch(Iterable inputElements) { + startBundle(); + for (InputT inputElement : inputElements) { + processElement(inputElement); + } + finishBundle(); + return takeOutputElements(); + } + + /** + * A convenience method for testing {@link DoFn DoFns} with bundles of elements. + * Logic proceeds as follows: + * + *

      + *
    1. Calls {@link #startBundle}.
    2. + *
    3. Calls {@link #processElement} on each of the arguments.
    4. + *
    5. Calls {@link #finishBundle}.
    6. + *
    7. Returns the result of {@link #takeOutputElements}.
    8. + *
    + */ + @SafeVarargs + public final List processBatch(InputT... inputElements) { + return processBatch(Arrays.asList(inputElements)); + } + + /** + * Calls {@link DoFn#startBundle} on the {@code DoFn} under test. + * + *

    If needed, first creates a fresh instance of the DoFn under test. + */ + public void startBundle() { + resetState(); + initializeState(); + fnRunner.startBundle(); + state = State.STARTED; + } + + /** + * Calls {@link DoFn#processElement} on the {@code DoFn} under test, in a + * context where {@link DoFn.ProcessContext#element} returns the + * given element. + * + *

    Will call {@link #startBundle} automatically, if it hasn't + * already been called. + * + * @throws IllegalStateException if the {@code DoFn} under test has already + * been finished + */ + public void processElement(InputT element) { + if (state == State.FINISHED) { + throw new IllegalStateException("finishBundle() has already been called"); + } + if (state == State.UNSTARTED) { + startBundle(); + } + fnRunner.processElement(WindowedValue.valueInGlobalWindow(element)); + } + + /** + * Calls {@link DoFn#finishBundle} of the {@code DoFn} under test. + * + *

    Will call {@link #startBundle} automatically, if it hasn't + * already been called. + * + * @throws IllegalStateException if the {@code DoFn} under test has already + * been finished + */ + public void finishBundle() { + if (state == State.FINISHED) { + throw new IllegalStateException("finishBundle() has already been called"); + } + if (state == State.UNSTARTED) { + startBundle(); + } + fnRunner.finishBundle(); + state = State.FINISHED; + } + + /** + * Returns the elements output so far to the main output. Does not + * clear them, so subsequent calls will continue to include these + * elements. + * + * @see #takeOutputElements + * @see #clearOutputElements + * + */ + public List peekOutputElements() { + // TODO: Should we return an unmodifiable list? + return Lists.transform( + peekOutputElementsWithTimestamp(), + new Function, OutputT>() { + @Override + @SuppressWarnings("unchecked") + public OutputT apply(OutputElementWithTimestamp input) { + return input.getValue(); + } + }); + } + + /** + * Returns the elements output so far to the main output with associated timestamps. Does not + * clear them, so subsequent calls will continue to include these. + * elements. + * + * @see #takeOutputElementsWithTimestamp + * @see #clearOutputElements + */ + @Experimental + public List> peekOutputElementsWithTimestamp() { + // TODO: Should we return an unmodifiable list? + return Lists.transform( + outputManager.getOutput(mainOutputTag), + new Function>() { + @Override + @SuppressWarnings("unchecked") + public OutputElementWithTimestamp apply(Object input) { + return new OutputElementWithTimestamp( + ((WindowedValue) input).getValue(), + ((WindowedValue) input).getTimestamp()); + } + }); + } + + /** + * Clears the record of the elements output so far to the main output. + * + * @see #peekOutputElements + */ + public void clearOutputElements() { + peekOutputElements().clear(); + } + + /** + * Returns the elements output so far to the main output. + * Clears the list so these elements don't appear in future calls. + * + * @see #peekOutputElements + */ + public List takeOutputElements() { + List resultElems = new ArrayList<>(peekOutputElements()); + clearOutputElements(); + return resultElems; + } + + /** + * Returns the elements output so far to the main output with associated timestamps. + * Clears the list so these elements don't appear in future calls. + * + * @see #peekOutputElementsWithTimestamp + * @see #takeOutputElements + * @see #clearOutputElements + */ + @Experimental + public List> takeOutputElementsWithTimestamp() { + List> resultElems = + new ArrayList<>(peekOutputElementsWithTimestamp()); + clearOutputElements(); + return resultElems; + } + + /** + * Returns the elements output so far to the side output with the + * given tag. Does not clear them, so subsequent calls will + * continue to include these elements. + * + * @see #takeSideOutputElements + * @see #clearSideOutputElements + */ + public List peekSideOutputElements(TupleTag tag) { + // TODO: Should we return an unmodifiable list? + return Lists.transform( + outputManager.getOutput(tag), + new Function, T>() { + @SuppressWarnings("unchecked") + @Override + public T apply(WindowedValue input) { + return input.getValue(); + }}); + } + + /** + * Clears the record of the elements output so far to the side + * output with the given tag. + * + * @see #peekSideOutputElements + */ + public void clearSideOutputElements(TupleTag tag) { + peekSideOutputElements(tag).clear(); + } + + /** + * Returns the elements output so far to the side output with the given tag. + * Clears the list so these elements don't appear in future calls. + * + * @see #peekSideOutputElements + */ + public List takeSideOutputElements(TupleTag tag) { + List resultElems = new ArrayList<>(peekSideOutputElements(tag)); + clearSideOutputElements(tag); + return resultElems; + } + + /** + * Returns the value of the provided {@link Aggregator}. + */ + public AggregateT getAggregatorValue(Aggregator agg) { + @SuppressWarnings("unchecked") + Counter counter = + (Counter) + counterSet.getExistingCounter("user-" + STEP_NAME + "-" + agg.getName()); + return counter.getAggregate(); + } + + /** + * Holder for an OutputElement along with its associated timestamp. + */ + @Experimental + public static class OutputElementWithTimestamp { + private final OutputT value; + private final Instant timestamp; + + OutputElementWithTimestamp(OutputT value, Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + + OutputT getValue() { + return value; + } + + Instant getTimestamp() { + return timestamp; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof OutputElementWithTimestamp)) { + return false; + } + OutputElementWithTimestamp other = (OutputElementWithTimestamp) obj; + return Objects.equal(other.value, value) && Objects.equal(other.timestamp, timestamp); + } + + @Override + public int hashCode() { + return Objects.hashCode(value, timestamp); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** The possible states of processing a DoFn. */ + enum State { + UNSTARTED, + STARTED, + FINISHED + } + + /** The name of the step of a DoFnTester. */ + static final String STEP_NAME = "stepName"; + /** The name of the enclosing DoFn PTransform for a DoFnTester. */ + static final String TRANSFORM_NAME = "transformName"; + + final PipelineOptions options = PipelineOptionsFactory.create(); + + /** The original DoFn under test. */ + final DoFn origFn; + + /** The side input values to provide to the DoFn under test. */ + private Map, Iterable>> sideInputs = + new HashMap<>(); + + /** The output tags used by the DoFn under test. */ + TupleTag mainOutputTag = new TupleTag<>(); + List> sideOutputTags = new ArrayList<>(); + + /** The original DoFn under test, if started. */ + DoFn fn; + + /** The ListOutputManager to examine the outputs. */ + DoFnRunnerBase.ListOutputManager outputManager; + + /** The DoFnRunner if processing is in progress. */ + DoFnRunner fnRunner; + + /** Counters for user-defined Aggregators if processing is in progress. */ + CounterSet counterSet; + + /** The state of processing of the DoFn under test. */ + State state; + + DoFnTester(DoFn origFn) { + this.origFn = origFn; + resetState(); + } + + void resetState() { + fn = null; + outputManager = null; + fnRunner = null; + counterSet = null; + state = State.UNSTARTED; + } + + @SuppressWarnings("unchecked") + void initializeState() { + fn = (DoFn) + SerializableUtils.deserializeFromByteArray( + SerializableUtils.serializeToByteArray(origFn), + origFn.toString()); + counterSet = new CounterSet(); + PTuple runnerSideInputs = PTuple.empty(); + for (Map.Entry, Iterable>> entry + : sideInputs.entrySet()) { + runnerSideInputs = runnerSideInputs.and(entry.getKey().getTagInternal(), entry.getValue()); + } + outputManager = new DoFnRunnerBase.ListOutputManager(); + fnRunner = DoFnRunners.createDefault( + options, + fn, + DirectSideInputReader.of(runnerSideInputs), + outputManager, + mainOutputTag, + sideOutputTags, + DirectModeExecutionContext.create().getOrCreateStepContext(STEP_NAME, TRANSFORM_NAME, null), + counterSet.getAddCounterMutator(), + WindowingStrategy.globalDefault()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnWithContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnWithContext.java new file mode 100644 index 000000000000..10bf0eb16a90 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFnWithContext.java @@ -0,0 +1,416 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.DelegatingAggregator; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.Serializable; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.HashMap; +import java.util.Map; + +/** + * The argument to {@link ParDo} providing the code to use to process + * elements of the input + * {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * + *

    See {@link ParDo} for more explanation, examples of use, and + * discussion of constraints on {@code DoFnWithContext}s, including their + * serializability, lack of access to global shared mutable state, + * requirements for failure tolerance, and benefits of optimization. + * + *

    {@code DoFnWithContext}s can be tested in a particular + * {@code Pipeline} by running that {@code Pipeline} on sample input + * and then checking its output. Unit testing of a {@code DoFnWithContext}, + * separately from any {@code ParDo} transform or {@code Pipeline}, + * can be done via the {@link DoFnTester} harness. + * + *

    Implementations must define a method annotated with {@link ProcessElement} + * that satisfies the requirements described there. See the {@link ProcessElement} + * for details. + * + *

    This functionality is experimental and likely to change. + * + *

    Example usage: + * + *

     {@code
    + * PCollection lines = ... ;
    + * PCollection words =
    + *     lines.apply(ParDo.of(new DoFnWithContext() {
    + *         @ProcessElement
    + *         public void processElement(ProcessContext c, BoundedWindow window) {
    + *
    + *         }}));
    + * } 
    + * + * @param the type of the (main) input elements + * @param the type of the (main) output elements + */ +@Experimental +public abstract class DoFnWithContext implements Serializable { + + /** Information accessible to all methods in this {@code DoFnWithContext}. */ + public abstract class Context { + + /** + * Returns the {@code PipelineOptions} specified with the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} + * invoking this {@code DoFnWithContext}. The {@code PipelineOptions} will + * be the default running via {@link DoFnTester}. + */ + public abstract PipelineOptions getPipelineOptions(); + + /** + * Adds the given element to the main output {@code PCollection}. + * + *

    Once passed to {@code output} the element should not be modified in + * any way. + * + *

    If invoked from {@link ProcessElement}, the output + * element will have the same timestamp and be in the same windows + * as the input element passed to the method annotated with + * {@code @ProcessElement}. + * + *

    If invoked from {@link StartBundle} or {@link FinishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element. The output element + * will have a timestamp of negative infinity. + */ + public abstract void output(OutputT output); + + /** + * Adds the given element to the main output {@code PCollection}, + * with the given timestamp. + * + *

    Once passed to {@code outputWithTimestamp} the element should not be + * modified in any way. + * + *

    If invoked from {@link ProcessElement}), the timestamp + * must not be older than the input element's timestamp minus + * {@link DoFn#getAllowedTimestampSkew}. The output element will + * be in the same windows as the input element. + * + *

    If invoked from {@link StartBundle} or {@link FinishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element except for the + * timestamp. + */ + public abstract void outputWithTimestamp(OutputT output, Instant timestamp); + + /** + * Adds the given element to the side output {@code PCollection} with the + * given tag. + * + *

    Once passed to {@code sideOutput} the element should not be modified + * in any way. + * + *

    The caller of {@code ParDo} uses {@link ParDo#withOutputTags} to + * specify the tags of side outputs that it consumes. Non-consumed side + * outputs, e.g., outputs for monitoring purposes only, don't necessarily + * need to be specified. + * + *

    The output element will have the same timestamp and be in the same + * windows as the input element passed to {@link ProcessElement}). + * + *

    If invoked from {@link StartBundle} or {@link FinishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element. The output element + * will have a timestamp of negative infinity. + * + * @see ParDo#withOutputTags + */ + public abstract void sideOutput(TupleTag tag, T output); + + /** + * Adds the given element to the specified side output {@code PCollection}, + * with the given timestamp. + * + *

    Once passed to {@code sideOutputWithTimestamp} the element should not be + * modified in any way. + * + *

    If invoked from {@link ProcessElement}), the timestamp + * must not be older than the input element's timestamp minus + * {@link DoFn#getAllowedTimestampSkew}. The output element will + * be in the same windows as the input element. + * + *

    If invoked from {@link StartBundle} or {@link FinishBundle}, + * this will attempt to use the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * of the input {@code PCollection} to determine what windows the element + * should be in, throwing an exception if the {@code WindowFn} attempts + * to access any information about the input element except for the + * timestamp. + * + * @see ParDo#withOutputTags + */ + public abstract void sideOutputWithTimestamp( + TupleTag tag, T output, Instant timestamp); + } + + /** + * Information accessible when running {@link DoFn#processElement}. + */ + public abstract class ProcessContext extends Context { + + /** + * Returns the input element to be processed. + * + *

    The element will not be changed -- it is safe to cache, etc. + * without copying. + */ + public abstract InputT element(); + + + /** + * Returns the value of the side input. + * + * @throws IllegalArgumentException if this is not a side input + * @see ParDo#withSideInputs + */ + public abstract T sideInput(PCollectionView view); + + /** + * Returns the timestamp of the input element. + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + */ + public abstract Instant timestamp(); + + /** + * Returns information about the pane within this window into which the + * input element has been assigned. + * + *

    Generally all data is in a single, uninteresting pane unless custom + * triggering and/or late data has been explicitly requested. + * See {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} + * for more information. + */ + public abstract PaneInfo pane(); + } + + /** + * Returns the allowed timestamp skew duration, which is the maximum + * duration that timestamps can be shifted backward in + * {@link DoFnWithContext.Context#outputWithTimestamp}. + * + *

    The default value is {@code Duration.ZERO}, in which case + * timestamps can only be shifted forward to future. For infinite + * skew, return {@code Duration.millis(Long.MAX_VALUE)}. + */ + public Duration getAllowedTimestampSkew() { + return Duration.ZERO; + } + + ///////////////////////////////////////////////////////////////////////////// + + Map> aggregators = new HashMap<>(); + + /** + * Protects aggregators from being created after initialization. + */ + private boolean aggregatorsAreFinal; + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the input type of this {@code DoFnWithContext} instance's most-derived + * class. + * + *

    See {@link #getOutputTypeDescriptor} for more discussion. + */ + protected TypeDescriptor getInputTypeDescriptor() { + return new TypeDescriptor(getClass()) {}; + } + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the output type of this {@code DoFnWithContext} instance's + * most-derived class. + * + *

    In the normal case of a concrete {@code DoFnWithContext} subclass with + * no generic type parameters of its own (including anonymous inner + * classes), this will be a complete non-generic type, which is good + * for choosing a default output {@code Coder} for the output + * {@code PCollection}. + */ + protected TypeDescriptor getOutputTypeDescriptor() { + return new TypeDescriptor(getClass()) {}; + } + + /** + * Interface for runner implementors to provide implementations of extra context information. + * + *

    The methods on this interface are called by {@link DoFnReflector} before invoking an + * annotated {@link StartBundle}, {@link ProcessElement} or {@link FinishBundle} method that + * has indicated it needs the given extra context. + * + *

    In the case of {@link ProcessElement} it is called once per invocation of + * {@link ProcessElement}. + */ + public interface ExtraContextFactory { + /** + * Construct the {@link BoundedWindow} to use within a {@link DoFnWithContext} that + * needs it. This is called if the {@link ProcessElement} method has a parameter of type + * {@link BoundedWindow}. + * + * @return {@link BoundedWindow} of the element currently being processed. + */ + BoundedWindow window(); + + /** + * Construct the {@link WindowingInternals} to use within a {@link DoFnWithContext} that + * needs it. This is called if the {@link ProcessElement} method has a parameter of type + * {@link WindowingInternals}. + */ + WindowingInternals windowingInternals(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Annotation for the method to use to prepare an instance for processing a batch of elements. + * The method annotated with this must satisfy the following constraints: + *

      + *
    • It must have at least one argument. + *
    • Its first (and only) argument must be a {@link DoFnWithContext.Context}. + *
    + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + public @interface StartBundle {} + + /** + * Annotation for the method to use for processing elements. A subclass of + * {@link DoFnWithContext} must have a method with this annotation satisfying + * the following constraints in order for it to be executable: + *
      + *
    • It must have at least one argument. + *
    • Its first argument must be a {@link DoFnWithContext.ProcessContext}. + *
    • Its remaining arguments must be {@link BoundedWindow}, or + * {@link WindowingInternals WindowingInternals<InputT, OutputT>}. + *
    + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + public @interface ProcessElement {} + + /** + * Annotation for the method to use to prepare an instance for processing a batch of elements. + * The method annotated with this must satisfy the following constraints: + *
      + *
    • It must have at least one argument. + *
    • Its first (and only) argument must be a {@link DoFnWithContext.Context}. + *
    + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + public @interface FinishBundle {} + + /** + * Returns an {@link Aggregator} with aggregation logic specified by the + * {@link CombineFn} argument. The name provided must be unique across + * {@link Aggregator}s created within the DoFn. Aggregators can only be created + * during pipeline construction. + * + * @param name the name of the aggregator + * @param combiner the {@link CombineFn} to use in the aggregator + * @return an aggregator for the provided name and combiner in the scope of + * this DoFn + * @throws NullPointerException if the name or combiner is null + * @throws IllegalArgumentException if the given name collides with another + * aggregator in this scope + * @throws IllegalStateException if called during pipeline execution. + */ + public final Aggregator + createAggregator(String name, Combine.CombineFn combiner) { + checkNotNull(name, "name cannot be null"); + checkNotNull(combiner, "combiner cannot be null"); + checkArgument(!aggregators.containsKey(name), + "Cannot create aggregator with name %s." + + " An Aggregator with that name already exists within this scope.", + name); + checkState(!aggregatorsAreFinal, + "Cannot create an aggregator during pipeline execution." + + " Aggregators should be registered during pipeline construction."); + + DelegatingAggregator aggregator = + new DelegatingAggregator<>(name, combiner); + aggregators.put(name, aggregator); + return aggregator; + } + + /** + * Returns an {@link Aggregator} with the aggregation logic specified by the + * {@link SerializableFunction} argument. The name provided must be unique + * across {@link Aggregator}s created within the DoFn. Aggregators can only be + * created during pipeline construction. + * + * @param name the name of the aggregator + * @param combiner the {@link SerializableFunction} to use in the aggregator + * @return an aggregator for the provided name and combiner in the scope of + * this DoFn + * @throws NullPointerException if the name or combiner is null + * @throws IllegalArgumentException if the given name collides with another + * aggregator in this scope + * @throws IllegalStateException if called during pipeline execution. + */ + public final Aggregator createAggregator( + String name, SerializableFunction, AggInputT> combiner) { + checkNotNull(combiner, "combiner cannot be null."); + return createAggregator(name, Combine.IterableCombineFn.of(combiner)); + } + + /** + * Finalize the @{link DoFnWithContext} construction to prepare for processing. + * This method should be called by runners before any processing methods. + */ + void prepareForProcessing() { + aggregatorsAreFinal = true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Filter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Filter.java new file mode 100644 index 000000000000..9e123a19fcd1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Filter.java @@ -0,0 +1,234 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code PTransform}s for filtering from a {@code PCollection} the + * elements satisfying a predicate, or satisfying an inequality with + * a given value based on the elements' natural ordering. + * + * @param the type of the values in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ +public class Filter extends PTransform, PCollection> { + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection} with + * elements that satisfy the given predicate. The predicate must be + * a {@code SerializableFunction}. + * + *

    Example of use: + *

     {@code
    +   * PCollection wordList = ...;
    +   * PCollection longWords =
    +   *     wordList.apply(Filter.byPredicate(new MatchIfWordLengthGT(6)));
    +   * } 
    + * + *

    See also {@link #lessThan}, {@link #lessThanEq}, + * {@link #greaterThan}, {@link #greaterThanEq}, which return elements + * satisfying various inequalities with the specified value based on + * the elements' natural ordering. + */ + public static > Filter + byPredicate(PredicateT predicate) { + return new Filter("Filter", predicate); + } + + /** + * @deprecated use {@link #byPredicate}, which returns a {@link Filter} transform instead of + * a {@link ParDo.Bound}. + */ + @Deprecated + public static > ParDo.Bound + by(final PredicateT filterPred) { + return ParDo.named("Filter").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (filterPred.apply(c.element()) == true) { + c.output(c.element()); + } + } + }); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@link PCollection} and returns a {@link PCollection} with + * elements that are less than a given value, based on the + * elements' natural ordering. Elements must be {@code Comparable}. + * + *

    Example of use: + *

     {@code
    +   * PCollection listOfNumbers = ...;
    +   * PCollection smallNumbers =
    +   *     listOfNumbers.apply(Filter.lessThan(10));
    +   * } 
    + * + *

    See also {@link #lessThanEq}, {@link #greaterThanEq}, + * and {@link #greaterThan}, which return elements satisfying various + * inequalities with the specified value based on the elements' + * natural ordering. + * + *

    See also {@link #byPredicate}, which returns elements + * that satisfy the given predicate. + */ + public static > ParDo.Bound lessThan(final T value) { + return ParDo.named("Filter.lessThan").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (c.element().compareTo(value) < 0) { + c.output(c.element()); + } + } + }); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection} with + * elements that are greater than a given value, based on the + * elements' natural ordering. Elements must be {@code Comparable}. + * + *

    Example of use: + *

     {@code
    +   * PCollection listOfNumbers = ...;
    +   * PCollection largeNumbers =
    +   *     listOfNumbers.apply(Filter.greaterThan(1000));
    +   * } 
    + * + *

    See also {@link #greaterThanEq}, {@link #lessThan}, + * and {@link #lessThanEq}, which return elements satisfying various + * inequalities with the specified value based on the elements' + * natural ordering. + * + *

    See also {@link #byPredicate}, which returns elements + * that satisfy the given predicate. + */ + public static > ParDo.Bound greaterThan(final T value) { + return ParDo.named("Filter.greaterThan").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (c.element().compareTo(value) > 0) { + c.output(c.element()); + } + } + }); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection} with + * elements that are less than or equal to a given value, based on the + * elements' natural ordering. Elements must be {@code Comparable}. + * + *

    Example of use: + *

     {@code
    +   * PCollection listOfNumbers = ...;
    +   * PCollection smallOrEqualNumbers =
    +   *     listOfNumbers.apply(Filter.lessThanEq(10));
    +   * } 
    + * + *

    See also {@link #lessThan}, {@link #greaterThanEq}, + * and {@link #greaterThan}, which return elements satisfying various + * inequalities with the specified value based on the elements' + * natural ordering. + * + *

    See also {@link #byPredicate}, which returns elements + * that satisfy the given predicate. + */ + public static > ParDo.Bound lessThanEq(final T value) { + return ParDo.named("Filter.lessThanEq").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (c.element().compareTo(value) <= 0) { + c.output(c.element()); + } + } + }); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection} with + * elements that are greater than or equal to a given value, based on + * the elements' natural ordering. Elements must be {@code Comparable}. + * + *

    Example of use: + *

     {@code
    +   * PCollection listOfNumbers = ...;
    +   * PCollection largeOrEqualNumbers =
    +   *     listOfNumbers.apply(Filter.greaterThanEq(1000));
    +   * } 
    + * + *

    See also {@link #greaterThan}, {@link #lessThan}, + * and {@link #lessThanEq}, which return elements satisfying various + * inequalities with the specified value based on the elements' + * natural ordering. + * + *

    See also {@link #byPredicate}, which returns elements + * that satisfy the given predicate. + */ + public static > ParDo.Bound greaterThanEq(final T value) { + return ParDo.named("Filter.greaterThanEq").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (c.element().compareTo(value) >= 0) { + c.output(c.element()); + } + } + }); + } + + /////////////////////////////////////////////////////////////////////////////// + + private SerializableFunction predicate; + + private Filter(SerializableFunction predicate) { + this.predicate = predicate; + } + + private Filter(String name, SerializableFunction predicate) { + super(name); + this.predicate = predicate; + } + + public Filter named(String name) { + return new Filter<>(name, predicate); + } + + @Override + public PCollection apply(PCollection input) { + PCollection output = input.apply(ParDo.named("Filter").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + if (predicate.apply(c.element()) == true) { + c.output(c.element()); + } + } + })); + return output; + } + + @Override + protected Coder getDefaultOutputCoder(PCollection input) { + return input.getCoder(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/FlatMapElements.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/FlatMapElements.java new file mode 100644 index 000000000000..fbaad5be6d93 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/FlatMapElements.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import java.lang.reflect.ParameterizedType; + +/** + * {@code PTransform}s for mapping a simple function that returns iterables over the elements of a + * {@link PCollection} and merging the results. + */ +public class FlatMapElements +extends PTransform, PCollection> { + /** + * For a {@code SerializableFunction>} {@code fn}, + * returns a {@link PTransform} that applies {@code fn} to every element of the input + * {@code PCollection} and outputs all of the elements to the output + * {@code PCollection}. + * + *

    Example of use in Java 8: + *

    {@code
    +   * PCollection words = lines.apply(
    +   *     FlatMapElements.via((String line) -> Arrays.asList(line.split(" ")))
    +   *         .withOutputType(new TypeDescriptor(){});
    +   * }
    + * + *

    In Java 7, the overload {@link #via(SimpleFunction)} is more concise as the output type + * descriptor need not be provided. + */ + public static MissingOutputTypeDescriptor + via(SerializableFunction> fn) { + return new MissingOutputTypeDescriptor<>(fn); + } + + /** + * For a {@code SimpleFunction>} {@code fn}, + * return a {@link PTransform} that applies {@code fn} to every element of the input + * {@code PCollection} and outputs all of the elements to the output + * {@code PCollection}. + * + *

    This overload is intended primarily for use in Java 7. In Java 8, the overload + * {@link #via(SerializableFunction)} supports use of lambda for greater concision. + * + *

    Example of use in Java 7: + *

    {@code
    +   * PCollection lines = ...;
    +   * PCollection words = lines.apply(FlatMapElements.via(
    +   *     new SimpleFunction>() {
    +   *       public Integer apply(String line) {
    +   *         return Arrays.asList(line.split(" "));
    +   *       }
    +   *     });
    +   * }
    + * + *

    To use a Java 8 lambda, see {@link #via(SerializableFunction)}. + */ + public static FlatMapElements + via(SimpleFunction> fn) { + + @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing + TypeDescriptor> iterableType = (TypeDescriptor) fn.getOutputTypeDescriptor(); + + @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType + TypeDescriptor outputType = + (TypeDescriptor) getIterableElementType(iterableType); + + return new FlatMapElements<>(fn, outputType); + } + + /** + * An intermediate builder for a {@link FlatMapElements} transform. To complete the transform, + * provide an output type descriptor to {@link MissingOutputTypeDescriptor#withOutputType}. See + * {@link #via(SerializableFunction)} for a full example of use. + */ + public static final class MissingOutputTypeDescriptor { + + private final SerializableFunction> fn; + + private MissingOutputTypeDescriptor( + SerializableFunction> fn) { + this.fn = fn; + } + + public FlatMapElements withOutputType(TypeDescriptor outputType) { + return new FlatMapElements<>(fn, outputType); + } + } + + private static TypeDescriptor getIterableElementType( + TypeDescriptor> iterableTypeDescriptor) { + + // If a rawtype was used, the type token may be for Object, not a subtype of Iterable. + // In this case, we rely on static typing of the function elsewhere to ensure it is + // at least some kind of iterable, and grossly overapproximate the element type to be Object. + if (!iterableTypeDescriptor.isSubtypeOf(new TypeDescriptor>() {})) { + return new TypeDescriptor() {}; + } + + // Otherwise we can do the proper thing and get the actual type parameter. + ParameterizedType iterableType = + (ParameterizedType) iterableTypeDescriptor.getSupertype(Iterable.class).getType(); + return TypeDescriptor.of(iterableType.getActualTypeArguments()[0]); + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + private final SerializableFunction> fn; + private final transient TypeDescriptor outputType; + + private FlatMapElements( + SerializableFunction> fn, + TypeDescriptor outputType) { + this.fn = fn; + this.outputType = outputType; + } + + @Override + public PCollection apply(PCollection input) { + return input.apply(ParDo.named("Map").of(new DoFn() { + private static final long serialVersionUID = 0L; + @Override + public void processElement(ProcessContext c) { + for (OutputT element : fn.apply(c.element())) { + c.output(element); + } + } + })).setTypeDescriptorInternal(outputType); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Flatten.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Flatten.java new file mode 100644 index 000000000000..de6add0ea3c6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Flatten.java @@ -0,0 +1,219 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableLikeCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@code Flatten} takes multiple {@code PCollection}s bundled + * into a {@code PCollectionList} and returns a single + * {@code PCollection} containing all the elements in all the input + * {@code PCollection}s. The name "Flatten" suggests taking a list of + * lists and flattening them into a single list. + * + *

    Example of use: + *

     {@code
    + * PCollection pc1 = ...;
    + * PCollection pc2 = ...;
    + * PCollection pc3 = ...;
    + * PCollectionList pcs = PCollectionList.of(pc1).and(pc2).and(pc3);
    + * PCollection merged = pcs.apply(Flatten.pCollections());
    + * } 
    + * + *

    By default, the {@code Coder} of the output {@code PCollection} + * is the same as the {@code Coder} of the first {@code PCollection} + * in the input {@code PCollectionList} (if the + * {@code PCollectionList} is non-empty). + * + */ +public class Flatten { + + /** + * Returns a {@link PTransform} that flattens a {@link PCollectionList} + * into a {@link PCollection} containing all the elements of all + * the {@link PCollection}s in its input. + * + *

    All inputs must have equal {@link WindowFn}s. + * The output elements of {@code Flatten} are in the same windows and + * have the same timestamps as their corresponding input elements. The output + * {@code PCollection} will have the same + * {@link WindowFn} as all of the inputs. + * + * @param the type of the elements in the input and output + * {@code PCollection}s. + */ + public static FlattenPCollectionList pCollections() { + return new FlattenPCollectionList<>(); + } + + /** + * Returns a {@code PTransform} that takes a {@code PCollection>} + * and returns a {@code PCollection} containing all the elements from + * all the {@code Iterable}s. + * + *

    Example of use: + *

     {@code
    +   * PCollection> pcOfIterables = ...;
    +   * PCollection pc = pcOfIterables.apply(Flatten.iterables());
    +   * } 
    + * + *

    By default, the output {@code PCollection} encodes its elements + * using the same {@code Coder} that the input uses for + * the elements in its {@code Iterable}. + * + * @param the type of the elements of the input {@code Iterable} and + * the output {@code PCollection} + */ + public static FlattenIterables iterables() { + return new FlattenIterables<>(); + } + + /** + * A {@link PTransform} that flattens a {@link PCollectionList} + * into a {@link PCollection} containing all the elements of all + * the {@link PCollection}s in its input. + * Implements {@link #pCollections}. + * + * @param the type of the elements in the input and output + * {@code PCollection}s. + */ + public static class FlattenPCollectionList + extends PTransform, PCollection> { + + private FlattenPCollectionList() { } + + @Override + public PCollection apply(PCollectionList inputs) { + WindowingStrategy windowingStrategy; + IsBounded isBounded = IsBounded.BOUNDED; + if (!inputs.getAll().isEmpty()) { + windowingStrategy = inputs.get(0).getWindowingStrategy(); + for (PCollection input : inputs.getAll()) { + WindowingStrategy other = input.getWindowingStrategy(); + if (!windowingStrategy.getWindowFn().isCompatible(other.getWindowFn())) { + throw new IllegalStateException( + "Inputs to Flatten had incompatible window windowFns: " + + windowingStrategy.getWindowFn() + ", " + other.getWindowFn()); + } + + if (!windowingStrategy.getTrigger().getSpec() + .isCompatible(other.getTrigger().getSpec())) { + throw new IllegalStateException( + "Inputs to Flatten had incompatible triggers: " + + windowingStrategy.getTrigger() + ", " + other.getTrigger()); + } + isBounded = isBounded.and(input.isBounded()); + } + } else { + windowingStrategy = WindowingStrategy.globalDefault(); + } + + return PCollection.createPrimitiveOutputInternal( + inputs.getPipeline(), + windowingStrategy, + isBounded); + } + + @Override + protected Coder getDefaultOutputCoder(PCollectionList input) + throws CannotProvideCoderException { + + // Take coder from first collection + for (PCollection pCollection : input.getAll()) { + return pCollection.getCoder(); + } + + // No inputs + throw new CannotProvideCoderException( + this.getClass().getSimpleName() + " cannot provide a Coder for" + + " empty " + PCollectionList.class.getSimpleName()); + } + } + + /** + * {@code FlattenIterables} takes a {@code PCollection>} and returns a + * {@code PCollection} that contains all the elements from each iterable. + * Implements {@link #iterables}. + * + * @param the type of the elements of the input {@code Iterable}s and + * the output {@code PCollection} + */ + public static class FlattenIterables + extends PTransform>, PCollection> { + + @Override + public PCollection apply(PCollection> in) { + Coder> inCoder = in.getCoder(); + if (!(inCoder instanceof IterableLikeCoder)) { + throw new IllegalArgumentException( + "expecting the input Coder to be an IterableLikeCoder"); + } + @SuppressWarnings("unchecked") + Coder elemCoder = ((IterableLikeCoder) inCoder).getElemCoder(); + + return in.apply(ParDo.named("FlattenIterables").of( + new DoFn, T>() { + @Override + public void processElement(ProcessContext c) { + for (T i : c.element()) { + c.output(i); + } + } + })) + .setCoder(elemCoder); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + FlattenPCollectionList.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + FlattenPCollectionList transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + FlattenPCollectionList transform, + DirectPipelineRunner.EvaluationContext context) { + List> outputElems = new ArrayList<>(); + PCollectionList inputs = context.getInput(transform); + + for (PCollection input : inputs.getAll()) { + outputElems.addAll(context.getPCollectionValuesWithMetadata(input)); + } + + context.setPCollectionValuesWithMetadata(context.getOutput(transform), outputElems); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/GroupByKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/GroupByKey.java new file mode 100644 index 000000000000..8fde3e086974 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/GroupByKey.java @@ -0,0 +1,575 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.ValueWithMetadata; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.InvalidWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowsViaOutputBufferDoFn; +import com.google.cloud.dataflow.sdk.util.ReifyTimestampAndWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.SystemReduceFn; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.WindowedValue.WindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * {@code GroupByKey} takes a {@code PCollection>}, + * groups the values by key and windows, and returns a + * {@code PCollection>>} representing a map from + * each distinct key and window of the input {@code PCollection} to an + * {@code Iterable} over all the values associated with that key in + * the input per window. Absent repeatedly-firing + * {@link Window#triggering triggering}, each key in the output + * {@code PCollection} is unique within each window. + * + *

    {@code GroupByKey} is analogous to converting a multi-map into + * a uni-map, and related to {@code GROUP BY} in SQL. It corresponds + * to the "shuffle" step between the Mapper and the Reducer in the + * MapReduce framework. + * + *

    Two keys of type {@code K} are compared for equality + * not by regular Java {@link Object#equals}, but instead by + * first encoding each of the keys using the {@code Coder} of the + * keys of the input {@code PCollection}, and then comparing the + * encoded bytes. This admits efficient parallel evaluation. Note that + * this requires that the {@code Coder} of the keys be deterministic (see + * {@link Coder#verifyDeterministic()}). If the key {@code Coder} is not + * deterministic, an exception is thrown at pipeline construction time. + * + *

    By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input, + * and the {@code Coder} of the elements of the {@code Iterable} + * values of the output {@code PCollection} is the same as the + * {@code Coder} of the values of the input. + * + *

    Example of use: + *

     {@code
    + * PCollection> urlDocPairs = ...;
    + * PCollection>> urlToDocs =
    + *     urlDocPairs.apply(GroupByKey.create());
    + * PCollection results =
    + *     urlToDocs.apply(ParDo.of(new DoFn>, R>() {
    + *       public void processElement(ProcessContext c) {
    + *         String url = c.element().getKey();
    + *         Iterable docsWithThatUrl = c.element().getValue();
    + *         ... process all docs having that url ...
    + *       }}));
    + * } 
    + * + *

    {@code GroupByKey} is a key primitive in data-parallel + * processing, since it is the main way to efficiently bring + * associated data together into one location. It is also a key + * determiner of the performance of a data-parallel pipeline. + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey} + * for a way to group multiple input PCollections by a common key at once. + * + *

    See {@link Combine.PerKey} for a common pattern of + * {@code GroupByKey} followed by {@link Combine.GroupedValues}. + * + *

    When grouping, windows that can be merged according to the {@link WindowFn} + * of the input {@code PCollection} will be merged together, and a window pane + * corresponding to the new, merged window will be created. The items in this pane + * will be emitted when a trigger fires. By default this will be when the input + * sources estimate there will be no more data for the window. See + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark} + * for details on the estimation. + * + *

    The timestamp for each emitted pane is determined by the + * {@link Window.Bound#withOutputTimeFn windowing operation}. + * The output {@code PCollection} will have the same {@link WindowFn} + * as the input. + * + *

    If the input {@code PCollection} contains late data (see + * {@link com.google.cloud.dataflow.sdk.io.PubsubIO.Read.Bound#timestampLabel} + * for an example of how this can occur) or the + * {@link Window#triggering requested TriggerFn} can fire before + * the watermark, then there may be multiple elements + * output by a {@code GroupByKey} that correspond to the same key and window. + * + *

    If the {@link WindowFn} of the input requires merging, it is not + * valid to apply another {@code GroupByKey} without first applying a new + * {@link WindowFn} or applying {@link Window#remerge()}. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * and the elements of the {@code Iterable}s in the output + * {@code PCollection} + */ +public class GroupByKey + extends PTransform>, + PCollection>>> { + + private final boolean fewKeys; + + private GroupByKey(boolean fewKeys) { + this.fewKeys = fewKeys; + } + + /** + * Returns a {@code GroupByKey} {@code PTransform}. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * and the elements of the {@code Iterable}s in the output + * {@code PCollection} + */ + public static GroupByKey create() { + return new GroupByKey<>(false); + } + + /** + * Returns a {@code GroupByKey} {@code PTransform}. + * + * @param the type of the keys of the input and output + * {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + * and the elements of the {@code Iterable}s in the output + * {@code PCollection} + * @param fewKeys whether it groups just few keys. + */ + static GroupByKey create(boolean fewKeys) { + return new GroupByKey<>(fewKeys); + } + + /** + * Returns whether it groups just few keys. + */ + public boolean fewKeys() { + return fewKeys; + } + + ///////////////////////////////////////////////////////////////////////////// + + public static void applicableTo(PCollection input) { + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + // Verify that the input PCollection is bounded, or that there is windowing/triggering being + // used. Without this, the watermark (at end of global window) will never be reached. + if (windowingStrategy.getWindowFn() instanceof GlobalWindows + && windowingStrategy.getTrigger().getSpec() instanceof DefaultTrigger + && input.isBounded() != IsBounded.BOUNDED) { + throw new IllegalStateException("GroupByKey cannot be applied to non-bounded PCollection in " + + "the GlobalWindow without a trigger. Use a Window.into or Window.triggering transform " + + "prior to GroupByKey."); + } + + // Validate the window merge function. + if (windowingStrategy.getWindowFn() instanceof InvalidWindows) { + String cause = ((InvalidWindows) windowingStrategy.getWindowFn()).getCause(); + throw new IllegalStateException( + "GroupByKey must have a valid Window merge function. " + + "Invalid because: " + cause); + } + } + + @Override + public void validate(PCollection> input) { + applicableTo(input); + + // Verify that the input Coder> is a KvCoder, and that + // the key coder is deterministic. + Coder keyCoder = getKeyCoder(input.getCoder()); + try { + keyCoder.verifyDeterministic(); + } catch (NonDeterministicException e) { + throw new IllegalStateException( + "the keyCoder of a GroupByKey must be deterministic", e); + } + } + + public WindowingStrategy updateWindowingStrategy(WindowingStrategy inputStrategy) { + WindowFn inputWindowFn = inputStrategy.getWindowFn(); + if (!inputWindowFn.isNonMerging()) { + // Prevent merging windows again, without explicit user + // involvement, e.g., by Window.into() or Window.remerge(). + inputWindowFn = new InvalidWindows<>( + "WindowFn has already been consumed by previous GroupByKey", inputWindowFn); + } + + // We also switch to the continuation trigger associated with the current trigger. + return inputStrategy + .withWindowFn(inputWindowFn) + .withTrigger(inputStrategy.getTrigger().getSpec().getContinuationTrigger()); + } + + @Override + public PCollection>> apply(PCollection> input) { + // This operation groups by the combination of key and window, + // merging windows as needed, using the windows assigned to the + // key/value input elements and the window merge operation of the + // window function associated with the input PCollection. + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + + // By default, implement GroupByKey[AndWindow] via a series of lower-level + // operations. + return input + // Make each input element's timestamp and assigned windows + // explicit, in the value part. + .apply(new ReifyTimestampsAndWindows()) + + // Group by just the key. + // Combiner lifting will not happen regardless of the disallowCombinerLifting value. + // There will be no combiners right after the GroupByKeyOnly because of the two ParDos + // introduced in here. + .apply(new GroupByKeyOnly>()) + + // Sort each key's values by timestamp. GroupAlsoByWindow requires + // its input to be sorted by timestamp. + .apply(new SortValuesByTimestamp()) + + // Group each key's values by window, merging windows as needed. + .apply(new GroupAlsoByWindow(windowingStrategy)) + + // And update the windowing strategy as appropriate. + .setWindowingStrategyInternal(updateWindowingStrategy(windowingStrategy)); + } + + @Override + protected Coder>> getDefaultOutputCoder(PCollection> input) { + return getOutputKvCoder(input.getCoder()); + } + + /** + * Returns the {@code Coder} of the input to this transform, which + * should be a {@code KvCoder}. + */ + @SuppressWarnings("unchecked") + static KvCoder getInputKvCoder(Coder> inputCoder) { + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "GroupByKey requires its input to use KvCoder"); + } + return (KvCoder) inputCoder; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns the {@code Coder} of the keys of the input to this + * transform, which is also used as the {@code Coder} of the keys of + * the output of this transform. + */ + static Coder getKeyCoder(Coder> inputCoder) { + return getInputKvCoder(inputCoder).getKeyCoder(); + } + + /** + * Returns the {@code Coder} of the values of the input to this transform. + */ + public static Coder getInputValueCoder(Coder> inputCoder) { + return getInputKvCoder(inputCoder).getValueCoder(); + } + + /** + * Returns the {@code Coder} of the {@code Iterable} values of the + * output of this transform. + */ + static Coder> getOutputValueCoder(Coder> inputCoder) { + return IterableCoder.of(getInputValueCoder(inputCoder)); + } + + /** + * Returns the {@code Coder} of the output of this transform. + */ + static KvCoder> getOutputKvCoder(Coder> inputCoder) { + return KvCoder.of(getKeyCoder(inputCoder), getOutputValueCoder(inputCoder)); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Helper transform that makes timestamps and window assignments + * explicit in the value part of each key/value pair. + */ + public static class ReifyTimestampsAndWindows + extends PTransform>, + PCollection>>> { + @Override + public PCollection>> apply( + PCollection> input) { + @SuppressWarnings("unchecked") + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder inputValueCoder = inputKvCoder.getValueCoder(); + Coder> outputValueCoder = FullWindowedValueCoder.of( + inputValueCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + Coder>> outputKvCoder = + KvCoder.of(keyCoder, outputValueCoder); + return input.apply(ParDo.of(new ReifyTimestampAndWindowsDoFn())) + .setCoder(outputKvCoder); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Helper transform that sorts the values associated with each key + * by timestamp. + */ + public static class SortValuesByTimestamp + extends PTransform>>>, + PCollection>>>> { + @Override + public PCollection>>> apply( + PCollection>>> input) { + return input.apply(ParDo.of( + new DoFn>>, + KV>>>() { + @Override + public void processElement(ProcessContext c) { + KV>> kvs = c.element(); + K key = kvs.getKey(); + Iterable> unsortedValues = kvs.getValue(); + List> sortedValues = new ArrayList<>(); + for (WindowedValue value : unsortedValues) { + sortedValues.add(value); + } + Collections.sort(sortedValues, + new Comparator>() { + @Override + public int compare(WindowedValue e1, WindowedValue e2) { + return e1.getTimestamp().compareTo(e2.getTimestamp()); + } + }); + c.output(KV.>>of(key, sortedValues)); + }})) + .setCoder(input.getCoder()); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Helper transform that takes a collection of timestamp-ordered + * values associated with each key, groups the values by window, + * combines windows as needed, and for each window in each key, + * outputs a collection of key/value-list pairs implicitly assigned + * to the window and with the timestamp derived from that window. + */ + public static class GroupAlsoByWindow + extends PTransform>>>, + PCollection>>> { + private final WindowingStrategy windowingStrategy; + + public GroupAlsoByWindow(WindowingStrategy windowingStrategy) { + this.windowingStrategy = windowingStrategy; + } + + @Override + @SuppressWarnings("unchecked") + public PCollection>> apply( + PCollection>>> input) { + @SuppressWarnings("unchecked") + KvCoder>> inputKvCoder = + (KvCoder>>) input.getCoder(); + + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder>> inputValueCoder = + inputKvCoder.getValueCoder(); + + IterableCoder> inputIterableValueCoder = + (IterableCoder>) inputValueCoder; + Coder> inputIterableElementCoder = + inputIterableValueCoder.getElemCoder(); + WindowedValueCoder inputIterableWindowedValueCoder = + (WindowedValueCoder) inputIterableElementCoder; + + Coder inputIterableElementValueCoder = + inputIterableWindowedValueCoder.getValueCoder(); + Coder> outputValueCoder = + IterableCoder.of(inputIterableElementValueCoder); + Coder>> outputKvCoder = KvCoder.of(keyCoder, outputValueCoder); + + return input + .apply(ParDo.of(groupAlsoByWindowsFn(windowingStrategy, inputIterableElementValueCoder))) + .setCoder(outputKvCoder); + } + + private GroupAlsoByWindowsViaOutputBufferDoFn, W> + groupAlsoByWindowsFn( + WindowingStrategy strategy, Coder inputIterableElementValueCoder) { + return new GroupAlsoByWindowsViaOutputBufferDoFn, W>( + strategy, SystemReduceFn.buffering(inputIterableElementValueCoder)); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Primitive helper transform that groups by key only, ignoring any + * window assignments. + */ + public static class GroupByKeyOnly + extends PTransform>, + PCollection>>> { + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + public PCollection>> apply(PCollection> input) { + return PCollection.>>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()); + } + + /** + * Returns the {@code Coder} of the input to this transform, which + * should be a {@code KvCoder}. + */ + @SuppressWarnings("unchecked") + KvCoder getInputKvCoder(Coder> inputCoder) { + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException( + "GroupByKey requires its input to use KvCoder"); + } + return (KvCoder) inputCoder; + } + + @Override + protected Coder>> getDefaultOutputCoder(PCollection> input) { + return GroupByKey.getOutputKvCoder(input.getCoder()); + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + static { + registerWithDirectPipelineRunner(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static void registerWithDirectPipelineRunner() { + DirectPipelineRunner.registerDefaultTransformEvaluator( + GroupByKeyOnly.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + GroupByKeyOnly transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateHelper(transform, context); + } + }); + } + + private static void evaluateHelper( + GroupByKeyOnly transform, + DirectPipelineRunner.EvaluationContext context) { + PCollection> input = context.getInput(transform); + + List>> inputElems = + context.getPCollectionValuesWithMetadata(input); + + Coder keyCoder = GroupByKey.getKeyCoder(input.getCoder()); + + Map, List> groupingMap = new HashMap<>(); + + for (ValueWithMetadata> elem : inputElems) { + K key = elem.getValue().getKey(); + V value = elem.getValue().getValue(); + byte[] encodedKey; + try { + encodedKey = encodeToByteArray(keyCoder, key); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + "unable to encode key " + key + " of input to " + transform + + " using " + keyCoder, + exn); + } + GroupingKey groupingKey = new GroupingKey<>(key, encodedKey); + List values = groupingMap.get(groupingKey); + if (values == null) { + values = new ArrayList(); + groupingMap.put(groupingKey, values); + } + values.add(value); + } + + List>>> outputElems = + new ArrayList<>(); + for (Map.Entry, List> entry : groupingMap.entrySet()) { + GroupingKey groupingKey = entry.getKey(); + K key = groupingKey.getKey(); + List values = entry.getValue(); + values = context.randomizeIfUnordered(values, true /* inPlaceAllowed */); + outputElems.add(ValueWithMetadata + .of(WindowedValue.valueInEmptyWindows(KV.>of(key, values))) + .withKey(key)); + } + + context.setPCollectionValuesWithMetadata(context.getOutput(transform), + outputElems); + } + + private static class GroupingKey { + private K key; + private byte[] encodedKey; + + public GroupingKey(K key, byte[] encodedKey) { + this.key = key; + this.encodedKey = encodedKey; + } + + public K getKey() { + return key; + } + + @Override + public boolean equals(Object o) { + if (o instanceof GroupingKey) { + GroupingKey that = (GroupingKey) o; + return Arrays.equals(this.encodedKey, that.encodedKey); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(encodedKey); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/IntraBundleParallelization.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/IntraBundleParallelization.java new file mode 100644 index 000000000000..b6497b71c4ef --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/IntraBundleParallelization.java @@ -0,0 +1,346 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; + +import org.joda.time.Instant; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Provides multi-threading of {@link DoFn}s, using threaded execution to + * process multiple elements concurrently within a bundle. + * + *

    Note, that each Dataflow worker will already process multiple bundles + * concurrently and usage of this class is meant only for cases where processing + * elements from within a bundle is limited by blocking calls. + * + *

    CPU intensive or IO intensive tasks are in general a poor fit for parallelization. + * This is because a limited resource that is already maximally utilized does not + * benefit from sub-division of work. The parallelization will increase the amount of time + * to process each element yet the throughput for processing will remain relatively the same. + * For example, if the local disk (an IO resource) has a maximum write rate of 10 MiB/s, + * and processing each element requires to write 20 MiBs to disk, then processing one element + * to disk will take 2 seconds. Yet processing 3 elements concurrently (each getting an equal + * share of the maximum write rate) will take at least 6 seconds to complete (there is additional + * overhead in the extra parallelization). + * + *

    To parallelize a {@link DoFn} to 10 threads: + *

    {@code
    + * PCollection data = ...;
    + * data.apply(
    + *   IntraBundleParallelization.of(new MyDoFn())
    + *                             .withMaxParallelism(10)));
    + * }
    + * + *

    An uncaught exception from the wrapped {@link DoFn} will result in the exception + * being rethrown in later calls to {@link MultiThreadedIntraBundleProcessingDoFn#processElement} + * or a call to {@link MultiThreadedIntraBundleProcessingDoFn#finishBundle}. + */ +public class IntraBundleParallelization { + /** + * Creates a {@link IntraBundleParallelization} {@link PTransform} for the given + * {@link DoFn} that processes elements using multiple threads. + * + *

    Note that the specified {@code doFn} needs to be thread safe. + */ + public static Bound of(DoFn doFn) { + return new Unbound().of(doFn); + } + + /** + * Creates a {@link IntraBundleParallelization} {@link PTransform} with the specified + * maximum concurrency level. + */ + public static Unbound withMaxParallelism(int maxParallelism) { + return new Unbound().withMaxParallelism(maxParallelism); + } + + /** + * An incomplete {@code IntraBundleParallelization} transform, with unbound input/output types. + * + *

    Before being applied, {@link IntraBundleParallelization.Unbound#of} must be + * invoked to specify the {@link DoFn} to invoke, which will also + * bind the input/output types of this {@code PTransform}. + */ + public static class Unbound { + private final int maxParallelism; + + Unbound() { + this(DEFAULT_MAX_PARALLELISM); + } + + Unbound(int maxParallelism) { + Preconditions.checkArgument(maxParallelism > 0, + "Expected parallelism factor greater than zero, received %s.", maxParallelism); + this.maxParallelism = maxParallelism; + } + + /** + * Returns a new {@link IntraBundleParallelization} {@link PTransform} like this one + * with the specified maximum concurrency level. + */ + public Unbound withMaxParallelism(int maxParallelism) { + return new Unbound(maxParallelism); + } + + /** + * Returns a new {@link IntraBundleParallelization} {@link PTransform} like this one + * with the specified {@link DoFn}. + * + *

    Note that the specified {@code doFn} needs to be thread safe. + */ + public Bound of(DoFn doFn) { + return new Bound<>(doFn, maxParallelism); + } + } + + /** + * A {@code PTransform} that, when applied to a {@code PCollection}, + * invokes a user-specified {@code DoFn} on all its elements, + * with all its outputs collected into an output + * {@code PCollection}. + * + *

    Note that the specified {@code doFn} needs to be thread safe. + * + * @param the type of the (main) input {@code PCollection} elements + * @param the type of the (main) output {@code PCollection} elements + */ + public static class Bound + extends PTransform, PCollection> { + private final DoFn doFn; + private final int maxParallelism; + + Bound(DoFn doFn, int maxParallelism) { + Preconditions.checkArgument(maxParallelism > 0, + "Expected parallelism factor greater than zero, received %s.", maxParallelism); + this.doFn = doFn; + this.maxParallelism = maxParallelism; + } + + /** + * Returns a new {@link IntraBundleParallelization} {@link PTransform} like this one + * with the specified maximum concurrency level. + */ + public Bound withMaxParallelism(int maxParallelism) { + return new Bound<>(doFn, maxParallelism); + } + + /** + * Returns a new {@link IntraBundleParallelization} {@link PTransform} like this one + * with the specified {@link DoFn}. + * + *

    Note that the specified {@code doFn} needs to be thread safe. + */ + public Bound + of(DoFn doFn) { + return new Bound<>(doFn, maxParallelism); + } + + @Override + public PCollection apply(PCollection input) { + return input.apply( + ParDo.of(new MultiThreadedIntraBundleProcessingDoFn<>(doFn, maxParallelism))); + } + } + + /** + * A multi-threaded {@code DoFn} wrapper. + * + * @see IntraBundleParallelization#of(DoFn) + * + * @param the type of the (main) input elements + * @param the type of the (main) output elements + */ + public static class MultiThreadedIntraBundleProcessingDoFn + extends DoFn { + + public MultiThreadedIntraBundleProcessingDoFn(DoFn doFn, int maxParallelism) { + Preconditions.checkArgument(maxParallelism > 0, + "Expected parallelism factor greater than zero, received %s.", maxParallelism); + this.doFn = doFn; + this.maxParallelism = maxParallelism; + } + + @Override + public void startBundle(Context c) throws Exception { + doFn.startBundle(c); + + executor = c.getPipelineOptions().as(GcsOptions.class).getExecutorService(); + workTickets = new Semaphore(maxParallelism); + failure = new AtomicReference<>(); + } + + @Override + public void processElement(final ProcessContext c) throws Exception { + try { + workTickets.acquire(); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while scheduling work", e); + } + + if (failure.get() != null) { + throw Throwables.propagate(failure.get()); + } + + executor.submit(new Runnable() { + @Override + public void run() { + try { + doFn.processElement(new WrappedContext(c)); + } catch (Throwable t) { + failure.compareAndSet(null, t); + Throwables.propagateIfPossible(t); + throw new AssertionError("Unexpected checked exception: " + t); + } finally { + workTickets.release(); + } + } + }); + } + + @Override + public void finishBundle(Context c) throws Exception { + // Acquire all the work tickets to guarantee that all the previous + // processElement calls have finished. + workTickets.acquire(maxParallelism); + if (failure.get() != null) { + throw Throwables.propagate(failure.get()); + } + doFn.finishBundle(c); + } + + @Override + protected TypeDescriptor getInputTypeDescriptor() { + return doFn.getInputTypeDescriptor(); + } + + @Override + protected TypeDescriptor getOutputTypeDescriptor() { + return doFn.getOutputTypeDescriptor(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Wraps a DoFn context, forcing single-thread output so that threads don't + * propagate through to downstream functions. + */ + private class WrappedContext extends ProcessContext { + private final ProcessContext context; + + WrappedContext(ProcessContext context) { + this.context = context; + } + + @Override + public InputT element() { + return context.element(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return context.getPipelineOptions(); + } + + @Override + public T sideInput(PCollectionView view) { + return context.sideInput(view); + } + + @Override + public void output(OutputT output) { + synchronized (MultiThreadedIntraBundleProcessingDoFn.this) { + context.output(output); + } + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + synchronized (MultiThreadedIntraBundleProcessingDoFn.this) { + context.outputWithTimestamp(output, timestamp); + } + } + + @Override + public void sideOutput(TupleTag tag, T output) { + synchronized (MultiThreadedIntraBundleProcessingDoFn.this) { + context.sideOutput(tag, output); + } + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + synchronized (MultiThreadedIntraBundleProcessingDoFn.this) { + context.sideOutputWithTimestamp(tag, output, timestamp); + } + } + + @Override + public Instant timestamp() { + return context.timestamp(); + } + + @Override + public BoundedWindow window() { + return context.window(); + } + + @Override + public PaneInfo pane() { + return context.pane(); + } + + @Override + public WindowingInternals windowingInternals() { + return context.windowingInternals(); + } + + @Override + protected Aggregator createAggregatorInternal( + String name, CombineFn combiner) { + return context.createAggregatorInternal(name, combiner); + } + } + + private final DoFn doFn; + private int maxParallelism; + + private transient ExecutorService executor; + private transient Semaphore workTickets; + private transient AtomicReference failure; + } + + /** + * Default maximum for number of concurrent elements to process. + */ + private static final int DEFAULT_MAX_PARALLELISM = 16; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Keys.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Keys.java new file mode 100644 index 000000000000..370d43dd0236 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Keys.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code Keys} takes a {@code PCollection} of {@code KV}s and + * returns a {@code PCollection} of the keys. + * + *

    Example of use: + *

     {@code
    + * PCollection> wordCounts = ...;
    + * PCollection words = wordCounts.apply(Keys.create());
    + * } 
    + * + *

    Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + *

    See also {@link Values}. + * + * @param the type of the keys in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ +public class Keys extends PTransform>, + PCollection> { + /** + * Returns a {@code Keys} {@code PTransform}. + * + * @param the type of the keys in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ + public static Keys create() { + return new Keys<>(); + } + + private Keys() { } + + @Override + public PCollection apply(PCollection> in) { + return + in.apply(ParDo.named("Keys") + .of(new DoFn, K>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey()); + } + })); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/KvSwap.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/KvSwap.java new file mode 100644 index 000000000000..5a9cc87bc294 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/KvSwap.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code KvSwap} takes a {@code PCollection>} and + * returns a {@code PCollection>}, where all the keys and + * values have been swapped. + * + *

    Example of use: + *

     {@code
    + * PCollection wordsToCounts = ...;
    + * PCollection countsToWords =
    + *     wordToCounts.apply(KvSwap.create());
    + * } 
    + * + *

    Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + * @param the type of the keys in the input {@code PCollection} + * and the values in the output {@code PCollection} + * @param the type of the values in the input {@code PCollection} + * and the keys in the output {@code PCollection} + */ +public class KvSwap extends PTransform>, + PCollection>> { + /** + * Returns a {@code KvSwap} {@code PTransform}. + * + * @param the type of the keys in the input {@code PCollection} + * and the values in the output {@code PCollection} + * @param the type of the values in the input {@code PCollection} + * and the keys in the output {@code PCollection} + */ + public static KvSwap create() { + return new KvSwap<>(); + } + + private KvSwap() { } + + @Override + public PCollection> apply(PCollection> in) { + return + in.apply(ParDo.named("KvSwap") + .of(new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + c.output(KV.of(e.getValue(), e.getKey())); + } + })); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/MapElements.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/MapElements.java new file mode 100644 index 000000000000..89970508645b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/MapElements.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +/** + * {@code PTransform}s for mapping a simple function over the elements of a {@link PCollection}. + */ +public class MapElements +extends PTransform, PCollection> { + + /** + * For a {@code SerializableFunction} {@code fn} and output type descriptor, + * returns a {@code PTransform} that takes an input {@code PCollection} and returns + * a {@code PCollection} containing {@code fn.apply(v)} for every element {@code v} in + * the input. + * + *

    Example of use in Java 8: + *

    {@code
    +   * PCollection wordLengths = words.apply(
    +   *     MapElements.via((String word) -> word.length())
    +   *         .withOutputType(new TypeDescriptor() {});
    +   * }
    + * + *

    In Java 7, the overload {@link #via(SimpleFunction)} is more concise as the output type + * descriptor need not be provided. + */ + public static MissingOutputTypeDescriptor + via(SerializableFunction fn) { + return new MissingOutputTypeDescriptor<>(fn); + } + + /** + * For a {@code SimpleFunction} {@code fn}, returns a {@code PTransform} that + * takes an input {@code PCollection} and returns a {@code PCollection} + * containing {@code fn.apply(v)} for every element {@code v} in the input. + * + *

    This overload is intended primarily for use in Java 7. In Java 8, the overload + * {@link #via(SerializableFunction)} supports use of lambda for greater concision. + * + *

    Example of use in Java 7: + *

    {@code
    +   * PCollection words = ...;
    +   * PCollection wordsPerLine = words.apply(MapElements.via(
    +   *     new SimpleFunction() {
    +   *       public Integer apply(String word) {
    +   *         return word.length();
    +   *       }
    +   *     }));
    +   * }
    + */ + public static MapElements + via(final SimpleFunction fn) { + return new MapElements<>(fn, fn.getOutputTypeDescriptor()); + } + + /** + * An intermediate builder for a {@link MapElements} transform. To complete the transform, provide + * an output type descriptor to {@link MissingOutputTypeDescriptor#withOutputType}. See + * {@link #via(SerializableFunction)} for a full example of use. + */ + public static final class MissingOutputTypeDescriptor { + + private final SerializableFunction fn; + + private MissingOutputTypeDescriptor(SerializableFunction fn) { + this.fn = fn; + } + + public MapElements withOutputType(TypeDescriptor outputType) { + return new MapElements<>(fn, outputType); + } + } + + /////////////////////////////////////////////////////////////////// + + private final SerializableFunction fn; + private final transient TypeDescriptor outputType; + + private MapElements( + SerializableFunction fn, + TypeDescriptor outputType) { + this.fn = fn; + this.outputType = outputType; + } + + @Override + public PCollection apply(PCollection input) { + return input.apply(ParDo.named("Map").of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(fn.apply(c.element())); + } + })).setTypeDescriptorInternal(outputType); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Max.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Max.java new file mode 100644 index 000000000000..8678e4f33eae --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Max.java @@ -0,0 +1,255 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind; +import com.google.cloud.dataflow.sdk.util.common.CounterProvider; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * {@code PTransform}s for computing the maximum of the elements in a {@code PCollection}, or the + * maximum of the values associated with each key in a {@code PCollection} of {@code KV}s. + * + *

    Example 1: get the maximum of a {@code PCollection} of {@code Double}s. + *

     {@code
    + * PCollection input = ...;
    + * PCollection max = input.apply(Max.doublesGlobally());
    + * } 
    + * + *

    Example 2: calculate the maximum of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

     {@code
    + * PCollection> input = ...;
    + * PCollection> maxPerKey = input
    + *     .apply(Max.integersPerKey());
    + * } 
    + */ +public class Max { + + private Max() { + // do not instantiate + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a + * {@code PCollection} whose contents is the maximum of the input {@code PCollection}'s + * elements, or {@code Integer.MIN_VALUE} if there are no elements. + */ + public static Combine.Globally integersGlobally() { + return Combine.globally(new MaxIntegerFn()).named("Max.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and + * returns a {@code PCollection>} that contains an output element mapping each + * distinct key in the input {@code PCollection} to the maximum of the values associated with that + * key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey integersPerKey() { + return Combine.perKey(new MaxIntegerFn()).named("Max.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a {@code + * PCollection} whose contents is the maximum of the input {@code PCollection}'s elements, + * or {@code Long.MIN_VALUE} if there are no elements. + */ + public static Combine.Globally longsGlobally() { + return Combine.globally(new MaxLongFn()).named("Max.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns a + * {@code PCollection>} that contains an output element mapping each distinct key in + * the input {@code PCollection} to the maximum of the values associated with that key in the + * input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey longsPerKey() { + return Combine.perKey(new MaxLongFn()).named("Max.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a + * {@code PCollection} whose contents is the maximum of the input {@code PCollection}'s + * elements, or {@code Double.NEGATIVE_INFINITY} if there are no elements. + */ + public static Combine.Globally doublesGlobally() { + return Combine.globally(new MaxDoubleFn()).named("Max.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns + * a {@code PCollection>} that contains an output element mapping each distinct key + * in the input {@code PCollection} to the maximum of the values associated with that key in the + * input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey doublesPerKey() { + return Combine.perKey(new MaxDoubleFn()).named("Max.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a {@code + * PCollection} whose contents is the maximum according to the natural ordering of {@code T} + * of the input {@code PCollection}'s elements, or {@code null} if there are no elements. + */ + public static > + Combine.Globally globally() { + return Combine.globally(MaxFn.naturalOrder()).named("Max.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns a + * {@code PCollection>} that contains an output element mapping each distinct key in the + * input {@code PCollection} to the maximum according to the natural ordering of {@code T} of the + * values associated with that key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static > + Combine.PerKey perKey() { + return Combine.perKey(MaxFn.naturalOrder()).named("Max.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a {@code + * PCollection} whose contents is the maximum of the input {@code PCollection}'s elements, or + * {@code null} if there are no elements. + */ + public static & Serializable> + Combine.Globally globally(ComparatorT comparator) { + return Combine.globally(MaxFn.of(comparator)).named("Max.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns a + * {@code PCollection>} that contains one output element per key mapping each + * to the maximum of the values associated with that key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static & Serializable> + Combine.PerKey perKey(ComparatorT comparator) { + return Combine.perKey(MaxFn.of(comparator)).named("Max.PerKey"); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code CombineFn} that computes the maximum of a collection of elements of type {@code T} + * using an arbitrary {@link Comparator}, useful as an argument to {@link Combine#globally} or + * {@link Combine#perKey}. + * + * @param the type of the values being compared + */ + public static class MaxFn extends BinaryCombineFn { + + private final T identity; + private final Comparator comparator; + + private & Serializable> MaxFn( + T identity, ComparatorT comparator) { + this.identity = identity; + this.comparator = comparator; + } + + public static & Serializable> + MaxFn of(T identity, ComparatorT comparator) { + return new MaxFn(identity, comparator); + } + + public static & Serializable> + MaxFn of(ComparatorT comparator) { + return new MaxFn(null, comparator); + } + + public static > MaxFn naturalOrder(T identity) { + return new MaxFn(identity, new Top.Largest()); + } + + public static > MaxFn naturalOrder() { + return new MaxFn(null, new Top.Largest()); + } + + @Override + public T identity() { + return identity; + } + + @Override + public T apply(T left, T right) { + return comparator.compare(left, right) >= 0 ? left : right; + } + } + + /** + * A {@code CombineFn} that computes the maximum of a collection of {@code Integer}s, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MaxIntegerFn extends MaxFn implements + CounterProvider { + public MaxIntegerFn() { + super(Integer.MIN_VALUE, new Top.Largest()); + } + + @Override + public Counter getCounter(String name) { + return Counter.ints(name, AggregationKind.MAX); + } + } + + /** + * A {@code CombineFn} that computes the maximum of a collection of {@code Long}s, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MaxLongFn extends MaxFn implements + CounterProvider { + public MaxLongFn() { + super(Long.MIN_VALUE, new Top.Largest()); + } + + @Override + public Counter getCounter(String name) { + return Counter.longs(name, AggregationKind.MAX); + } + } + + /** + * A {@code CombineFn} that computes the maximum of a collection of {@code Double}s, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MaxDoubleFn extends MaxFn implements + CounterProvider { + public MaxDoubleFn() { + super(Double.NEGATIVE_INFINITY, new Top.Largest()); + } + + @Override + public Counter getCounter(String name) { + return Counter.doubles(name, AggregationKind.MAX); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Mean.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Mean.java new file mode 100644 index 000000000000..7dccfb626bd6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Mean.java @@ -0,0 +1,202 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn.Accumulator; +import com.google.common.base.MoreObjects; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Objects; + +/** + * {@code PTransform}s for computing the arithmetic mean + * (a.k.a. average) of the elements in a {@code PCollection}, or the + * mean of the values associated with each key in a + * {@code PCollection} of {@code KV}s. + * + *

    Example 1: get the mean of a {@code PCollection} of {@code Long}s. + *

     {@code
    + * PCollection input = ...;
    + * PCollection mean = input.apply(Mean.globally());
    + * } 
    + * + *

    Example 2: calculate the mean of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

     {@code
    + * PCollection> input = ...;
    + * PCollection> meanPerKey =
    + *     input.apply(Mean.perKey());
    + * } 
    + */ +public class Mean { + + private Mean() { } // Namespace only + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the mean of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + * + * @param the type of the {@code Number}s being combined + */ + public static Combine.Globally globally() { + return Combine.globally(new MeanFn<>()).named("Mean.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the mean of the values associated with + * that key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and bucketing. + * + * @param the type of the keys + * @param the type of the {@code Number}s being combined + */ + public static Combine.PerKey perKey() { + return Combine.perKey(new MeanFn<>()).named("Mean.PerKey"); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code Combine.CombineFn} that computes the arithmetic mean + * (a.k.a. average) of an {@code Iterable} of numbers of type + * {@code N}, useful as an argument to {@link Combine#globally} or + * {@link Combine#perKey}. + * + *

    Returns {@code Double.NaN} if combining zero elements. + * + * @param the type of the {@code Number}s being combined + */ + static class MeanFn + extends Combine.AccumulatingCombineFn, Double> { + /** + * Constructs a combining function that computes the mean over + * a collection of values of type {@code N}. + */ + public MeanFn() {} + + @Override + public CountSum createAccumulator() { + return new CountSum<>(); + } + + @Override + public Coder> getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return new CountSumCoder<>(); + } + } + + /** + * Accumulator class for {@link MeanFn}. + */ + static class CountSum + implements Accumulator, Double> { + + long count = 0; + double sum = 0.0; + + public CountSum() { + this(0, 0); + } + + public CountSum(long count, double sum) { + this.count = count; + this.sum = sum; + } + + @Override + public void addInput(NumT element) { + count++; + sum += element.doubleValue(); + } + + @Override + public void mergeAccumulator(CountSum accumulator) { + count += accumulator.count; + sum += accumulator.sum; + } + + @Override + public Double extractOutput() { + return count == 0 ? Double.NaN : sum / count; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof CountSum)) { + return false; + } + @SuppressWarnings("unchecked") + CountSum otherCountSum = (CountSum) other; + return (count == otherCountSum.count) + && (sum == otherCountSum.sum); + } + + @Override + public int hashCode() { + return Objects.hash(count, sum); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("count", count) + .add("sum", sum) + .toString(); + } + } + + static class CountSumCoder + extends AtomicCoder> { + private static final Coder LONG_CODER = BigEndianLongCoder.of(); + private static final Coder DOUBLE_CODER = DoubleCoder.of(); + + @Override + public void encode(CountSum value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + LONG_CODER.encode(value.count, outStream, nestedContext); + DOUBLE_CODER.encode(value.sum, outStream, nestedContext); + } + + @Override + public CountSum decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + return new CountSum<>( + LONG_CODER.decode(inStream, nestedContext), + DOUBLE_CODER.decode(inStream, nestedContext)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Min.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Min.java new file mode 100644 index 000000000000..47ab3a0ad27d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Min.java @@ -0,0 +1,255 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind; +import com.google.cloud.dataflow.sdk.util.common.CounterProvider; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * {@code PTransform}s for computing the minimum of the elements in a {@code PCollection}, or the + * minimum of the values associated with each key in a {@code PCollection} of {@code KV}s. + * + *

    Example 1: get the minimum of a {@code PCollection} of {@code Double}s. + *

     {@code
    + * PCollection input = ...;
    + * PCollection min = input.apply(Min.doublesGlobally());
    + * } 
    + * + *

    Example 2: calculate the minimum of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

     {@code
    + * PCollection> input = ...;
    + * PCollection> minPerKey = input
    + *     .apply(Min.integersPerKey());
    + * } 
    + */ +public class Min { + + private Min() { + // do not instantiate + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a + * {@code PCollection} whose contents is a single value that is the minimum of the input + * {@code PCollection}'s elements, or {@code Integer.MAX_VALUE} if there are no elements. + */ + public static Combine.Globally integersGlobally() { + return Combine.globally(new MinIntegerFn()).named("Min.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and + * returns a {@code PCollection>} that contains an output element mapping each + * distinct key in the input {@code PCollection} to the minimum of the values associated with that + * key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey integersPerKey() { + return Combine.perKey(new MinIntegerFn()).named("Min.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a {@code + * PCollection} whose contents is the minimum of the input {@code PCollection}'s elements, + * or {@code Long.MAX_VALUE} if there are no elements. + */ + public static Combine.Globally longsGlobally() { + return Combine.globally(new MinLongFn()).named("Min.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns a + * {@code PCollection>} that contains an output element mapping each distinct key in + * the input {@code PCollection} to the minimum of the values associated with that key in the + * input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey longsPerKey() { + return Combine.perKey(new MinLongFn()).named("Min.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a + * {@code PCollection} whose contents is the minimum of the input {@code PCollection}'s + * elements, or {@code Double.POSITIVE_INFINITY} if there are no elements. + */ + public static Combine.Globally doublesGlobally() { + return Combine.globally(new MinDoubleFn()).named("Min.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns + * a {@code PCollection>} that contains an output element mapping each distinct key + * in the input {@code PCollection} to the minimum of the values associated with that key in the + * input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static Combine.PerKey doublesPerKey() { + return Combine.perKey(new MinDoubleFn()).named("Min.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a {@code + * PCollection} whose contents is the minimum according to the natural ordering of {@code T} + * of the input {@code PCollection}'s elements, or {@code null} if there are no elements. + */ + public static > + Combine.Globally globally() { + return Combine.globally(MinFn.naturalOrder()).named("Min.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns a + * {@code PCollection>} that contains an output element mapping each distinct key in the + * input {@code PCollection} to the minimum according to the natural ordering of {@code T} of the + * values associated with that key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static > + Combine.PerKey perKey() { + return Combine.perKey(MinFn.naturalOrder()).named("Min.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection} and returns a {@code + * PCollection} whose contents is the minimum of the input {@code PCollection}'s elements, or + * {@code null} if there are no elements. + */ + public static & Serializable> + Combine.Globally globally(ComparatorT comparator) { + return Combine.globally(MinFn.of(comparator)).named("Min.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input {@code PCollection>} and returns a + * {@code PCollection>} that contains one output element per key mapping each + * to the minimum of the values associated with that key in the input {@code PCollection}. + * + *

    See {@link Combine.PerKey} for how this affects timestamps and windowing. + */ + public static & Serializable> + Combine.PerKey perKey(ComparatorT comparator) { + return Combine.perKey(MinFn.of(comparator)).named("Min.PerKey"); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code CombineFn} that computes the maximum of a collection of elements of type {@code T} + * using an arbitrary {@link Comparator}, useful as an argument to {@link Combine#globally} or + * {@link Combine#perKey}. + * + * @param the type of the values being compared + */ + public static class MinFn extends BinaryCombineFn { + + private final T identity; + private final Comparator comparator; + + private & Serializable> MinFn( + T identity, ComparatorT comparator) { + this.identity = identity; + this.comparator = comparator; + } + + public static & Serializable> + MinFn of(T identity, ComparatorT comparator) { + return new MinFn(identity, comparator); + } + + public static & Serializable> + MinFn of(ComparatorT comparator) { + return new MinFn(null, comparator); + } + + public static > MinFn naturalOrder(T identity) { + return new MinFn(identity, new Top.Largest()); + } + + public static > MinFn naturalOrder() { + return new MinFn(null, new Top.Largest()); + } + + @Override + public T identity() { + return identity; + } + + @Override + public T apply(T left, T right) { + return comparator.compare(left, right) <= 0 ? left : right; + } + } + + /** + * A {@code CombineFn} that computes the minimum of a collection of {@code Integer}s, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MinIntegerFn extends MinFn implements + CounterProvider { + public MinIntegerFn() { + super(Integer.MAX_VALUE, new Top.Largest()); + } + + @Override + public Counter getCounter(String name) { + return Counter.ints(name, AggregationKind.MIN); + } + } + + /** + * A {@code CombineFn} that computes the minimum of a collection of {@code Long}s, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MinLongFn extends MinFn implements + CounterProvider { + public MinLongFn() { + super(Long.MAX_VALUE, new Top.Largest()); + } + + @Override + public Counter getCounter(String name) { + return Counter.longs(name, AggregationKind.MIN); + } + } + + /** + * A {@code CombineFn} that computes the minimum of a collection of {@code Double}s, useful as an + * argument to {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class MinDoubleFn extends MinFn implements + CounterProvider { + public MinDoubleFn() { + super(Double.POSITIVE_INFINITY, new Top.Largest()); + } + + @Override + public Counter getCounter(String name) { + return Counter.doubles(name, AggregationKind.MIN); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java new file mode 100644 index 000000000000..8a7450997aab --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.TypedPValue; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +/** + * A {@code PTransform} is an operation that takes an + * {@code InputT} (some subtype of {@link PInput}) and produces an + * {@code OutputT} (some subtype of {@link POutput}). + * + *

    Common PTransforms include root PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read}, + * {@link Create}, processing and + * conversion operations like {@link ParDo}, + * {@link GroupByKey}, + * {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey}, + * {@link Combine}, and {@link Count}, and outputting + * PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Write}. Users also + * define their own application-specific composite PTransforms. + * + *

    Each {@code PTransform} has a single + * {@code InputT} type and a single {@code OutputT} type. Many + * PTransforms conceptually transform one input value to one output + * value, and in this case {@code InputT} and {@code Output} are + * typically instances of + * {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * A root + * PTransform conceptually has no input; in this case, conventionally + * a {@link com.google.cloud.dataflow.sdk.values.PBegin} object + * produced by calling {@link Pipeline#begin} is used as the input. + * An outputting PTransform conceptually has no output; in this case, + * conventionally {@link com.google.cloud.dataflow.sdk.values.PDone} + * is used as its output type. Some PTransforms conceptually have + * multiple inputs and/or outputs; in these cases special "bundling" + * classes like + * {@link com.google.cloud.dataflow.sdk.values.PCollectionList}, + * {@link com.google.cloud.dataflow.sdk.values.PCollectionTuple} + * are used + * to combine multiple values into a single bundle for passing into or + * returning from the PTransform. + * + *

    A {@code PTransform} is invoked by calling + * {@code apply()} on its {@code InputT}, returning its {@code OutputT}. + * Calls can be chained to concisely create linear pipeline segments. + * For example: + * + *

     {@code
    + * PCollection pc1 = ...;
    + * PCollection pc2 =
    + *     pc1.apply(ParDo.of(new MyDoFn>()))
    + *        .apply(GroupByKey.create())
    + *        .apply(Combine.perKey(new MyKeyedCombineFn()))
    + *        .apply(ParDo.of(new MyDoFn2,T2>()));
    + * } 
    + * + *

    PTransform operations have unique names, which are used by the + * system when explaining what's going on during optimization and + * execution. Each PTransform gets a system-provided default name, + * but it's a good practice to specify an explicit name, where + * possible, using the {@code named()} method offered by some + * PTransforms such as {@link ParDo}. For example: + * + *

     {@code
    + * ...
    + * .apply(ParDo.named("Step1").of(new MyDoFn3()))
    + * ...
    + * } 
    + * + *

    Each PCollection output produced by a PTransform, + * either directly or within a "bundling" class, automatically gets + * its own name derived from the name of its producing PTransform. + * + *

    Each PCollection output produced by a PTransform + * also records a {@link com.google.cloud.dataflow.sdk.coders.Coder} + * that specifies how the elements of that PCollection + * are to be encoded as a byte string, if necessary. The + * PTransform may provide a default Coder for any of its outputs, for + * instance by deriving it from the PTransform input's Coder. If the + * PTransform does not specify the Coder for an output PCollection, + * the system will attempt to infer a Coder for it, based on + * what's known at run-time about the Java type of the output's + * elements. The enclosing {@link Pipeline}'s + * {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry} + * (accessible via {@link Pipeline#getCoderRegistry}) defines the + * mapping from Java types to the default Coder to use, for a standard + * set of Java types; users can extend this mapping for additional + * types, via + * {@link com.google.cloud.dataflow.sdk.coders.CoderRegistry#registerCoder}. + * If this inference process fails, either because the Java type was + * not known at run-time (e.g., due to Java's "erasure" of generic + * types) or there was no default Coder registered, then the Coder + * should be specified manually by calling + * {@link com.google.cloud.dataflow.sdk.values.TypedPValue#setCoder} + * on the output PCollection. The Coder of every output + * PCollection must be determined one way or another + * before that output is used as an input to another PTransform, or + * before the enclosing Pipeline is run. + * + *

    A small number of PTransforms are implemented natively by the + * Google Cloud Dataflow SDK; such PTransforms simply return an + * output value as their apply implementation. + * The majority of PTransforms are + * implemented as composites of other PTransforms. Such a PTransform + * subclass typically just implements {@link #apply}, computing its + * Output value from its {@code InputT} value. User programs are encouraged to + * use this mechanism to modularize their own code. Such composite + * abstractions get their own name, and navigating through the + * composition hierarchy of PTransforms is supported by the monitoring + * interface. Examples of composite PTransforms can be found in this + * directory and in examples. From the caller's point of view, there + * is no distinction between a PTransform implemented natively and one + * implemented in terms of other PTransforms; both kinds of PTransform + * are invoked in the same way, using {@code apply()}. + * + *

    Note on Serialization

    + * + *

    {@code PTransform} doesn't actually support serialization, despite + * implementing {@code Serializable}. + * + *

    {@code PTransform} is marked {@code Serializable} solely + * because it is common for an anonymous {@code DoFn}, + * instance to be created within an + * {@code apply()} method of a composite {@code PTransform}. + * + *

    Each of those {@code *Fn}s is {@code Serializable}, but + * unfortunately its instance state will contain a reference to the + * enclosing {@code PTransform} instance, and so attempt to serialize + * the {@code PTransform} instance, even though the {@code *Fn} + * instance never references anything about the enclosing + * {@code PTransform}. + * + *

    To allow such anonymous {@code *Fn}s to be written + * conveniently, {@code PTransform} is marked as {@code Serializable}, + * and includes dummy {@code writeObject()} and {@code readObject()} + * operations that do not save or restore any state. + * + * @see Applying Transformations + * + * @param the type of the input to this PTransform + * @param the type of the output of this PTransform + */ +public abstract class PTransform + implements Serializable /* See the note above */ { + /** + * Applies this {@code PTransform} on the given {@code InputT}, and returns its + * {@code Output}. + * + *

    Composite transforms, which are defined in terms of other transforms, + * should return the output of one of the composed transforms. Non-composite + * transforms, which do not apply any transforms internally, should return + * a new unbound output and register evaluators (via backend-specific + * registration methods). + * + *

    The default implementation throws an exception. A derived class must + * either implement apply, or else each runner must supply a custom + * implementation via + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner#apply}. + */ + public OutputT apply(InputT input) { + throw new IllegalArgumentException( + "Runner " + input.getPipeline().getRunner() + + " has not registered an implementation for the required primitive operation " + + this); + } + + /** + * Called before invoking apply (which may be intercepted by the runner) to + * verify this transform is fully specified and applicable to the specified + * input. + * + *

    By default, does nothing. + */ + public void validate(InputT input) { } + + /** + * Returns the transform name. + * + *

    This name is provided by the transform creator and is not required to be unique. + */ + public String getName() { + return name != null ? name : getKindString(); + } + + ///////////////////////////////////////////////////////////////////////////// + + // See the note about about PTransform's fake Serializability, to + // understand why all of its instance state is transient. + + /** + * The base name of this {@code PTransform}, e.g., from + * {@link ParDo#named(String)}, or from defaults, or {@code null} if not + * yet assigned. + */ + protected final transient String name; + + protected PTransform() { + this.name = null; + } + + protected PTransform(String name) { + this.name = name; + } + + @Override + public String toString() { + if (name == null) { + return getKindString(); + } else { + return getName() + " [" + getKindString() + "]"; + } + } + + /** + * Returns the name to use by default for this {@code PTransform} + * (not including the names of any enclosing {@code PTransform}s). + * + *

    By default, returns the base name of this {@code PTransform}'s class. + * + *

    The caller is responsible for ensuring that names of applied + * {@code PTransform}s are unique, e.g., by adding a uniquifying + * suffix when needed. + */ + protected String getKindString() { + if (getClass().isAnonymousClass()) { + return "AnonymousTransform"; + } else { + return StringUtils.approximatePTransformName(getClass()); + } + } + + private void writeObject(ObjectOutputStream oos) { + // We don't really want to be serializing this object, but we + // often have serializable anonymous DoFns nested within a + // PTransform. + } + + private void readObject(ObjectInputStream oos) { + // We don't really want to be serializing this object, but we + // often have serializable anonymous DoFns nested within a + // PTransform. + } + + /** + * Returns the default {@code Coder} to use for the output of this + * single-output {@code PTransform}. + * + *

    By default, always throws + * + * @throws CannotProvideCoderException if no coder can be inferred + */ + protected Coder getDefaultOutputCoder() throws CannotProvideCoderException { + throw new CannotProvideCoderException( + "PTransform.getDefaultOutputCoder called."); + } + + /** + * Returns the default {@code Coder} to use for the output of this + * single-output {@code PTransform} when applied to the given input. + * + * @throws CannotProvideCoderException if none can be inferred. + * + *

    By default, always throws. + */ + protected Coder getDefaultOutputCoder(@SuppressWarnings("unused") InputT input) + throws CannotProvideCoderException { + return getDefaultOutputCoder(); + } + + /** + * Returns the default {@code Coder} to use for the given output of + * this single-output {@code PTransform} when applied to the given input. + * + * @throws CannotProvideCoderException if none can be inferred. + * + *

    By default, always throws. + */ + public Coder getDefaultOutputCoder( + InputT input, @SuppressWarnings("unused") TypedPValue output) + throws CannotProvideCoderException { + @SuppressWarnings("unchecked") + Coder defaultOutputCoder = (Coder) getDefaultOutputCoder(input); + return defaultOutputCoder; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java new file mode 100644 index 000000000000..0922767adc82 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java @@ -0,0 +1,1308 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext; +import com.google.cloud.dataflow.sdk.util.DirectSideInputReader; +import com.google.cloud.dataflow.sdk.util.DoFnRunner; +import com.google.cloud.dataflow.sdk.util.DoFnRunnerBase; +import com.google.cloud.dataflow.sdk.util.DoFnRunners; +import com.google.cloud.dataflow.sdk.util.IllegalMutationException; +import com.google.cloud.dataflow.sdk.util.MutationDetector; +import com.google.cloud.dataflow.sdk.util.MutationDetectors; +import com.google.cloud.dataflow.sdk.util.PTuple; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.StringUtils; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nullable; + +/** + * {@link ParDo} is the core element-wise transform in Google Cloud + * Dataflow, invoking a user-specified function on each of the elements of the input + * {@link PCollection} to produce zero or more output elements, all + * of which are collected into the output {@link PCollection}. + * + *

    Elements are processed independently, and possibly in parallel across + * distributed cloud resources. + * + *

    The {@link ParDo} processing style is similar to what happens inside + * the "Mapper" or "Reducer" class of a MapReduce-style algorithm. + * + *

    {@link DoFn DoFns}

    + * + *

    The function to use to process each element is specified by a + * {@link DoFn DoFn<InputT, OutputT>}, primarily via its + * {@link DoFn#processElement processElement} method. The {@link DoFn} may also + * override the default implementations of {@link DoFn#startBundle startBundle} + * and {@link DoFn#finishBundle finishBundle}. + * + *

    Conceptually, when a {@link ParDo} transform is executed, the + * elements of the input {@link PCollection} are first divided up + * into some number of "bundles". These are farmed off to distributed + * worker machines (or run locally, if using the {@link DirectPipelineRunner}). + * For each bundle of input elements processing proceeds as follows: + * + *

      + *
    1. A fresh instance of the argument {@link DoFn} is created on a worker. This may + * be through deserialization or other means. If the {@link DoFn} subclass + * does not override {@link DoFn#startBundle startBundle} or + * {@link DoFn#finishBundle finishBundle} then this may be optimized since + * it cannot observe the start and end of a bundle.
    2. + *
    3. The {@link DoFn DoFn's} {@link DoFn#startBundle} method is called to + * initialize it. If this method is not overridden, the call may be optimized + * away.
    4. + *
    5. The {@link DoFn DoFn's} {@link DoFn#processElement} method + * is called on each of the input elements in the bundle.
    6. + *
    7. The {@link DoFn DoFn's} {@link DoFn#finishBundle} method is called + * to complete its work. After {@link DoFn#finishBundle} is called, the + * framework will never again invoke any of these three processing methods. + * If this method is not overridden, this call may be optimized away.
    8. + *
    + * + * Each of the calls to any of the {@link DoFn DoFn's} processing + * methods can produce zero or more output elements. All of the + * of output elements from all of the {@link DoFn} instances + * are included in the output {@link PCollection}. + * + *

    For example: + * + *

     {@code
    + * PCollection lines = ...;
    + * PCollection words =
    + *     lines.apply(ParDo.of(new DoFn() {
    + *         public void processElement(ProcessContext c) {
    + *           String line = c.element();
    + *           for (String word : line.split("[^a-zA-Z']+")) {
    + *             c.output(word);
    + *           }
    + *         }}));
    + * PCollection wordLengths =
    + *     words.apply(ParDo.of(new DoFn() {
    + *         public void processElement(ProcessContext c) {
    + *           String word = c.element();
    + *           Integer length = word.length();
    + *           c.output(length);
    + *         }}));
    + * } 
    + * + *

    Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same {@link WindowFn} associated with it as the input. + * + *

    Naming {@link ParDo ParDo} transforms

    + * + *

    The name of a transform is used to provide a name for any node in the + * {@link Pipeline} graph resulting from application of the transform. + * It is best practice to provide a name at the time of application, + * via {@link PCollection#apply(String, PTransform)}. Otherwise, + * a unique name - which may not be stable across pipeline revision - + * will be generated, based on the transform name. + * + *

    If a {@link ParDo} is applied exactly once inlined, then + * it can be given a name via {@link #named}. For example: + * + *

     {@code
    + * PCollection words =
    + *     lines.apply(ParDo.named("ExtractWords")
    + *                      .of(new DoFn() { ... }));
    + * PCollection wordLengths =
    + *     words.apply(ParDo.named("ComputeWordLengths")
    + *                      .of(new DoFn() { ... }));
    + * } 
    + * + *

    Side Inputs

    + * + *

    While a {@link ParDo} processes elements from a single "main input" + * {@link PCollection}, it can take additional "side input" + * {@link PCollectionView PCollectionViews}. These side input + * {@link PCollectionView PCollectionViews} express styles of accessing + * {@link PCollection PCollections} computed by earlier pipeline operations, + * passed in to the {@link ParDo} transform using + * {@link #withSideInputs}, and their contents accessible to each of + * the {@link DoFn} operations via {@link DoFn.ProcessContext#sideInput sideInput}. + * For example: + * + *

     {@code
    + * PCollection words = ...;
    + * PCollection maxWordLengthCutOff = ...; // Singleton PCollection
    + * final PCollectionView maxWordLengthCutOffView =
    + *     maxWordLengthCutOff.apply(View.asSingleton());
    + * PCollection wordsBelowCutOff =
    + *     words.apply(ParDo.withSideInputs(maxWordLengthCutOffView)
    + *                      .of(new DoFn() {
    + *         public void processElement(ProcessContext c) {
    + *           String word = c.element();
    + *           int lengthCutOff = c.sideInput(maxWordLengthCutOffView);
    + *           if (word.length() <= lengthCutOff) {
    + *             c.output(word);
    + *           }
    + *         }}));
    + * } 
    + * + *

    Side Outputs

    + * + *

    Optionally, a {@link ParDo} transform can produce multiple + * output {@link PCollection PCollections}, both a "main output" + * {@code PCollection} plus any number of "side output" + * {@link PCollection PCollections}, each keyed by a distinct {@link TupleTag}, + * and bundled in a {@link PCollectionTuple}. The {@link TupleTag TupleTags} + * to be used for the output {@link PCollectionTuple} are specified by + * invoking {@link #withOutputTags}. Unconsumed side outputs do not + * necessarily need to be explicitly specified, even if the {@link DoFn} + * generates them. Within the {@link DoFn}, an element is added to the + * main output {@link PCollection} as normal, using + * {@link DoFn.Context#output}, while an element is added to a side output + * {@link PCollection} using {@link DoFn.Context#sideOutput}. For example: + * + *

     {@code
    + * PCollection words = ...;
    + * // Select words whose length is below a cut off,
    + * // plus the lengths of words that are above the cut off.
    + * // Also select words starting with "MARKER".
    + * final int wordLengthCutOff = 10;
    + * // Create tags to use for the main and side outputs.
    + * final TupleTag wordsBelowCutOffTag =
    + *     new TupleTag(){};
    + * final TupleTag wordLengthsAboveCutOffTag =
    + *     new TupleTag(){};
    + * final TupleTag markedWordsTag =
    + *     new TupleTag(){};
    + * PCollectionTuple results =
    + *     words.apply(
    + *         ParDo
    + *         // Specify the main and consumed side output tags of the
    + *         // PCollectionTuple result:
    + *         .withOutputTags(wordsBelowCutOffTag,
    + *                         TupleTagList.of(wordLengthsAboveCutOffTag)
    + *                                     .and(markedWordsTag))
    + *         .of(new DoFn() {
    + *             // Create a tag for the unconsumed side output.
    + *             final TupleTag specialWordsTag =
    + *                 new TupleTag(){};
    + *             public void processElement(ProcessContext c) {
    + *               String word = c.element();
    + *               if (word.length() <= wordLengthCutOff) {
    + *                 // Emit this short word to the main output.
    + *                 c.output(word);
    + *               } else {
    + *                 // Emit this long word's length to a side output.
    + *                 c.sideOutput(wordLengthsAboveCutOffTag, word.length());
    + *               }
    + *               if (word.startsWith("MARKER")) {
    + *                 // Emit this word to a different side output.
    + *                 c.sideOutput(markedWordsTag, word);
    + *               }
    + *               if (word.startsWith("SPECIAL")) {
    + *                 // Emit this word to the unconsumed side output.
    + *                 c.sideOutput(specialWordsTag, word);
    + *               }
    + *             }}));
    + * // Extract the PCollection results, by tag.
    + * PCollection wordsBelowCutOff =
    + *     results.get(wordsBelowCutOffTag);
    + * PCollection wordLengthsAboveCutOff =
    + *     results.get(wordLengthsAboveCutOffTag);
    + * PCollection markedWords =
    + *     results.get(markedWordsTag);
    + * } 
    + * + *

    Properties May Be Specified In Any Order

    + * + *

    Several properties can be specified for a {@link ParDo} + * {@link PTransform}, including name, side inputs, side output tags, + * and {@link DoFn} to invoke. Only the {@link DoFn} is required; the + * name is encouraged but not required, and side inputs and side + * output tags are only specified when they're needed. These + * properties can be specified in any order, as long as they're + * specified before the {@link ParDo} {@link PTransform} is applied. + * + *

    The approach used to allow these properties to be specified in + * any order, with some properties omitted, is to have each of the + * property "setter" methods defined as static factory methods on + * {@link ParDo} itself, which return an instance of either + * {@link ParDo.Unbound} or + * {@link ParDo.Bound} nested classes, each of which offer + * property setter instance methods to enable setting additional + * properties. {@link ParDo.Bound} is used for {@link ParDo} + * transforms whose {@link DoFn} is specified and whose input and + * output static types have been bound. {@link ParDo.Unbound ParDo.Unbound} is used + * for {@link ParDo} transforms that have not yet had their + * {@link DoFn} specified. Only {@link ParDo.Bound} instances can be + * applied. + * + *

    Another benefit of this approach is that it reduces the number + * of type parameters that need to be specified manually. In + * particular, the input and output types of the {@link ParDo} + * {@link PTransform} are inferred automatically from the type + * parameters of the {@link DoFn} argument passed to {@link ParDo#of}. + * + *

    Output Coders

    + * + *

    By default, the {@link Coder Coder<OutputT>} for the + * elements of the main output {@link PCollection PCollection<OutputT>} is + * inferred from the concrete type of the {@link DoFn DoFn<InputT, OutputT>}. + * + *

    By default, the {@link Coder Coder<SideOutputT>} for the elements of + * a side output {@link PCollection PCollection<SideOutputT>} is inferred + * from the concrete type of the corresponding {@link TupleTag TupleTag<SideOutputT>}. + * To be successful, the {@link TupleTag} should be created as an instance + * of a trivial anonymous subclass, with {@code {}} suffixed to the + * constructor call. Such uses block Java's generic type parameter + * inference, so the {@code } argument must be provided explicitly. + * For example: + *

     {@code
    + * // A TupleTag to use for a side input can be written concisely:
    + * final TupleTag sideInputag = new TupleTag<>();
    + * // A TupleTag to use for a side output should be written with "{}",
    + * // and explicit generic parameter type:
    + * final TupleTag sideOutputTag = new TupleTag(){};
    + * } 
    + * This style of {@code TupleTag} instantiation is used in the example of + * multiple side outputs, above. + * + *

    Serializability of {@link DoFn DoFns}

    + * + *

    A {@link DoFn} passed to a {@link ParDo} transform must be + * {@link Serializable}. This allows the {@link DoFn} instance + * created in this "main program" to be sent (in serialized form) to + * remote worker machines and reconstituted for each bundles of elements + * of the input {@link PCollection} being processed. A {@link DoFn} + * can have instance variable state, and non-transient instance + * variable state will be serialized in the main program and then + * deserialized on remote worker machines for each bundle of elements + * to process. + * + *

    To aid in ensuring that {@link DoFn DoFns} are properly + * {@link Serializable}, even local execution using the + * {@link DirectPipelineRunner} will serialize and then deserialize + * {@link DoFn DoFns} before executing them on a bundle. + * + *

    {@link DoFn DoFns} expressed as anonymous inner classes can be + * convenient, but due to a quirk in Java's rules for serializability, + * non-static inner or nested classes (including anonymous inner + * classes) automatically capture their enclosing class's instance in + * their serialized state. This can lead to including much more than + * intended in the serialized state of a {@link DoFn}, or even things + * that aren't {@link Serializable}. + * + *

    There are two ways to avoid unintended serialized state in a + * {@link DoFn}: + * + *

      + * + *
    • Define the {@link DoFn} as a named, static class. + * + *
    • Define the {@link DoFn} as an anonymous inner class inside of + * a static method. + * + *
    + * + *

    Both of these approaches ensure that there is no implicit enclosing + * instance serialized along with the {@link DoFn} instance. + * + *

    Prior to Java 8, any local variables of the enclosing + * method referenced from within an anonymous inner class need to be + * marked as {@code final}. If defining the {@link DoFn} as a named + * static class, such variables would be passed as explicit + * constructor arguments and stored in explicit instance variables. + * + *

    There are three main ways to initialize the state of a + * {@link DoFn} instance processing a bundle: + * + *

      + * + *
    • Define instance variable state (including implicit instance + * variables holding final variables captured by an anonymous inner + * class), initialized by the {@link DoFn}'s constructor (which is + * implicit for an anonymous inner class). This state will be + * automatically serialized and then deserialized in the {@code DoFn} + * instance created for each bundle. This method is good for state + * known when the original {@code DoFn} is created in the main + * program, if it's not overly large. + * + *
    • Compute the state as a singleton {@link PCollection} and pass it + * in as a side input to the {@link DoFn}. This is good if the state + * needs to be computed by the pipeline, or if the state is very large + * and so is best read from file(s) rather than sent as part of the + * {@code DoFn}'s serialized state. + * + *
    • Initialize the state in each {@link DoFn} instance, in + * {@link DoFn#startBundle}. This is good if the initialization + * doesn't depend on any information known only by the main program or + * computed by earlier pipeline operations, but is the same for all + * instances of this {@link DoFn} for all program executions, say + * setting up empty caches or initializing constant data. + * + *
    + * + *

    No Global Shared State

    + * + *

    {@link ParDo} operations are intended to be able to run in + * parallel across multiple worker machines. This precludes easy + * sharing and updating mutable state across those machines. There is + * no support in the Google Cloud Dataflow system for communicating + * and synchronizing updates to shared state across worker machines, + * so programs should not access any mutable static variable state in + * their {@link DoFn}, without understanding that the Java processes + * for the main program and workers will each have its own independent + * copy of such state, and there won't be any automatic copying of + * that state across Java processes. All information should be + * communicated to {@link DoFn} instances via main and side inputs and + * serialized state, and all output should be communicated from a + * {@link DoFn} instance via main and side outputs, in the absence of + * external communication mechanisms written by user code. + * + *

    Fault Tolerance

    + * + *

    In a distributed system, things can fail: machines can crash, + * machines can be unable to communicate across the network, etc. + * While individual failures are rare, the larger the job, the greater + * the chance that something, somewhere, will fail. The Google Cloud + * Dataflow service strives to mask such failures automatically, + * principally by retrying failed {@link DoFn} bundle. This means + * that a {@code DoFn} instance might process a bundle partially, then + * crash for some reason, then be rerun (often on a different worker + * machine) on that same bundle and on the same elements as before. + * Sometimes two or more {@link DoFn} instances will be running on the + * same bundle simultaneously, with the system taking the results of + * the first instance to complete successfully. Consequently, the + * code in a {@link DoFn} needs to be written such that these + * duplicate (sequential or concurrent) executions do not cause + * problems. If the outputs of a {@link DoFn} are a pure function of + * its inputs, then this requirement is satisfied. However, if a + * {@link DoFn DoFn's} execution has external side-effects, such as performing + * updates to external HTTP services, then the {@link DoFn DoFn's} code + * needs to take care to ensure that those updates are idempotent and + * that concurrent updates are acceptable. This property can be + * difficult to achieve, so it is advisable to strive to keep + * {@link DoFn DoFns} as pure functions as much as possible. + * + *

    Optimization

    + * + *

    The Google Cloud Dataflow service automatically optimizes a + * pipeline before it is executed. A key optimization, fusion, + * relates to {@link ParDo} operations. If one {@link ParDo} operation produces a + * {@link PCollection} that is then consumed as the main input of another + * {@link ParDo} operation, the two {@link ParDo} operations will be fused + * together into a single ParDo operation and run in a single pass; + * this is "producer-consumer fusion". Similarly, if + * two or more ParDo operations have the same {@link PCollection} main input, + * they will be fused into a single {@link ParDo} that makes just one pass + * over the input {@link PCollection}; this is "sibling fusion". + * + *

    If after fusion there are no more unfused references to a + * {@link PCollection} (e.g., one between a producer ParDo and a consumer + * {@link ParDo}), the {@link PCollection} itself is "fused away" and won't ever be + * written to disk, saving all the I/O and space expense of + * constructing it. + * + *

    The Google Cloud Dataflow service applies fusion as much as + * possible, greatly reducing the cost of executing pipelines. As a + * result, it is essentially "free" to write {@link ParDo} operations in a + * very modular, composable style, each {@link ParDo} operation doing one + * clear task, and stringing together sequences of {@link ParDo} operations to + * get the desired overall effect. Such programs can be easier to + * understand, easier to unit-test, easier to extend and evolve, and + * easier to reuse in new programs. The predefined library of + * PTransforms that come with Google Cloud Dataflow makes heavy use of + * this modular, composable style, trusting to the Google Cloud + * Dataflow service's optimizer to "flatten out" all the compositions + * into highly optimized stages. + * + * @see the web + * documentation for ParDo + */ +public class ParDo { + + /** + * Creates a {@link ParDo} {@link PTransform} with the given name. + * + *

    See the discussion of naming above for more explanation. + * + *

    The resulting {@link PTransform} is incomplete, and its + * input/output types are not yet bound. Use + * {@link ParDo.Unbound#of} to specify the {@link DoFn} to + * invoke, which will also bind the input/output types of this + * {@link PTransform}. + */ + public static Unbound named(String name) { + return new Unbound().named(name); + } + + /** + * Creates a {@link ParDo} {@link PTransform} with the given + * side inputs. + * + *

    Side inputs are {@link PCollectionView PCollectionViews}, whose contents are + * computed during pipeline execution and then made accessible to + * {@link DoFn} code via {@link DoFn.ProcessContext#sideInput sideInput}. Each + * invocation of the {@link DoFn} receives the same values for these + * side inputs. + * + *

    See the discussion of Side Inputs above for more explanation. + * + *

    The resulting {@link PTransform} is incomplete, and its + * input/output types are not yet bound. Use + * {@link ParDo.Unbound#of} to specify the {@link DoFn} to + * invoke, which will also bind the input/output types of this + * {@link PTransform}. + */ + public static Unbound withSideInputs(PCollectionView... sideInputs) { + return new Unbound().withSideInputs(sideInputs); + } + + /** + * Creates a {@link ParDo} with the given side inputs. + * + *

    Side inputs are {@link PCollectionView}s, whose contents are + * computed during pipeline execution and then made accessible to + * {@code DoFn} code via {@link DoFn.ProcessContext#sideInput sideInput}. + * + *

    See the discussion of Side Inputs above for more explanation. + * + *

    The resulting {@link PTransform} is incomplete, and its + * input/output types are not yet bound. Use + * {@link ParDo.Unbound#of} to specify the {@link DoFn} to + * invoke, which will also bind the input/output types of this + * {@link PTransform}. + */ + public static Unbound withSideInputs( + Iterable> sideInputs) { + return new Unbound().withSideInputs(sideInputs); + } + + /** + * Creates a multi-output {@link ParDo} {@link PTransform} whose + * output {@link PCollection}s will be referenced using the given main + * output and side output tags. + * + *

    {@link TupleTag TupleTags} are used to name (with its static element + * type {@code T}) each main and side output {@code PCollection}. + * This {@link PTransform PTransform's} {@link DoFn} emits elements to the main + * output {@link PCollection} as normal, using + * {@link DoFn.Context#output}. The {@link DoFn} emits elements to + * a side output {@code PCollection} using + * {@link DoFn.Context#sideOutput}, passing that side output's tag + * as an argument. The result of invoking this {@link PTransform} + * will be a {@link PCollectionTuple}, and any of the the main and + * side output {@code PCollection}s can be retrieved from it via + * {@link PCollectionTuple#get}, passing the output's tag as an + * argument. + * + *

    See the discussion of Side Outputs above for more explanation. + * + *

    The resulting {@link PTransform} is incomplete, and its input + * type is not yet bound. Use {@link ParDo.UnboundMulti#of} + * to specify the {@link DoFn} to invoke, which will also bind the + * input type of this {@link PTransform}. + */ + public static UnboundMulti withOutputTags( + TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + return new Unbound().withOutputTags(mainOutputTag, sideOutputTags); + } + + /** + * Creates a {@link ParDo} {@link PTransform} that will invoke the + * given {@link DoFn} function. + * + *

    The resulting {@link PTransform PTransform's} types have been bound, with the + * input being a {@code PCollection} and the output a + * {@code PCollection}, inferred from the types of the argument + * {@code DoFn}. It is ready to be applied, or further + * properties can be set on it first. + */ + public static Bound of(DoFn fn) { + return new Unbound().of(fn); + } + + private static DoFn + adapt(DoFnWithContext fn) { + return DoFnReflector.of(fn.getClass()).toDoFn(fn); + } + + /** + * Creates a {@link ParDo} {@link PTransform} that will invoke the + * given {@link DoFnWithContext} function. + * + *

    The resulting {@link PTransform PTransform's} types have been bound, with the + * input being a {@code PCollection} and the output a + * {@code PCollection}, inferred from the types of the argument + * {@code DoFn}. It is ready to be applied, or further + * properties can be set on it first. + * + *

    {@link DoFnWithContext} is an experimental alternative to + * {@link DoFn} which simplifies accessing the window of the element. + */ + @Experimental + public static Bound of(DoFnWithContext fn) { + return of(adapt(fn)); + } + + /** + * An incomplete {@link ParDo} transform, with unbound input/output types. + * + *

    Before being applied, {@link ParDo.Unbound#of} must be + * invoked to specify the {@link DoFn} to invoke, which will also + * bind the input/output types of this {@link PTransform}. + */ + public static class Unbound { + private final String name; + private final List> sideInputs; + + Unbound() { + this(null, ImmutableList.>of()); + } + + Unbound(String name, List> sideInputs) { + this.name = name; + this.sideInputs = sideInputs; + } + + /** + * Returns a new {@link ParDo} transform that's like this + * transform but with the specified name. Does not modify this + * transform. The resulting transform is still incomplete. + * + *

    See the discussion of naming above for more explanation. + */ + public Unbound named(String name) { + return new Unbound(name, sideInputs); + } + + /** + * Returns a new {@link ParDo} transform that's like this + * transform but with the specified additional side inputs. + * Does not modify this transform. The resulting transform is + * still incomplete. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Unbound withSideInputs(PCollectionView... sideInputs) { + return withSideInputs(Arrays.asList(sideInputs)); + } + + /** + * Returns a new {@link ParDo} transform that is like this + * transform but with the specified additional side inputs. Does not modify + * this transform. The resulting transform is still incomplete. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Unbound withSideInputs( + Iterable> sideInputs) { + ImmutableList.Builder> builder = ImmutableList.builder(); + builder.addAll(this.sideInputs); + builder.addAll(sideInputs); + return new Unbound(name, builder.build()); + } + + /** + * Returns a new multi-output {@link ParDo} transform that's like + * this transform but with the specified main and side output + * tags. Does not modify this transform. The resulting transform + * is still incomplete. + * + *

    See the discussion of Side Outputs above and on + * {@link ParDo#withOutputTags} for more explanation. + */ + public UnboundMulti withOutputTags(TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + return new UnboundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags); + } + + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this + * transform but that will invoke the given {@link DoFn} + * function, and that has its input and output types bound. Does + * not modify this transform. The resulting {@link PTransform} is + * sufficiently specified to be applied, but more properties can + * still be specified. + */ + public Bound of(DoFn fn) { + return new Bound<>(name, sideInputs, fn); + } + + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this + * transform but which will invoke the given {@link DoFnWithContext} + * function, and which has its input and output types bound. Does + * not modify this transform. The resulting {@link PTransform} is + * sufficiently specified to be applied, but more properties can + * still be specified. + */ + public Bound of(DoFnWithContext fn) { + return of(adapt(fn)); + } + } + + /** + * A {@link PTransform} that, when applied to a {@code PCollection}, + * invokes a user-specified {@code DoFn} on all its elements, + * with all its outputs collected into an output + * {@code PCollection}. + * + *

    A multi-output form of this transform can be created with + * {@link ParDo.Bound#withOutputTags}. + * + * @param the type of the (main) input {@link PCollection} elements + * @param the type of the (main) output {@link PCollection} elements + */ + public static class Bound + extends PTransform, PCollection> { + // Inherits name. + private final List> sideInputs; + private final DoFn fn; + + Bound(String name, + List> sideInputs, + DoFn fn) { + super(name); + this.sideInputs = sideInputs; + this.fn = SerializableUtils.clone(fn); + } + + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this + * {@link PTransform} but with the specified name. Does not + * modify this {@link PTransform}. + * + *

    See the discussion of Naming above for more explanation. + */ + public Bound named(String name) { + return new Bound<>(name, sideInputs, fn); + } + + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this + * {@link PTransform} but with the specified additional side inputs. Does not + * modify this {@link PTransform}. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Bound withSideInputs(PCollectionView... sideInputs) { + return withSideInputs(Arrays.asList(sideInputs)); + } + + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this + * {@link PTransform} but with the specified additional side inputs. Does not + * modify this {@link PTransform}. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public Bound withSideInputs( + Iterable> sideInputs) { + ImmutableList.Builder> builder = ImmutableList.builder(); + builder.addAll(this.sideInputs); + builder.addAll(sideInputs); + return new Bound<>(name, builder.build(), fn); + } + + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} + * that's like this {@link PTransform} but with the specified main + * and side output tags. Does not modify this {@link PTransform}. + * + *

    See the discussion of Side Outputs above and on + * {@link ParDo#withOutputTags} for more explanation. + */ + public BoundMulti withOutputTags(TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + return new BoundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags, fn); + } + + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + input.getWindowingStrategy(), + input.isBounded()) + .setTypeDescriptorInternal(fn.getOutputTypeDescriptor()); + } + + @Override + @SuppressWarnings("unchecked") + protected Coder getDefaultOutputCoder(PCollection input) + throws CannotProvideCoderException { + return input.getPipeline().getCoderRegistry().getDefaultCoder( + fn.getOutputTypeDescriptor(), + fn.getInputTypeDescriptor(), + ((PCollection) input).getCoder()); + } + + @Override + protected String getKindString() { + Class clazz = DoFnReflector.getDoFnClass(fn); + if (clazz.isAnonymousClass()) { + return "AnonymousParDo"; + } else { + return String.format("ParDo(%s)", StringUtils.approximateSimpleName(clazz)); + } + } + + public DoFn getFn() { + return fn; + } + + public List> getSideInputs() { + return sideInputs; + } + } + + /** + * An incomplete multi-output {@link ParDo} transform, with unbound + * input type. + * + *

    Before being applied, {@link ParDo.UnboundMulti#of} must be + * invoked to specify the {@link DoFn} to invoke, which will also + * bind the input type of this {@link PTransform}. + * + * @param the type of the main output {@code PCollection} elements + */ + public static class UnboundMulti { + private final String name; + private final List> sideInputs; + private final TupleTag mainOutputTag; + private final TupleTagList sideOutputTags; + + UnboundMulti(String name, + List> sideInputs, + TupleTag mainOutputTag, + TupleTagList sideOutputTags) { + this.name = name; + this.sideInputs = sideInputs; + this.mainOutputTag = mainOutputTag; + this.sideOutputTags = sideOutputTags; + } + + /** + * Returns a new multi-output {@link ParDo} transform that's like + * this transform but with the specified name. Does not modify + * this transform. The resulting transform is still incomplete. + * + *

    See the discussion of Naming above for more explanation. + */ + public UnboundMulti named(String name) { + return new UnboundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags); + } + + /** + * Returns a new multi-output {@link ParDo} transform that's like + * this transform but with the specified side inputs. Does not + * modify this transform. The resulting transform is still + * incomplete. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public UnboundMulti withSideInputs( + PCollectionView... sideInputs) { + return withSideInputs(Arrays.asList(sideInputs)); + } + + /** + * Returns a new multi-output {@link ParDo} transform that's like + * this transform but with the specified additional side inputs. Does not + * modify this transform. The resulting transform is still + * incomplete. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public UnboundMulti withSideInputs( + Iterable> sideInputs) { + ImmutableList.Builder> builder = ImmutableList.builder(); + builder.addAll(this.sideInputs); + builder.addAll(sideInputs); + return new UnboundMulti<>( + name, builder.build(), + mainOutputTag, sideOutputTags); + } + + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} + * that's like this transform but that will invoke the given + * {@link DoFn} function, and that has its input type bound. + * Does not modify this transform. The resulting + * {@link PTransform} is sufficiently specified to be applied, but + * more properties can still be specified. + */ + public BoundMulti of(DoFn fn) { + return new BoundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags, fn); + } + + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} + * that's like this transform but which will invoke the given + * {@link DoFnWithContext} function, and which has its input type bound. + * Does not modify this transform. The resulting + * {@link PTransform} is sufficiently specified to be applied, but + * more properties can still be specified. + */ + public BoundMulti of(DoFnWithContext fn) { + return of(adapt(fn)); + } + } + + /** + * A {@link PTransform} that, when applied to a + * {@code PCollection}, invokes a user-specified + * {@code DoFn} on all its elements, which can emit elements + * to any of the {@link PTransform}'s main and side output + * {@code PCollection}s, which are bundled into a result + * {@code PCollectionTuple}. + * + * @param the type of the (main) input {@code PCollection} elements + * @param the type of the main output {@code PCollection} elements + */ + public static class BoundMulti + extends PTransform, PCollectionTuple> { + // Inherits name. + private final List> sideInputs; + private final TupleTag mainOutputTag; + private final TupleTagList sideOutputTags; + private final DoFn fn; + + BoundMulti(String name, + List> sideInputs, + TupleTag mainOutputTag, + TupleTagList sideOutputTags, + DoFn fn) { + super(name); + this.sideInputs = sideInputs; + this.mainOutputTag = mainOutputTag; + this.sideOutputTags = sideOutputTags; + this.fn = SerializableUtils.clone(fn); + } + + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} + * that's like this {@link PTransform} but with the specified + * name. Does not modify this {@link PTransform}. + * + *

    See the discussion of Naming above for more explanation. + */ + public BoundMulti named(String name) { + return new BoundMulti<>( + name, sideInputs, mainOutputTag, sideOutputTags, fn); + } + + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} + * that's like this {@link PTransform} but with the specified additional side + * inputs. Does not modify this {@link PTransform}. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public BoundMulti withSideInputs( + PCollectionView... sideInputs) { + return withSideInputs(Arrays.asList(sideInputs)); + } + + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} + * that's like this {@link PTransform} but with the specified additional side + * inputs. Does not modify this {@link PTransform}. + * + *

    See the discussion of Side Inputs above and on + * {@link ParDo#withSideInputs} for more explanation. + */ + public BoundMulti withSideInputs( + Iterable> sideInputs) { + ImmutableList.Builder> builder = ImmutableList.builder(); + builder.addAll(this.sideInputs); + builder.addAll(sideInputs); + return new BoundMulti<>( + name, builder.build(), + mainOutputTag, sideOutputTags, fn); + } + + + @Override + public PCollectionTuple apply(PCollection input) { + PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( + input.getPipeline(), + TupleTagList.of(mainOutputTag).and(sideOutputTags.getAll()), + input.getWindowingStrategy(), + input.isBounded()); + + // The fn will likely be an instance of an anonymous subclass + // such as DoFn { }, thus will have a high-fidelity + // TypeDescriptor for the output type. + outputs.get(mainOutputTag).setTypeDescriptorInternal(fn.getOutputTypeDescriptor()); + + return outputs; + } + + @Override + protected Coder getDefaultOutputCoder() { + throw new RuntimeException( + "internal error: shouldn't be calling this on a multi-output ParDo"); + } + + @Override + public Coder getDefaultOutputCoder( + PCollection input, TypedPValue output) + throws CannotProvideCoderException { + @SuppressWarnings("unchecked") + Coder inputCoder = ((PCollection) input).getCoder(); + return input.getPipeline().getCoderRegistry().getDefaultCoder( + output.getTypeDescriptor(), + fn.getInputTypeDescriptor(), + inputCoder); + } + + @Override + protected String getKindString() { + Class clazz = DoFnReflector.getDoFnClass(fn); + if (fn.getClass().isAnonymousClass()) { + return "AnonymousParMultiDo"; + } else { + return String.format("ParMultiDo(%s)", StringUtils.approximateSimpleName(clazz)); + } + } + + public DoFn getFn() { + return fn; + } + + public TupleTag getMainOutputTag() { + return mainOutputTag; + } + + public TupleTagList getSideOutputTags() { + return sideOutputTags; + } + + public List> getSideInputs() { + return sideInputs; + } + } + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + Bound.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateSingleHelper(transform, context); + } + }); + } + + private static void evaluateSingleHelper( + Bound transform, + DirectPipelineRunner.EvaluationContext context) { + TupleTag mainOutputTag = new TupleTag<>("out"); + + DirectModeExecutionContext executionContext = DirectModeExecutionContext.create(); + + PCollectionTuple outputs = PCollectionTuple.of(mainOutputTag, context.getOutput(transform)); + + evaluateHelper( + transform.fn, + context.getStepName(transform), + context.getInput(transform), + transform.sideInputs, + mainOutputTag, + Collections.>emptyList(), + outputs, + context, + executionContext); + + context.setPCollectionValuesWithMetadata( + context.getOutput(transform), + executionContext.getOutput(mainOutputTag)); + } + + ///////////////////////////////////////////////////////////////////////////// + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + BoundMulti.class, + new DirectPipelineRunner.TransformEvaluator() { + @Override + public void evaluate( + BoundMulti transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateMultiHelper(transform, context); + } + }); + } + + private static void evaluateMultiHelper( + BoundMulti transform, + DirectPipelineRunner.EvaluationContext context) { + + DirectModeExecutionContext executionContext = DirectModeExecutionContext.create(); + + evaluateHelper( + transform.fn, + context.getStepName(transform), + context.getInput(transform), + transform.sideInputs, + transform.mainOutputTag, + transform.sideOutputTags.getAll(), + context.getOutput(transform), + context, + executionContext); + + for (Map.Entry, PCollection> entry + : context.getOutput(transform).getAll().entrySet()) { + @SuppressWarnings("unchecked") + TupleTag tag = (TupleTag) entry.getKey(); + @SuppressWarnings("unchecked") + PCollection pc = (PCollection) entry.getValue(); + + context.setPCollectionValuesWithMetadata( + pc, + (tag == transform.mainOutputTag + ? executionContext.getOutput(tag) + : executionContext.getSideOutput(tag))); + } + } + + /** + * Evaluates a single-output or multi-output {@link ParDo} directly. + * + *

    This evaluation method is intended for use in testing scenarios; it is designed for clarity + * and correctness-checking, not speed. + * + *

    Of particular note, this performs best-effort checking that inputs and outputs are not + * mutated in violation of the requirements upon a {@link DoFn}. + */ + private static void evaluateHelper( + DoFn doFn, + String stepName, + PCollection input, + List> sideInputs, + TupleTag mainOutputTag, + List> sideOutputTags, + PCollectionTuple outputs, + DirectPipelineRunner.EvaluationContext context, + DirectModeExecutionContext executionContext) { + // TODO: Run multiple shards? + DoFn fn = context.ensureSerializable(doFn); + + SideInputReader sideInputReader = makeSideInputReader(context, sideInputs); + + // When evaluating via the DirectPipelineRunner, this output manager checks each output for + // illegal mutations when the next output comes along. We then verify again after finishBundle() + // The common case we expect this to catch is a user mutating an input in order to repeatedly + // emit "variations". + ImmutabilityCheckingOutputManager outputManager = + new ImmutabilityCheckingOutputManager<>( + fn.getClass().getSimpleName(), + new DoFnRunnerBase.ListOutputManager(), + outputs); + + DoFnRunner fnRunner = + DoFnRunners.createDefault( + context.getPipelineOptions(), + fn, + sideInputReader, + outputManager, + mainOutputTag, + sideOutputTags, + executionContext.getOrCreateStepContext(stepName, stepName, null), + context.getAddCounterMutator(), + input.getWindowingStrategy()); + + fnRunner.startBundle(); + + for (DirectPipelineRunner.ValueWithMetadata elem + : context.getPCollectionValuesWithMetadata(input)) { + if (elem.getValue() instanceof KV) { + // In case the DoFn needs keyed state, set the implicit keys to the keys + // in the input elements. + @SuppressWarnings("unchecked") + KV kvElem = (KV) elem.getValue(); + executionContext.setKey(kvElem.getKey()); + } else { + executionContext.setKey(elem.getKey()); + } + + // We check the input for mutations only through the call span of processElement. + // This will miss some cases, but the check is ad hoc and best effort. The common case + // is that the input is mutated to be used for output. + try { + MutationDetector inputMutationDetector = MutationDetectors.forValueWithCoder( + elem.getWindowedValue().getValue(), input.getCoder()); + @SuppressWarnings("unchecked") + WindowedValue windowedElem = ((WindowedValue) elem.getWindowedValue()); + fnRunner.processElement(windowedElem); + inputMutationDetector.verifyUnmodified(); + } catch (CoderException e) { + throw UserCodeException.wrap(e); + } catch (IllegalMutationException exn) { + throw new IllegalMutationException( + String.format("DoFn %s mutated input value %s of class %s (new value was %s)." + + " Input values must not be mutated in any way.", + fn.getClass().getSimpleName(), + exn.getSavedValue(), exn.getSavedValue().getClass(), exn.getNewValue()), + exn.getSavedValue(), + exn.getNewValue(), + exn); + } + } + + // Note that the input could have been retained and mutated prior to this final output, + // but for now it degrades readability too much to be worth trying to catch that particular + // corner case. + fnRunner.finishBundle(); + outputManager.verifyLatestOutputsUnmodified(); + } + + private static SideInputReader makeSideInputReader( + DirectPipelineRunner.EvaluationContext context, List> sideInputs) { + PTuple sideInputValues = PTuple.empty(); + for (PCollectionView view : sideInputs) { + sideInputValues = sideInputValues.and( + view.getTagInternal(), + context.getPCollectionView(view)); + } + return DirectSideInputReader.of(sideInputValues); + } + + /** + * A {@code DoFnRunner.OutputManager} that provides facilities for checking output values for + * illegal mutations. + * + *

    When used via the try-with-resources pattern, it is guaranteed that every value passed + * to {@link #output} will have been checked for illegal mutation. + */ + private static class ImmutabilityCheckingOutputManager + implements DoFnRunners.OutputManager, AutoCloseable { + + private final DoFnRunners.OutputManager underlyingOutputManager; + private final ConcurrentMap, MutationDetector> mutationDetectorForTag; + private final PCollectionTuple outputs; + private String doFnName; + + public ImmutabilityCheckingOutputManager( + String doFnName, + DoFnRunners.OutputManager underlyingOutputManager, + PCollectionTuple outputs) { + this.doFnName = doFnName; + this.underlyingOutputManager = underlyingOutputManager; + this.outputs = outputs; + this.mutationDetectorForTag = Maps.newConcurrentMap(); + } + + @Override + public void output(TupleTag tag, WindowedValue output) { + + // Skip verifying undeclared outputs, since we don't have coders for them. + if (outputs.has(tag)) { + try { + MutationDetector newDetector = + MutationDetectors.forValueWithCoder( + output.getValue(), outputs.get(tag).getCoder()); + MutationDetector priorDetector = mutationDetectorForTag.put(tag, newDetector); + verifyOutputUnmodified(priorDetector); + } catch (CoderException e) { + throw UserCodeException.wrap(e); + } + } + + // Actually perform the output. + underlyingOutputManager.output(tag, output); + } + + /** + * Throws {@link IllegalMutationException} if the prior output for any tag has been mutated + * since being output. + */ + public void verifyLatestOutputsUnmodified() { + for (MutationDetector detector : mutationDetectorForTag.values()) { + verifyOutputUnmodified(detector); + } + } + + /** + * Adapts the error message from the provided {@code detector}. + * + *

    The {@code detector} may be null, in which case no check is performed. This is merely + * to consolidate null checking to this method. + */ + private void verifyOutputUnmodified(@Nullable MutationDetector detector) { + if (detector == null) { + return; + } + + try { + detector.verifyUnmodified(); + } catch (IllegalMutationException exn) { + throw new IllegalMutationException(String.format( + "DoFn %s mutated value %s after it was output (new value was %s)." + + " Values must not be mutated in any way after being output.", + doFnName, exn.getSavedValue(), exn.getNewValue()), + exn.getSavedValue(), exn.getNewValue(), + exn); + } + } + + /** + * When used in a {@code try}-with-resources block, verifies all of the latest outputs upon + * {@link #close()}. + */ + @Override + public void close() { + verifyLatestOutputsUnmodified(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Partition.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Partition.java new file mode 100644 index 000000000000..bbbccbc75ddd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Partition.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import java.io.Serializable; + +/** + * {@code Partition} takes a {@code PCollection} and a + * {@code PartitionFn}, uses the {@code PartitionFn} to split the + * elements of the input {@code PCollection} into {@code N} partitions, and + * returns a {@code PCollectionList} that bundles {@code N} + * {@code PCollection}s containing the split elements. + * + *

    Example of use: + *

     {@code
    + * PCollection students = ...;
    + * // Split students up into 10 partitions, by percentile:
    + * PCollectionList studentsByPercentile =
    + *     students.apply(Partition.of(10, new PartitionFn() {
    + *         public int partitionFor(Student student, int numPartitions) {
    + *             return student.getPercentile()  // 0..99
    + *                  * numPartitions / 100;
    + *         }}))
    + * for (int i = 0; i < 10; i++) {
    + *   PCollection partition = studentsByPercentile.get(i);
    + *   ...
    + * }
    + * } 
    + * + *

    By default, the {@code Coder} of each of the + * {@code PCollection}s in the output {@code PCollectionList} is the + * same as the {@code Coder} of the input {@code PCollection}. + * + *

    Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and each output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ +public class Partition extends PTransform, PCollectionList> { + + /** + * A function object that chooses an output partition for an element. + * + * @param the type of the elements being partitioned + */ + public interface PartitionFn extends Serializable { + /** + * Chooses the partition into which to put the given element. + * + * @param elem the element to be partitioned + * @param numPartitions the total number of partitions ({@code >= 1}) + * @return index of the selected partition (in the range + * {@code [0..numPartitions-1]}) + */ + public int partitionFor(T elem, int numPartitions); + } + + /** + * Returns a new {@code Partition} {@code PTransform} that divides + * its input {@code PCollection} into the given number of partitions, + * using the given partitioning function. + * + * @param numPartitions the number of partitions to divide the input + * {@code PCollection} into + * @param partitionFn the function to invoke on each element to + * choose its output partition + * @throws IllegalArgumentException if {@code numPartitions <= 0} + */ + public static Partition of( + int numPartitions, PartitionFn partitionFn) { + return new Partition<>(new PartitionDoFn(numPartitions, partitionFn)); + } + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public PCollectionList apply(PCollection in) { + final TupleTagList outputTags = partitionDoFn.getOutputTags(); + + PCollectionTuple outputs = in.apply( + ParDo + .withOutputTags(new TupleTag(){}, outputTags) + .of(partitionDoFn)); + + PCollectionList pcs = PCollectionList.empty(in.getPipeline()); + Coder coder = in.getCoder(); + + for (TupleTag outputTag : outputTags.getAll()) { + // All the tuple tags are actually TupleTag + // And all the collections are actually PCollection + @SuppressWarnings("unchecked") + TupleTag typedOutputTag = (TupleTag) outputTag; + pcs = pcs.and(outputs.get(typedOutputTag).setCoder(coder)); + } + return pcs; + } + + private final transient PartitionDoFn partitionDoFn; + + private Partition(PartitionDoFn partitionDoFn) { + this.partitionDoFn = partitionDoFn; + } + + private static class PartitionDoFn extends DoFn { + private final int numPartitions; + private final PartitionFn partitionFn; + private final TupleTagList outputTags; + + /** + * Constructs a PartitionDoFn. + * + * @throws IllegalArgumentException if {@code numPartitions <= 0} + */ + public PartitionDoFn(int numPartitions, PartitionFn partitionFn) { + if (numPartitions <= 0) { + throw new IllegalArgumentException("numPartitions must be > 0"); + } + + this.numPartitions = numPartitions; + this.partitionFn = partitionFn; + + TupleTagList buildOutputTags = TupleTagList.empty(); + for (int partition = 0; partition < numPartitions; partition++) { + buildOutputTags = buildOutputTags.and(new TupleTag()); + } + outputTags = buildOutputTags; + } + + public TupleTagList getOutputTags() { + return outputTags; + } + + @Override + public void processElement(ProcessContext c) { + X input = c.element(); + int partition = partitionFn.partitionFor(input, numPartitions); + if (0 <= partition && partition < numPartitions) { + @SuppressWarnings("unchecked") + TupleTag typedTag = (TupleTag) outputTags.get(partition); + c.sideOutput(typedTag, input); + } else { + throw new IndexOutOfBoundsException( + "Partition function returned out of bounds index: " + + partition + " not in [0.." + numPartitions + ")"); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicates.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicates.java new file mode 100644 index 000000000000..8913138abb27 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicates.java @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +/** + * {@code RemoveDuplicates} takes a {@code PCollection} and + * returns a {@code PCollection} that has all the elements of the + * input but with duplicate elements removed such that each element is + * unique within each window. + * + *

    Two values of type {@code T} are compared for equality not by + * regular Java {@link Object#equals}, but instead by first encoding + * each of the elements using the {@code PCollection}'s {@code Coder}, and then + * comparing the encoded bytes. This admits efficient parallel + * evaluation. + * + *

    Optionally, a function may be provided that maps each element to a representative + * value. In this case, two elements will be considered duplicates if they have equal + * representative values, with equality being determined as above. + * + *

    By default, the {@code Coder} of the output {@code PCollection} + * is the same as the {@code Coder} of the input {@code PCollection}. + * + *

    Each output element is in the same window as its corresponding input + * element, and has the timestamp of the end of that window. The output + * {@code PCollection} has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * as the input. + * + *

    Does not preserve any order the input PCollection might have had. + * + *

    Example of use: + *

     {@code
    + * PCollection words = ...;
    + * PCollection uniqueWords =
    + *     words.apply(RemoveDuplicates.create());
    + * } 
    + * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ +public class RemoveDuplicates extends PTransform, + PCollection> { + /** + * Returns a {@code RemoveDuplicates} {@code PTransform}. + * + * @param the type of the elements of the input and output + * {@code PCollection}s + */ + public static RemoveDuplicates create() { + return new RemoveDuplicates(); + } + + /** + * Returns a {@code RemoveDuplicates} {@code PTransform}. + * + * @param the type of the elements of the input and output + * {@code PCollection}s + * @param the type of the representative value used to dedup + */ + public static WithRepresentativeValues withRepresentativeValueFn( + SerializableFunction fn) { + return new WithRepresentativeValues(fn, null); + } + + @Override + public PCollection apply(PCollection in) { + return in + .apply(ParDo.named("CreateIndex") + .of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), (Void) null)); + } + })) + .apply(Combine.perKey( + new SerializableFunction, Void>() { + @Override + public Void apply(Iterable iter) { + return null; // ignore input + } + })) + .apply(Keys.create()); + } + + /** + * A {@link RemoveDuplicates} {@link PTransform} that uses a {@link SerializableFunction} to + * obtain a representative value for each input element. + * + * Construct via {@link RemoveDuplicates#withRepresentativeValueFn(SerializableFunction)}. + * + * @param the type of input and output element + * @param the type of representative values used to dedup + */ + public static class WithRepresentativeValues + extends PTransform, PCollection> { + private final SerializableFunction fn; + private final TypeDescriptor representativeType; + + private WithRepresentativeValues( + SerializableFunction fn, TypeDescriptor representativeType) { + this.fn = fn; + this.representativeType = representativeType; + } + + @Override + public PCollection apply(PCollection in) { + WithKeys withKeys = WithKeys.of(fn); + if (representativeType != null) { + withKeys = withKeys.withKeyType(representativeType); + } + return in + .apply(withKeys) + .apply(Combine.perKey( + new Combine.BinaryCombineFn() { + @Override + public T apply(T left, T right) { + return left; + } + })) + .apply(Values.create()); + } + + /** + * Return a {@code WithRepresentativeValues} {@link PTransform} that is like this one, but with + * the specified output type descriptor. + * + * Required for use of {@link RemoveDuplicates#withRepresentativeValueFn(SerializableFunction)} + * in Java 8 with a lambda as the fn. + * + * @param type a {@link TypeDescriptor} describing the representative type of this + * {@code WithRepresentativeValues} + * @return A {@code WithRepresentativeValues} {@link PTransform} that is like this one, but with + * the specified output type descriptor. + */ + public WithRepresentativeValues withRepresentativeType(TypeDescriptor type) { + return new WithRepresentativeValues<>(fn, type); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sample.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sample.java new file mode 100644 index 000000000000..c5b6e7ec1414 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sample.java @@ -0,0 +1,246 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.base.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * {@code PTransform}s for taking samples of the elements in a + * {@code PCollection}, or samples of the values associated with each + * key in a {@code PCollection} of {@code KV}s. + **/ +public class Sample { + + /** + * {@code Sample#any(long)} takes a {@code PCollection} and a limit, and + * produces a new {@code PCollection} containing up to limit + * elements of the input {@code PCollection}. + * + *

    If limit is less than or equal to the size of the input + * {@code PCollection}, then all the input's elements will be selected. + * + *

    All of the elements of the output {@code PCollection} should fit into + * main memory of a single worker machine. This operation does not + * run in parallel. + * + *

    Example of use: + *

     {@code
    +   * PCollection input = ...;
    +   * PCollection output = input.apply(Sample.any(100));
    +   * } 
    + * + * @param the type of the elements of the input and output + * {@code PCollection}s + * @param limit the number of elements to take from the input + */ + public static PTransform, PCollection> any(long limit) { + return new SampleAny<>(limit); + } + + /** + * Returns a {@code PTransform} that takes a {@code PCollection}, + * selects {@code sampleSize} elements, uniformly at random, and returns a + * {@code PCollection>} containing the selected elements. + * If the input {@code PCollection} has fewer than + * {@code sampleSize} elements, then the output {@code Iterable} + * will be all the input's elements. + * + *

    Example of use: + *

     {@code
    +   * PCollection pc = ...;
    +   * PCollection> sampleOfSize10 =
    +   *     pc.apply(Sample.fixedSizeGlobally(10));
    +   * } 
    + * + * @param sampleSize the number of elements to select; must be {@code >= 0} + * @param the type of the elements + */ + public static PTransform, PCollection>> + fixedSizeGlobally(int sampleSize) { + return Combine.globally(new FixedSizedSampleFn(sampleSize)); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to a sample of {@code sampleSize} values + * associated with that key in the input {@code PCollection}, taken + * uniformly at random. If a key in the input {@code PCollection} + * has fewer than {@code sampleSize} values associated with it, then + * the output {@code Iterable} associated with that key will be + * all the values associated with that key in the input + * {@code PCollection}. + * + *

    Example of use: + *

     {@code
    +   * PCollection> pc = ...;
    +   * PCollection>> sampleOfSize10PerKey =
    +   *     pc.apply(Sample.fixedSizePerKey());
    +   * } 
    + * + * @param sampleSize the number of values to select for each + * distinct key; must be {@code >= 0} + * @param the type of the keys + * @param the type of the values + */ + public static PTransform>, + PCollection>>> + fixedSizePerKey(int sampleSize) { + return Combine.perKey(new FixedSizedSampleFn(sampleSize)); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@link PTransform} that takes a {@code PCollection} and a limit, and + * produces a new {@code PCollection} containing up to limit + * elements of the input {@code PCollection}. + */ + public static class SampleAny extends PTransform, PCollection> { + private final long limit; + + /** + * Constructs a {@code SampleAny} PTransform that, when applied, + * produces a new PCollection containing up to {@code limit} + * elements of its input {@code PCollection}. + */ + private SampleAny(long limit) { + Preconditions.checkArgument(limit >= 0, "Expected non-negative limit, received %s.", limit); + this.limit = limit; + } + + @Override + public PCollection apply(PCollection in) { + PCollectionView> iterableView = in.apply(View.asIterable()); + return + in.getPipeline() + .apply(Create.of((Void) null).withCoder(VoidCoder.of())) + .apply(ParDo + .withSideInputs(iterableView) + .of(new SampleAnyDoFn<>(limit, iterableView))) + .setCoder(in.getCoder()); + } + } + + /** + * A {@link DoFn} that returns up to limit elements from the side input PCollection. + */ + private static class SampleAnyDoFn extends DoFn { + long limit; + final PCollectionView> iterableView; + + public SampleAnyDoFn(long limit, PCollectionView> iterableView) { + this.limit = limit; + this.iterableView = iterableView; + } + + @Override + public void processElement(ProcessContext c) { + for (T i : c.sideInput(iterableView)) { + if (limit-- <= 0) { + break; + } + c.output(i); + } + } + } + + /** + * {@code CombineFn} that computes a fixed-size sample of a + * collection of values. + * + * @param the type of the elements + */ + public static class FixedSizedSampleFn + extends CombineFn, SerializableComparator>>, + Iterable> { + private final Top.TopCombineFn, SerializableComparator>> + topCombineFn; + private final Random rand = new Random(); + + private FixedSizedSampleFn(int sampleSize) { + if (sampleSize < 0) { + throw new IllegalArgumentException("sample size must be >= 0"); + } + topCombineFn = new Top.TopCombineFn, SerializableComparator>>( + sampleSize, new KV.OrderByKey()); + } + + @Override + public Top.BoundedHeap, SerializableComparator>> + createAccumulator() { + return topCombineFn.createAccumulator(); + } + + @Override + public Top.BoundedHeap, SerializableComparator>> addInput( + Top.BoundedHeap, SerializableComparator>> accumulator, + T input) { + accumulator.addInput(KV.of(rand.nextInt(), input)); + return accumulator; + } + + @Override + public Top.BoundedHeap, SerializableComparator>> + mergeAccumulators( + Iterable, SerializableComparator>>> + accumulators) { + return topCombineFn.mergeAccumulators(accumulators); + } + + @Override + public Iterable extractOutput( + Top.BoundedHeap, SerializableComparator>> accumulator) { + List out = new ArrayList<>(); + for (KV element : accumulator.extractOutput()) { + out.add(element.getValue()); + } + return out; + } + + @Override + public Coder, SerializableComparator>>> + getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return topCombineFn.getAccumulatorCoder( + registry, KvCoder.of(BigEndianIntegerCoder.of(), inputCoder)); + } + + @Override + public Coder> getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) { + return IterableCoder.of(inputCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableComparator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableComparator.java new file mode 100644 index 000000000000..7d41917a94c1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableComparator.java @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * A {@code Comparator} that is also {@code Serializable}. + * + * @param type of values being compared + */ +public interface SerializableComparator extends Comparator, Serializable { +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableFunction.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableFunction.java new file mode 100644 index 000000000000..81bf3d4cb584 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SerializableFunction.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import java.io.Serializable; + +/** + * A function that computes an output value of type {@code OutputT} from an input value of type + * {@code InputT} and is {@link Serializable}. + * + * @param input value type + * @param output value type + */ +public interface SerializableFunction extends Serializable { + /** Returns the result of invoking this function on the given input. */ + public OutputT apply(InputT input); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SimpleFunction.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SimpleFunction.java new file mode 100644 index 000000000000..ef6fd81a23f6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/SimpleFunction.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +/** + * A {@link SerializableFunction} which is not a functional interface. + * Concrete subclasses allow us to infer type information, which in turn aids + * {@link Coder} inference. + */ +public abstract class SimpleFunction + implements SerializableFunction { + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the input type of this {@code DoFn} instance's most-derived + * class. + * + *

    See {@link #getOutputTypeDescriptor} for more discussion. + */ + public TypeDescriptor getInputTypeDescriptor() { + return new TypeDescriptor(this) {}; + } + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the output type of this {@code DoFn} instance's + * most-derived class. + * + *

    In the normal case of a concrete {@code DoFn} subclass with + * no generic type parameters of its own (including anonymous inner + * classes), this will be a complete non-generic type, which is good + * for choosing a default output {@code Coder} for the output + * {@code PCollection}. + */ + public TypeDescriptor getOutputTypeDescriptor() { + return new TypeDescriptor(this) {}; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sum.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sum.java new file mode 100644 index 000000000000..5b30475a9d8c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Sum.java @@ -0,0 +1,188 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind; +import com.google.cloud.dataflow.sdk.util.common.CounterProvider; + +/** + * {@code PTransform}s for computing the sum of the elements in a + * {@code PCollection}, or the sum of the values associated with + * each key in a {@code PCollection} of {@code KV}s. + * + *

    Example 1: get the sum of a {@code PCollection} of {@code Double}s. + *

     {@code
    + * PCollection input = ...;
    + * PCollection sum = input.apply(Sum.doublesGlobally());
    + * } 
    + * + *

    Example 2: calculate the sum of the {@code Integer}s + * associated with each unique key (which is of type {@code String}). + *

     {@code
    + * PCollection> input = ...;
    + * PCollection> sumPerKey = input
    + *     .apply(Sum.integersPerKey());
    + * } 
    + */ +public class Sum { + + private Sum() { + // do not instantiate + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the sum of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + */ + public static Combine.Globally integersGlobally() { + return Combine.globally(new SumIntegerFn()).named("Sum.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the sum of the values associated with + * that key in the input {@code PCollection}. + */ + public static Combine.PerKey integersPerKey() { + return Combine.perKey(new SumIntegerFn()).named("Sum.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the sum of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + */ + public static Combine.Globally longsGlobally() { + return Combine.globally(new SumLongFn()).named("Sum.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the sum of the values associated with + * that key in the input {@code PCollection}. + */ + public static Combine.PerKey longsPerKey() { + return Combine.perKey(new SumLongFn()).named("Sum.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a + * {@code PCollection} whose contents is the sum of the + * input {@code PCollection}'s elements, or + * {@code 0} if there are no elements. + */ + public static Combine.Globally doublesGlobally() { + return Combine.globally(new SumDoubleFn()).named("Sum.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the sum of the values associated with + * that key in the input {@code PCollection}. + */ + public static Combine.PerKey doublesPerKey() { + return Combine.perKey(new SumDoubleFn()).named("Sum.PerKey"); + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@code SerializableFunction} that computes the sum of an + * {@code Iterable} of {@code Integer}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class SumIntegerFn + extends Combine.BinaryCombineIntegerFn implements CounterProvider { + @Override + public int apply(int a, int b) { + return a + b; + } + + @Override + public int identity() { + return 0; + } + + @Override + public Counter getCounter(String name) { + return Counter.ints(name, AggregationKind.SUM); + } + } + + /** + * A {@code SerializableFunction} that computes the sum of an + * {@code Iterable} of {@code Long}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class SumLongFn + extends Combine.BinaryCombineLongFn implements CounterProvider { + @Override + public long apply(long a, long b) { + return a + b; + } + + @Override + public long identity() { + return 0; + } + + @Override + public Counter getCounter(String name) { + return Counter.longs(name, AggregationKind.SUM); + } + } + + /** + * A {@code SerializableFunction} that computes the sum of an + * {@code Iterable} of {@code Double}s, useful as an argument to + * {@link Combine#globally} or {@link Combine#perKey}. + */ + public static class SumDoubleFn + extends Combine.BinaryCombineDoubleFn implements CounterProvider { + @Override + public double apply(double a, double b) { + return a + b; + } + + @Override + public double identity() { + return 0; + } + + @Override + public Counter getCounter(String name) { + return Counter.doubles(name, AggregationKind.SUM); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Top.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Top.java new file mode 100644 index 000000000000..98fe53c0a847 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Top.java @@ -0,0 +1,559 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.AccumulatingCombineFn.Accumulator; +import com.google.cloud.dataflow.sdk.transforms.Combine.PerKey; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; + +/** + * {@code PTransform}s for finding the largest (or smallest) set + * of elements in a {@code PCollection}, or the largest (or smallest) + * set of values associated with each key in a {@code PCollection} of + * {@code KV}s. + */ +public class Top { + + private Top() { + // do not instantiate + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection>} with a + * single element containing the largest {@code count} elements of the input + * {@code PCollection}, in decreasing order, sorted using the + * given {@code Comparator}. The {@code Comparator} must also + * be {@code Serializable}. + * + *

    If {@code count} {@code <} the number of elements in the + * input {@code PCollection}, then all the elements of the input + * {@code PCollection} will be in the resulting + * {@code List}, albeit in sorted order. + * + *

    All the elements of the result's {@code List} + * must fit into the memory of a single machine. + * + *

    Example of use: + *

     {@code
    +   * PCollection students = ...;
    +   * PCollection> top10Students =
    +   *     students.apply(Top.of(10, new CompareStudentsByAvgGrade()));
    +   * } 
    + * + *

    By default, the {@code Coder} of the output {@code PCollection} + * is a {@code ListCoder} of the {@code Coder} of the elements of + * the input {@code PCollection}. + * + *

    If the input {@code PCollection} is windowed into {@link GlobalWindows}, + * an empty {@code List} in the {@link GlobalWindow} will be output if the input + * {@code PCollection} is empty. To use this with inputs with other windowing, + * either {@link Combine.Globally#withoutDefaults withoutDefaults} or + * {@link Combine.Globally#asSingletonView asSingletonView} must be called. + * + *

    See also {@link #smallest} and {@link #largest}, which sort + * {@code Comparable} elements using their natural ordering. + * + *

    See also {@link #perKey}, {@link #smallestPerKey}, and + * {@link #largestPerKey}, which take a {@code PCollection} of + * {@code KV}s and return the top values associated with each key. + */ + public static & Serializable> + Combine.Globally> of(int count, ComparatorT compareFn) { + return Combine.globally(new TopCombineFn<>(count, compareFn)).named("Top.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection>} with a + * single element containing the smallest {@code count} elements of the input + * {@code PCollection}, in increasing order, sorted according to + * their natural order. + * + *

    If {@code count} {@code <} the number of elements in the + * input {@code PCollection}, then all the elements of the input + * {@code PCollection} will be in the resulting {@code PCollection}'s + * {@code List}, albeit in sorted order. + * + *

    All the elements of the result {@code List} + * must fit into the memory of a single machine. + * + *

    Example of use: + *

     {@code
    +   * PCollection values = ...;
    +   * PCollection> smallest10Values = values.apply(Top.smallest(10));
    +   * } 
    + * + *

    By default, the {@code Coder} of the output {@code PCollection} + * is a {@code ListCoder} of the {@code Coder} of the elements of + * the input {@code PCollection}. + * + *

    If the input {@code PCollection} is windowed into {@link GlobalWindows}, + * an empty {@code List} in the {@link GlobalWindow} will be output if the input + * {@code PCollection} is empty. To use this with inputs with other windowing, + * either {@link Combine.Globally#withoutDefaults withoutDefaults} or + * {@link Combine.Globally#asSingletonView asSingletonView} must be called. + * + *

    See also {@link #largest}. + * + *

    See also {@link #of}, which sorts using a user-specified + * {@code Comparator} function. + * + *

    See also {@link #perKey}, {@link #smallestPerKey}, and + * {@link #largestPerKey}, which take a {@code PCollection} of + * {@code KV}s and return the top values associated with each key. + */ + public static > Combine.Globally> smallest(int count) { + return Combine.globally(new TopCombineFn<>(count, new Smallest())) + .named("Smallest.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection} and returns a {@code PCollection>} with a + * single element containing the largest {@code count} elements of the input + * {@code PCollection}, in decreasing order, sorted according to + * their natural order. + * + *

    If {@code count} {@code <} the number of elements in the + * input {@code PCollection}, then all the elements of the input + * {@code PCollection} will be in the resulting {@code PCollection}'s + * {@code List}, albeit in sorted order. + * + *

    All the elements of the result's {@code List} + * must fit into the memory of a single machine. + * + *

    Example of use: + *

     {@code
    +   * PCollection values = ...;
    +   * PCollection> largest10Values = values.apply(Top.largest(10));
    +   * } 
    + * + *

    By default, the {@code Coder} of the output {@code PCollection} + * is a {@code ListCoder} of the {@code Coder} of the elements of + * the input {@code PCollection}. + * + *

    If the input {@code PCollection} is windowed into {@link GlobalWindows}, + * an empty {@code List} in the {@link GlobalWindow} will be output if the input + * {@code PCollection} is empty. To use this with inputs with other windowing, + * either {@link Combine.Globally#withoutDefaults withoutDefaults} or + * {@link Combine.Globally#asSingletonView asSingletonView} must be called. + * + *

    See also {@link #smallest}. + * + *

    See also {@link #of}, which sorts using a user-specified + * {@code Comparator} function. + * + *

    See also {@link #perKey}, {@link #smallestPerKey}, and + * {@link #largestPerKey}, which take a {@code PCollection} of + * {@code KV}s and return the top values associated with each key. + */ + public static > Combine.Globally> largest(int count) { + return Combine.globally(new TopCombineFn<>(count, new Largest())).named("Largest.Globally"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the largest {@code count} values + * associated with that key in the input + * {@code PCollection>}, in decreasing order, sorted using + * the given {@code Comparator}. The + * {@code Comparator} must also be {@code Serializable}. + * + *

    If there are fewer than {@code count} values associated with + * a particular key, then all those values will be in the result + * mapping for that key, albeit in sorted order. + * + *

    All the values associated with a single key must fit into the + * memory of a single machine, but there can be many more + * {@code KV}s in the resulting {@code PCollection} than can fit + * into the memory of a single machine. + * + *

    Example of use: + *

     {@code
    +   * PCollection> studentsBySchool = ...;
    +   * PCollection>> top10StudentsBySchool =
    +   *     studentsBySchool.apply(
    +   *         Top.perKey(10, new CompareStudentsByAvgGrade()));
    +   * } 
    + * + *

    By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input + * {@code PCollection}, and the {@code Coder} of the values of the + * output {@code PCollection} is a {@code ListCoder} of the + * {@code Coder} of the values of the input {@code PCollection}. + * + *

    See also {@link #smallestPerKey} and {@link #largestPerKey}, which + * sort {@code Comparable} values using their natural + * ordering. + * + *

    See also {@link #of}, {@link #smallest}, and {@link #largest}, which + * take a {@code PCollection} and return the top elements. + */ + public static & Serializable> + PTransform>, PCollection>>> + perKey(int count, ComparatorT compareFn) { + return Combine.perKey( + new TopCombineFn<>(count, compareFn).asKeyedFn()).named("Top.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the smallest {@code count} values + * associated with that key in the input + * {@code PCollection>}, in increasing order, sorted + * according to their natural order. + * + *

    If there are fewer than {@code count} values associated with + * a particular key, then all those values will be in the result + * mapping for that key, albeit in sorted order. + * + *

    All the values associated with a single key must fit into the + * memory of a single machine, but there can be many more + * {@code KV}s in the resulting {@code PCollection} than can fit + * into the memory of a single machine. + * + *

    Example of use: + *

     {@code
    +   * PCollection> keyedValues = ...;
    +   * PCollection>> smallest10ValuesPerKey =
    +   *     keyedValues.apply(Top.smallestPerKey(10));
    +   * } 
    + * + *

    By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input + * {@code PCollection}, and the {@code Coder} of the values of the + * output {@code PCollection} is a {@code ListCoder} of the + * {@code Coder} of the values of the input {@code PCollection}. + * + *

    See also {@link #largestPerKey}. + * + *

    See also {@link #perKey}, which sorts values using a user-specified + * {@code Comparator} function. + * + *

    See also {@link #of}, {@link #smallest}, and {@link #largest}, which + * take a {@code PCollection} and return the top elements. + */ + public static > + PTransform>, PCollection>>> + smallestPerKey(int count) { + return Combine.perKey(new TopCombineFn<>(count, new Smallest()).asKeyedFn()) + .named("Smallest.PerKey"); + } + + /** + * Returns a {@code PTransform} that takes an input + * {@code PCollection>} and returns a + * {@code PCollection>>} that contains an output + * element mapping each distinct key in the input + * {@code PCollection} to the largest {@code count} values + * associated with that key in the input + * {@code PCollection>}, in decreasing order, sorted + * according to their natural order. + * + *

    If there are fewer than {@code count} values associated with + * a particular key, then all those values will be in the result + * mapping for that key, albeit in sorted order. + * + *

    All the values associated with a single key must fit into the + * memory of a single machine, but there can be many more + * {@code KV}s in the resulting {@code PCollection} than can fit + * into the memory of a single machine. + * + *

    Example of use: + *

     {@code
    +   * PCollection> keyedValues = ...;
    +   * PCollection>> largest10ValuesPerKey =
    +   *     keyedValues.apply(Top.largestPerKey(10));
    +   * } 
    + * + *

    By default, the {@code Coder} of the keys of the output + * {@code PCollection} is the same as that of the keys of the input + * {@code PCollection}, and the {@code Coder} of the values of the + * output {@code PCollection} is a {@code ListCoder} of the + * {@code Coder} of the values of the input {@code PCollection}. + * + *

    See also {@link #smallestPerKey}. + * + *

    See also {@link #perKey}, which sorts values using a user-specified + * {@code Comparator} function. + * + *

    See also {@link #of}, {@link #smallest}, and {@link #largest}, which + * take a {@code PCollection} and return the top elements. + */ + public static > + PerKey> + largestPerKey(int count) { + return Combine.perKey( +new TopCombineFn<>(count, new Largest()).asKeyedFn()) + .named("Largest.PerKey"); + } + + /** + * A {@code Serializable} {@code Comparator} that that uses the compared elements' natural + * ordering. + */ + public static class Largest> + implements Comparator, Serializable { + @Override + public int compare(T a, T b) { + return a.compareTo(b); + } + } + + /** + * {@code Serializable} {@code Comparator} that that uses the reverse of the compared elements' + * natural ordering. + */ + public static class Smallest> + implements Comparator, Serializable { + @Override + public int compare(T a, T b) { + return b.compareTo(a); + } + } + + + //////////////////////////////////////////////////////////////////////////// + + /** + * {@code CombineFn} for {@code Top} transforms that combines a + * bunch of {@code T}s into a single {@code count}-long + * {@code List}, using {@code compareFn} to choose the largest + * {@code T}s. + * + * @param type of element being compared + */ + public static class TopCombineFn & Serializable> + extends AccumulatingCombineFn, List> { + + private final int count; + private final ComparatorT compareFn; + + public TopCombineFn(int count, ComparatorT compareFn) { + Preconditions.checkArgument( + count >= 0, + "count must be >= 0"); + this.count = count; + this.compareFn = compareFn; + } + + @Override + public BoundedHeap createAccumulator() { + return new BoundedHeap<>(count, compareFn, new ArrayList()); + } + + @Override + public Coder> getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return new BoundedHeapCoder<>(count, compareFn, inputCoder); + } + + @Override + public String getIncompatibleGlobalWindowErrorMessage() { + return "Default values are not supported in Top.[of, smallest, largest]() if the output " + + "PCollection is not windowed by GlobalWindows. Instead, use " + + "Top.[of, smallest, largest]().withoutDefaults() to output an empty PCollection if the" + + " input PCollection is empty, or Top.[of, smallest, largest]().asSingletonView() to " + + "get a PCollection containing the empty list if the input PCollection is empty."; + } + } + + /** + * A heap that stores only a finite number of top elements according to its provided + * {@code Comparator}. Implemented as an {@link Accumulator} to facilitate implementation of + * {@link Top}. + * + *

    This class is not safe for multithreaded use, except read-only. + */ + static class BoundedHeap & Serializable> + implements Accumulator, List> { + + /** + * A queue with smallest at the head, for quick adds. + * + *

    Only one of asList and asQueue may be non-null. + */ + private PriorityQueue asQueue; + + /** + * A list in with largest first, the form of extractOutput(). + * + *

    Only one of asList and asQueue may be non-null. + */ + private List asList; + + /** The user-provided Comparator. */ + private final ComparatorT compareFn; + + /** The maximum size of the heap. */ + private final int maximumSize; + + /** + * Creates a new heap with the provided size, comparator, and initial elements. + */ + private BoundedHeap(int maximumSize, ComparatorT compareFn, List asList) { + this.maximumSize = maximumSize; + this.asList = asList; + this.compareFn = compareFn; + } + + @Override + public void addInput(T value) { + maybeAddInput(value); + } + + /** + * Adds {@code value} to this heap if it is larger than any of the current elements. + * Returns {@code true} if {@code value} was added. + */ + private boolean maybeAddInput(T value) { + if (maximumSize == 0) { + // Don't add anything. + return false; + } + + // If asQueue == null, then this is the first add after the latest call to the + // constructor or asList(). + if (asQueue == null) { + asQueue = new PriorityQueue<>(maximumSize, compareFn); + for (T item : asList) { + asQueue.add(item); + } + asList = null; + } + + if (asQueue.size() < maximumSize) { + asQueue.add(value); + return true; + } else if (compareFn.compare(value, asQueue.peek()) > 0) { + asQueue.poll(); + asQueue.add(value); + return true; + } else { + return false; + } + } + + @Override + public void mergeAccumulator(BoundedHeap accumulator) { + for (T value : accumulator.asList()) { + if (!maybeAddInput(value)) { + // If this element of accumulator does not make the top N, neither + // will the rest, which are all smaller. + break; + } + } + } + + @Override + public List extractOutput() { + return asList(); + } + + /** + * Returns the contents of this Heap as a List sorted largest-to-smallest. + */ + private List asList() { + if (asList == null) { + List smallestFirstList = Lists.newArrayListWithCapacity(asQueue.size()); + while (!asQueue.isEmpty()) { + smallestFirstList.add(asQueue.poll()); + } + asList = Lists.reverse(smallestFirstList); + asQueue = null; + } + return asList; + } + } + + /** + * A {@link Coder} for {@link BoundedHeap}, using Java serialization via {@link CustomCoder}. + */ + private static class BoundedHeapCoder & Serializable> + extends CustomCoder> { + private final Coder> listCoder; + private final ComparatorT compareFn; + private final int maximumSize; + + public BoundedHeapCoder(int maximumSize, ComparatorT compareFn, Coder elementCoder) { + listCoder = ListCoder.of(elementCoder); + this.compareFn = compareFn; + this.maximumSize = maximumSize; + } + + @Override + public void encode( + BoundedHeap value, OutputStream outStream, Context context) + throws CoderException, IOException { + listCoder.encode(value.asList(), outStream, context); + } + + @Override + public BoundedHeap decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + return new BoundedHeap<>(maximumSize, compareFn, listCoder.decode(inStream, context)); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "HeapCoder requires a deterministic list coder", listCoder); + } + + @Override + public boolean isRegisterByteSizeObserverCheap( + BoundedHeap value, Context context) { + return listCoder.isRegisterByteSizeObserverCheap( + value.asList(), context); + } + + @Override + public void registerByteSizeObserver( + BoundedHeap value, ElementByteSizeObserver observer, Context context) + throws Exception { + listCoder.registerByteSizeObserver(value.asList(), observer, context); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Values.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Values.java new file mode 100644 index 000000000000..d84bc779ece9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Values.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +/** + * {@code Values} takes a {@code PCollection} of {@code KV}s and + * returns a {@code PCollection} of the values. + * + *

    Example of use: + *

     {@code
    + * PCollection> wordCounts = ...;
    + * PCollection counts = wordCounts.apply(Values.create());
    + * } 
    + * + *

    Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + *

    See also {@link Keys}. + * + * @param the type of the values in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ +public class Values extends PTransform>, + PCollection> { + /** + * Returns a {@code Values} {@code PTransform}. + * + * @param the type of the values in the input {@code PCollection}, + * and the type of the elements in the output {@code PCollection} + */ + public static Values create() { + return new Values<>(); + } + + private Values() { } + + @Override + public PCollection apply(PCollection> in) { + return + in.apply(ParDo.named("Values") + .of(new DoFn, V>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getValue()); + } + })); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/View.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/View.java new file mode 100644 index 000000000000..e2c4487ae5f1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/View.java @@ -0,0 +1,470 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.util.PCollectionViews; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import java.util.List; +import java.util.Map; + +/** + * Transforms for creating {@link PCollectionView PCollectionViews} from + * {@link PCollection PCollections} (to read them as side inputs). + * + *

    While a {@link PCollection PCollection<ElemT>} has many values of type {@code ElemT} per + * window, a {@link PCollectionView PCollectionView<ViewT>} has a single value of type + * {@code ViewT} for each window. It can be thought of as a mapping from windows to values of + * type {@code ViewT}. The transforms here represent ways of converting the {@code ElemT} values + * in a window into a {@code ViewT} for that window. + * + *

    When a {@link ParDo} tranform is processing a main input + * element in a window {@code w} and a {@link PCollectionView} is read via + * {@link DoFn.ProcessContext#sideInput}, the value of the view for {@code w} is + * returned. + * + *

    The SDK supports viewing a {@link PCollection}, per window, as a single value, + * a {@link List}, an {@link Iterable}, a {@link Map}, or a multimap (iterable-valued {@link Map}). + * + *

    For a {@link PCollection} that contains a single value of type {@code T} + * per window, such as the output of {@link Combine#globally}, + * use {@link View#asSingleton()} to prepare it for use as a side input: + * + *

    + * {@code
    + * PCollectionView output = someOtherPCollection
    + *     .apply(Combine.globally(...))
    + *     .apply(View.asSingleton());
    + * }
    + * 
    + * + *

    For a small {@link PCollection} with windows that can fit entirely in memory, + * use {@link View#asList()} to prepare it for use as a {@code List}. + * When read as a side input, the entire list for a window will be cached in memory. + * + *

    + * {@code
    + * PCollectionView> output =
    + *    smallPCollection.apply(View.asList());
    + * }
    + * 
    + * + *

    If a {@link PCollection} of {@code KV} is known to + * have a single value per window for each key, then use {@link View#asMap()} + * to view it as a {@code Map}: + * + *

    + * {@code
    + * PCollectionView output =
    + *     somePCollection.apply(View.asMap());
    + * }
    + * 
    + * + *

    Otherwise, to access a {@link PCollection} of {@code KV} as a + * {@code Map>} side input, use {@link View#asMultimap()}: + * + *

    + * {@code
    + * PCollectionView> output =
    + *     somePCollection.apply(View.>asMap());
    + * }
    + * 
    + * + *

    To iterate over an entire window of a {@link PCollection} via + * side input, use {@link View#asIterable()}: + * + *

    + * {@code
    + * PCollectionView> output =
    + *     somePCollection.apply(View.asIterable());
    + * }
    + * 
    + * + * + *

    Both {@link View#asMultimap()} and {@link View#asMap()} are useful + * for implementing lookup based "joins" with the main input, when the + * side input is small enough to fit into memory. + * + *

    For example, if you represent a page on a website via some {@code Page} object and + * have some type {@code UrlVisits} logging that a URL was visited, you could convert these + * to more fully structured {@code PageVisit} objects using a side input, something like the + * following: + * + *

    + * {@code
    + * PCollection pages = ... // pages fit into memory
    + * PCollection urlVisits = ... // very large collection
    + * final PCollectionView> = urlToPage
    + *     .apply(WithKeys.of( ... )) // extract the URL from the page
    + *     .apply(View.asMap());
    + *
    + * PCollection PageVisits = urlVisits
    + *     .apply(ParDo.withSideInputs(urlToPage)
    + *         .of(new DoFn() {
    + *             {@literal @}Override
    + *             void processElement(ProcessContext context) {
    + *               UrlVisit urlVisit = context.element();
    + *               Page page = urlToPage.get(urlVisit.getUrl());
    + *               c.output(new PageVisit(page, urlVisit.getVisitData()));
    + *             }
    + *         }));
    + * }
    + * 
    + * + *

    See {@link ParDo#withSideInputs} for details on how to access + * this variable inside a {@link ParDo} over another {@link PCollection}. + */ +public class View { + + // Do not instantiate + private View() { } + + /** + * Returns a {@link AsSingleton} transform that takes a + * {@link PCollection} with a single value per window + * as input and produces a {@link PCollectionView} that returns + * the value in the main input window when read as a side input. + * + *

    +   * {@code
    +   * PCollection input = ...
    +   * CombineFn yourCombineFn = ...
    +   * PCollectionView output = input
    +   *     .apply(Combine.globally(yourCombineFn))
    +   *     .apply(View.asSingleton());
    +   * }
    + * + *

    If the input {@link PCollection} is empty, + * throws {@link java.util.NoSuchElementException} in the consuming + * {@link DoFn}. + * + *

    If the input {@link PCollection} contains more than one + * element, throws {@link IllegalArgumentException} in the + * consuming {@link DoFn}. + */ + public static AsSingleton asSingleton() { + return new AsSingleton<>(); + } + + /** + * Returns a {@link View.AsList} transform that takes a {@link PCollection} and returns a + * {@link PCollectionView} mapping each window to a {@link List} containing + * all of the elements in the window. + * + *

    The resulting list is required to fit in memory. + */ + public static AsList asList() { + return new AsList<>(); + } + + /** + * Returns a {@link View.AsIterable} transform that takes a {@link PCollection} as input + * and produces a {@link PCollectionView} mapping each window to an + * {@link Iterable} of the values in that window. + * + *

    The values of the {@link Iterable} for a window are not required to fit in memory, + * but they may also not be effectively cached. If it is known that every window fits in memory, + * and stronger caching is desired, use {@link #asList}. + */ + public static AsIterable asIterable() { + return new AsIterable<>(); + } + + /** + * Returns a {@link View.AsMap} transform that takes a + * {@link PCollection PCollection<KV<K V>>} as + * input and produces a {@link PCollectionView} mapping each window to + * a {@link Map Map>K, V>}. It is required that each key of the input be + * associated with a single value, per window. If this is not the case, precede this + * view with {@code Combine.perKey}, as in the example below, or alternatively + * use {@link View#asMultimap()}. + * + *

    +   * {@code
    +   * PCollection> input = ...
    +   * CombineFn yourCombineFn = ...
    +   * PCollectionView> output = input
    +   *     .apply(Combine.perKey(yourCombineFn.asKeyedFn()))
    +   *     .apply(View.asMap());
    +   * }
    + * + *

    Currently, the resulting map is required to fit into memory. + */ + public static AsMap asMap() { + return new AsMap(); + } + + /** + * Returns a {@link View.AsMultimap} transform that takes a + * {@link PCollection PCollection<KV<K, V>>} + * as input and produces a {@link PCollectionView} mapping + * each window to its contents as a {@link Map Map<K, Iterable<V>>} + * for use as a side input. + * In contrast to {@link View#asMap()}, it is not required that the keys in the + * input collection be unique. + * + *

    +   * {@code
    +   * PCollection> input = ... // maybe more than one occurrence of a some keys
    +   * PCollectionView> output = input.apply(View.asMultimap());
    +   * }
    + * + *

    Currently, the resulting map is required to fit into memory. + */ + public static AsMultimap asMultimap() { + return new AsMultimap(); + } + + /** + * Not intended for direct use by pipeline authors; public only so a {@link PipelineRunner} may + * override its behavior. + * + *

    See {@link View#asList()}. + */ + public static class AsList extends PTransform, PCollectionView>> { + private AsList() { } + + @Override + public void validate(PCollection input) { + try { + GroupByKey.applicableTo(input); + } catch (IllegalStateException e) { + throw new IllegalStateException("Unable to create a side-input view from input", e); + } + } + + @Override + public PCollectionView> apply(PCollection input) { + return input.apply(CreatePCollectionView.>of(PCollectionViews.listView( + input.getPipeline(), input.getWindowingStrategy(), input.getCoder()))); + } + } + + /** + * Not intended for direct use by pipeline authors; public only so a {@link PipelineRunner} may + * override its behavior. + * + *

    See {@link View#asIterable()}. + */ + public static class AsIterable + extends PTransform, PCollectionView>> { + private AsIterable() { } + + @Override + public void validate(PCollection input) { + try { + GroupByKey.applicableTo(input); + } catch (IllegalStateException e) { + throw new IllegalStateException("Unable to create a side-input view from input", e); + } + } + + @Override + public PCollectionView> apply(PCollection input) { + return input.apply(CreatePCollectionView.>of(PCollectionViews.iterableView( + input.getPipeline(), input.getWindowingStrategy(), input.getCoder()))); + } + } + + /** + * Not intended for direct use by pipeline authors; public only so a {@link PipelineRunner} may + * override its behavior. + * + *

    See {@link View#asSingleton()}. + */ + public static class AsSingleton extends PTransform, PCollectionView> { + private final T defaultValue; + private final boolean hasDefault; + + private AsSingleton() { + this.defaultValue = null; + this.hasDefault = false; + } + + private AsSingleton(T defaultValue) { + this.defaultValue = defaultValue; + this.hasDefault = true; + } + + /** + * Returns whether this transform has a default value. + */ + public boolean hasDefaultValue() { + return hasDefault; + } + + /** + * Returns the default value of this transform, or null if there isn't one. + */ + public T defaultValue() { + return defaultValue; + } + + /** + * Default value to return for windows with no value in them. + */ + public AsSingleton withDefaultValue(T defaultValue) { + return new AsSingleton<>(defaultValue); + } + + @Override + public void validate(PCollection input) { + try { + GroupByKey.applicableTo(input); + } catch (IllegalStateException e) { + throw new IllegalStateException("Unable to create a side-input view from input", e); + } + } + + @Override + public PCollectionView apply(PCollection input) { + return input.apply(CreatePCollectionView.of(PCollectionViews.singletonView( + input.getPipeline(), + input.getWindowingStrategy(), + hasDefault, + defaultValue, + input.getCoder()))); + } + } + + /** + * Not intended for direct use by pipeline authors; public only so a {@link PipelineRunner} may + * override its behavior. + * + *

    See {@link View#asMultimap()}. + */ + public static class AsMultimap + extends PTransform>, PCollectionView>>> { + private AsMultimap() { } + + @Override + public void validate(PCollection> input) { + try { + GroupByKey.applicableTo(input); + } catch (IllegalStateException e) { + throw new IllegalStateException("Unable to create a side-input view from input", e); + } + } + + @Override + public PCollectionView>> apply(PCollection> input) { + return input.apply(CreatePCollectionView., Map>>of( + PCollectionViews.multimapView( + input.getPipeline(), + input.getWindowingStrategy(), + input.getCoder()))); + } + } + + /** + * Not intended for direct use by pipeline authors; public only so a {@link PipelineRunner} may + * override its behavior. + * + *

    See {@link View#asMap()}. + */ + public static class AsMap + extends PTransform>, PCollectionView>> { + private AsMap() { } + + /** + * @deprecated this method simply returns this AsMap unmodified + */ + @Deprecated() + public AsMap withSingletonValues() { + return this; + } + + @Override + public void validate(PCollection> input) { + try { + GroupByKey.applicableTo(input); + } catch (IllegalStateException e) { + throw new IllegalStateException("Unable to create a side-input view from input", e); + } + } + + @Override + public PCollectionView> apply(PCollection> input) { + return input.apply(CreatePCollectionView., Map>of( + PCollectionViews.mapView( + input.getPipeline(), + input.getWindowingStrategy(), + input.getCoder()))); + } + } + + //////////////////////////////////////////////////////////////////////////// + // Internal details below + + /** + * Creates a primitive {@link PCollectionView}. + * + *

    For internal use only by runner implementors. + * + * @param The type of the elements of the input PCollection + * @param The type associated with the {@link PCollectionView} used as a side input + */ + public static class CreatePCollectionView + extends PTransform, PCollectionView> { + private PCollectionView view; + + private CreatePCollectionView(PCollectionView view) { + this.view = view; + } + + public static CreatePCollectionView of( + PCollectionView view) { + return new CreatePCollectionView<>(view); + } + + public PCollectionView getView() { + return view; + } + + @Override + public PCollectionView apply(PCollection input) { + return view; + } + + static { + DirectPipelineRunner.registerDefaultTransformEvaluator( + CreatePCollectionView.class, + new DirectPipelineRunner.TransformEvaluator() { + @SuppressWarnings("rawtypes") + @Override + public void evaluate( + CreatePCollectionView transform, + DirectPipelineRunner.EvaluationContext context) { + evaluateTyped(transform, context); + } + + private void evaluateTyped( + CreatePCollectionView transform, + DirectPipelineRunner.EvaluationContext context) { + List> elems = + context.getPCollectionWindowedValues(context.getInput(transform)); + context.setPCollectionView(context.getOutput(transform), elems); + } + }); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithKeys.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithKeys.java new file mode 100644 index 000000000000..c06795c70384 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithKeys.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +/** + * {@code WithKeys} takes a {@code PCollection}, and either a + * constant key of type {@code K} or a function from {@code V} to + * {@code K}, and returns a {@code PCollection>}, where each + * of the values in the input {@code PCollection} has been paired with + * either the constant key or a key computed from the value. + * + *

    Example of use: + *

     {@code
    + * PCollection words = ...;
    + * PCollection> lengthsToWords =
    + *     words.apply(WithKeys.of(new SerializableFunction() {
    + *         public Integer apply(String s) { return s.length(); } }));
    + * } 
    + * + *

    Each output element has the same timestamp and is in the same windows + * as its corresponding input element, and the output {@code PCollection} + * has the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn} + * associated with it as the input. + * + * @param the type of the keys in the output {@code PCollection} + * @param the type of the elements in the input + * {@code PCollection} and the values in the output + * {@code PCollection} + */ +public class WithKeys extends PTransform, + PCollection>> { + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection>}, where each of the + * values in the input {@code PCollection} has been paired with a + * key computed from the value by invoking the given + * {@code SerializableFunction}. + * + *

    If using a lambda in Java 8, {@link #withKeyType(TypeDescriptor)} must + * be called on the result {@link PTransform}. + */ + public static WithKeys of(SerializableFunction fn) { + return new WithKeys<>(fn, null); + } + + /** + * Returns a {@code PTransform} that takes a {@code PCollection} + * and returns a {@code PCollection>}, where each of the + * values in the input {@code PCollection} has been paired with the + * given key. + */ + @SuppressWarnings("unchecked") + public static WithKeys of(final K key) { + return new WithKeys<>( + new SerializableFunction() { + @Override + public K apply(V value) { + return key; + } + }, + (Class) (key == null ? null : key.getClass())); + } + + + ///////////////////////////////////////////////////////////////////////////// + + private SerializableFunction fn; + private transient Class keyClass; + + private WithKeys(SerializableFunction fn, Class keyClass) { + this.fn = fn; + this.keyClass = keyClass; + } + + /** + * Return a {@link WithKeys} that is like this one with the specified key type descriptor. + * + * For use with lambdas in Java 8, either this method must be called with an appropriate type + * descriptor or {@link PCollection#setCoder(Coder)} must be called on the output + * {@link PCollection}. + */ + public WithKeys withKeyType(TypeDescriptor keyType) { + // Safe cast + @SuppressWarnings("unchecked") + Class rawType = (Class) keyType.getRawType(); + return new WithKeys<>(fn, rawType); + } + + @Override + public PCollection> apply(PCollection in) { + PCollection> result = + in.apply(ParDo.named("AddKeys") + .of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(fn.apply(c.element()), + c.element())); + } + })); + + try { + Coder keyCoder; + CoderRegistry coderRegistry = in.getPipeline().getCoderRegistry(); + if (keyClass == null) { + keyCoder = coderRegistry.getDefaultOutputCoder(fn, in.getCoder()); + } else { + keyCoder = coderRegistry.getDefaultCoder(TypeDescriptor.of(keyClass)); + } + // TODO: Remove when we can set the coder inference context. + result.setCoder(KvCoder.of(keyCoder, in.getCoder())); + } catch (CannotProvideCoderException exc) { + // let lazy coder inference have a try + } + + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithTimestamps.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithTimestamps.java new file mode 100644 index 000000000000..85a93bfe18d4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/WithTimestamps.java @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** + * A {@link PTransform} for assigning timestamps to all the elements of a {@link PCollection}. + * + *

    Timestamps are used to assign {@link BoundedWindow Windows} to elements within the + * {@link Window#into(com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn)} + * {@link PTransform}. Assigning timestamps is useful when the input data set comes from a + * {@link Source} without implicit timestamps (such as + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read TextIO}). + * + */ +public class WithTimestamps extends PTransform, PCollection> { + /** + * For a {@link SerializableFunction} {@code fn} from {@code T} to {@link Instant}, outputs a + * {@link PTransform} that takes an input {@link PCollection PCollection<T>} and outputs a + * {@link PCollection PCollection<T>} containing every element {@code v} in the input where + * each element is output with a timestamp obtained as the result of {@code fn.apply(v)}. + * + *

    If the input {@link PCollection} elements have timestamps, the output timestamp for each + * element must not be before the input element's timestamp minus the value of + * {@link #getAllowedTimestampSkew()}. If an output timestamp is before this time, the transform + * will throw an {@link IllegalArgumentException} when executed. Use + * {@link #withAllowedTimestampSkew(Duration)} to update the allowed skew. + * + *

    Each output element will be in the same windows as the input element. If a new window based + * on the new output timestamp is desired, apply a new instance of {@link Window#into(WindowFn)}. + * + *

    This transform will fail at execution time with a {@link NullPointerException} if for any + * input element the result of {@code fn.apply(v)} is {@code null}. + * + *

    Example of use in Java 8: + *

    {@code
    +   * PCollection timestampedRecords = records.apply(
    +   *     WithTimestamps.of((Record rec) -> rec.getInstant());
    +   * }
    + */ + public static WithTimestamps of(SerializableFunction fn) { + return new WithTimestamps<>(fn, Duration.ZERO); + } + + /////////////////////////////////////////////////////////////////// + + private final SerializableFunction fn; + private final Duration allowedTimestampSkew; + + private WithTimestamps(SerializableFunction fn, Duration allowedTimestampSkew) { + this.fn = checkNotNull(fn, "WithTimestamps fn cannot be null"); + this.allowedTimestampSkew = allowedTimestampSkew; + } + + /** + * Return a new WithTimestamps like this one with updated allowed timestamp skew, which is the + * maximum duration that timestamps can be shifted backward. Does not modify this object. + * + *

    The default value is {@code Duration.ZERO}, allowing timestamps to only be shifted into the + * future. For infinite skew, use {@code new Duration(Long.MAX_VALUE)}. + */ + public WithTimestamps withAllowedTimestampSkew(Duration allowedTimestampSkew) { + return new WithTimestamps<>(this.fn, allowedTimestampSkew); + } + + /** + * Returns the allowed timestamp skew duration, which is the maximum + * duration that timestamps can be shifted backwards from the timestamp of the input element. + * + * @see DoFn#getAllowedTimestampSkew() + */ + public Duration getAllowedTimestampSkew() { + return allowedTimestampSkew; + } + + @Override + public PCollection apply(PCollection input) { + return input + .apply(ParDo.named("AddTimestamps").of(new AddTimestampsDoFn(fn, allowedTimestampSkew))) + .setTypeDescriptorInternal(input.getTypeDescriptor()); + } + + private static class AddTimestampsDoFn extends DoFn { + private final SerializableFunction fn; + private final Duration allowedTimestampSkew; + + public AddTimestampsDoFn(SerializableFunction fn, Duration allowedTimestampSkew) { + this.fn = fn; + this.allowedTimestampSkew = allowedTimestampSkew; + } + + @Override + public void processElement(ProcessContext c) { + Instant timestamp = fn.apply(c.element()); + checkNotNull( + timestamp, "Timestamps for WithTimestamps cannot be null. Timestamp provided by %s.", fn); + c.outputWithTimestamp(c.element(), timestamp); + } + + @Override + public Duration getAllowedTimestampSkew() { + return allowedTimestampSkew; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Write.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Write.java new file mode 100644 index 000000000000..5cf655a3e563 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Write.java @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +/** + * A backwards-compatible {@code Write} class that simply inherits from the + * {@link com.google.cloud.dataflow.sdk.io.Write} class that should be used instead. + * + * @deprecated: use {@link com.google.cloud.dataflow.sdk.io.Write} from the + * {@code com.google.cloud.dataflow.sdk.io} package instead. + */ +@Deprecated +public class Write extends com.google.cloud.dataflow.sdk.io.Write { +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResult.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResult.java new file mode 100644 index 000000000000..aac57bc5fc1b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResult.java @@ -0,0 +1,463 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static com.google.cloud.dataflow.sdk.util.Structs.addObject; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * A row result of a {@link CoGroupByKey}. This is a tuple of {@link Iterable}s produced for + * a given key, and these can be accessed in different ways. + */ +public class CoGbkResult { + /** + * A map of integer union tags to a list of union objects. + * Note: the key and the embedded union tag are the same, so it is redundant + * to store it multiple times, but for now it makes encoding easier. + */ + private final List> valueMap; + + private final CoGbkResultSchema schema; + + private static final int DEFAULT_IN_MEMORY_ELEMENT_COUNT = 10_000; + + private static final Logger LOG = LoggerFactory.getLogger(CoGbkResult.class); + + /** + * A row in the {@link PCollection} resulting from a {@link CoGroupByKey} transform. + * Currently, this row must fit into memory. + * + * @param schema the set of tuple tags used to refer to input tables and + * result values + * @param taggedValues the raw results from a group-by-key + */ + public CoGbkResult( + CoGbkResultSchema schema, + Iterable taggedValues) { + this(schema, taggedValues, DEFAULT_IN_MEMORY_ELEMENT_COUNT); + } + + @SuppressWarnings("unchecked") + public CoGbkResult( + CoGbkResultSchema schema, + Iterable taggedValues, + int inMemoryElementCount) { + this.schema = schema; + valueMap = new ArrayList<>(); + for (int unionTag = 0; unionTag < schema.size(); unionTag++) { + valueMap.add(new ArrayList<>()); + } + + // Demultiplex the first imMemoryElementCount tagged union values + // according to their tag. + final Iterator taggedIter = taggedValues.iterator(); + int elementCount = 0; + while (taggedIter.hasNext()) { + if (elementCount++ >= inMemoryElementCount && taggedIter instanceof Reiterator) { + // Let the tails be lazy. + break; + } + RawUnionValue value = taggedIter.next(); + // Make sure the given union tag has a corresponding tuple tag in the + // schema. + int unionTag = value.getUnionTag(); + if (schema.size() <= unionTag) { + throw new IllegalStateException("union tag " + unionTag + + " has no corresponding tuple tag in the result schema"); + } + List valueList = (List) valueMap.get(unionTag); + valueList.add(value.getValue()); + } + + if (taggedIter.hasNext()) { + // If we get here, there were more elements than we can afford to + // keep in memory, so we copy the re-iterable of remaining items + // and append filtered views to each of the sorted lists computed earlier. + LOG.info("CoGbkResult has more than " + inMemoryElementCount + " elements," + + " reiteration (which may be slow) is required."); + final Reiterator tail = (Reiterator) taggedIter; + // This is a trinary-state array recording whether a given tag is present in the tail. The + // initial value is null (unknown) for all tags, and the first iteration through the entire + // list will set these values to true or false to avoid needlessly iterating if filtering + // against a given tag would not match anything. + final Boolean[] containsTag = new Boolean[schema.size()]; + for (int unionTag = 0; unionTag < schema.size(); unionTag++) { + final int unionTag0 = unionTag; + updateUnionTag(tail, containsTag, unionTag, unionTag0); + } + } + } + + private void updateUnionTag( + final Reiterator tail, final Boolean[] containsTag, + int unionTag, final int unionTag0) { + @SuppressWarnings("unchecked") + final Iterable head = (Iterable) valueMap.get(unionTag); + valueMap.set( + unionTag, + new Iterable() { + @Override + public Iterator iterator() { + return Iterators.concat( + head.iterator(), + new UnionValueIterator(unionTag0, tail.copy(), containsTag)); + } + }); + } + + public boolean isEmpty() { + for (Iterable tagValues : valueMap) { + if (tagValues.iterator().hasNext()) { + return false; + } + } + return true; + } + + /** + * Returns the schema used by this {@link CoGbkResult}. + */ + public CoGbkResultSchema getSchema() { + return schema; + } + + @Override + public String toString() { + return valueMap.toString(); + } + + /** + * Returns the values from the table represented by the given + * {@code TupleTag} as an {@code Iterable} (which may be empty if there + * are no results). + * + *

    If tag was not part of the original {@link CoGroupByKey}, + * throws an IllegalArgumentException. + */ + public Iterable getAll(TupleTag tag) { + int index = schema.getIndex(tag); + if (index < 0) { + throw new IllegalArgumentException("TupleTag " + tag + + " is not in the schema"); + } + @SuppressWarnings("unchecked") + Iterable unions = (Iterable) valueMap.get(index); + return unions; + } + + /** + * If there is a singleton value for the given tag, returns it. + * Otherwise, throws an IllegalArgumentException. + * + *

    If tag was not part of the original {@link CoGroupByKey}, + * throws an IllegalArgumentException. + */ + public V getOnly(TupleTag tag) { + return innerGetOnly(tag, null, false); + } + + /** + * If there is a singleton value for the given tag, returns it. If there is + * no value for the given tag, returns the defaultValue. + * + *

    If tag was not part of the original {@link CoGroupByKey}, + * throws an IllegalArgumentException. + */ + public V getOnly(TupleTag tag, V defaultValue) { + return innerGetOnly(tag, defaultValue, true); + } + + /** + * A {@link Coder} for {@link CoGbkResult}s. + */ + public static class CoGbkResultCoder extends StandardCoder { + + private final CoGbkResultSchema schema; + private final UnionCoder unionCoder; + + /** + * Returns a {@link CoGbkResultCoder} for the given schema and {@link UnionCoder}. + */ + public static CoGbkResultCoder of( + CoGbkResultSchema schema, + UnionCoder unionCoder) { + return new CoGbkResultCoder(schema, unionCoder); + } + + @JsonCreator + public static CoGbkResultCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components, + @JsonProperty(PropertyNames.CO_GBK_RESULT_SCHEMA) CoGbkResultSchema schema) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return new CoGbkResultCoder(schema, (UnionCoder) components.get(0)); + } + + private CoGbkResultCoder( + CoGbkResultSchema tupleTags, + UnionCoder unionCoder) { + this.schema = tupleTags; + this.unionCoder = unionCoder; + } + + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return Arrays.>asList(unionCoder); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addObject(result, PropertyNames.CO_GBK_RESULT_SCHEMA, schema.asCloudObject()); + return result; + } + + @Override + @SuppressWarnings("unchecked") + public void encode( + CoGbkResult value, + OutputStream outStream, + Context context) throws CoderException, + IOException { + if (!schema.equals(value.getSchema())) { + throw new CoderException("input schema does not match coder schema"); + } + for (int unionTag = 0; unionTag < schema.size(); unionTag++) { + tagListCoder(unionTag).encode(value.valueMap.get(unionTag), outStream, Context.NESTED); + } + } + + @Override + public CoGbkResult decode( + InputStream inStream, + Context context) + throws CoderException, IOException { + List> valueMap = new ArrayList<>(); + for (int unionTag = 0; unionTag < schema.size(); unionTag++) { + valueMap.add(tagListCoder(unionTag).decode(inStream, Context.NESTED)); + } + return new CoGbkResult(schema, valueMap); + } + + @SuppressWarnings("rawtypes") + private IterableCoder tagListCoder(int unionTag) { + return IterableCoder.of(unionCoder.getComponents().get(unionTag)); + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + if (!(object instanceof CoGbkResultCoder)) { + return false; + } + CoGbkResultCoder other = (CoGbkResultCoder) object; + return schema.equals(other.schema) && unionCoder.equals(other.unionCoder); + } + + @Override + public int hashCode() { + return Objects.hashCode(schema); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "CoGbkResult requires the union coder to be deterministic", unionCoder); + } + } + + + ////////////////////////////////////////////////////////////////////////////// + // Methods for directly constructing a CoGbkResult + // + // (for example, creating test data for a transform that consumes a + // CoGbkResult) + + /** + * Returns a new CoGbkResult that contains just the given tag and given data. + */ + public static CoGbkResult of(TupleTag tag, List data) { + return CoGbkResult.empty().and(tag, data); + } + + /** + * Returns a new {@link CoGbkResult} based on this, with the given tag and given data + * added to it. + */ + public CoGbkResult and(TupleTag tag, List data) { + if (nextTestUnionId != schema.size()) { + throw new IllegalArgumentException( + "Attempting to call and() on a CoGbkResult apparently not created by" + + " of()."); + } + List> valueMap = new ArrayList<>(this.valueMap); + valueMap.add(data); + return new CoGbkResult( + new CoGbkResultSchema(schema.getTupleTagList().and(tag)), valueMap, + nextTestUnionId + 1); + } + + /** + * Returns an empty {@link CoGbkResult}. + */ + public static CoGbkResult empty() { + return new CoGbkResult(new CoGbkResultSchema(TupleTagList.empty()), + new ArrayList>()); + } + + ////////////////////////////////////////////////////////////////////////////// + + private int nextTestUnionId = 0; + + private CoGbkResult( + CoGbkResultSchema schema, + List> valueMap, + int nextTestUnionId) { + this(schema, valueMap); + this.nextTestUnionId = nextTestUnionId; + } + + private CoGbkResult( + CoGbkResultSchema schema, + List> valueMap) { + this.schema = schema; + this.valueMap = valueMap; + } + + private V innerGetOnly( + TupleTag tag, + V defaultValue, + boolean useDefault) { + int index = schema.getIndex(tag); + if (index < 0) { + throw new IllegalArgumentException("TupleTag " + tag + + " is not in the schema"); + } + @SuppressWarnings("unchecked") + Iterator unions = (Iterator) valueMap.get(index).iterator(); + if (!unions.hasNext()) { + if (useDefault) { + return defaultValue; + } else { + throw new IllegalArgumentException("TupleTag " + tag + + " corresponds to an empty result, and no default was provided"); + } + } + V value = unions.next(); + if (unions.hasNext()) { + throw new IllegalArgumentException("TupleTag " + tag + + " corresponds to a non-singleton result"); + } + return value; + } + + /** + * Lazily filters and recasts an {@code Iterator} into an + * {@code Iterator}, where V is the type of the raw union value's contents. + */ + private static class UnionValueIterator implements Iterator { + + private final int tag; + private final PeekingIterator unions; + private final Boolean[] containsTag; + + private UnionValueIterator(int tag, Iterator unions, Boolean[] containsTag) { + this.tag = tag; + this.unions = Iterators.peekingIterator(unions); + this.containsTag = containsTag; + } + + @Override + public boolean hasNext() { + if (containsTag[tag] == Boolean.FALSE) { + return false; + } + advance(); + if (unions.hasNext()) { + return true; + } else { + // Now that we've iterated over all the values, we can resolve all the "unknown" null + // values to false. + for (int i = 0; i < containsTag.length; i++) { + if (containsTag[i] == null) { + containsTag[i] = false; + } + } + return false; + } + } + + @Override + @SuppressWarnings("unchecked") + public V next() { + advance(); + return (V) unions.next().getValue(); + } + + private void advance() { + while (unions.hasNext()) { + int curTag = unions.peek().getUnionTag(); + containsTag[curTag] = true; + if (curTag == tag) { + break; + } + unions.next(); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultSchema.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultSchema.java new file mode 100644 index 000000000000..2860ba70e1e0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultSchema.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static com.google.cloud.dataflow.sdk.util.Structs.addList; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +/** + * A schema for the results of a {@link CoGroupByKey}. This maintains the full + * set of {@link TupleTag}s for the results of a {@link CoGroupByKey} and + * facilitates mapping between {@link TupleTag}s and + * {@link RawUnionValue} tags (which are used as secondary keys in the + * {@link CoGroupByKey}). + */ +public class CoGbkResultSchema implements Serializable { + + private final TupleTagList tupleTagList; + + @JsonCreator + public static CoGbkResultSchema of( + @JsonProperty(PropertyNames.TUPLE_TAGS) List> tags) { + TupleTagList tupleTags = TupleTagList.empty(); + for (TupleTag tag : tags) { + tupleTags = tupleTags.and(tag); + } + return new CoGbkResultSchema(tupleTags); + } + + /** + * Maps TupleTags to union tags. This avoids needing to encode the tags + * themselves. + */ + private final HashMap, Integer> tagMap = new HashMap<>(); + + /** + * Builds a schema from a tuple of {@code TupleTag}s. + */ + public CoGbkResultSchema(TupleTagList tupleTagList) { + this.tupleTagList = tupleTagList; + int index = -1; + for (TupleTag tag : tupleTagList.getAll()) { + index++; + tagMap.put(tag, index); + } + } + + /** + * Returns the index for the given tuple tag, if the tag is present in this + * schema, -1 if it isn't. + */ + public int getIndex(TupleTag tag) { + Integer index = tagMap.get(tag); + return index == null ? -1 : index; + } + + /** + * Returns the tuple tag at the given index. + */ + public TupleTag getTag(int index) { + return tupleTagList.get(index); + } + + /** + * Returns the number of columns for this schema. + */ + public int size() { + return tupleTagList.getAll().size(); + } + + /** + * Returns the TupleTagList tuple associated with this schema. + */ + public TupleTagList getTupleTagList() { + return tupleTagList; + } + + public CloudObject asCloudObject() { + CloudObject result = CloudObject.forClass(getClass()); + List serializedTags = new ArrayList<>(tupleTagList.size()); + for (TupleTag tag : tupleTagList.getAll()) { + serializedTags.add(tag.asCloudObject()); + } + addList(result, PropertyNames.TUPLE_TAGS, serializedTags); + return result; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof CoGbkResultSchema)) { + return false; + } + CoGbkResultSchema other = (CoGbkResultSchema) obj; + return tupleTagList.getAll().equals(other.tupleTagList.getAll()); + } + + @Override + public int hashCode() { + return tupleTagList.getAll().hashCode(); + } + + @Override + public String toString() { + return "CoGbkResultSchema: " + tupleTagList.getAll(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKey.java new file mode 100644 index 000000000000..b84068295a23 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKey.java @@ -0,0 +1,211 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult.CoGbkResultCoder; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple.TaggedKeyedPCollection; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import java.util.ArrayList; +import java.util.List; + +/** + * A {@link PTransform} that performs a {@link CoGroupByKey} on a tuple + * of tables. A {@link CoGroupByKey} groups results from all + * tables by like keys into {@link CoGbkResult}s, + * from which the results for any specific table can be accessed by the + * {@link com.google.cloud.dataflow.sdk.values.TupleTag} + * supplied with the initial table. + * + *

    Example of performing a {@link CoGroupByKey} followed by a + * {@link ParDo} that consumes + * the results: + *

     {@code
    + * PCollection> pt1 = ...;
    + * PCollection> pt2 = ...;
    + *
    + * final TupleTag t1 = new TupleTag<>();
    + * final TupleTag t2 = new TupleTag<>();
    + * PCollection> coGbkResultCollection =
    + *   KeyedPCollectionTuple.of(t1, pt1)
    + *                        .and(t2, pt2)
    + *                        .apply(CoGroupByKey.create());
    + *
    + * PCollection finalResultCollection =
    + *   coGbkResultCollection.apply(ParDo.of(
    + *     new DoFn, T>() {
    + *       @Override
    + *       public void processElement(ProcessContext c) {
    + *         KV e = c.element();
    + *         Iterable pt1Vals = e.getValue().getAll(t1);
    + *         V2 pt2Val = e.getValue().getOnly(t2);
    + *          ... Do Something ....
    + *         c.output(...some T...);
    + *       }
    + *     }));
    + * } 
    + * + * @param the type of the keys in the input and output + * {@code PCollection}s + */ +public class CoGroupByKey extends + PTransform, + PCollection>> { + /** + * Returns a {@code CoGroupByKey} {@code PTransform}. + * + * @param the type of the keys in the input and output + * {@code PCollection}s + */ + public static CoGroupByKey create() { + return new CoGroupByKey<>(); + } + + private CoGroupByKey() { } + + @Override + public PCollection> apply( + KeyedPCollectionTuple input) { + if (input.isEmpty()) { + throw new IllegalArgumentException( + "must have at least one input to a KeyedPCollections"); + } + + // First build the union coder. + // TODO: Look at better integration of union types with the + // schema specified in the input. + List> codersList = new ArrayList<>(); + for (TaggedKeyedPCollection entry : input.getKeyedCollections()) { + codersList.add(getValueCoder(entry.pCollection)); + } + UnionCoder unionCoder = UnionCoder.of(codersList); + Coder keyCoder = input.getKeyCoder(); + KvCoder kVCoder = + KvCoder.of(keyCoder, unionCoder); + + PCollectionList> unionTables = + PCollectionList.empty(input.getPipeline()); + + // TODO: Use the schema to order the indices rather than depending + // on the fact that the schema ordering is identical to the ordering from + // input.getJoinCollections(). + int index = -1; + for (TaggedKeyedPCollection entry : input.getKeyedCollections()) { + index++; + PCollection> unionTable = + makeUnionTable(index, entry.pCollection, kVCoder); + unionTables = unionTables.and(unionTable); + } + + PCollection> flattenedTable = + unionTables.apply(Flatten.>pCollections()); + + PCollection>> groupedTable = + flattenedTable.apply(GroupByKey.create()); + + CoGbkResultSchema tupleTags = input.getCoGbkResultSchema(); + PCollection> result = groupedTable.apply( + ParDo.of(new ConstructCoGbkResultFn(tupleTags)) + .named("ConstructCoGbkResultFn")); + result.setCoder(KvCoder.of(keyCoder, + CoGbkResultCoder.of(tupleTags, unionCoder))); + + return result; + } + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Returns the value coder for the given PCollection. Assumes that the value + * coder is an instance of {@code KvCoder}. + */ + private Coder getValueCoder(PCollection> pCollection) { + // Assumes that the PCollection uses a KvCoder. + Coder entryCoder = pCollection.getCoder(); + if (!(entryCoder instanceof KvCoder)) { + throw new IllegalArgumentException("PCollection does not use a KvCoder"); + } + @SuppressWarnings("unchecked") + KvCoder coder = (KvCoder) entryCoder; + return coder.getValueCoder(); + } + + /** + * Returns a UnionTable for the given input PCollection, using the given + * union index and the given unionTableEncoder. + */ + private PCollection> makeUnionTable( + final int index, + PCollection> pCollection, + KvCoder unionTableEncoder) { + + return pCollection.apply(ParDo.of( + new ConstructUnionTableFn(index)).named("MakeUnionTable" + index)) + .setCoder(unionTableEncoder); + } + + /** + * A DoFn to construct a UnionTable (i.e., a + * {@code PCollection>} from a + * {@code PCollection>}. + */ + private static class ConstructUnionTableFn extends + DoFn, KV> { + + private final int index; + + public ConstructUnionTableFn(int index) { + this.index = index; + } + + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + c.output(KV.of(e.getKey(), new RawUnionValue(index, e.getValue()))); + } + } + + /** + * A DoFn to construct a CoGbkResult from an input grouped union + * table. + */ + private static class ConstructCoGbkResultFn + extends DoFn>, + KV> { + + private final CoGbkResultSchema schema; + + public ConstructCoGbkResultFn(CoGbkResultSchema schema) { + this.schema = schema; + } + + @Override + public void processElement(ProcessContext c) { + KV> e = c.element(); + c.output(KV.of(e.getKey(), new CoGbkResult(schema, e.getValue()))); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/KeyedPCollectionTuple.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/KeyedPCollectionTuple.java new file mode 100644 index 000000000000..abfbe08b2859 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/KeyedPCollectionTuple.java @@ -0,0 +1,247 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * An immutable tuple of keyed {@link PCollection PCollections} + * with key type K. + * ({@link PCollection PCollections} containing values of type + * {@code KV}) + * + * @param the type of key shared by all constituent PCollections + */ +public class KeyedPCollectionTuple implements PInput { + /** + * Returns an empty {@code KeyedPCollectionTuple} on the given pipeline. + */ + public static KeyedPCollectionTuple empty(Pipeline pipeline) { + return new KeyedPCollectionTuple<>(pipeline); + } + + /** + * Returns a new {@code KeyedPCollectionTuple} with the given tag and initial + * PCollection. + */ + public static KeyedPCollectionTuple of( + TupleTag tag, + PCollection> pc) { + return new KeyedPCollectionTuple(pc.getPipeline()).and(tag, pc); + } + + /** + * Returns a new {@code KeyedPCollectionTuple} that is the same as this, + * appended with the given PCollection. + */ + public KeyedPCollectionTuple and( + TupleTag< V> tag, + PCollection> pc) { + if (pc.getPipeline() != getPipeline()) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + TaggedKeyedPCollection wrapper = + new TaggedKeyedPCollection<>(tag, pc); + Coder myKeyCoder = keyCoder == null ? getKeyCoder(pc) : keyCoder; + List> + newKeyedCollections = + copyAddLast( + keyedCollections, + wrapper); + return new KeyedPCollectionTuple<>( + getPipeline(), + newKeyedCollections, + schema.getTupleTagList().and(tag), + myKeyCoder); + } + + public boolean isEmpty() { + return keyedCollections.isEmpty(); + } + + /** + * Returns a list of {@link TaggedKeyedPCollection TaggedKeyedPCollections} for the + * {@link PCollection PCollections} contained in this {@link KeyedPCollectionTuple}. + */ + public List> getKeyedCollections() { + return keyedCollections; + } + + /** + * Like {@link #apply(String, PTransform)} but defaulting to the name + * provided by the {@link PTransform}. + */ + public OutputT apply( + PTransform, OutputT> transform) { + return Pipeline.applyTransform(this, transform); + } + + /** + * Applies the given {@link PTransform} to this input {@code KeyedPCollectionTuple} and returns + * its {@code OutputT}. This uses {@code name} to identify the specific application of + * the transform. This name is used in various places, including the monitoring UI, + * logging, and to stably identify this application node in the job graph. + */ + public OutputT apply( + String name, PTransform, OutputT> transform) { + return Pipeline.applyTransform(name, this, transform); + } + + /** + * Expands the component {@link PCollection PCollections}, stripping off + * any tag-specific information. + */ + @Override + public Collection expand() { + List> retval = new ArrayList<>(); + for (TaggedKeyedPCollection taggedPCollection : keyedCollections) { + retval.add(taggedPCollection.pCollection); + } + return retval; + } + + /** + * Returns the key {@link Coder} for all {@link PCollection PCollections} + * in this {@link KeyedPCollectionTuple}. + */ + public Coder getKeyCoder() { + if (keyCoder == null) { + throw new IllegalStateException("cannot return null keyCoder"); + } + return keyCoder; + } + + /** + * Returns the {@link CoGbkResultSchema} associated with this + * {@link KeyedPCollectionTuple}. + */ + public CoGbkResultSchema getCoGbkResultSchema() { + return schema; + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public void finishSpecifying() { + for (TaggedKeyedPCollection taggedPCollection : keyedCollections) { + taggedPCollection.pCollection.finishSpecifying(); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A utility class to help ensure coherence of tag and input PCollection + * types. + */ + public static class TaggedKeyedPCollection { + + final TupleTag tupleTag; + final PCollection> pCollection; + + public TaggedKeyedPCollection( + TupleTag tupleTag, + PCollection> pCollection) { + this.tupleTag = tupleTag; + this.pCollection = pCollection; + } + + /** + * Returns the underlying PCollection of this TaggedKeyedPCollection. + */ + public PCollection> getCollection() { + return pCollection; + } + + /** + * Returns the TupleTag of this TaggedKeyedPCollection. + */ + public TupleTag getTupleTag() { + return tupleTag; + } + } + + /** + * We use a List to properly track the order in which collections are added. + */ + private final List> keyedCollections; + + private final Coder keyCoder; + + private final CoGbkResultSchema schema; + + private final Pipeline pipeline; + + KeyedPCollectionTuple(Pipeline pipeline) { + this(pipeline, + new ArrayList>(), + TupleTagList.empty(), + null); + } + + KeyedPCollectionTuple( + Pipeline pipeline, + List> keyedCollections, + TupleTagList tupleTagList, + Coder keyCoder) { + this.pipeline = pipeline; + this.keyedCollections = keyedCollections; + this.schema = new CoGbkResultSchema(tupleTagList); + this.keyCoder = keyCoder; + } + + private static Coder getKeyCoder(PCollection> pc) { + // Need to run coder inference on this PCollection before inspecting it. + pc.finishSpecifying(); + + // Assumes that the PCollection uses a KvCoder. + Coder entryCoder = pc.getCoder(); + if (!(entryCoder instanceof KvCoder)) { + throw new IllegalArgumentException("PCollection does not use a KvCoder"); + } + @SuppressWarnings("unchecked") + KvCoder coder = (KvCoder) entryCoder; + return coder.getKeyCoder(); + } + + private static List> copyAddLast( + List> keyedCollections, + TaggedKeyedPCollection taggedCollection) { + List> retval = + new ArrayList<>(keyedCollections); + retval.add(taggedCollection); + return retval; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/RawUnionValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/RawUnionValue.java new file mode 100644 index 000000000000..514853e44643 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/RawUnionValue.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +// TODO: Think about making this a complete dynamic union by adding +// a schema. Type would then be defined by the corresponding schema entry. + +/** + * This corresponds to an integer union tag and value. The mapping of + * union tag to type must come from elsewhere. + */ +public class RawUnionValue { + private final int unionTag; + private final Object value; + + /** + * Constructs a partial union from the given union tag and value. + */ + public RawUnionValue(int unionTag, Object value) { + this.unionTag = unionTag; + this.value = value; + } + + public int getUnionTag() { + return unionTag; + } + + public Object getValue() { + return value; + } + + @Override + public String toString() { + return unionTag + ":" + value; + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoder.java new file mode 100644 index 000000000000..2f1c2befd42f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoder.java @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * A UnionCoder encodes RawUnionValues. + */ +class UnionCoder extends StandardCoder { + // TODO: Think about how to integrate this with a schema object (i.e. + // a tuple of tuple tags). + /** + * Builds a union coder with the given list of element coders. This list + * corresponds to a mapping of union tag to Coder. Union tags start at 0. + */ + public static UnionCoder of(List> elementCoders) { + return new UnionCoder(elementCoders); + } + + @JsonCreator + public static UnionCoder jsonOf( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> elements) { + return UnionCoder.of(elements); + } + + private int getIndexForEncoding(RawUnionValue union) { + if (union == null) { + throw new IllegalArgumentException("cannot encode a null tagged union"); + } + int index = union.getUnionTag(); + if (index < 0 || index >= elementCoders.size()) { + throw new IllegalArgumentException( + "union value index " + index + " not in range [0.." + + (elementCoders.size() - 1) + "]"); + } + return index; + } + + @SuppressWarnings("unchecked") + @Override + public void encode( + RawUnionValue union, + OutputStream outStream, + Context context) + throws IOException, CoderException { + int index = getIndexForEncoding(union); + // Write out the union tag. + VarInt.encode(index, outStream); + + // Write out the actual value. + Coder coder = (Coder) elementCoders.get(index); + coder.encode( + union.getValue(), + outStream, + context); + } + + @Override + public RawUnionValue decode(InputStream inStream, Context context) + throws IOException, CoderException { + int index = VarInt.decodeInt(inStream); + Object value = elementCoders.get(index).decode(inStream, context); + return new RawUnionValue(index, value); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return elementCoders; + } + + /** + * Since this coder uses elementCoders.get(index) and coders that are known to run in constant + * time, we defer the return value to that coder. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(RawUnionValue union, Context context) { + int index = getIndexForEncoding(union); + @SuppressWarnings("unchecked") + Coder coder = (Coder) elementCoders.get(index); + return coder.isRegisterByteSizeObserverCheap(union.getValue(), context); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + RawUnionValue union, ElementByteSizeObserver observer, Context context) + throws Exception { + int index = getIndexForEncoding(union); + // Write out the union tag. + observer.update(VarInt.getLength(index)); + // Write out the actual value. + @SuppressWarnings("unchecked") + Coder coder = (Coder) elementCoders.get(index); + coder.registerByteSizeObserver(union.getValue(), observer, context); + } + + ///////////////////////////////////////////////////////////////////////////// + + private final List> elementCoders; + + private UnionCoder(List> elementCoders) { + this.elementCoders = elementCoders; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "UnionCoder is only deterministic if all element coders are", + elementCoders); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/package-info.java new file mode 100644 index 000000000000..be8bffad29a1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/join/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines the {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey} transform + * for joining multiple PCollections. + */ +package com.google.cloud.dataflow.sdk.transforms.join; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/package-info.java new file mode 100644 index 000000000000..3c041f6736e5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/package-info.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s for transforming + * data in a pipeline. + * + *

    A {@link com.google.cloud.dataflow.sdk.transforms.PTransform} is an operation that takes an + * {@code InputT} (some subtype of {@link com.google.cloud.dataflow.sdk.values.PInput}) + * and produces an + * {@code OutputT} (some subtype of {@link com.google.cloud.dataflow.sdk.values.POutput}). + * + *

    Common PTransforms include root PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Read} and + * {@link com.google.cloud.dataflow.sdk.transforms.Create}, processing and + * conversion operations like {@link com.google.cloud.dataflow.sdk.transforms.ParDo}, + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}, + * {@link com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey}, + * {@link com.google.cloud.dataflow.sdk.transforms.Combine}, and + * {@link com.google.cloud.dataflow.sdk.transforms.Count}, and outputting + * PTransforms like + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Write}. + * + *

    New PTransforms can be created by composing existing PTransforms. + * Most PTransforms in this package are composites, and users can also create composite PTransforms + * for their own application-specific logic. + * + */ +package com.google.cloud.dataflow.sdk.transforms; + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterAll.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterAll.java new file mode 100644 index 000000000000..bb43010ae512 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterAll.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; +import com.google.common.base.Preconditions; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.List; + +/** + * Create a {@link Trigger} that fires and finishes once after all of its sub-triggers have fired. + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code Trigger} + */ +@Experimental(Experimental.Kind.TRIGGER) +public class AfterAll extends OnceTrigger { + + private AfterAll(List> subTriggers) { + super(subTriggers); + Preconditions.checkArgument(subTriggers.size() > 1); + } + + /** + * Returns an {@code AfterAll} {@code Trigger} with the given subtriggers. + */ + @SafeVarargs + public static OnceTrigger of( + OnceTrigger... triggers) { + return new AfterAll(Arrays.>asList(triggers)); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + for (ExecutableTrigger subTrigger : c.trigger().unfinishedSubTriggers()) { + // Since subTriggers are all OnceTriggers, they must either CONTINUE or FIRE_AND_FINISH. + // invokeElement will automatically mark the finish bit if they return FIRE_AND_FINISH. + subTrigger.invokeOnElement(c); + } + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + subTrigger.invokeOnMerge(c); + } + boolean allFinished = true; + for (ExecutableTrigger subTrigger1 : c.trigger().subTriggers()) { + allFinished &= c.forTrigger(subTrigger1).trigger().isFinished(); + } + c.trigger().setFinished(allFinished); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + // This trigger will fire after the latest of its sub-triggers. + Instant deadline = BoundedWindow.TIMESTAMP_MIN_VALUE; + for (Trigger subTrigger : subTriggers) { + Instant subDeadline = subTrigger.getWatermarkThatGuaranteesFiring(window); + if (deadline.isBefore(subDeadline)) { + deadline = subDeadline; + } + } + return deadline; + } + + @Override + public OnceTrigger getContinuationTrigger(List> continuationTriggers) { + return new AfterAll(continuationTriggers); + } + + /** + * {@inheritDoc} + * + * @return {@code true} if all subtriggers return {@code true}. + */ + @Override + public boolean shouldFire(TriggerContext context) throws Exception { + for (ExecutableTrigger subtrigger : context.trigger().subTriggers()) { + if (!context.forTrigger(subtrigger).trigger().isFinished() + && !subtrigger.invokeShouldFire(context)) { + return false; + } + } + return true; + } + + /** + * Invokes {@link #onFire} for all subtriggers, eliding redundant calls to {@link #shouldFire} + * because they all must be ready to fire. + */ + @Override + public void onOnlyFiring(TriggerContext context) throws Exception { + for (ExecutableTrigger subtrigger : context.trigger().subTriggers()) { + subtrigger.invokeOnFire(context); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterDelayFromFirstElement.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterDelayFromFirstElement.java new file mode 100644 index 000000000000..71968e919cee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterDelayFromFirstElement.java @@ -0,0 +1,322 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.state.AccumulatorCombiningState; +import com.google.cloud.dataflow.sdk.util.state.CombiningState; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateMerging; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.List; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * A base class for triggers that happen after a processing time delay from the arrival + * of the first element in a pane. + * + *

    This class is for internal use only and may change at any time. + */ +@Experimental(Experimental.Kind.TRIGGER) +public abstract class AfterDelayFromFirstElement extends OnceTrigger { + + protected static final List> IDENTITY = + ImmutableList.>of(); + + protected static final StateTag, Instant>> DELAYED_UNTIL_TAG = + StateTags.makeSystemTagInternal(StateTags.combiningValueFromInputInternal( + "delayed", InstantCoder.of(), Min.MinFn.naturalOrder())); + + /** + * To complete an implementation, return the desired time from the TriggerContext. + */ + @Nullable + public abstract Instant getCurrentTime(Trigger.TriggerContext context); + + /** + * To complete an implementation, return a new instance like this one, but incorporating + * the provided timestamp mapping functions. Generally should be used by calling the + * constructor of this class from the constructor of the subclass. + */ + protected abstract AfterDelayFromFirstElement newWith( + List> transform); + + /** + * A list of timestampMappers m1, m2, m3, ... m_n considered to be composed in sequence. The + * overall mapping for an instance `instance` is `m_n(... m3(m2(m1(instant))`, + * implemented via #computeTargetTimestamp + */ + protected final List> timestampMappers; + + private final TimeDomain timeDomain; + + public AfterDelayFromFirstElement( + TimeDomain timeDomain, + List> timestampMappers) { + super(null); + this.timestampMappers = timestampMappers; + this.timeDomain = timeDomain; + } + + private Instant getTargetTimestamp(OnElementContext c) { + return computeTargetTimestamp(c.currentProcessingTime()); + } + + /** + * Aligns timestamps to the smallest multiple of {@code size} since the {@code offset} greater + * than the timestamp. + * + *

    TODO: Consider sharing this with FixedWindows, and bring over the equivalent of + * CalendarWindows. + */ + public AfterDelayFromFirstElement alignedTo(final Duration size, final Instant offset) { + return newWith(new AlignFn(size, offset)); + } + + /** + * Aligns the time to be the smallest multiple of {@code size} greater than the timestamp + * since the epoch. + */ + public AfterDelayFromFirstElement alignedTo(final Duration size) { + return alignedTo(size, new Instant(0)); + } + + /** + * Adds some delay to the original target time. + * + * @param delay the delay to add + * @return An updated time trigger that will wait the additional time before firing. + */ + public AfterDelayFromFirstElement plusDelayOf(final Duration delay) { + return newWith(new DelayFn(delay)); + } + + /** + * @deprecated This will be removed in the next major version. Please use only + * {@link #plusDelayOf} and {@link #alignedTo}. + */ + @Deprecated + public OnceTrigger mappedTo(SerializableFunction timestampMapper) { + return newWith(timestampMapper); + } + + @Override + public boolean isCompatible(Trigger other) { + if (!getClass().equals(other.getClass())) { + return false; + } + + AfterDelayFromFirstElement that = (AfterDelayFromFirstElement) other; + return this.timestampMappers.equals(that.timestampMappers); + } + + + private AfterDelayFromFirstElement newWith( + SerializableFunction timestampMapper) { + return newWith( + ImmutableList.>builder() + .addAll(timestampMappers) + .add(timestampMapper) + .build()); + } + + @Override + public void prefetchOnElement(StateAccessor state) { + state.access(DELAYED_UNTIL_TAG).readLater(); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + CombiningState delayUntilState = c.state().access(DELAYED_UNTIL_TAG); + Instant oldDelayUntil = delayUntilState.read(); + + // Since processing time can only advance, resulting in target wake-up times we would + // ignore anyhow, we don't bother with it if it is already set. + if (oldDelayUntil != null) { + return; + } + + Instant targetTimestamp = getTargetTimestamp(c); + delayUntilState.add(targetTimestamp); + c.setTimer(targetTimestamp, timeDomain); + } + + @Override + public void prefetchOnMerge(MergingStateAccessor state) { + super.prefetchOnMerge(state); + StateMerging.prefetchCombiningValues(state, DELAYED_UNTIL_TAG); + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + // NOTE: We could try to delete all timers which are still active, but we would + // need access to a timer context for each merging window. + // for (CombiningValueStateInternal, Instant> state : + // c.state().accessInEachMergingWindow(DELAYED_UNTIL_TAG).values()) { + // Instant timestamp = state.get().read(); + // if (timestamp != null) { + // .deleteTimer(timestamp, timeDomain); + // } + // } + // Instead let them fire and be ignored. + + // If the trigger is already finished, there is no way it will become re-activated + if (c.trigger().isFinished()) { + StateMerging.clear(c.state(), DELAYED_UNTIL_TAG); + // NOTE: We do not attempt to delete the timers. + return; + } + + // Determine the earliest point across all the windows, and delay to that. + StateMerging.mergeCombiningValues(c.state(), DELAYED_UNTIL_TAG); + + Instant earliestTargetTime = c.state().access(DELAYED_UNTIL_TAG).read(); + if (earliestTargetTime != null) { + c.setTimer(earliestTargetTime, timeDomain); + } + } + + @Override + public void prefetchShouldFire(StateAccessor state) { + state.access(DELAYED_UNTIL_TAG).readLater(); + } + + @Override + public void clear(TriggerContext c) throws Exception { + c.state().access(DELAYED_UNTIL_TAG).clear(); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + Instant delayedUntil = context.state().access(DELAYED_UNTIL_TAG).read(); + return delayedUntil != null + && getCurrentTime(context) != null + && getCurrentTime(context).isAfter(delayedUntil); + } + + @Override + protected void onOnlyFiring(Trigger.TriggerContext context) throws Exception { + clear(context); + } + + protected Instant computeTargetTimestamp(Instant time) { + Instant result = time; + for (SerializableFunction timestampMapper : timestampMappers) { + result = timestampMapper.apply(result); + } + return result; + } + + /** + * A {@link SerializableFunction} to delay the timestamp at which this triggers fires. + */ + private static final class DelayFn implements SerializableFunction { + private final Duration delay; + + public DelayFn(Duration delay) { + this.delay = delay; + } + + @Override + public Instant apply(Instant input) { + return input.plus(delay); + } + + @Override + public boolean equals(Object object) { + if (object == this) { + return true; + } + + if (!(object instanceof DelayFn)) { + return false; + } + + return this.delay.equals(((DelayFn) object).delay); + } + + @Override + public int hashCode() { + return Objects.hash(delay); + } + } + + /** + * A {@link SerializableFunction} to align an instant to the nearest interval boundary. + */ + static final class AlignFn implements SerializableFunction { + private final Duration size; + private final Instant offset; + + + /** + * Aligns timestamps to the smallest multiple of {@code size} since the {@code offset} greater + * than the timestamp. + */ + public AlignFn(Duration size, Instant offset) { + this.size = size; + this.offset = offset; + } + + @Override + public Instant apply(Instant point) { + long millisSinceStart = new Duration(offset, point).getMillis() % size.getMillis(); + return millisSinceStart == 0 ? point : point.plus(size).minus(millisSinceStart); + } + + @Override + public boolean equals(Object object) { + if (object == this) { + return true; + } + + if (!(object instanceof AlignFn)) { + return false; + } + + AlignFn other = (AlignFn) object; + return other.size.equals(this.size) + && other.offset.equals(this.offset); + } + + @Override + public int hashCode() { + return Objects.hash(size, offset); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterEach.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterEach.java new file mode 100644 index 000000000000..4b052faeb8c6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterEach.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.List; + +/** + * A composite {@link Trigger} that executes its sub-triggers in order. + * Only one sub-trigger is executing at a time, + * and any time it fires the {@code AfterEach} fires. When the currently executing + * sub-trigger finishes, the {@code AfterEach} starts executing the next sub-trigger. + * + *

    {@code AfterEach.inOrder(t1, t2, ...)} finishes when all of the sub-triggers have finished. + * + *

    The following properties hold: + *

      + *
    • {@code AfterEach.inOrder(AfterEach.inOrder(a, b), c)} behaves the same as + * {@code AfterEach.inOrder(a, b, c)} and {@code AfterEach.inOrder(a, AfterEach.inOrder(b, c)}. + *
    • {@code AfterEach.inOrder(Repeatedly.forever(a), b)} behaves the same as + * {@code Repeatedly.forever(a)}, since the repeated trigger never finishes. + *
    + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code Trigger} + */ +@Experimental(Experimental.Kind.TRIGGER) +public class AfterEach extends Trigger { + + private AfterEach(List> subTriggers) { + super(subTriggers); + checkArgument(subTriggers.size() > 1); + } + + /** + * Returns an {@code AfterEach} {@code Trigger} with the given subtriggers. + */ + @SafeVarargs + public static Trigger inOrder(Trigger... triggers) { + return new AfterEach(Arrays.>asList(triggers)); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + if (!c.trigger().isMerging()) { + // If merges are not possible, we need only run the first unfinished subtrigger + c.trigger().firstUnfinishedSubTrigger().invokeOnElement(c); + } else { + // If merges are possible, we need to run all subtriggers in parallel + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + // Even if the subTrigger is done, it may be revived via merging and must have + // adequate state. + subTrigger.invokeOnElement(c); + } + } + } + + @Override + public void onMerge(OnMergeContext context) throws Exception { + // If merging makes a subtrigger no-longer-finished, it will automatically + // begin participating in shouldFire and onFire appropriately. + + // All the following triggers are retroactively "not started" but that is + // also automatic because they are cleared whenever this trigger + // fires. + boolean priorTriggersAllFinished = true; + for (ExecutableTrigger subTrigger : context.trigger().subTriggers()) { + if (priorTriggersAllFinished) { + subTrigger.invokeOnMerge(context); + priorTriggersAllFinished &= context.forTrigger(subTrigger).trigger().isFinished(); + } else { + subTrigger.invokeClear(context); + } + } + updateFinishedState(context); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + // This trigger will fire at least once when the first trigger in the sequence + // fires at least once. + return subTriggers.get(0).getWatermarkThatGuaranteesFiring(window); + } + + @Override + public Trigger getContinuationTrigger(List> continuationTriggers) { + return Repeatedly.forever(new AfterFirst(continuationTriggers)); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + ExecutableTrigger firstUnfinished = context.trigger().firstUnfinishedSubTrigger(); + return firstUnfinished.invokeShouldFire(context); + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { + context.trigger().firstUnfinishedSubTrigger().invokeOnFire(context); + + // Reset all subtriggers if in a merging context; any may be revived by merging so they are + // all run in parallel for each pending pane. + if (context.trigger().isMerging()) { + for (ExecutableTrigger subTrigger : context.trigger().subTriggers()) { + subTrigger.invokeClear(context); + } + } + + updateFinishedState(context); + } + + private void updateFinishedState(TriggerContext context) { + context.trigger().setFinished(context.trigger().firstUnfinishedSubTrigger() == null); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterFirst.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterFirst.java new file mode 100644 index 000000000000..29b19bf9b9c2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterFirst.java @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; +import com.google.common.base.Preconditions; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.List; + +/** + * Create a composite {@link Trigger} that fires once after at least one of its sub-triggers have + * fired. + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code Trigger} + */ +@Experimental(Experimental.Kind.TRIGGER) +public class AfterFirst extends OnceTrigger { + + AfterFirst(List> subTriggers) { + super(subTriggers); + Preconditions.checkArgument(subTriggers.size() > 1); + } + + /** + * Returns an {@code AfterFirst} {@code Trigger} with the given subtriggers. + */ + @SafeVarargs + public static OnceTrigger of( + OnceTrigger... triggers) { + return new AfterFirst(Arrays.>asList(triggers)); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + subTrigger.invokeOnElement(c); + } + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + subTrigger.invokeOnMerge(c); + } + updateFinishedStatus(c); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + // This trigger will fire after the earliest of its sub-triggers. + Instant deadline = BoundedWindow.TIMESTAMP_MAX_VALUE; + for (Trigger subTrigger : subTriggers) { + Instant subDeadline = subTrigger.getWatermarkThatGuaranteesFiring(window); + if (deadline.isAfter(subDeadline)) { + deadline = subDeadline; + } + } + return deadline; + } + + @Override + public OnceTrigger getContinuationTrigger(List> continuationTriggers) { + return new AfterFirst(continuationTriggers); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + for (ExecutableTrigger subtrigger : context.trigger().subTriggers()) { + if (context.forTrigger(subtrigger).trigger().isFinished() + || subtrigger.invokeShouldFire(context)) { + return true; + } + } + return false; + } + + @Override + protected void onOnlyFiring(TriggerContext context) throws Exception { + for (ExecutableTrigger subtrigger : context.trigger().subTriggers()) { + TriggerContext subContext = context.forTrigger(subtrigger); + if (subtrigger.invokeShouldFire(subContext)) { + // If the trigger is ready to fire, then do whatever it needs to do. + subtrigger.invokeOnFire(subContext); + } else { + // If the trigger is not ready to fire, it is nonetheless true that whatever + // pending pane it was tracking is now gone. + subtrigger.invokeClear(subContext); + } + } + } + + private void updateFinishedStatus(TriggerContext c) { + boolean anyFinished = false; + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + anyFinished |= c.forTrigger(subTrigger).trigger().isFinished(); + } + c.trigger().setFinished(anyFinished); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterPane.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterPane.java new file mode 100644 index 000000000000..28c8560ac4ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterPane.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.state.AccumulatorCombiningState; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateMerging; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; + +import org.joda.time.Instant; + +import java.util.List; +import java.util.Objects; + +/** + * {@link Trigger}s that fire based on properties of the elements in the current pane. + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@link Trigger} + */ +@Experimental(Experimental.Kind.TRIGGER) +public class AfterPane extends OnceTrigger{ + +private static final StateTag> + ELEMENTS_IN_PANE_TAG = + StateTags.makeSystemTagInternal(StateTags.combiningValueFromInputInternal( + "count", VarLongCoder.of(), new Sum.SumLongFn())); + + private final int countElems; + + private AfterPane(int countElems) { + super(null); + this.countElems = countElems; + } + + /** + * Creates a trigger that fires when the pane contains at least {@code countElems} elements. + */ + public static AfterPane elementCountAtLeast(int countElems) { + return new AfterPane<>(countElems); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + c.state().access(ELEMENTS_IN_PANE_TAG).add(1L); + } + + @Override + public void prefetchOnMerge(MergingStateAccessor state) { + super.prefetchOnMerge(state); + StateMerging.prefetchCombiningValues(state, ELEMENTS_IN_PANE_TAG); + } + + @Override + public void onMerge(OnMergeContext context) throws Exception { + // If we've already received enough elements and finished in some window, + // then this trigger is just finished. + if (context.trigger().finishedInAnyMergingWindow()) { + context.trigger().setFinished(true); + StateMerging.clear(context.state(), ELEMENTS_IN_PANE_TAG); + return; + } + + // Otherwise, compute the sum of elements in all the active panes. + StateMerging.mergeCombiningValues(context.state(), ELEMENTS_IN_PANE_TAG); + } + + @Override + public void prefetchShouldFire(StateAccessor state) { + state.access(ELEMENTS_IN_PANE_TAG).readLater(); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + long count = context.state().access(ELEMENTS_IN_PANE_TAG).read(); + return count >= countElems; + } + + @Override + public void clear(TriggerContext c) throws Exception { + c.state().access(ELEMENTS_IN_PANE_TAG).clear(); + } + + @Override + public boolean isCompatible(Trigger other) { + return this.equals(other); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + public OnceTrigger getContinuationTrigger(List> continuationTriggers) { + return AfterPane.elementCountAtLeast(1); + } + + @Override + public String toString() { + return "AfterPane.elementCountAtLeast(" + countElems + ")"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof AfterPane)) { + return false; + } + AfterPane that = (AfterPane) obj; + return this.countElems == that.countElems; + } + + @Override + public int hashCode() { + return Objects.hash(countElems); + } + + @Override + protected void onOnlyFiring(Trigger.TriggerContext context) throws Exception { + clear(context); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterProcessingTime.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterProcessingTime.java new file mode 100644 index 000000000000..7e8990274136 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterProcessingTime.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.TimeDomain; + +import org.joda.time.Instant; + +import java.util.List; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * {@code AfterProcessingTime} triggers fire based on the current processing time. They operate in + * the real-time domain. + * + *

    The time at which to fire the timer can be adjusted via the methods in {@link TimeTrigger}, + * such as {@link TimeTrigger#plusDelayOf} or {@link TimeTrigger#alignedTo}. + * + * @param {@link BoundedWindow} subclass used to represent the windows used + */ +@Experimental(Experimental.Kind.TRIGGER) +public class AfterProcessingTime extends AfterDelayFromFirstElement { + + @Override + @Nullable + public Instant getCurrentTime(Trigger.TriggerContext context) { + return context.currentProcessingTime(); + } + + private AfterProcessingTime(List> transforms) { + super(TimeDomain.PROCESSING_TIME, transforms); + } + + /** + * Creates a trigger that fires when the current processing time passes the processing time + * at which this trigger saw the first element in a pane. + */ + public static AfterProcessingTime pastFirstElementInPane() { + return new AfterProcessingTime(IDENTITY); + } + + @Override + protected AfterProcessingTime newWith( + List> transforms) { + return new AfterProcessingTime(transforms); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + protected Trigger getContinuationTrigger(List> continuationTriggers) { + return new AfterSynchronizedProcessingTime(); + } + + @Override + public String toString() { + return "AfterProcessingTime.pastFirstElementInPane(" + timestampMappers + ")"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof AfterProcessingTime)) { + return false; + } + AfterProcessingTime that = (AfterProcessingTime) obj; + return Objects.equals(this.timestampMappers, that.timestampMappers); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), this.timestampMappers); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterSynchronizedProcessingTime.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterSynchronizedProcessingTime.java new file mode 100644 index 000000000000..0a274c9ce08e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterSynchronizedProcessingTime.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.common.base.Objects; + +import org.joda.time.Instant; + +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nullable; + +class AfterSynchronizedProcessingTime + extends AfterDelayFromFirstElement { + + @Override + @Nullable + public Instant getCurrentTime(Trigger.TriggerContext context) { + return context.currentSynchronizedProcessingTime(); + } + + public AfterSynchronizedProcessingTime() { + super(TimeDomain.SYNCHRONIZED_PROCESSING_TIME, + Collections.>emptyList()); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + protected Trigger getContinuationTrigger(List> continuationTriggers) { + return this; + } + + @Override + public String toString() { + return "AfterSynchronizedProcessingTime.pastFirstElementInPane()"; + } + + @Override + public boolean equals(Object obj) { + return this == obj || obj instanceof AfterSynchronizedProcessingTime; + } + + @Override + public int hashCode() { + return Objects.hashCode(AfterSynchronizedProcessingTime.class); + } + + @Override + protected AfterSynchronizedProcessingTime + newWith(List> transforms) { + // ignore transforms + return this; + } + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java new file mode 100644 index 000000000000..da16db99c619 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java @@ -0,0 +1,397 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; + +import java.util.List; +import java.util.Objects; + +/** + *

    {@code AfterWatermark} triggers fire based on progress of the system watermark. This time is a + * lower-bound, sometimes heuristically established, on event times that have been fully processed + * by the pipeline. + * + *

    For sources that provide non-heuristic watermarks (e.g. + * {@link com.google.cloud.dataflow.sdk.io.PubsubIO} when using arrival times as event times), the + * watermark is a strict guarantee that no data with an event time earlier than + * that watermark will ever be observed in the pipeline. In this case, it's safe to assume that any + * pane triggered by an {@code AfterWatermark} trigger with a reference point at or beyond the end + * of the window will be the last pane ever for that window. + * + *

    For sources that provide heuristic watermarks (e.g. + * {@link com.google.cloud.dataflow.sdk.io.PubsubIO} when using user-supplied event times), the + * watermark itself becomes an estimate that no data with an event time earlier than that + * watermark (i.e. "late data") will ever be observed in the pipeline. These heuristics can + * often be quite accurate, but the chance of seeing late data for any given window is non-zero. + * Thus, if absolute correctness over time is important to your use case, you may want to consider + * using a trigger that accounts for late data. The default trigger, + * {@code Repeatedly.forever(AfterWatermark.pastEndOfWindow())}, which fires + * once when the watermark passes the end of the window and then immediately therafter when any + * late data arrives, is one such example. + * + *

    The watermark is the clock that defines {@link TimeDomain#EVENT_TIME}. + * + * Additionaly firings before or after the watermark can be requested by calling + * {@code AfterWatermark.pastEndOfWindow.withEarlyFirings(OnceTrigger)} or + * {@code AfterWatermark.pastEndOfWindow.withEarlyFirings(OnceTrigger)}. + * + * @param {@link BoundedWindow} subclass used to represent the windows used. + */ +@Experimental(Experimental.Kind.TRIGGER) +public class AfterWatermark { + + // Static factory class. + private AfterWatermark() {} + + /** + * Creates a trigger that fires when the watermark passes the end of the window. + */ + public static FromEndOfWindow pastEndOfWindow() { + return new FromEndOfWindow(); + } + + /** + * Interface for building an AfterWatermarkTrigger with early firings already filled in. + */ + public interface AfterWatermarkEarly extends TriggerBuilder { + /** + * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever + * the given {@code Trigger} fires before the watermark has passed the end of the window. + */ + TriggerBuilder withLateFirings(OnceTrigger lateTrigger); + } + + /** + * Interface for building an AfterWatermarkTrigger with late firings already filled in. + */ + public interface AfterWatermarkLate extends TriggerBuilder { + /** + * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever + * the given {@code Trigger} fires after the watermark has passed the end of the window. + */ + TriggerBuilder withEarlyFirings(OnceTrigger earlyTrigger); + } + + /** + * A trigger which never fires. Used for the "early" trigger when only a late trigger was + * specified. + */ + private static class NeverTrigger extends OnceTrigger { + + protected NeverTrigger() { + super(null); + } + + @Override + public void onElement(OnElementContext c) throws Exception { } + + @Override + public void onMerge(OnMergeContext c) throws Exception { } + + @Override + protected Trigger getContinuationTrigger(List> continuationTriggers) { + return this; + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return false; + } + + @Override + protected void onOnlyFiring(Trigger.TriggerContext context) throws Exception { + throw new UnsupportedOperationException( + String.format("%s should never fire", getClass().getSimpleName())); + } + } + + private static class AfterWatermarkEarlyAndLate + extends Trigger + implements TriggerBuilder, AfterWatermarkEarly, AfterWatermarkLate { + + private static final int EARLY_INDEX = 0; + private static final int LATE_INDEX = 1; + + private final OnceTrigger earlyTrigger; + private final OnceTrigger lateTrigger; + + @SuppressWarnings("unchecked") + private AfterWatermarkEarlyAndLate(OnceTrigger earlyTrigger, OnceTrigger lateTrigger) { + super(lateTrigger == null + ? ImmutableList.>of(earlyTrigger) + : ImmutableList.>of(earlyTrigger, lateTrigger)); + this.earlyTrigger = checkNotNull(earlyTrigger, "earlyTrigger should not be null"); + this.lateTrigger = lateTrigger; + } + + @Override + public TriggerBuilder withEarlyFirings(OnceTrigger earlyTrigger) { + return new AfterWatermarkEarlyAndLate(earlyTrigger, lateTrigger); + } + + @Override + public TriggerBuilder withLateFirings(OnceTrigger lateTrigger) { + return new AfterWatermarkEarlyAndLate(earlyTrigger, lateTrigger); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + if (!c.trigger().isMerging()) { + // If merges can never happen, we just run the unfinished subtrigger + c.trigger().firstUnfinishedSubTrigger().invokeOnElement(c); + } else { + // If merges can happen, we run for all subtriggers because they might be + // de-activated or re-activated + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + subTrigger.invokeOnElement(c); + } + } + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + // NOTE that the ReduceFnRunner will delete all end-of-window timers for the + // merged-away windows. + + ExecutableTrigger earlySubtrigger = c.trigger().subTrigger(EARLY_INDEX); + // We check the early trigger to determine if we are still processing it or + // if the end of window has transitioned us to the late trigger + OnMergeContext earlyContext = c.forTrigger(earlySubtrigger); + + // If the early trigger is still active in any merging window then it is still active in + // the new merged window, because even if the merged window is "done" some pending elements + // haven't had a chance to fire. + if (!earlyContext.trigger().finishedInAllMergingWindows() || !endOfWindowReached(c)) { + earlyContext.trigger().setFinished(false); + if (lateTrigger != null) { + ExecutableTrigger lateSubtrigger = c.trigger().subTrigger(LATE_INDEX); + OnMergeContext lateContext = c.forTrigger(lateSubtrigger); + lateContext.trigger().setFinished(false); + lateSubtrigger.invokeClear(lateContext); + } + } else { + // Otherwise the early trigger and end-of-window bit is done for good. + earlyContext.trigger().setFinished(true); + if (lateTrigger != null) { + c.trigger().subTrigger(LATE_INDEX).invokeOnMerge(c); + } + } + } + + @Override + public Trigger getContinuationTrigger() { + return new AfterWatermarkEarlyAndLate( + earlyTrigger.getContinuationTrigger(), + lateTrigger == null ? null : lateTrigger.getContinuationTrigger()); + } + + @Override + protected Trigger getContinuationTrigger(List> continuationTriggers) { + throw new UnsupportedOperationException( + "Should not call getContinuationTrigger(List>)"); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + // Even without an early or late trigger, we'll still produce a firing at the watermark. + return window.maxTimestamp(); + } + + private boolean endOfWindowReached(Trigger.TriggerContext context) { + return context.currentEventTime() != null + && context.currentEventTime().isAfter(context.window().maxTimestamp()); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + if (!context.trigger().isFinished(EARLY_INDEX)) { + // We have not yet transitioned to late firings. + // We should fire if either the trigger is ready or we reach the end of the window. + return context.trigger().subTrigger(EARLY_INDEX).invokeShouldFire(context) + || endOfWindowReached(context); + } else if (lateTrigger == null) { + return false; + } else { + // We are running the late trigger + return context.trigger().subTrigger(LATE_INDEX).invokeShouldFire(context); + } + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { + if (!context.forTrigger(context.trigger().subTrigger(EARLY_INDEX)).trigger().isFinished()) { + onNonLateFiring(context); + } else if (lateTrigger != null) { + onLateFiring(context); + } else { + // all done + context.trigger().setFinished(true); + } + } + + private void onNonLateFiring(Trigger.TriggerContext context) throws Exception { + // We have not yet transitioned to late firings. + ExecutableTrigger earlySubtrigger = context.trigger().subTrigger(EARLY_INDEX); + Trigger.TriggerContext earlyContext = context.forTrigger(earlySubtrigger); + + if (!endOfWindowReached(context)) { + // This is an early firing, since we have not arrived at the end of the window + // Implicitly repeats + earlySubtrigger.invokeOnFire(context); + earlySubtrigger.invokeClear(context); + earlyContext.trigger().setFinished(false); + } else { + // We have arrived at the end of the window; terminate the early trigger + // and clear out the late trigger's state + if (earlySubtrigger.invokeShouldFire(context)) { + earlySubtrigger.invokeOnFire(context); + } + earlyContext.trigger().setFinished(true); + earlySubtrigger.invokeClear(context); + + if (lateTrigger == null) { + // Done if there is no late trigger. + context.trigger().setFinished(true); + } else { + // If there is a late trigger, we transition to it, and need to clear its state + // because it was run in parallel. + context.trigger().subTrigger(LATE_INDEX).invokeClear(context); + } + } + + } + + private void onLateFiring(Trigger.TriggerContext context) throws Exception { + // We are firing the late trigger, with implicit repeat + ExecutableTrigger lateSubtrigger = context.trigger().subTrigger(LATE_INDEX); + lateSubtrigger.invokeOnFire(context); + // It is a OnceTrigger, so it must have finished; unfinished it and clear it + lateSubtrigger.invokeClear(context); + context.forTrigger(lateSubtrigger).trigger().setFinished(false); + } + } + + /** + * A watermark trigger targeted relative to the end of the window. + */ + public static class FromEndOfWindow extends OnceTrigger { + + private FromEndOfWindow() { + super(null); + } + + /** + * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever + * the given {@code Trigger} fires before the watermark has passed the end of the window. + */ + public AfterWatermarkEarly withEarlyFirings(OnceTrigger earlyFirings) { + Preconditions.checkNotNull(earlyFirings, + "Must specify the trigger to use for early firings"); + return new AfterWatermarkEarlyAndLate(earlyFirings, null); + } + + /** + * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever + * the given {@code Trigger} fires after the watermark has passed the end of the window. + */ + public AfterWatermarkLate withLateFirings(OnceTrigger lateFirings) { + Preconditions.checkNotNull(lateFirings, + "Must specify the trigger to use for late firings"); + return new AfterWatermarkEarlyAndLate(new NeverTrigger(), lateFirings); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + // We're interested in knowing when the input watermark passes the end of the window. + // (It is possible this has already happened, in which case the timer will be fired + // almost immediately). + c.setTimer(c.window().maxTimestamp(), TimeDomain.EVENT_TIME); + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + // NOTE that the ReduceFnRunner will delete all end-of-window timers for the + // merged-away windows. + + if (!c.trigger().finishedInAllMergingWindows()) { + // If the trigger is still active in any merging window then it is still active in the new + // merged window, because even if the merged window is "done" some pending elements haven't + // had a chance to fire + c.trigger().setFinished(false); + } else if (!endOfWindowReached(c)) { + // If the end of the new window has not been reached, then the trigger is active again. + c.trigger().setFinished(false); + } else { + // Otherwise it is done for good + c.trigger().setFinished(true); + } + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return window.maxTimestamp(); + } + + @Override + public FromEndOfWindow getContinuationTrigger(List> continuationTriggers) { + return this; + } + + @Override + public String toString() { + return "AfterWatermark.pastEndOfWindow()"; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof FromEndOfWindow; + } + + @Override + public int hashCode() { + return Objects.hash(getClass()); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return endOfWindowReached(context); + } + + private boolean endOfWindowReached(Trigger.TriggerContext context) { + return context.currentEventTime() != null + && context.currentEventTime().isAfter(context.window().maxTimestamp()); + } + + @Override + protected void onOnlyFiring(Trigger.TriggerContext context) throws Exception { } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/BoundedWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/BoundedWindow.java new file mode 100644 index 000000000000..0afd8e33c2d7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/BoundedWindow.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import org.joda.time.Instant; + +import java.util.concurrent.TimeUnit; + +/** + * A {@code BoundedWindow} represents a finite grouping of elements, with an + * upper bound (larger timestamps represent more recent data) on the timestamps + * of elements that can be placed in the window. This finiteness means that for + * every window, at some point in time, all data for that window will have + * arrived and can be processed together. + * + *

    Windows must also implement {@link Object#equals} and + * {@link Object#hashCode} such that windows that are logically equal will + * be treated as equal by {@code equals()} and {@code hashCode()}. + */ +public abstract class BoundedWindow { + // The min and max timestamps that won't overflow when they are converted to + // usec. + public static final Instant TIMESTAMP_MIN_VALUE = + new Instant(TimeUnit.MICROSECONDS.toMillis(Long.MIN_VALUE)); + public static final Instant TIMESTAMP_MAX_VALUE = + new Instant(TimeUnit.MICROSECONDS.toMillis(Long.MAX_VALUE)); + + /** + * Returns the inclusive upper bound of timestamps for values in this window. + */ + public abstract Instant maxTimestamp(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindows.java new file mode 100644 index 000000000000..de5140f2a5d6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindows.java @@ -0,0 +1,348 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.Days; +import org.joda.time.Instant; +import org.joda.time.Months; +import org.joda.time.Years; + +/** + * A collection of {@link WindowFn}s that windows values into calendar-based + * windows such as spans of days, months, or years. + * + *

    For example, to group data into quarters that change on the 15th, use + * {@code CalendarWindows.months(3).withStartingMonth(2014, 1).beginningOnDay(15)}. + */ +public class CalendarWindows { + + /** + * Returns a {@link WindowFn} that windows elements into periods measured by days. + * + *

    For example, {@code CalendarWindows.days(1)} will window elements into + * separate windows for each day. + */ + public static DaysWindows days(int number) { + return new DaysWindows(number, new DateTime(0, DateTimeZone.UTC), DateTimeZone.UTC); + } + + /** + * Returns a {@link WindowFn} that windows elements into periods measured by weeks. + * + *

    For example, {@code CalendarWindows.weeks(1, DateTimeConstants.TUESDAY)} will + * window elements into week-long windows starting on Tuesdays. + */ + public static DaysWindows weeks(int number, int startDayOfWeek) { + return new DaysWindows( + 7 * number, + new DateTime(0, DateTimeZone.UTC).withDayOfWeek(startDayOfWeek), + DateTimeZone.UTC); + } + + /** + * Returns a {@link WindowFn} that windows elements into periods measured by months. + * + *

    For example, + * {@code CalendarWindows.months(8).withStartingMonth(2014, 1).beginningOnDay(10)} + * will window elements into 8 month windows where that start on the 10th day of month, + * and the first window begins in January 2014. + */ + public static MonthsWindows months(int number) { + return new MonthsWindows(number, 1, new DateTime(0, DateTimeZone.UTC), DateTimeZone.UTC); + } + + /** + * Returns a {@link WindowFn} that windows elements into periods measured by years. + * + *

    For example, + * {@code CalendarWindows.years(1).withTimeZone(DateTimeZone.forId("America/Los_Angeles"))} + * will window elements into year-long windows that start at midnight on Jan 1, in the + * America/Los_Angeles time zone. + */ + public static YearsWindows years(int number) { + return new YearsWindows(number, 1, 1, new DateTime(0, DateTimeZone.UTC), DateTimeZone.UTC); + } + + /** + * A {@link WindowFn} that windows elements into periods measured by days. + * + *

    By default, periods of multiple days are measured starting at the + * epoch. This can be overridden with {@link #withStartingDay}. + * + *

    The time zone used to determine calendar boundaries is UTC, unless this + * is overridden with the {@link #withTimeZone} method. + */ + public static class DaysWindows extends PartitioningWindowFn { + public DaysWindows withStartingDay(int year, int month, int day) { + return new DaysWindows( + number, new DateTime(year, month, day, 0, 0, timeZone), timeZone); + } + + public DaysWindows withTimeZone(DateTimeZone timeZone) { + return new DaysWindows( + number, startDate.withZoneRetainFields(timeZone), timeZone); + } + + //////////////////////////////////////////////////////////////////////////// + + private int number; + private DateTime startDate; + private DateTimeZone timeZone; + + private DaysWindows(int number, DateTime startDate, DateTimeZone timeZone) { + this.number = number; + this.startDate = startDate; + this.timeZone = timeZone; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + DateTime datetime = new DateTime(timestamp, timeZone); + + int dayOffset = Days.daysBetween(startDate, datetime).getDays() / number * number; + + DateTime begin = startDate.plusDays(dayOffset); + DateTime end = begin.plusDays(number); + + return new IntervalWindow(begin.toInstant(), end.toInstant()); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowFn other) { + if (!(other instanceof DaysWindows)) { + return false; + } + DaysWindows that = (DaysWindows) other; + return number == that.number + && startDate == that.startDate + && timeZone == that.timeZone; + } + + public int getNumber() { + return number; + } + + public DateTime getStartDate() { + return startDate; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + } + + /** + * A {@link WindowFn} that windows elements into periods measured by months. + * + *

    By default, periods of multiple months are measured starting at the + * epoch. This can be overridden with {@link #withStartingMonth}. + * + *

    Months start on the first day of each calendar month, unless overridden by + * {@link #beginningOnDay}. + * + *

    The time zone used to determine calendar boundaries is UTC, unless this + * is overridden with the {@link #withTimeZone} method. + */ + public static class MonthsWindows extends PartitioningWindowFn { + public MonthsWindows beginningOnDay(int dayOfMonth) { + return new MonthsWindows( + number, dayOfMonth, startDate, timeZone); + } + + public MonthsWindows withStartingMonth(int year, int month) { + return new MonthsWindows( + number, dayOfMonth, new DateTime(year, month, 1, 0, 0, timeZone), timeZone); + } + + public MonthsWindows withTimeZone(DateTimeZone timeZone) { + return new MonthsWindows( + number, dayOfMonth, startDate.withZoneRetainFields(timeZone), timeZone); + } + + //////////////////////////////////////////////////////////////////////////// + + private int number; + private int dayOfMonth; + private DateTime startDate; + private DateTimeZone timeZone; + + private MonthsWindows(int number, int dayOfMonth, DateTime startDate, DateTimeZone timeZone) { + this.number = number; + this.dayOfMonth = dayOfMonth; + this.startDate = startDate; + this.timeZone = timeZone; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + DateTime datetime = new DateTime(timestamp, timeZone); + + int monthOffset = + Months.monthsBetween(startDate.withDayOfMonth(dayOfMonth), datetime).getMonths() + / number * number; + + DateTime begin = startDate.withDayOfMonth(dayOfMonth).plusMonths(monthOffset); + DateTime end = begin.plusMonths(number); + + return new IntervalWindow(begin.toInstant(), end.toInstant()); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowFn other) { + if (!(other instanceof MonthsWindows)) { + return false; + } + MonthsWindows that = (MonthsWindows) other; + return number == that.number + && dayOfMonth == that.dayOfMonth + && startDate == that.startDate + && timeZone == that.timeZone; + } + + public int getNumber() { + return number; + } + + public int getDayOfMonth() { + return dayOfMonth; + } + + public DateTime getStartDate() { + return startDate; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + } + + /** + * A {@link WindowFn} that windows elements into periods measured by years. + * + *

    By default, periods of multiple years are measured starting at the + * epoch. This can be overridden with {@link #withStartingYear}. + * + *

    Years start on the first day of each calendar year, unless overridden by + * {@link #beginningOnDay}. + * + *

    The time zone used to determine calendar boundaries is UTC, unless this + * is overridden with the {@link #withTimeZone} method. + */ + public static class YearsWindows extends PartitioningWindowFn { + public YearsWindows beginningOnDay(int monthOfYear, int dayOfMonth) { + return new YearsWindows( + number, monthOfYear, dayOfMonth, startDate, timeZone); + } + + public YearsWindows withStartingYear(int year) { + return new YearsWindows( + number, monthOfYear, dayOfMonth, new DateTime(year, 1, 1, 0, 0, timeZone), timeZone); + } + + public YearsWindows withTimeZone(DateTimeZone timeZone) { + return new YearsWindows( + number, monthOfYear, dayOfMonth, startDate.withZoneRetainFields(timeZone), timeZone); + } + + //////////////////////////////////////////////////////////////////////////// + + private int number; + private int monthOfYear; + private int dayOfMonth; + private DateTime startDate; + private DateTimeZone timeZone; + + private YearsWindows( + int number, int monthOfYear, int dayOfMonth, DateTime startDate, DateTimeZone timeZone) { + this.number = number; + this.monthOfYear = monthOfYear; + this.dayOfMonth = dayOfMonth; + this.startDate = startDate; + this.timeZone = timeZone; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + DateTime datetime = new DateTime(timestamp, timeZone); + + DateTime offsetStart = startDate.withMonthOfYear(monthOfYear).withDayOfMonth(dayOfMonth); + + int yearOffset = + Years.yearsBetween(offsetStart, datetime).getYears() / number * number; + + DateTime begin = offsetStart.plusYears(yearOffset); + DateTime end = begin.plusYears(number); + + return new IntervalWindow(begin.toInstant(), end.toInstant()); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowFn other) { + if (!(other instanceof YearsWindows)) { + return false; + } + YearsWindows that = (YearsWindows) other; + return number == that.number + && monthOfYear == that.monthOfYear + && dayOfMonth == that.dayOfMonth + && startDate == that.startDate + && timeZone == that.timeZone; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + public DateTime getStartDate() { + return startDate; + } + + public int getDayOfMonth() { + return dayOfMonth; + } + + public int getMonthOfYear() { + return monthOfYear; + } + + public int getNumber() { + return number; + } + + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/DefaultTrigger.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/DefaultTrigger.java new file mode 100644 index 000000000000..9ac4abd894be --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/DefaultTrigger.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.util.TimeDomain; + +import org.joda.time.Instant; + +import java.util.List; + +/** + * A trigger that is equivalent to {@code Repeatedly.forever(AfterWatermark.pastEndOfWindow())}. + * See {@link Repeatedly#forever} and {@link AfterWatermark#pastEndOfWindow} for more details. + * + * @param The type of windows being triggered/encoded. + */ +@Experimental(Experimental.Kind.TRIGGER) +public class DefaultTrigger extends Trigger{ + + private DefaultTrigger() { + super(null); + } + + /** + * Returns the default trigger. + */ + public static DefaultTrigger of() { + return new DefaultTrigger(); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + // If the end of the window has already been reached, then we are already ready to fire + // and do not need to set a wake-up timer. + if (!endOfWindowReached(c)) { + c.setTimer(c.window().maxTimestamp(), TimeDomain.EVENT_TIME); + } + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + // If the end of the window has already been reached, then we are already ready to fire + // and do not need to set a wake-up timer. + if (!endOfWindowReached(c)) { + c.setTimer(c.window().maxTimestamp(), TimeDomain.EVENT_TIME); + } + } + + @Override + public void clear(TriggerContext c) throws Exception { } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + return window.maxTimestamp(); + } + + @Override + public boolean isCompatible(Trigger other) { + // Semantically, all default triggers are identical + return other instanceof DefaultTrigger; + } + + @Override + public Trigger getContinuationTrigger(List> continuationTriggers) { + return this; + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return endOfWindowReached(context); + } + + private boolean endOfWindowReached(Trigger.TriggerContext context) { + return context.currentEventTime() != null + && context.currentEventTime().isAfter(context.window().maxTimestamp()); + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindows.java new file mode 100644 index 000000000000..12a0f1b9185e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindows.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.Objects; + +/** + * A {@link WindowFn} that windows values into fixed-size timestamp-based windows. + * + *

    For example, in order to partition the data into 10 minute windows: + *

     {@code
    + * PCollection items = ...;
    + * PCollection windowedItems = items.apply(
    + *   Window.into(FixedWindows.of(Duration.standardMinutes(10))));
    + * } 
    + */ +public class FixedWindows extends PartitioningWindowFn { + + /** + * Size of this window. + */ + private final Duration size; + + /** + * Offset of this window. Windows start at time + * N * size + offset, where 0 is the epoch. + */ + private final Duration offset; + + /** + * Partitions the timestamp space into half-open intervals of the form + * [N * size, (N + 1) * size), where 0 is the epoch. + */ + public static FixedWindows of(Duration size) { + return new FixedWindows(size, Duration.ZERO); + } + + /** + * Partitions the timestamp space into half-open intervals of the form + * [N * size + offset, (N + 1) * size + offset), + * where 0 is the epoch. + * + * @throws IllegalArgumentException if offset is not in [0, size) + */ + public FixedWindows withOffset(Duration offset) { + return new FixedWindows(size, offset); + } + + private FixedWindows(Duration size, Duration offset) { + if (offset.isShorterThan(Duration.ZERO) || !offset.isShorterThan(size)) { + throw new IllegalArgumentException( + "FixedWindows WindowingStrategies must have 0 <= offset < size"); + } + this.size = size; + this.offset = offset; + } + + @Override + public IntervalWindow assignWindow(Instant timestamp) { + long start = timestamp.getMillis() + - timestamp.plus(size).minus(offset).getMillis() % size.getMillis(); + return new IntervalWindow(new Instant(start), size); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowFn other) { + return this.equals(other); + } + + public Duration getSize() { + return size; + } + + public Duration getOffset() { + return offset; + } + + @Override + public boolean equals(Object object) { + if (!(object instanceof FixedWindows)) { + return false; + } + FixedWindows other = (FixedWindows) object; + return getOffset().equals(other.getOffset()) + && getSize().equals(other.getSize()); + } + + @Override + public int hashCode() { + return Objects.hash(size, offset); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindow.java new file mode 100644 index 000000000000..d7fc396f493b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindow.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * The default window into which all data is placed (via {@link GlobalWindows}). + */ +public class GlobalWindow extends BoundedWindow { + /** + * Singleton instance of {@link GlobalWindow}. + */ + public static final GlobalWindow INSTANCE = new GlobalWindow(); + + // Triggers use maxTimestamp to set timers' timestamp. Timers fires when + // the watermark passes their timestamps. So, the maxTimestamp needs to be + // smaller than the TIMESTAMP_MAX_VALUE. + // One standard day is subtracted from TIMESTAMP_MAX_VALUE to make sure + // the maxTimestamp is smaller than TIMESTAMP_MAX_VALUE even after rounding up + // to seconds or minutes. + private static final Instant END_OF_GLOBAL_WINDOW = + TIMESTAMP_MAX_VALUE.minus(Duration.standardDays(1)); + + @Override + public Instant maxTimestamp() { + return END_OF_GLOBAL_WINDOW; + } + + private GlobalWindow() {} + + /** + * {@link Coder} for encoding and decoding {@code GlobalWindow}s. + */ + public static class Coder extends AtomicCoder { + public static final Coder INSTANCE = new Coder(); + + @Override + public void encode(GlobalWindow window, OutputStream outStream, Context context) {} + + @Override + public GlobalWindow decode(InputStream inStream, Context context) { + return GlobalWindow.INSTANCE; + } + + private Coder() {} + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindows.java new file mode 100644 index 000000000000..d3d949c7c9c3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/GlobalWindows.java @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Instant; + +import java.util.Collection; +import java.util.Collections; + +/** + * Default {@link WindowFn} that assigns all data to the same window. + */ +public class GlobalWindows extends NonMergingWindowFn { + + private static final Collection GLOBAL_WINDOWS = + Collections.singletonList(GlobalWindow.INSTANCE); + + @Override + public Collection assignWindows(AssignContext c) { + return GLOBAL_WINDOWS; + } + + @Override + public boolean isCompatible(WindowFn o) { + return o instanceof GlobalWindows; + } + + @Override + public Coder windowCoder() { + return GlobalWindow.Coder.INSTANCE; + } + + @Override + public GlobalWindow getSideInputWindow(BoundedWindow window) { + return GlobalWindow.INSTANCE; + } + + @Override + public boolean assignsToSingleWindow() { + return true; + } + + @Override + public Instant getOutputTime(Instant inputTimestamp, GlobalWindow window) { + return inputTimestamp; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindow.java new file mode 100644 index 000000000000..58287c71d6b0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindow.java @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.DurationCoder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.ReadableDuration; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * An implementation of {@link BoundedWindow} that represents an interval from + * {@link #start} (inclusive) to {@link #end} (exclusive). + */ +public class IntervalWindow extends BoundedWindow + implements Comparable { + /** + * Start of the interval, inclusive. + */ + private final Instant start; + + /** + * End of the interval, exclusive. + */ + private final Instant end; + + /** + * Creates a new IntervalWindow that represents the half-open time + * interval [start, end). + */ + public IntervalWindow(Instant start, Instant end) { + this.start = start; + this.end = end; + } + + public IntervalWindow(Instant start, ReadableDuration size) { + this.start = start; + this.end = start.plus(size); + } + + /** + * Returns the start of this window, inclusive. + */ + public Instant start() { + return start; + } + + /** + * Returns the end of this window, exclusive. + */ + public Instant end() { + return end; + } + + /** + * Returns the largest timestamp that can be included in this window. + */ + @Override + public Instant maxTimestamp() { + // end not inclusive + return end.minus(1); + } + + /** + * Returns whether this window contains the given window. + */ + public boolean contains(IntervalWindow other) { + return !this.start.isAfter(other.start) && !this.end.isBefore(other.end); + } + + /** + * Returns whether this window is disjoint from the given window. + */ + public boolean isDisjoint(IntervalWindow other) { + return !this.end.isAfter(other.start) || !other.end.isAfter(this.start); + } + + /** + * Returns whether this window intersects the given window. + */ + public boolean intersects(IntervalWindow other) { + return !isDisjoint(other); + } + + /** + * Returns the minimal window that includes both this window and + * the given window. + */ + public IntervalWindow span(IntervalWindow other) { + return new IntervalWindow( + new Instant(Math.min(start.getMillis(), other.start.getMillis())), + new Instant(Math.max(end.getMillis(), other.end.getMillis()))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof IntervalWindow) + && ((IntervalWindow) o).end.isEqual(end) + && ((IntervalWindow) o).start.isEqual(start); + } + + @Override + public int hashCode() { + // The end values are themselves likely to be arithmetic sequence, which + // is a poor distribution to use for a hashtable, so we + // add a highly non-linear transformation. + return (int) + (start.getMillis() + modInverse((int) (end.getMillis() << 1) + 1)); + } + + /** + * Compute the inverse of (odd) x mod 2^32. + */ + private int modInverse(int x) { + // Cube gives inverse mod 2^4, as x^4 == 1 (mod 2^4) for all odd x. + int inverse = x * x * x; + // Newton iteration doubles correct bits at each step. + inverse *= 2 - x * inverse; + inverse *= 2 - x * inverse; + inverse *= 2 - x * inverse; + return inverse; + } + + @Override + public String toString() { + return "[" + start + ".." + end + ")"; + } + + @Override + public int compareTo(IntervalWindow o) { + if (start.isEqual(o.start)) { + return end.compareTo(o.end); + } + return start.compareTo(o.start); + } + + /** + * Returns a {@link Coder} suitable for {@link IntervalWindow}. + */ + public static Coder getCoder() { + return IntervalWindowCoder.of(); + } + + /** + * Encodes an {@link IntervalWindow} as a pair of its upper bound and duration. + */ + private static class IntervalWindowCoder extends AtomicCoder { + + private static final IntervalWindowCoder INSTANCE = + new IntervalWindowCoder(); + + private static final Coder instantCoder = InstantCoder.of(); + private static final Coder durationCoder = DurationCoder.of(); + + @JsonCreator + public static IntervalWindowCoder of() { + return INSTANCE; + } + + @Override + public void encode(IntervalWindow window, + OutputStream outStream, + Context context) + throws IOException, CoderException { + instantCoder.encode(window.end, outStream, context.nested()); + durationCoder.encode(new Duration(window.start, window.end), outStream, context.nested()); + } + + @Override + public IntervalWindow decode(InputStream inStream, Context context) + throws IOException, CoderException { + Instant end = instantCoder.decode(inStream, context.nested()); + ReadableDuration duration = durationCoder.decode(inStream, context.nested()); + return new IntervalWindow(end.minus(duration), end); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/InvalidWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/InvalidWindows.java new file mode 100644 index 000000000000..596f4e7253dd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/InvalidWindows.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * A {@link WindowFn} that represents an invalid pipeline state. + * + * @param window type + */ +public class InvalidWindows extends WindowFn { + private String cause; + private WindowFn originalWindowFn; + + public InvalidWindows(String cause, WindowFn originalWindowFn) { + this.originalWindowFn = originalWindowFn; + this.cause = cause; + } + + /** + * Returns the reason that this {@code WindowFn} is invalid. + */ + public String getCause() { + return cause; + } + + /** + * Returns the original windowFn that this InvalidWindows replaced. + */ + public WindowFn getOriginalWindowFn() { + return originalWindowFn; + } + + @Override + public Collection assignWindows(AssignContext c) { + throw new UnsupportedOperationException(); + } + + @Override + public void mergeWindows(MergeContext c) { + throw new UnsupportedOperationException(); + } + + @Override + public Coder windowCoder() { + return originalWindowFn.windowCoder(); + } + + /** + * {@code InvalidWindows} objects with the same {@code originalWindowFn} are compatible. + */ + @Override + public boolean isCompatible(WindowFn other) { + return getClass() == other.getClass() + && getOriginalWindowFn().isCompatible( + ((InvalidWindows) other).getOriginalWindowFn()); + } + + @Override + public W getSideInputWindow(BoundedWindow window) { + throw new UnsupportedOperationException("InvalidWindows is not allowed in side inputs"); + } + + @Override + public Instant getOutputTime(Instant inputTimestamp, W window) { + return inputTimestamp; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java new file mode 100644 index 000000000000..4e06234c15f8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * A utility function for merging overlapping {@link IntervalWindow}s. + */ +public class MergeOverlappingIntervalWindows { + + /** + * Merge overlapping {@link IntervalWindow}s. + */ + public static void mergeWindows(WindowFn.MergeContext c) throws Exception { + // Merge any overlapping windows into a single window. + // Sort the list of existing windows so we only have to + // traverse the list once rather than considering all + // O(n^2) window pairs. + List sortedWindows = new ArrayList<>(); + for (IntervalWindow window : c.windows()) { + sortedWindows.add(window); + } + Collections.sort(sortedWindows); + List merges = new ArrayList<>(); + MergeCandidate current = new MergeCandidate(); + for (IntervalWindow window : sortedWindows) { + if (current.intersects(window)) { + current.add(window); + } else { + merges.add(current); + current = new MergeCandidate(window); + } + } + merges.add(current); + for (MergeCandidate merge : merges) { + merge.apply(c); + } + } + + private static class MergeCandidate { + private IntervalWindow union; + private final List parts; + public MergeCandidate() { + parts = new ArrayList<>(); + } + public MergeCandidate(IntervalWindow window) { + union = window; + parts = new ArrayList<>(Arrays.asList(window)); + } + public boolean intersects(IntervalWindow window) { + return union == null || union.intersects(window); + } + public void add(IntervalWindow window) { + union = union == null ? window : union.span(window); + parts.add(window); + } + public void apply(WindowFn.MergeContext c) throws Exception { + if (parts.size() > 1) { + c.merge(parts, union); + } + } + + @Override + public String toString() { + return "MergeCandidate[union=" + union + ", parts=" + parts + "]"; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/NonMergingWindowFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/NonMergingWindowFn.java new file mode 100644 index 000000000000..8aa66fcfc843 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/NonMergingWindowFn.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +/** + * Abstract base class for {@link WindowFn}s that do not merge windows. + * + * @param type of elements being windowed + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code WindowFn} + */ +public abstract class NonMergingWindowFn + extends WindowFn { + @Override + public final void mergeWindows(MergeContext c) { } + + @Override + public final boolean isNonMerging() { + return true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OrFinallyTrigger.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OrFinallyTrigger.java new file mode 100644 index 000000000000..652092ad6e27 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OrFinallyTrigger.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; +import com.google.common.annotations.VisibleForTesting; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.List; + +/** + * Executes the {@code actual} trigger until it finishes or until the {@code until} trigger fires. + */ +class OrFinallyTrigger extends Trigger { + + private static final int ACTUAL = 0; + private static final int UNTIL = 1; + + @VisibleForTesting OrFinallyTrigger(Trigger actual, Trigger.OnceTrigger until) { + super(Arrays.asList(actual, until)); + } + + @Override + public void onElement(OnElementContext c) throws Exception { + c.trigger().subTrigger(ACTUAL).invokeOnElement(c); + c.trigger().subTrigger(UNTIL).invokeOnElement(c); + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + subTrigger.invokeOnMerge(c); + } + updateFinishedState(c); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + // This trigger fires once either the trigger or the until trigger fires. + Instant actualDeadline = subTriggers.get(ACTUAL).getWatermarkThatGuaranteesFiring(window); + Instant untilDeadline = subTriggers.get(UNTIL).getWatermarkThatGuaranteesFiring(window); + return actualDeadline.isBefore(untilDeadline) ? actualDeadline : untilDeadline; + } + + @Override + public Trigger getContinuationTrigger(List> continuationTriggers) { + // Use OrFinallyTrigger instead of AfterFirst because the continuation of ACTUAL + // may not be a OnceTrigger. + return Repeatedly.forever( + new OrFinallyTrigger( + continuationTriggers.get(ACTUAL), + (Trigger.OnceTrigger) continuationTriggers.get(UNTIL))); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return context.trigger().subTrigger(ACTUAL).invokeShouldFire(context) + || context.trigger().subTrigger(UNTIL).invokeShouldFire(context); + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { + ExecutableTrigger actualSubtrigger = context.trigger().subTrigger(ACTUAL); + ExecutableTrigger untilSubtrigger = context.trigger().subTrigger(UNTIL); + + if (untilSubtrigger.invokeShouldFire(context)) { + untilSubtrigger.invokeOnFire(context); + actualSubtrigger.invokeClear(context); + } else { + // If until didn't fire, then the actual must have (or it is forbidden to call + // onFire) so we are done only if actual is done. + actualSubtrigger.invokeOnFire(context); + // Do not clear the until trigger, because it tracks data cross firings. + } + updateFinishedState(context); + } + + private void updateFinishedState(TriggerContext c) throws Exception { + boolean anyStillFinished = false; + for (ExecutableTrigger subTrigger : c.trigger().subTriggers()) { + anyStillFinished |= c.forTrigger(subTrigger).trigger().isFinished(); + } + c.trigger().setFinished(anyStillFinished); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OutputTimeFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OutputTimeFn.java new file mode 100644 index 000000000000..c5d943d3c5eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OutputTimeFn.java @@ -0,0 +1,319 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.common.collect.Ordering; + +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.Objects; + +/** + * (Experimental) A function from timestamps of input values to the timestamp for a + * computed value. + * + *

    The function is represented via three components: + *

      + *
    1. {@link #assignOutputTime} calculates an output timestamp for any input + * value in a particular window.
    2. + *
    3. The output timestamps for all non-late input values within a window are combined + * according to {@link #combine combine()}, a commutative and associative operation on + * the output timestamps.
    4. + *
    5. The output timestamp when windows merge is provided by {@link #merge merge()}.
    6. + *
    + * + *

    This abstract class cannot be subclassed directly, by design: it may grow + * in consumer-compatible ways that require mutually-exclusive default implementations. To + * create a concrete subclass, extend {@link OutputTimeFn.Defaults} or + * {@link OutputTimeFn.DependsOnlyOnWindow}. Note that as long as this class remains + * experimental, we may also choose to change it in arbitrary backwards-incompatible ways. + * + * @param the type of window. Contravariant: methods accepting any subtype of + * {@code OutputTimeFn} should use the parameter type {@code OutputTimeFn}. + */ +@Experimental(Experimental.Kind.OUTPUT_TIME) +public abstract class OutputTimeFn implements Serializable { + + /** + * Private constructor to prevent subclassing other than provided base classes. + */ + private OutputTimeFn() { } + + /** + * Returns the output timestamp to use for data depending on the given + * {@code inputTimestamp} in the specified {@code window}. + * + * + *

    The result of this method must be between {@code inputTimestamp} and + * {@code window.maxTimestamp()} (inclusive on both sides). + * + *

    This function must be monotonic across input timestamps. Specifically, if {@code A < B}, + * then {@code assignOutputTime(A, window) <= assignOutputTime(B, window)}. + * + *

    For a {@link WindowFn} that doesn't produce overlapping windows, this can (and typically + * should) just return {@code inputTimestamp}. In the presence of overlapping windows, it is + * suggested that the result in later overlapping windows is past the end of earlier windows + * so that the later windows don't prevent the watermark from + * progressing past the end of the earlier window. + * + *

    See the overview of {@link OutputTimeFn} for the consistency properties required + * between {@link #assignOutputTime}, {@link #combine}, and {@link #merge}. + */ + public abstract Instant assignOutputTime(Instant inputTimestamp, W window); + + /** + * Combines the given output times, which must be from the same window, into an output time + * for a computed value. + * + *

      + *
    • {@code combine} must be commutative: {@code combine(a, b).equals(combine(b, a))}.
    • + *
    • {@code combine} must be associative: + * {@code combine(a, combine(b, c)).equals(combine(combine(a, b), c))}.
    • + *
    + */ + public abstract Instant combine(Instant outputTime, Instant otherOutputTime); + + /** + * Merges the given output times, presumed to be combined output times for windows that + * are merging, into an output time for the {@code resultWindow}. + * + *

    When windows {@code w1} and {@code w2} merge to become a new window {@code w1plus2}, + * then {@link #merge} must be implemented such that the output time is the same as + * if all timestamps were assigned in {@code w1plus2}. Formally: + * + *

    {@code fn.merge(w, fn.assignOutputTime(t1, w1), fn.assignOutputTime(t2, w2))} + * + *

    must be equal to + * + *

    {@code fn.combine(fn.assignOutputTime(t1, w1plus2), fn.assignOutputTime(t2, w1plus2))} + * + *

    If the assigned time depends only on the window, the correct implementation of + * {@link #merge merge()} necessarily returns the result of + * {@link #assignOutputTime assignOutputTime(t1, w1plus2)} + * (which equals {@link #assignOutputTime assignOutputTime(t2, w1plus2)}. + * Defaults for this case are provided by {@link DependsOnlyOnWindow}. + * + *

    For many other {@link OutputTimeFn} implementations, such as taking the earliest or latest + * timestamp, this will be the same as {@link #combine combine()}. Defaults for this + * case are provided by {@link Defaults}. + */ + public abstract Instant merge(W intoWindow, Iterable mergingTimestamps); + + /** + * Returns {@code true} if the result of combination of many output timestamps actually depends + * only on the earliest. + * + *

    This may allow optimizations when it is very efficient to retrieve the earliest timestamp + * to be combined. + */ + public abstract boolean dependsOnlyOnEarliestInputTimestamp(); + + /** + * Returns {@code true} if the result does not depend on what outputs were combined but only + * the window they are in. The canonical example is if all timestamps are sure to + * be the end of the window. + * + *

    This may allow optimizations, since it is typically very efficient to retrieve the window + * and combining output timestamps is not necessary. + * + *

    If the assigned output time for an implementation depends only on the window, consider + * extending {@link DependsOnlyOnWindow}, which returns {@code true} here and also provides + * a framework for easily implementing a correct {@link #merge}, {@link #combine} and + * {@link #assignOutputTime}. + */ + public abstract boolean dependsOnlyOnWindow(); + + /** + * (Experimental) Default method implementations for {@link OutputTimeFn} where the + * output time depends on the input element timestamps and possibly the window. + * + *

    To complete an implementation, override {@link #assignOutputTime}, at a minimum. + * + *

    By default, {@link #combine} and {@link #merge} return the earliest timestamp of their + * inputs. + */ + public abstract static class Defaults extends OutputTimeFn { + + protected Defaults() { + super(); + } + + /** + * {@inheritDoc} + * + * @return the earlier of the two timestamps. + */ + @Override + public Instant combine(Instant outputTimestamp, Instant otherOutputTimestamp) { + return Ordering.natural().min(outputTimestamp, otherOutputTimestamp); + } + + /** + * {@inheritDoc} + * + * @return the result of {@link #combine combine(outputTimstamp, otherOutputTimestamp)}, + * by default. + */ + @Override + public Instant merge(W resultWindow, Iterable mergingTimestamps) { + return OutputTimeFns.combineOutputTimes(this, mergingTimestamps); + } + + /** + * {@inheritDoc} + * + * @return {@code false}. An {@link OutputTimeFn} that depends only on the window should extend + * {@link OutputTimeFn.DependsOnlyOnWindow}. + */ + @Override + public final boolean dependsOnlyOnWindow() { + return false; + } + + /** + * {@inheritDoc} + * + * @return {@code true} by default. + */ + @Override + public boolean dependsOnlyOnEarliestInputTimestamp() { + return false; + } + + /** + * {@inheritDoc} + * + * @return {@code true} if the two {@link OutputTimeFn} instances have the same class, by + * default. + */ + @Override + public boolean equals(Object other) { + if (other == null) { + return false; + } + + return this.getClass().equals(other.getClass()); + } + + @Override + public int hashCode() { + return Objects.hash(getClass()); + } + } + + /** + * (Experimental) Default method implementations for {@link OutputTimeFn} when the + * output time depends only on the window. + * + *

    To complete an implementation, override {@link #assignOutputTime(BoundedWindow)}. + */ + public abstract static class DependsOnlyOnWindow + extends OutputTimeFn { + + protected DependsOnlyOnWindow() { + super(); + } + + /** + * Returns the output timestamp to use for data in the specified {@code window}. + * + *

    Note that the result of this method must be between the maximum possible input timestamp + * in {@code window} and {@code window.maxTimestamp()} (inclusive on both sides). + * + *

    For example, using {@code Sessions.withGapDuration(gapDuration)}, we know that all input + * timestamps must lie at least {@code gapDuration} from the end of the session, so + * {@code window.maxTimestamp() - gapDuration} is an acceptable assigned timestamp. + * + * @see #assignOutputTime(Instant, BoundedWindow) + */ + protected abstract Instant assignOutputTime(W window); + + /** + * {@inheritDoc} + * + * @return the result of {#link assignOutputTime(BoundedWindow) assignOutputTime(window)}. + */ + @Override + public final Instant assignOutputTime(Instant timestamp, W window) { + return assignOutputTime(window); + } + + /** + * {@inheritDoc} + * + * @return the same timestamp as both argument timestamps, which are necessarily equal. + */ + @Override + public final Instant combine(Instant outputTimestamp, Instant otherOutputTimestamp) { + return outputTimestamp; + } + + /** + * {@inheritDoc} + * + * @return the result of + * {@link #assignOutputTime(BoundedWindow) assignOutputTime(resultWindow)}. + */ + @Override + public final Instant merge(W resultWindow, Iterable mergingTimestamps) { + return assignOutputTime(resultWindow); + } + + /** + * {@inheritDoc} + * + * @return {@code true}. + */ + @Override + public final boolean dependsOnlyOnWindow() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true}. Since the output time depends only on the window, it can + * certainly be ascertained given a single input timestamp. + */ + @Override + public final boolean dependsOnlyOnEarliestInputTimestamp() { + return true; + } + + /** + * {@inheritDoc} + * + * @return {@code true} if the two {@link OutputTimeFn} instances have the same class, by + * default. + */ + @Override + public boolean equals(Object other) { + if (other == null) { + return false; + } + + return this.getClass().equals(other.getClass()); + } + + @Override + public int hashCode() { + return Objects.hash(getClass()); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OutputTimeFns.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OutputTimeFns.java new file mode 100644 index 000000000000..dcc0f5b7b9c2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/OutputTimeFns.java @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.common.collect.Iterables; +import com.google.common.collect.Ordering; + +import org.joda.time.Instant; + +import javax.annotation.Nullable; + +/** + * (Experimental) Static utility methods and provided implementations for + * {@link OutputTimeFn}. + */ +@Experimental(Experimental.Kind.OUTPUT_TIME) +public class OutputTimeFns { + /** + * The policy of outputting at the earliest of the input timestamps for non-late input data + * that led to a computed value. + * + *

    For example, suppose v1 through vn are all on-time + * elements being aggregated via some function {@code f} into + * {@code f}(v1, ..., vn. When emitted, the output + * timestamp of the result will be the earliest of the event time timestamps + * + *

    If data arrives late, it has no effect on the output timestamp. + */ + public static OutputTimeFn outputAtEarliestInputTimestamp() { + return new OutputAtEarliestInputTimestamp(); + } + + /** + * The policy of holding the watermark to the latest of the input timestamps + * for non-late input data that led to a computed value. + * + *

    For example, suppose v1 through vn are all on-time + * elements being aggregated via some function {@code f} into + * {@code f}(v1, ..., vn. When emitted, the output + * timestamp of the result will be the latest of the event time timestamps + * + *

    If data arrives late, it has no effect on the output timestamp. + */ + public static OutputTimeFn outputAtLatestInputTimestamp() { + return new OutputAtLatestInputTimestamp(); + } + + /** + * The policy of outputting with timestamps at the end of the window. + * + *

    Note that this output timestamp depends only on the window. See + * {#link dependsOnlyOnWindow()}. + * + *

    When windows merge, instead of using {@link OutputTimeFn#combine} to obtain an output + * timestamp for the results in the new window, it is mandatory to obtain a new output + * timestamp from {@link OutputTimeFn#assignOutputTime} with the new window and an arbitrary + * timestamp (because it is guaranteed that the timestamp is irrelevant). + * + *

    For non-merging window functions, this {@link OutputTimeFn} works transparently. + */ + public static OutputTimeFn outputAtEndOfWindow() { + return new OutputAtEndOfWindow(); + } + + /** + * Applies the given {@link OutputTimeFn} to the given output times, obtaining + * the output time for a value computed. See {@link OutputTimeFn#combine} for + * a full specification. + * + * @throws IllegalArgumentException if {@code outputTimes} is empty. + */ + public static Instant combineOutputTimes( + OutputTimeFn outputTimeFn, Iterable outputTimes) { + checkArgument( + !Iterables.isEmpty(outputTimes), + "Collection of output times must not be empty in %s.combineOutputTimes", + OutputTimeFns.class.getName()); + + @Nullable + Instant combinedOutputTime = null; + for (Instant outputTime : outputTimes) { + combinedOutputTime = + combinedOutputTime == null + ? outputTime : outputTimeFn.combine(combinedOutputTime, outputTime); + } + return combinedOutputTime; + } + + /** + * See {@link #outputAtEarliestInputTimestamp}. + */ + private static class OutputAtEarliestInputTimestamp extends OutputTimeFn.Defaults { + @Override + public Instant assignOutputTime(Instant inputTimestamp, BoundedWindow window) { + return inputTimestamp; + } + + @Override + public Instant combine(Instant outputTime, Instant otherOutputTime) { + return Ordering.natural().min(outputTime, otherOutputTime); + } + + /** + * {@inheritDoc} + * + * @return {@code true}. The result of any combine will be the earliest input timestamp. + */ + @Override + public boolean dependsOnlyOnEarliestInputTimestamp() { + return true; + } + } + + /** + * See {@link #outputAtLatestInputTimestamp}. + */ + private static class OutputAtLatestInputTimestamp extends OutputTimeFn.Defaults { + @Override + public Instant assignOutputTime(Instant inputTimestamp, BoundedWindow window) { + return inputTimestamp; + } + + @Override + public Instant combine(Instant outputTime, Instant otherOutputTime) { + return Ordering.natural().max(outputTime, otherOutputTime); + } + + /** + * {@inheritDoc} + * + * @return {@code false}. + */ + @Override + public boolean dependsOnlyOnEarliestInputTimestamp() { + return false; + } + } + + private static class OutputAtEndOfWindow extends OutputTimeFn.DependsOnlyOnWindow { + + /** + *{@inheritDoc} + * + *@return {@code window.maxTimestamp()}. + */ + @Override + protected Instant assignOutputTime(BoundedWindow window) { + return window.maxTimestamp(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PaneInfo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PaneInfo.java new file mode 100644 index 000000000000..18f7a973cc9d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PaneInfo.java @@ -0,0 +1,384 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Objects; + +/** + * Provides information about the pane an element belongs to. Every pane is implicitly associated + * with a window. Panes are observable only via the + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.ProcessContext#pane} method of the context + * passed to a {@link DoFn#processElement} overridden method. + * + *

    Note: This does not uniquely identify a pane, and should not be used for comparisons. + */ +public final class PaneInfo { + /** + * Enumerates the possibilities for the timing of this pane firing related to the + * input and output watermarks for its computation. + * + *

    A window may fire multiple panes, and the timing of those panes generally follows the + * regular expression {@code EARLY* ON_TIME? LATE*}. Generally a pane is considered: + *

      + *
    1. {@code EARLY} if the system cannot be sure it has seen all data which may contribute + * to the pane's window. + *
    2. {@code ON_TIME} if the system predicts it has seen all the data which may contribute + * to the pane's window. + *
    3. {@code LATE} if the system has encountered new data after predicting no more could arrive. + * It is possible an {@code ON_TIME} pane has already been emitted, in which case any + * following panes are considered {@code LATE}. + *
    + * + *

    Only an + * {@link AfterWatermark#pastEndOfWindow} trigger may produce an {@code ON_TIME} pane. + * With merging {@link WindowFn}'s, windows may be merged to produce new windows that satisfy + * their own instance of the above regular expression. The only guarantee is that once a window + * produces a final pane, it will not be merged into any new windows. + * + *

    The predictions above are made using the mechanism of watermarks. + * See {@link com.google.cloud.dataflow.sdk.util.TimerInternals} for more information + * about watermarks. + * + *

    We can state some properties of {@code LATE} and {@code ON_TIME} panes, but first need some + * definitions: + *

      + *
    1. We'll call a pipeline 'simple' if it does not use + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#outputWithTimestamp} in + * any {@code DoFn}, and it uses the same + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window.Bound#withAllowedLateness} + * argument value on all windows (or uses the default of {@link org.joda.time.Duration#ZERO}). + *
    2. We'll call an element 'locally late', from the point of view of a computation on a + * worker, if the element's timestamp is before the input watermark for that computation + * on that worker. The element is otherwise 'locally on-time'. + *
    3. We'll say 'the pane's timestamp' to mean the timestamp of the element produced to + * represent the pane's contents. + *
    + * + *

    Then in simple pipelines: + *

      + *
    1. (Soundness) An {@code ON_TIME} pane can never cause a later computation to generate a + * {@code LATE} pane. (If it did, it would imply a later computation's input watermark progressed + * ahead of an earlier stage's output watermark, which by design is not possible.) + *
    2. (Liveness) An {@code ON_TIME} pane is emitted as soon as possible after the input + * watermark passes the end of the pane's window. + *
    3. (Consistency) A pane with only locally on-time elements will always be {@code ON_TIME}. + * And a {@code LATE} pane cannot contain locally on-time elements. + *
    + * + * However, note that: + *
      + *
    1. An {@code ON_TIME} pane may contain locally late elements. It may even contain only + * locally late elements. Provided a locally late element finds its way into an {@code ON_TIME} + * pane its lateness becomes unobservable. + *
    2. A {@code LATE} pane does not necessarily cause any following computation panes to be + * marked as {@code LATE}. + *
    + */ + public enum Timing { + /** + * Pane was fired before the input watermark had progressed after the end of the window. + */ + EARLY, + /** + * Pane was fired by a {@link AfterWatermark#pastEndOfWindow} trigger because the input + * watermark progressed after the end of the window. However the output watermark has not + * yet progressed after the end of the window. Thus it is still possible to assign a timestamp + * to the element representing this pane which cannot be considered locally late by any + * following computation. + */ + ON_TIME, + /** + * Pane was fired after the output watermark had progressed past the end of the window. + */ + LATE, + /** + * This element was not produced in a triggered pane and its relation to input and + * output watermarks is unknown. + */ + UNKNOWN; + + // NOTE: Do not add fields or re-order them. The ordinal is used as part of + // the encoding. + } + + private static byte encodedByte(boolean isFirst, boolean isLast, Timing timing) { + byte result = 0x0; + if (isFirst) { + result |= 1; + } + if (isLast) { + result |= 2; + } + result |= timing.ordinal() << 2; + return result; + } + + private static final ImmutableMap BYTE_TO_PANE_INFO; + static { + ImmutableMap.Builder decodingBuilder = ImmutableMap.builder(); + for (Timing timing : Timing.values()) { + long onTimeIndex = timing == Timing.EARLY ? -1 : 0; + register(decodingBuilder, new PaneInfo(true, true, timing, 0, onTimeIndex)); + register(decodingBuilder, new PaneInfo(true, false, timing, 0, onTimeIndex)); + register(decodingBuilder, new PaneInfo(false, true, timing, -1, onTimeIndex)); + register(decodingBuilder, new PaneInfo(false, false, timing, -1, onTimeIndex)); + } + BYTE_TO_PANE_INFO = decodingBuilder.build(); + } + + private static void register(ImmutableMap.Builder builder, PaneInfo info) { + builder.put(info.encodedByte, info); + } + + private final byte encodedByte; + + private final boolean isFirst; + private final boolean isLast; + private final Timing timing; + private final long index; + private final long nonSpeculativeIndex; + + /** + * {@code PaneInfo} to use for elements on (and before) initial window assignemnt (including + * elements read from sources) before they have passed through a {@link GroupByKey} and are + * associated with a particular trigger firing. + */ + public static final PaneInfo NO_FIRING = + PaneInfo.createPane(true, true, Timing.UNKNOWN, 0, 0); + + /** + * {@code PaneInfo} to use when there will be exactly one firing and it is on time. + */ + public static final PaneInfo ON_TIME_AND_ONLY_FIRING = + PaneInfo.createPane(true, true, Timing.ON_TIME, 0, 0); + + private PaneInfo(boolean isFirst, boolean isLast, Timing timing, long index, long onTimeIndex) { + this.encodedByte = encodedByte(isFirst, isLast, timing); + this.isFirst = isFirst; + this.isLast = isLast; + this.timing = timing; + this.index = index; + this.nonSpeculativeIndex = onTimeIndex; + } + + public static PaneInfo createPane(boolean isFirst, boolean isLast, Timing timing) { + Preconditions.checkArgument(isFirst, "Indices must be provided for non-first pane info."); + return createPane(isFirst, isLast, timing, 0, timing == Timing.EARLY ? -1 : 0); + } + + /** + * Factory method to create a {@link PaneInfo} with the specified parameters. + */ + public static PaneInfo createPane( + boolean isFirst, boolean isLast, Timing timing, long index, long onTimeIndex) { + if (isFirst || timing == Timing.UNKNOWN) { + return Preconditions.checkNotNull( + BYTE_TO_PANE_INFO.get(encodedByte(isFirst, isLast, timing))); + } else { + return new PaneInfo(isFirst, isLast, timing, index, onTimeIndex); + } + } + + public static PaneInfo decodePane(byte encodedPane) { + return Preconditions.checkNotNull(BYTE_TO_PANE_INFO.get(encodedPane)); + } + + /** + * Return true if there is no timing information for the current {@link PaneInfo}. + * This typically indicates that the current element has not been assigned to + * windows or passed through an operation that executes triggers yet. + */ + public boolean isUnknown() { + return Timing.UNKNOWN.equals(timing); + } + + /** + * Return true if this is the first pane produced for the associated window. + */ + public boolean isFirst() { + return isFirst; + } + + /** + * Return true if this is the last pane that will be produced in the associated window. + */ + public boolean isLast() { + return isLast; + } + + /** + * Return true if this is the last pane that will be produced in the associated window. + */ + public Timing getTiming() { + return timing; + } + + /** + * The zero-based index of this trigger firing that produced this pane. + * + *

    This will return 0 for the first time the timer fires, 1 for the next time, etc. + * + *

    A given (key, window, pane-index) is guaranteed to be unique in the + * output of a group-by-key operation. + */ + public long getIndex() { + return index; + } + + /** + * The zero-based index of this trigger firing among non-speculative panes. + * + *

    This will return 0 for the first non-{@link Timing#EARLY} timer firing, 1 for the next one, + * etc. + * + *

    Always -1 for speculative data. + */ + public long getNonSpeculativeIndex() { + return nonSpeculativeIndex; + } + + int getEncodedByte() { + return encodedByte; + } + + @Override + public int hashCode() { + return Objects.hash(encodedByte, index, nonSpeculativeIndex); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + // Simple PaneInfos are interned. + return true; + } else if (obj instanceof PaneInfo) { + PaneInfo that = (PaneInfo) obj; + return this.encodedByte == that.encodedByte + && this.index == that.index + && this.nonSpeculativeIndex == that.nonSpeculativeIndex; + } else { + return false; + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .omitNullValues() + .add("isFirst", isFirst ? true : null) + .add("isLast", isLast ? true : null) + .add("timing", timing) + .add("index", index) + .add("onTimeIndex", nonSpeculativeIndex != -1 ? nonSpeculativeIndex : null) + .toString(); + } + + /** + * A Coder for encoding PaneInfo instances. + */ + public static class PaneInfoCoder extends AtomicCoder { + private static enum Encoding { + FIRST, + ONE_INDEX, + TWO_INDICES; + + // NOTE: Do not reorder fields. The ordinal is used as part of + // the encoding. + + public final byte tag; + + private Encoding() { + assert ordinal() < 16; + tag = (byte) (ordinal() << 4); + } + + public static Encoding fromTag(byte b) { + return Encoding.values()[b >> 4]; + } + } + + private Encoding chooseEncoding(PaneInfo value) { + if (value.index == 0 && value.nonSpeculativeIndex == 0 || value.timing == Timing.UNKNOWN) { + return Encoding.FIRST; + } else if (value.index == value.nonSpeculativeIndex || value.timing == Timing.EARLY) { + return Encoding.ONE_INDEX; + } else { + return Encoding.TWO_INDICES; + } + } + + public static final PaneInfoCoder INSTANCE = new PaneInfoCoder(); + + @Override + public void encode(PaneInfo value, final OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + Encoding encoding = chooseEncoding(value); + switch (chooseEncoding(value)) { + case FIRST: + outStream.write(value.encodedByte); + break; + case ONE_INDEX: + outStream.write(value.encodedByte | encoding.tag); + VarInt.encode(value.index, outStream); + break; + case TWO_INDICES: + outStream.write(value.encodedByte | encoding.tag); + VarInt.encode(value.index, outStream); + VarInt.encode(value.nonSpeculativeIndex, outStream); + break; + default: + throw new CoderException("Unknown encoding " + encoding); + } + } + + @Override + public PaneInfo decode(final InputStream inStream, Coder.Context context) + throws CoderException, IOException { + byte keyAndTag = (byte) inStream.read(); + PaneInfo base = BYTE_TO_PANE_INFO.get((byte) (keyAndTag & 0x0F)); + long index, onTimeIndex; + switch (Encoding.fromTag(keyAndTag)) { + case FIRST: + return base; + case ONE_INDEX: + index = VarInt.decodeLong(inStream); + onTimeIndex = base.timing == Timing.EARLY ? -1 : index; + break; + case TWO_INDICES: + index = VarInt.decodeLong(inStream); + onTimeIndex = VarInt.decodeLong(inStream); + break; + default: + throw new CoderException("Unknown encoding " + (keyAndTag & 0xF0)); + } + return new PaneInfo(base.isFirst, base.isLast, base.timing, index, onTimeIndex); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PartitioningWindowFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PartitioningWindowFn.java new file mode 100644 index 000000000000..bea0285b61a8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/PartitioningWindowFn.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.Collection; + +/** + * A {@link WindowFn} that places each value into exactly one window based on its timestamp and + * never merges windows. + * + * @param type of elements being windowed + * @param window type + */ +public abstract class PartitioningWindowFn + extends NonMergingWindowFn { + /** + * Returns the single window to which elements with this timestamp belong. + */ + public abstract W assignWindow(Instant timestamp); + + @Override + public final Collection assignWindows(AssignContext c) { + return Arrays.asList(assignWindow(c.timestamp())); + } + + @Override + public W getSideInputWindow(final BoundedWindow window) { + if (window instanceof GlobalWindow) { + throw new IllegalArgumentException( + "Attempted to get side input window for GlobalWindow from non-global WindowFn"); + } + return assignWindow(window.maxTimestamp()); + } + + @Override + public boolean assignsToSingleWindow() { + return true; + } + + @Override + public Instant getOutputTime(Instant inputTimestamp, W window) { + return inputTimestamp; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Repeatedly.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Repeatedly.java new file mode 100644 index 000000000000..e77e2a120338 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Repeatedly.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; + +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.List; + +/** + * Repeat a trigger, either until some condition is met or forever. + * + *

    For example, to fire after the end of the window, and every time late data arrives: + *

     {@code
    + *     Repeatedly.forever(AfterWatermark.isPastEndOfWindow());
    + * } 
    + * + *

    {@code Repeatedly.forever(someTrigger)} behaves like an infinite + * {@code AfterEach.inOrder(someTrigger, someTrigger, someTrigger, ...)}. + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code Trigger} + */ +public class Repeatedly extends Trigger { + + private static final int REPEATED = 0; + + /** + * Create a composite trigger that repeatedly executes the trigger {@code toRepeat}, firing each + * time it fires and ignoring any indications to finish. + * + *

    Unless used with {@link Trigger#orFinally} the composite trigger will never finish. + * + * @param repeated the trigger to execute repeatedly. + */ + public static Repeatedly forever(Trigger repeated) { + return new Repeatedly(repeated); + } + + private Repeatedly(Trigger repeated) { + super(Arrays.asList(repeated)); + } + + + @Override + public void onElement(OnElementContext c) throws Exception { + getRepeated(c).invokeOnElement(c); + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + getRepeated(c).invokeOnMerge(c); + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + // This trigger fires once the repeated trigger fires. + return subTriggers.get(REPEATED).getWatermarkThatGuaranteesFiring(window); + } + + @Override + public Trigger getContinuationTrigger(List> continuationTriggers) { + return new Repeatedly(continuationTriggers.get(REPEATED)); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return getRepeated(context).invokeShouldFire(context); + } + + @Override + public void onFire(TriggerContext context) throws Exception { + getRepeated(context).invokeOnFire(context); + + if (context.trigger().isFinished(REPEATED)) { + context.trigger().setFinished(false, REPEATED); + getRepeated(context).invokeClear(context); + } + } + + private ExecutableTrigger getRepeated(TriggerContext context) { + return context.trigger().subTrigger(REPEATED); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Sessions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Sessions.java new file mode 100644 index 000000000000..da137c1f47f1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Sessions.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Duration; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Objects; + +/** + * A {@link WindowFn} windowing values into sessions separated by {@link #gapDuration}-long + * periods with no elements. + * + *

    For example, in order to window data into session with at least 10 minute + * gaps in between them: + *

     {@code
    + * PCollection pc = ...;
    + * PCollection windowed_pc = pc.apply(
    + *   Window.into(Sessions.withGapDuration(Duration.standardMinutes(10))));
    + * } 
    + */ +public class Sessions extends WindowFn { + /** + * Duration of the gaps between sessions. + */ + private final Duration gapDuration; + + /** + * Creates a {@code Sessions} {@link WindowFn} with the specified gap duration. + */ + public static Sessions withGapDuration(Duration gapDuration) { + return new Sessions(gapDuration); + } + + /** + * Creates a {@code Sessions} {@link WindowFn} with the specified gap duration. + */ + private Sessions(Duration gapDuration) { + this.gapDuration = gapDuration; + } + + @Override + public Collection assignWindows(AssignContext c) { + // Assign each element into a window from its timestamp until gapDuration in the + // future. Overlapping windows (representing elements within gapDuration of + // each other) will be merged. + return Arrays.asList(new IntervalWindow(c.timestamp(), gapDuration)); + } + + @Override + public void mergeWindows(MergeContext c) throws Exception { + MergeOverlappingIntervalWindows.mergeWindows(c); + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public boolean isCompatible(WindowFn other) { + return other instanceof Sessions; + } + + @Override + public IntervalWindow getSideInputWindow(BoundedWindow window) { + throw new UnsupportedOperationException("Sessions is not allowed in side inputs"); + } + + @Experimental(Kind.OUTPUT_TIME) + @Override + public OutputTimeFn getOutputTimeFn() { + return OutputTimeFns.outputAtEarliestInputTimestamp(); + } + + public Duration getGapDuration() { + return gapDuration; + } + + @Override + public boolean equals(Object object) { + if (!(object instanceof Sessions)) { + return false; + } + Sessions other = (Sessions) object; + return getGapDuration().equals(other.getGapDuration()); + } + + @Override + public int hashCode() { + return Objects.hash(gapDuration); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindows.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindows.java new file mode 100644 index 000000000000..b0066d6124eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindows.java @@ -0,0 +1,214 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.Coder; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** + * A {@link WindowFn} that windows values into possibly overlapping fixed-size + * timestamp-based windows. + * + *

    For example, in order to window data into 10 minute windows that + * update every minute: + *

     {@code
    + * PCollection items = ...;
    + * PCollection windowedItems = items.apply(
    + *   Window.into(SlidingWindows.of(Duration.standardMinutes(10))));
    + * } 
    + */ +public class SlidingWindows extends NonMergingWindowFn { + + /** + * Amount of time between generated windows. + */ + private final Duration period; + + /** + * Size of the generated windows. + */ + private final Duration size; + + /** + * Offset of the generated windows. + * Windows start at time N * start + offset, where 0 is the epoch. + */ + private final Duration offset; + + /** + * Assigns timestamps into half-open intervals of the form + * [N * period, N * period + size), where 0 is the epoch. + * + *

    If {@link SlidingWindows#every} is not called, the period defaults + * to the largest time unit smaller than the given duration. For example, + * specifying a size of 5 seconds will result in a default period of 1 second. + */ + public static SlidingWindows of(Duration size) { + return new SlidingWindows(getDefaultPeriod(size), size, Duration.ZERO); + } + + /** + * Returns a new {@code SlidingWindows} with the original size, that assigns + * timestamps into half-open intervals of the form + * [N * period, N * period + size), where 0 is the epoch. + */ + public SlidingWindows every(Duration period) { + return new SlidingWindows(period, size, offset); + } + + /** + * Assigns timestamps into half-open intervals of the form + * [N * period + offset, N * period + offset + size). + * + * @throws IllegalArgumentException if offset is not in [0, period) + */ + public SlidingWindows withOffset(Duration offset) { + return new SlidingWindows(period, size, offset); + } + + private SlidingWindows(Duration period, Duration size, Duration offset) { + if (offset.isShorterThan(Duration.ZERO) + || !offset.isShorterThan(period) + || !size.isLongerThan(Duration.ZERO)) { + throw new IllegalArgumentException( + "SlidingWindows WindowingStrategies must have 0 <= offset < period and 0 < size"); + } + this.period = period; + this.size = size; + this.offset = offset; + } + + @Override + public Coder windowCoder() { + return IntervalWindow.getCoder(); + } + + @Override + public Collection assignWindows(AssignContext c) { + List windows = + new ArrayList<>((int) (size.getMillis() / period.getMillis())); + Instant timestamp = c.timestamp(); + long lastStart = lastStartFor(timestamp); + for (long start = lastStart; + start > timestamp.minus(size).getMillis(); + start -= period.getMillis()) { + windows.add(new IntervalWindow(new Instant(start), size)); + } + return windows; + } + + /** + * Return the earliest window that contains the end of the main-input window. + */ + @Override + public IntervalWindow getSideInputWindow(final BoundedWindow window) { + if (window instanceof GlobalWindow) { + throw new IllegalArgumentException( + "Attempted to get side input window for GlobalWindow from non-global WindowFn"); + } + long lastStart = lastStartFor(window.maxTimestamp().minus(size)); + return new IntervalWindow(new Instant(lastStart + period.getMillis()), size); + } + + @Override + public boolean isCompatible(WindowFn other) { + return equals(other); + } + + /** + * Return the last start of a sliding window that contains the timestamp. + */ + private long lastStartFor(Instant timestamp) { + return timestamp.getMillis() + - timestamp.plus(period).minus(offset).getMillis() % period.getMillis(); + } + + static Duration getDefaultPeriod(Duration size) { + if (size.isLongerThan(Duration.standardHours(1))) { + return Duration.standardHours(1); + } + if (size.isLongerThan(Duration.standardMinutes(1))) { + return Duration.standardMinutes(1); + } + if (size.isLongerThan(Duration.standardSeconds(1))) { + return Duration.standardSeconds(1); + } + return Duration.millis(1); + } + + public Duration getPeriod() { + return period; + } + + public Duration getSize() { + return size; + } + + public Duration getOffset() { + return offset; + } + + /** + * Ensures that later sliding windows have an output time that is past the end of earlier windows. + * + *

    If this is the earliest sliding window containing {@code inputTimestamp}, that's fine. + * Otherwise, we pick the earliest time that doesn't overlap with earlier windows. + */ + @Experimental(Kind.OUTPUT_TIME) + @Override + public OutputTimeFn getOutputTimeFn() { + return new OutputTimeFn.Defaults() { + @Override + public Instant assignOutputTime(Instant inputTimestamp, BoundedWindow window) { + Instant startOfLastSegment = window.maxTimestamp().minus(period); + return startOfLastSegment.isBefore(inputTimestamp) + ? inputTimestamp + : startOfLastSegment.plus(1); + } + + @Override + public boolean dependsOnlyOnEarliestInputTimestamp() { + return true; + } + }; + } + + @Override + public boolean equals(Object object) { + if (!(object instanceof SlidingWindows)) { + return false; + } + SlidingWindows other = (SlidingWindows) object; + return getOffset().equals(other.getOffset()) + && getSize().equals(other.getSize()) + && getPeriod().equals(other.getPeriod()); + } + + @Override + public int hashCode() { + return Objects.hash(size, offset, period); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Trigger.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Trigger.java new file mode 100644 index 000000000000..4471563e70c5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Trigger.java @@ -0,0 +1,544 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.util.ExecutableTrigger; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.common.base.Joiner; + +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * {@code Trigger}s control when the elements for a specific key and window are output. As elements + * arrive, they are put into one or more windows by a {@link Window} transform and its associated + * {@link WindowFn}, and then passed to the associated {@code Trigger} to determine if the + * {@code Window}s contents should be output. + * + *

    See {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} and {@link Window} + * for more information about how grouping with windows works. + * + *

    The elements that are assigned to a window since the last time it was fired (or since the + * window was created) are placed into the current window pane. Triggers are evaluated against the + * elements as they are added. When the root trigger fires, the elements in the current pane will be + * output. When the root trigger finishes (indicating it will never fire again), the window is + * closed and any new elements assigned to that window are discarded. + * + *

    Several predefined {@code Trigger}s are provided: + *

      + *
    • {@link AfterWatermark} for firing when the watermark passes a timestamp determined from + * either the end of the window or the arrival of the first element in a pane. + *
    • {@link AfterProcessingTime} for firing after some amount of processing time has elapsed + * (typically since the first element in a pane). + *
    • {@link AfterPane} for firing off a property of the elements in the current pane, such as + * the number of elements that have been assigned to the current pane. + *
    + * + *

    In addition, {@code Trigger}s can be combined in a variety of ways: + *

      + *
    • {@link Repeatedly#forever} to create a trigger that executes forever. Any time its + * argument finishes it gets reset and starts over. Can be combined with + * {@link Trigger#orFinally} to specify a condition that causes the repetition to stop. + *
    • {@link AfterEach#inOrder} to execute each trigger in sequence, firing each (and every) + * time that a trigger fires, and advancing to the next trigger in the sequence when it finishes. + *
    • {@link AfterFirst#of} to create a trigger that fires after at least one of its arguments + * fires. An {@link AfterFirst} trigger finishes after it fires once. + *
    • {@link AfterAll#of} to create a trigger that fires after all least one of its arguments + * have fired at least once. An {@link AfterAll} trigger finishes after it fires once. + *
    + * + *

    Each trigger tree is instantiated per-key and per-window. Every trigger in the tree is in one + * of the following states: + *

      + *
    • Never Existed - before the trigger has started executing, there is no state associated + * with it anywhere in the system. A trigger moves to the executing state as soon as it + * processes in the current pane. + *
    • Executing - while the trigger is receiving items and may fire. While it is in this state, + * it may persist book-keeping information to persisted state, set timers, etc. + *
    • Finished - after a trigger finishes, all of its book-keeping data is cleaned up, and the + * system remembers only that it is finished. Entering this state causes us to discard any + * elements in the buffer for that window, as well. + *
    + * + *

    Once finished, a trigger cannot return itself back to an earlier state, however a composite + * trigger could reset its sub-triggers. + * + *

    Triggers should not build up any state internally since they may be recreated + * between invocations of the callbacks. All important values should be persisted using + * state before the callback returns. + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code Trigger} + */ +@Experimental(Experimental.Kind.TRIGGER) +public abstract class Trigger implements Serializable, TriggerBuilder { + + /** + * Interface for accessing information about the trigger being executed and other triggers in the + * same tree. + */ + public interface TriggerInfo { + + /** + * Returns true if the windowing strategy of the current {@code PCollection} is a merging + * WindowFn. If true, the trigger execution needs to keep enough information to support the + * possibility of {@link Trigger#onMerge} being called. If false, {@link Trigger#onMerge} will + * never be called. + */ + boolean isMerging(); + + /** + * Access the executable versions of the sub-triggers of the current trigger. + */ + Iterable> subTriggers(); + + /** + * Access the executable version of the specified sub-trigger. + */ + ExecutableTrigger subTrigger(int subtriggerIndex); + + /** + * Returns true if the current trigger is marked finished. + */ + boolean isFinished(); + + /** + * Return true if the given subtrigger is marked finished. + */ + boolean isFinished(int subtriggerIndex); + + /** + * Returns true if all the sub-triggers of the current trigger are marked finished. + */ + boolean areAllSubtriggersFinished(); + + /** + * Returns an iterable over the unfinished sub-triggers of the current trigger. + */ + Iterable> unfinishedSubTriggers(); + + /** + * Returns the first unfinished sub-trigger. + */ + ExecutableTrigger firstUnfinishedSubTrigger(); + + /** + * Clears all keyed state for triggers in the current sub-tree and unsets all the associated + * finished bits. + */ + void resetTree() throws Exception; + + /** + * Sets the finished bit for the current trigger. + */ + void setFinished(boolean finished); + + /** + * Sets the finished bit for the given sub-trigger. + */ + void setFinished(boolean finished, int subTriggerIndex); + } + + /** + * Interact with properties of the trigger being executed, with extensions to deal with the + * merging windows. + */ + public interface MergingTriggerInfo extends TriggerInfo { + + /** Return true if the trigger is finished in any window being merged. */ + public abstract boolean finishedInAnyMergingWindow(); + + /** Return true if the trigger is finished in all windows being merged. */ + public abstract boolean finishedInAllMergingWindows(); + + /** Return the merging windows in which the trigger is finished. */ + public abstract Iterable getFinishedMergingWindows(); + } + + /** + * Information accessible to all operational hooks in this {@code Trigger}. + * + *

    Used directly in {@link Trigger#shouldFire} and {@link Trigger#clear}, and + * extended with additional information in other methods. + */ + public abstract class TriggerContext { + + /** Returns the interface for accessing trigger info. */ + public abstract TriggerInfo trigger(); + + /** Returns the interface for accessing persistent state. */ + public abstract StateAccessor state(); + + /** The window that the current context is executing in. */ + public abstract W window(); + + /** Create a sub-context for the given sub-trigger. */ + public abstract TriggerContext forTrigger(ExecutableTrigger trigger); + + /** + * Removes the timer set in this trigger context for the given {@link Instant} + * and {@link TimeDomain}. + */ + public abstract void deleteTimer(Instant timestamp, TimeDomain domain); + + /** The current processing time. */ + public abstract Instant currentProcessingTime(); + + /** The current synchronized upstream processing time or {@code null} if unknown. */ + @Nullable + public abstract Instant currentSynchronizedProcessingTime(); + + /** The current event time for the input or {@code null} if unknown. */ + @Nullable + public abstract Instant currentEventTime(); + } + + /** + * Extended {@link TriggerContext} containing information accessible to the {@link #onElement} + * operational hook. + */ + public abstract class OnElementContext extends TriggerContext { + /** The event timestamp of the element currently being processed. */ + public abstract Instant eventTimestamp(); + + /** + * Sets a timer to fire when the watermark or processing time is beyond the given timestamp. + * Timers are not guaranteed to fire immediately, but will be delivered at some time afterwards. + * + *

    As with {@link #state}, timers are implicitly scoped to the current window. All + * timer firings for a window will be received, but the implementation should choose to ignore + * those that are not applicable. + * + * @param timestamp the time at which the trigger should be re-evaluated + * @param domain the domain that the {@code timestamp} applies to + */ + public abstract void setTimer(Instant timestamp, TimeDomain domain); + + /** Create an {@code OnElementContext} for executing the given trigger. */ + @Override + public abstract OnElementContext forTrigger(ExecutableTrigger trigger); + } + + /** + * Extended {@link TriggerContext} containing information accessible to the {@link #onMerge} + * operational hook. + */ + public abstract class OnMergeContext extends TriggerContext { + /** + * Sets a timer to fire when the watermark or processing time is beyond the given timestamp. + * Timers are not guaranteed to fire immediately, but will be delivered at some time afterwards. + * + *

    As with {@link #state}, timers are implicitly scoped to the current window. All + * timer firings for a window will be received, but the implementation should choose to ignore + * those that are not applicable. + * + * @param timestamp the time at which the trigger should be re-evaluated + * @param domain the domain that the {@code timestamp} applies to + */ + public abstract void setTimer(Instant timestamp, TimeDomain domain); + + /** Create an {@code OnMergeContext} for executing the given trigger. */ + @Override + public abstract OnMergeContext forTrigger(ExecutableTrigger trigger); + + @Override + public abstract MergingStateAccessor state(); + + @Override + public abstract MergingTriggerInfo trigger(); + } + + @Nullable + protected final List> subTriggers; + + protected Trigger(@Nullable List> subTriggers) { + this.subTriggers = subTriggers; + } + + + /** + * Called immediately after an element is first incorporated into a window. + */ + public abstract void onElement(OnElementContext c) throws Exception; + + /** + * Called immediately after windows have been merged. + * + *

    Leaf triggers should update their state by inspecting their status and any state + * in the merging windows. Composite triggers should update their state by calling + * {@link ExecutableTrigger#invokeOnMerge} on their sub-triggers, and applying appropriate logic. + * + *

    A trigger such as {@link AfterWatermark#pastEndOfWindow} may no longer be finished; + * it is the responsibility of the trigger itself to record this fact. It is forbidden for + * a trigger to become finished due to {@link #onMerge}, as it has not yet fired the pending + * elements that led to it being ready to fire. + * + *

    The implementation does not need to clear out any state associated with the old windows. + */ + public abstract void onMerge(OnMergeContext c) throws Exception; + + /** + * Returns {@code true} if the current state of the trigger indicates that its condition + * is satisfied and it is ready to fire. + */ + public abstract boolean shouldFire(TriggerContext context) throws Exception; + + /** + * Adjusts the state of the trigger to be ready for the next pane. For example, a + * {@link Repeatedly} trigger will reset its inner trigger, since it has fired. + * + *

    If the trigger is finished, it is the responsibility of the trigger itself to + * record that fact via the {@code context}. + */ + public abstract void onFire(TriggerContext context) throws Exception; + + /** + * Called to allow the trigger to prefetch any state it will likely need to read from during + * an {@link #onElement} call. + */ + public void prefetchOnElement(StateAccessor state) { + if (subTriggers != null) { + for (Trigger trigger : subTriggers) { + trigger.prefetchOnElement(state); + } + } + } + + /** + * Called to allow the trigger to prefetch any state it will likely need to read from during + * an {@link #onMerge} call. + */ + public void prefetchOnMerge(MergingStateAccessor state) { + if (subTriggers != null) { + for (Trigger trigger : subTriggers) { + trigger.prefetchOnMerge(state); + } + } + } + + /** + * Called to allow the trigger to prefetch any state it will likely need to read from during + * an {@link #shouldFire} call. + */ + public void prefetchShouldFire(StateAccessor state) { + if (subTriggers != null) { + for (Trigger trigger : subTriggers) { + trigger.prefetchShouldFire(state); + } + } + } + + /** + * Called to allow the trigger to prefetch any state it will likely need to read from during + * an {@link #onFire} call. + */ + public void prefetchOnFire(StateAccessor state) { + if (subTriggers != null) { + for (Trigger trigger : subTriggers) { + trigger.prefetchOnFire(state); + } + } + } + + /** + * Clear any state associated with this trigger in the given window. + * + *

    This is called after a trigger has indicated it will never fire again. The trigger system + * keeps enough information to know that the trigger is finished, so this trigger should clear all + * of its state. + */ + public void clear(TriggerContext c) throws Exception { + if (subTriggers != null) { + for (ExecutableTrigger trigger : c.trigger().subTriggers()) { + trigger.invokeClear(c); + } + } + } + + public Iterable> subTriggers() { + return subTriggers; + } + + /** + * Return a trigger to use after a {@code GroupByKey} to preserve the + * intention of this trigger. Specifically, triggers that are time based + * and intended to provide speculative results should continue providing + * speculative results. Triggers that fire once (or multiple times) should + * continue firing once (or multiple times). + */ + public Trigger getContinuationTrigger() { + if (subTriggers == null) { + return getContinuationTrigger(null); + } + + List> subTriggerContinuations = new ArrayList<>(); + for (Trigger subTrigger : subTriggers) { + subTriggerContinuations.add(subTrigger.getContinuationTrigger()); + } + return getContinuationTrigger(subTriggerContinuations); + } + + /** + * Return the {@link #getContinuationTrigger} of this {@code Trigger}. For convenience, this + * is provided the continuation trigger of each of the sub-triggers. + */ + protected abstract Trigger getContinuationTrigger(List> continuationTriggers); + + /** + * Returns a bound in watermark time by which this trigger would have fired at least once + * for a given window had there been input data. This is a static property of a trigger + * that does not depend on its state. + * + *

    For triggers that do not fire based on the watermark advancing, returns + * {@link BoundedWindow#TIMESTAMP_MAX_VALUE}. + * + *

    This estimate is used to determine that there are no elements in a side-input window, which + * causes the default value to be used instead. + */ + public abstract Instant getWatermarkThatGuaranteesFiring(W window); + + /** + * Returns whether this performs the same triggering as the given {@code Trigger}. + */ + public boolean isCompatible(Trigger other) { + if (!getClass().equals(other.getClass())) { + return false; + } + + if (subTriggers == null) { + return other.subTriggers == null; + } else if (other.subTriggers == null) { + return false; + } else if (subTriggers.size() != other.subTriggers.size()) { + return false; + } + + for (int i = 0; i < subTriggers.size(); i++) { + if (!subTriggers.get(i).isCompatible(other.subTriggers.get(i))) { + return false; + } + } + + return true; + } + + @Override + public String toString() { + String simpleName = getClass().getSimpleName(); + if (getClass().getEnclosingClass() != null) { + simpleName = getClass().getEnclosingClass().getSimpleName() + "." + simpleName; + } + if (subTriggers == null || subTriggers.size() == 0) { + return simpleName; + } else { + return simpleName + "(" + Joiner.on(", ").join(subTriggers) + ")"; + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Trigger)) { + return false; + } + @SuppressWarnings("unchecked") + Trigger that = (Trigger) obj; + return Objects.equals(getClass(), that.getClass()) + && Objects.equals(subTriggers, that.subTriggers); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), subTriggers); + } + + /** + * Specify an ending condition for this trigger. If the {@code until} fires then the combination + * fires. + * + *

    The expression {@code t1.orFinally(t2)} fires every time {@code t1} fires, and finishes + * as soon as either {@code t1} finishes or {@code t2} fires, in which case it fires one last time + * for {@code t2}. Both {@code t1} and {@code t2} are executed in parallel. This means that + * {@code t1} may have fired since {@code t2} started, so not all of the elements that {@code t2} + * has seen are necessarily in the current pane. + * + *

    For example the final firing of the following trigger may only have 1 element: + *

     {@code
    +   * Repeatedly.forever(AfterPane.elementCountAtLeast(2))
    +   *     .orFinally(AfterPane.elementCountAtLeast(5))
    +   * } 
    + * + *

    Note that if {@code t1} is {@link OnceTrigger}, then {@code t1.orFinally(t2)} is the same + * as {@code AfterFirst.of(t1, t2)}. + */ + public Trigger orFinally(OnceTrigger until) { + return new OrFinallyTrigger(this, until); + } + + @Override + public Trigger buildTrigger() { + return this; + } + + /** + * {@link Trigger}s that are guaranteed to fire at most once should extend from this, rather + * than the general {@link Trigger} class to indicate that behavior. + * + * @param {@link BoundedWindow} subclass used to represent the windows used by this + * {@code AtMostOnceTrigger} + */ + public abstract static class OnceTrigger extends Trigger { + protected OnceTrigger(List> subTriggers) { + super(subTriggers); + } + + @Override + public final OnceTrigger getContinuationTrigger() { + Trigger continuation = super.getContinuationTrigger(); + if (!(continuation instanceof OnceTrigger)) { + throw new IllegalStateException("Continuation of a OnceTrigger must be a OnceTrigger"); + } + return (OnceTrigger) continuation; + } + + /** + * {@inheritDoc} + */ + @Override + public final void onFire(TriggerContext context) throws Exception { + onOnlyFiring(context); + context.trigger().setFinished(true); + } + + /** + * Called exactly once by {@link #onFire} when the trigger is fired. By default, + * invokes {@link #onFire} on all subtriggers for which {@link #shouldFire} is {@code true}. + */ + protected abstract void onOnlyFiring(TriggerContext context) throws Exception; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/TriggerBuilder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/TriggerBuilder.java new file mode 100644 index 000000000000..cc817ba1a3ab --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/TriggerBuilder.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; + +/** + * Anything that can be used to create an instance of a {@code Trigger} implements this interface. + * + *

    This includes {@code Trigger}s (which can return themselves) and any "enhanced" syntax for + * constructing a trigger. + * + * @param The type of windows the built trigger will operate on. + */ +public interface TriggerBuilder { + /** Return the {@code Trigger} built by this builder. */ + Trigger buildTrigger(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Window.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Window.java new file mode 100644 index 000000000000..6793e7648ff6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/Window.java @@ -0,0 +1,662 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.AssignWindowsDoFn; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; + +import javax.annotation.Nullable; + +/** + * {@code Window} logically divides up or groups the elements of a + * {@link PCollection} into finite windows according to a {@link WindowFn}. + * The output of {@code Window} contains the same elements as input, but they + * have been logically assigned to windows. The next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey GroupByKeys}, + * including one within composite transforms, will group by the combination of + * keys and windows. + + *

    See {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} + * for more information about how grouping with windows works. + * + *

    Windowing

    + * + *

    Windowing a {@code PCollection} divides the elements into windows based + * on the associated event time for each element. This is especially useful + * for {@code PCollection}s with unbounded size, since it allows operating on + * a sub-group of the elements placed into a related window. For {@code PCollection}s + * with a bounded size (aka. conventional batch mode), by default, all data is + * implicitly in a single window, unless {@code Window} is applied. + * + *

    For example, a simple form of windowing divides up the data into + * fixed-width time intervals, using {@link FixedWindows}. + * The following example demonstrates how to use {@code Window} in a pipeline + * that counts the number of occurrences of strings each minute: + * + *

     {@code
    + * PCollection items = ...;
    + * PCollection windowed_items = items.apply(
    + *   Window.into(FixedWindows.of(Duration.standardMinutes(1))));
    + * PCollection> windowed_counts = windowed_items.apply(
    + *   Count.perElement());
    + * } 
    + * + *

    Let (data, timestamp) denote a data element along with its timestamp. + * Then, if the input to this pipeline consists of + * {("foo", 15s), ("bar", 30s), ("foo", 45s), ("foo", 1m30s)}, + * the output will be + * {(KV("foo", 2), 1m), (KV("bar", 1), 1m), (KV("foo", 1), 2m)} + * + *

    Several predefined {@link WindowFn}s are provided: + *

      + *
    • {@link FixedWindows} partitions the timestamps into fixed-width intervals. + *
    • {@link SlidingWindows} places data into overlapping fixed-width intervals. + *
    • {@link Sessions} groups data into sessions where each item in a window + * is separated from the next by no more than a specified gap. + *
    + * + *

    Additionally, custom {@link WindowFn}s can be created, by creating new + * subclasses of {@link WindowFn}. + * + *

    Triggers

    + * + *

    {@link Window.Bound#triggering(TriggerBuilder)} allows specifying a trigger to control when + * (in processing time) results for the given window can be produced. If unspecified, the default + * behavior is to trigger first when the watermark passes the end of the window, and then trigger + * again every time there is late arriving data. + * + *

    Elements are added to the current window pane as they arrive. When the root trigger fires, + * output is produced based on the elements in the current pane. + * + *

    Depending on the trigger, this can be used both to output partial results + * early during the processing of the whole window, and to deal with late + * arriving in batches. + * + *

    Continuing the earlier example, if we wanted to emit the values that were available + * when the watermark passed the end of the window, and then output any late arriving + * elements once-per (actual hour) hour until we have finished processing the next 24-hours of data. + * (The use of watermark time to stop processing tends to be more robust if the data source is slow + * for a few days, etc.) + * + *

     {@code
    + * PCollection items = ...;
    + * PCollection windowed_items = items.apply(
    + *   Window.into(FixedWindows.of(Duration.standardMinutes(1)))
    + *      .triggering(
    + *          AfterWatermark.pastEndOfWindow()
    + *              .withLateFirings(AfterProcessingTime
    + *                  .pastFirstElementInPane().plusDelayOf(Duration.standardHours(1))))
    + *      .withAllowedLateness(Duration.standardDays(1)));
    + * PCollection> windowed_counts = windowed_items.apply(
    + *   Count.perElement());
    + * } 
    + * + *

    On the other hand, if we wanted to get early results every minute of processing + * time (for which there were new elements in the given window) we could do the following: + * + *

     {@code
    + * PCollection windowed_items = items.apply(
    + *   Window.into(FixedWindows.of(Duration.standardMinutes(1))
    + *      .triggering(
    + *      .triggering(
    + *          AfterWatermark.pastEndOfWindow()
    + *              .withEarlyFirings(AfterProcessingTime
    + *                  .pastFirstElementInPane().plusDelayOf(Duration.standardMinutes(1))))
    + *      .withAllowedLateness(Duration.ZERO));
    + * } 
    + * + *

    After a {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} the trigger is set to + * a trigger that will preserve the intent of the upstream trigger. See + * {@link Trigger#getContinuationTrigger} for more information. + * + *

    See {@link Trigger} for details on the available triggers. + */ +public class Window { + + /** + * Specifies the conditions under which a final pane will be created when a window is permanently + * closed. + */ + public enum ClosingBehavior { + /** + * Always fire the last pane. Even if there is no new data since the previous firing, an element + * with {@link PaneInfo#isLast()} {@code true} will be produced. + */ + FIRE_ALWAYS, + /** + * Only fire the last pane if there is new data since the previous firing. + * + *

    This is the default behavior. + */ + FIRE_IF_NON_EMPTY; + } + + /** + * Creates a {@code Window} {@code PTransform} with the given name. + * + *

    See the discussion of Naming in + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} for more explanation. + * + *

    The resulting {@code PTransform} is incomplete, and its input/output + * type is not yet bound. Use {@link Window.Unbound#into} to specify the + * {@link WindowFn} to use, which will also bind the input/output type of this + * {@code PTransform}. + */ + public static Unbound named(String name) { + return new Unbound().named(name); + } + + /** + * Creates a {@code Window} {@code PTransform} that uses the given + * {@link WindowFn} to window the data. + * + *

    The resulting {@code PTransform}'s types have been bound, with both the + * input and output being a {@code PCollection}, inferred from the types of + * the argument {@code WindowFn}. It is ready to be applied, or further + * properties can be set on it first. + */ + public static Bound into(WindowFn fn) { + return new Unbound().into(fn); + } + + /** + * Sets a non-default trigger for this {@code Window} {@code PTransform}. + * Elements that are assigned to a specific window will be output when + * the trigger fires. + * + *

    Must also specify allowed lateness using {@link #withAllowedLateness} and accumulation + * mode using either {@link #discardingFiredPanes()} or {@link #accumulatingFiredPanes()}. + */ + @Experimental(Kind.TRIGGER) + public static Bound triggering(TriggerBuilder trigger) { + return new Unbound().triggering(trigger); + } + + /** + * Returns a new {@code Window} {@code PTransform} that uses the registered WindowFn and + * Triggering behavior, and that discards elements in a pane after they are triggered. + * + *

    Does not modify this transform. The resulting {@code PTransform} is sufficiently + * specified to be applied, but more properties can still be specified. + */ + @Experimental(Kind.TRIGGER) + public static Bound discardingFiredPanes() { + return new Unbound().discardingFiredPanes(); + } + + /** + * Returns a new {@code Window} {@code PTransform} that uses the registered WindowFn and + * Triggering behavior, and that accumulates elements in a pane after they are triggered. + * + *

    Does not modify this transform. The resulting {@code PTransform} is sufficiently + * specified to be applied, but more properties can still be specified. + */ + @Experimental(Kind.TRIGGER) + public static Bound accumulatingFiredPanes() { + return new Unbound().accumulatingFiredPanes(); + } + + /** + * Override the amount of lateness allowed for data elements in the pipeline. Like + * the other properties on this {@link Window} operation, this will be applied at + * the next {@link GroupByKey}. Any elements that are later than this as decided by + * the system-maintained watermark will be dropped. + * + *

    This value also determines how long state will be kept around for old windows. + * Once no elements will be added to a window (because this duration has passed) any state + * associated with the window will be cleaned up. + */ + @Experimental(Kind.TRIGGER) + public static Bound withAllowedLateness(Duration allowedLateness) { + return new Unbound().withAllowedLateness(allowedLateness); + } + + /** + * An incomplete {@code Window} transform, with unbound input/output type. + * + *

    Before being applied, {@link Window.Unbound#into} must be + * invoked to specify the {@link WindowFn} to invoke, which will also + * bind the input/output type of this {@code PTransform}. + */ + public static class Unbound { + String name; + + Unbound() {} + + Unbound(String name) { + this.name = name; + } + + /** + * Returns a new {@code Window} transform that's like this + * transform but with the specified name. Does not modify this + * transform. The resulting transform is still incomplete. + * + *

    See the discussion of Naming in + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} for more + * explanation. + */ + public Unbound named(String name) { + return new Unbound(name); + } + + /** + * Returns a new {@code Window} {@code PTransform} that's like this + * transform but that will use the given {@link WindowFn}, and that has + * its input and output types bound. Does not modify this transform. The + * resulting {@code PTransform} is sufficiently specified to be applied, + * but more properties can still be specified. + */ + public Bound into(WindowFn fn) { + return new Bound(name).into(fn); + } + + /** + * Sets a non-default trigger for this {@code Window} {@code PTransform}. + * Elements that are assigned to a specific window will be output when + * the trigger fires. + * + *

    {@link com.google.cloud.dataflow.sdk.transforms.windowing.Trigger} + * has more details on the available triggers. + * + *

    Must also specify allowed lateness using {@link #withAllowedLateness} and accumulation + * mode using either {@link #discardingFiredPanes()} or {@link #accumulatingFiredPanes()}. + */ + @Experimental(Kind.TRIGGER) + public Bound triggering(TriggerBuilder trigger) { + return new Bound(name).triggering(trigger); + } + + /** + * Returns a new {@code Window} {@code PTransform} that uses the registered WindowFn and + * Triggering behavior, and that discards elements in a pane after they are triggered. + * + *

    Does not modify this transform. The resulting {@code PTransform} is sufficiently + * specified to be applied, but more properties can still be specified. + */ + @Experimental(Kind.TRIGGER) + public Bound discardingFiredPanes() { + return new Bound(name).discardingFiredPanes(); + } + + /** + * Returns a new {@code Window} {@code PTransform} that uses the registered WindowFn and + * Triggering behavior, and that accumulates elements in a pane after they are triggered. + * + *

    Does not modify this transform. The resulting {@code PTransform} is sufficiently + * specified to be applied, but more properties can still be specified. + */ + @Experimental(Kind.TRIGGER) + public Bound accumulatingFiredPanes() { + return new Bound(name).accumulatingFiredPanes(); + } + + /** + * Override the amount of lateness allowed for data elements in the pipeline. Like + * the other properties on this {@link Window} operation, this will be applied at + * the next {@link GroupByKey}. Any elements that are later than this as decided by + * the system-maintained watermark will be dropped. + * + *

    This value also determines how long state will be kept around for old windows. + * Once no elements will be added to a window (because this duration has passed) any state + * associated with the window will be cleaned up. + * + *

    Depending on the trigger this may not produce a pane with {@link PaneInfo#isLast}. See + * {@link ClosingBehavior#FIRE_IF_NON_EMPTY} for more details. + */ + @Experimental(Kind.TRIGGER) + public Bound withAllowedLateness(Duration allowedLateness) { + return new Bound(name).withAllowedLateness(allowedLateness); + } + + /** + * Override the amount of lateness allowed for data elements in the pipeline. Like + * the other properties on this {@link Window} operation, this will be applied at + * the next {@link GroupByKey}. Any elements that are later than this as decided by + * the system-maintained watermark will be dropped. + * + *

    This value also determines how long state will be kept around for old windows. + * Once no elements will be added to a window (because this duration has passed) any state + * associated with the window will be cleaned up. + */ + @Experimental(Kind.TRIGGER) + public Bound withAllowedLateness(Duration allowedLateness, ClosingBehavior behavior) { + return new Bound(name).withAllowedLateness(allowedLateness, behavior); + } + } + + /** + * A {@code PTransform} that windows the elements of a {@code PCollection}, + * into finite windows according to a user-specified {@code WindowFn}. + * + * @param The type of elements this {@code Window} is applied to + */ + public static class Bound extends PTransform, PCollection> { + + @Nullable private final WindowFn windowFn; + @Nullable private final Trigger trigger; + @Nullable private final AccumulationMode mode; + @Nullable private final Duration allowedLateness; + @Nullable private final ClosingBehavior closingBehavior; + @Nullable private final OutputTimeFn outputTimeFn; + + private Bound(String name, + @Nullable WindowFn windowFn, @Nullable Trigger trigger, + @Nullable AccumulationMode mode, @Nullable Duration allowedLateness, + ClosingBehavior behavior, @Nullable OutputTimeFn outputTimeFn) { + super(name); + this.windowFn = windowFn; + this.trigger = trigger; + this.mode = mode; + this.allowedLateness = allowedLateness; + this.closingBehavior = behavior; + this.outputTimeFn = outputTimeFn; + } + + private Bound(String name) { + this(name, null, null, null, null, null, null); + } + + /** + * Returns a new {@code Window} {@code PTransform} that's like this + * transform but that will use the given {@link WindowFn}, and that has + * its input and output types bound. Does not modify this transform. The + * resulting {@code PTransform} is sufficiently specified to be applied, + * but more properties can still be specified. + */ + private Bound into(WindowFn windowFn) { + try { + windowFn.windowCoder().verifyDeterministic(); + } catch (NonDeterministicException e) { + throw new IllegalArgumentException("Window coders must be deterministic.", e); + } + + return new Bound<>( + name, windowFn, trigger, mode, allowedLateness, closingBehavior, outputTimeFn); + } + + /** + * Returns a new {@code Window} {@code PTransform} that's like this + * {@code PTransform} but with the specified name. Does not + * modify this {@code PTransform}. + * + *

    See the discussion of Naming in + * {@link com.google.cloud.dataflow.sdk.transforms.ParDo} for more + * explanation. + */ + public Bound named(String name) { + return new Bound<>( + name, windowFn, trigger, mode, allowedLateness, closingBehavior, outputTimeFn); + } + + /** + * Sets a non-default trigger for this {@code Window} {@code PTransform}. + * Elements that are assigned to a specific window will be output when + * the trigger fires. + * + *

    {@link com.google.cloud.dataflow.sdk.transforms.windowing.Trigger} + * has more details on the available triggers. + * + *

    Must also specify allowed lateness using {@link #withAllowedLateness} and accumulation + * mode using either {@link #discardingFiredPanes()} or {@link #accumulatingFiredPanes()}. + */ + @Experimental(Kind.TRIGGER) + public Bound triggering(TriggerBuilder trigger) { + return new Bound( + name, + windowFn, + trigger.buildTrigger(), + mode, + allowedLateness, + closingBehavior, + outputTimeFn); + } + + /** + * Returns a new {@code Window} {@code PTransform} that uses the registered WindowFn and + * Triggering behavior, and that discards elements in a pane after they are triggered. + * + *

    Does not modify this transform. The resulting {@code PTransform} is sufficiently + * specified to be applied, but more properties can still be specified. + */ + @Experimental(Kind.TRIGGER) + public Bound discardingFiredPanes() { + return new Bound( + name, + windowFn, + trigger, + AccumulationMode.DISCARDING_FIRED_PANES, + allowedLateness, + closingBehavior, + outputTimeFn); + } + + /** + * Returns a new {@code Window} {@code PTransform} that uses the registered WindowFn and + * Triggering behavior, and that accumulates elements in a pane after they are triggered. + * + *

    Does not modify this transform. The resulting {@code PTransform} is sufficiently + * specified to be applied, but more properties can still be specified. + */ + @Experimental(Kind.TRIGGER) + public Bound accumulatingFiredPanes() { + return new Bound( + name, + windowFn, + trigger, + AccumulationMode.ACCUMULATING_FIRED_PANES, + allowedLateness, + closingBehavior, + outputTimeFn); + } + + /** + * Override the amount of lateness allowed for data elements in the pipeline. Like + * the other properties on this {@link Window} operation, this will be applied at + * the next {@link GroupByKey}. Any elements that are later than this as decided by + * the system-maintained watermark will be dropped. + * + *

    This value also determines how long state will be kept around for old windows. + * Once no elements will be added to a window (because this duration has passed) any state + * associated with the window will be cleaned up. + * + *

    Depending on the trigger this may not produce a pane with {@link PaneInfo#isLast}. See + * {@link ClosingBehavior#FIRE_IF_NON_EMPTY} for more details. + */ + @Experimental(Kind.TRIGGER) + public Bound withAllowedLateness(Duration allowedLateness) { + return new Bound( + name, windowFn, trigger, mode, allowedLateness, closingBehavior, outputTimeFn); + } + + /** + * (Experimental) Override the default {@link OutputTimeFn}, to control + * the output timestamp of values output from a {@link GroupByKey} operation. + */ + @Experimental(Kind.OUTPUT_TIME) + public Bound withOutputTimeFn(OutputTimeFn outputTimeFn) { + return new Bound( + name, windowFn, trigger, mode, allowedLateness, closingBehavior, outputTimeFn); + } + + /** + * Override the amount of lateness allowed for data elements in the pipeline. Like + * the other properties on this {@link Window} operation, this will be applied at + * the next {@link GroupByKey}. Any elements that are later than this as decided by + * the system-maintained watermark will be dropped. + * + *

    This value also determines how long state will be kept around for old windows. + * Once no elements will be added to a window (because this duration has passed) any state + * associated with the window will be cleaned up. + */ + @Experimental(Kind.TRIGGER) + public Bound withAllowedLateness(Duration allowedLateness, ClosingBehavior behavior) { + return new Bound(name, windowFn, trigger, mode, allowedLateness, behavior, outputTimeFn); + } + + /** + * Get the output strategy of this {@link Window.Bound Window PTransform}. For internal use + * only. + */ + // Rawtype cast of OutputTimeFn cannot be eliminated with intermediate variable, as it is + // casting between wildcards + public WindowingStrategy getOutputStrategyInternal( + WindowingStrategy inputStrategy) { + WindowingStrategy result = inputStrategy; + if (windowFn != null) { + result = result.withWindowFn(windowFn); + } + if (trigger != null) { + result = result.withTrigger(trigger); + } + if (mode != null) { + result = result.withMode(mode); + } + if (allowedLateness != null) { + result = result.withAllowedLateness(allowedLateness); + } + if (closingBehavior != null) { + result = result.withClosingBehavior(closingBehavior); + } + if (outputTimeFn != null) { + result = result.withOutputTimeFn(outputTimeFn); + } + return result; + } + + /** + * Get the {@link WindowFn} of this {@link Window.Bound Window PTransform}. + */ + public WindowFn getWindowFn() { + return windowFn; + } + + @Override + public void validate(PCollection input) { + WindowingStrategy outputStrategy = + getOutputStrategyInternal(input.getWindowingStrategy()); + + // Make sure that the windowing strategy is complete & valid. + if (outputStrategy.isTriggerSpecified() + && !(outputStrategy.getTrigger().getSpec() instanceof DefaultTrigger)) { + if (!(outputStrategy.getWindowFn() instanceof GlobalWindows) + && !outputStrategy.isAllowedLatenessSpecified()) { + throw new IllegalArgumentException("Except when using GlobalWindows," + + " calling .triggering() to specify a trigger requires that the allowed lateness be" + + " specified using .withAllowedLateness() to set the upper bound on how late data" + + " can arrive before being dropped. See Javadoc for more details."); + } + + if (!outputStrategy.isModeSpecified()) { + throw new IllegalArgumentException( + "Calling .triggering() to specify a trigger requires that the accumulation mode be" + + " specified using .discardingFiredPanes() or .accumulatingFiredPanes()." + + " See Javadoc for more details."); + } + } + } + + @Override + public PCollection apply(PCollection input) { + WindowingStrategy outputStrategy = + getOutputStrategyInternal(input.getWindowingStrategy()); + PCollection output; + if (windowFn != null) { + // If the windowFn changed, we create a primitive, and run the AssignWindows operation here. + output = assignWindows(input, windowFn); + } else { + // If the windowFn didn't change, we just run a pass-through transform and then set the + // new windowing strategy. + output = input.apply(Window.identity()); + } + return output.setWindowingStrategyInternal(outputStrategy); + } + + private PCollection assignWindows( + PCollection input, WindowFn windowFn) { + return input.apply("AssignWindows", ParDo.of(new AssignWindowsDoFn(windowFn))); + } + + @Override + protected Coder getDefaultOutputCoder(PCollection input) { + return input.getCoder(); + } + + @Override + protected String getKindString() { + return "Window.Into()"; + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private static PTransform, PCollection> identity() { + return ParDo.named("Identity").of(new DoFn() { + @Override public void processElement(ProcessContext c) { + c.output(c.element()); + } + }); + } + + /** + * Creates a {@code Window} {@code PTransform} that does not change assigned + * windows, but will cause windows to be merged again as part of the next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}. + */ + public static Remerge remerge() { + return new Remerge(); + } + + /** + * {@code PTransform} that does not change assigned windows, but will cause + * windows to be merged again as part of the next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}. + */ + public static class Remerge extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + WindowingStrategy outputWindowingStrategy = getOutputWindowing( + input.getWindowingStrategy()); + + return input.apply(Window.identity()) + .setWindowingStrategyInternal(outputWindowingStrategy); + } + + private WindowingStrategy getOutputWindowing( + WindowingStrategy inputStrategy) { + if (inputStrategy.getWindowFn() instanceof InvalidWindows) { + @SuppressWarnings("unchecked") + InvalidWindows invalidWindows = (InvalidWindows) inputStrategy.getWindowFn(); + return inputStrategy.withWindowFn(invalidWindows.getOriginalWindowFn()); + } else { + return inputStrategy; + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowFn.java new file mode 100644 index 000000000000..d51fc7ead46b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowFn.java @@ -0,0 +1,221 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.common.collect.Ordering; + +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.Collection; + +/** + * The argument to the {@link Window} transform used to assign elements into + * windows and to determine how windows are merged. See {@link Window} for more + * information on how {@code WindowFn}s are used and for a library of + * predefined {@code WindowFn}s. + * + *

    Users will generally want to use the predefined + * {@code WindowFn}s, but it is also possible to create new + * subclasses. + * + *

    To create a custom {@code WindowFn}, inherit from this class and override all required + * methods. If no merging is required, inherit from {@link NonMergingWindowFn} + * instead. If no merging is required and each element is assigned to a single window, inherit from + * {@code PartitioningWindowFn}. Inheriting from the most specific subclass will enable more + * optimizations in the runner. + * + * @param type of elements being windowed + * @param {@link BoundedWindow} subclass used to represent the + * windows used by this {@code WindowFn} + */ +public abstract class WindowFn + implements Serializable { + /** + * Information available when running {@link #assignWindows}. + */ + public abstract class AssignContext { + /** + * Returns the current element. + */ + public abstract T element(); + + /** + * Returns the timestamp of the current element. + */ + public abstract Instant timestamp(); + + /** + * Returns the windows the current element was in, prior to this + * {@code WindowFn} being called. + */ + public abstract Collection windows(); + } + + /** + * Given a timestamp and element, returns the set of windows into which it + * should be placed. + */ + public abstract Collection assignWindows(AssignContext c) throws Exception; + + /** + * Information available when running {@link #mergeWindows}. + */ + public abstract class MergeContext { + /** + * Returns the current set of windows. + */ + public abstract Collection windows(); + + /** + * Signals to the framework that the windows in {@code toBeMerged} should + * be merged together to form {@code mergeResult}. + * + *

    {@code toBeMerged} should be a subset of {@link #windows} + * and disjoint from the {@code toBeMerged} set of previous calls + * to {@code merge}. + * + *

    {@code mergeResult} must either not be in {@link #windows} or be in + * {@code toBeMerged}. + * + * @throws IllegalArgumentException if any elements of toBeMerged are not + * in windows(), or have already been merged + */ + public abstract void merge(Collection toBeMerged, W mergeResult) + throws Exception; + } + + /** + * Does whatever merging of windows is necessary. + * + *

    See {@link MergeOverlappingIntervalWindows#mergeWindows} for an + * example of how to override this method. + */ + public abstract void mergeWindows(MergeContext c) throws Exception; + + /** + * Returns whether this performs the same merging as the given + * {@code WindowFn}. + */ + public abstract boolean isCompatible(WindowFn other); + + /** + * Returns the {@link Coder} used for serializing the windows used + * by this windowFn. + */ + public abstract Coder windowCoder(); + + /** + * Returns the window of the side input corresponding to the given window of + * the main input. + * + *

    Authors of custom {@code WindowFn}s should override this. + */ + public abstract W getSideInputWindow(final BoundedWindow window); + + /** + * @deprecated Implement {@link #getOutputTimeFn} to return one of the appropriate + * {@link OutputTimeFns}, or a custom {@link OutputTimeFn} extending + * {@link OutputTimeFn.Defaults}. + */ + @Deprecated + @Experimental(Kind.OUTPUT_TIME) + public Instant getOutputTime(Instant inputTimestamp, W window) { + return getOutputTimeFn().assignOutputTime(inputTimestamp, window); + } + + /** + * Provides a default implementation for {@link WindowingStrategy#getOutputTimeFn()}. + * See the full specification there. + * + *

    If this {@link WindowFn} doesn't produce overlapping windows, this need not (and probably + * should not) override any of the default implementations in {@link OutputTimeFn.Defaults}. + * + *

    If this {@link WindowFn} does produce overlapping windows that can be predicted here, it is + * suggested that the result in later overlapping windows is past the end of earlier windows so + * that the later windows don't prevent the watermark from progressing past the end of the earlier + * window. + * + *

    For example, a timestamp in a sliding window should be moved past the beginning of the next + * sliding window. See {@link SlidingWindows#getOutputTimeFn}. + */ + @Experimental(Kind.OUTPUT_TIME) + public OutputTimeFn getOutputTimeFn() { + return new OutputAtEarliestAssignedTimestamp<>(this); + } + + /** + * Returns true if this {@code WindowFn} never needs to merge any windows. + */ + public boolean isNonMerging() { + return false; + } + + /** + * Returns true if this {@code WindowFn} assigns each element to a single window. + */ + public boolean assignsToSingleWindow() { + return false; + } + + /** + * A compatibility adapter that will return the assigned timestamps according to the + * {@link WindowFn}, which was the prior policy. Specifying the assigned output timestamps + * on the {@link WindowFn} is now deprecated. + */ + private static class OutputAtEarliestAssignedTimestamp + extends OutputTimeFn.Defaults { + + private final WindowFn windowFn; + + public OutputAtEarliestAssignedTimestamp(WindowFn windowFn) { + this.windowFn = windowFn; + } + + /** + * {@inheritDoc} + * + * @return the result of {@link WindowFn#getOutputTime windowFn.getOutputTime()}. + */ + @Override + @SuppressWarnings("deprecation") // this is an adapter for the deprecated behavior + public Instant assignOutputTime(Instant timestamp, W window) { + return windowFn.getOutputTime(timestamp, window); + } + + @Override + public Instant combine(Instant outputTime, Instant otherOutputTime) { + return Ordering.natural().min(outputTime, otherOutputTime); + } + + /** + * {@inheritDoc} + * + * @return {@code true}. When the {@link OutputTimeFn} is not overridden by {@link WindowFn} + * or {@link WindowingStrategy}, the minimum output timestamp is taken, which depends + * only on the minimum input timestamp by monotonicity of {@link #assignOutputTime}. + */ + @Override + public boolean dependsOnlyOnEarliestInputTimestamp() { + return true; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/package-info.java new file mode 100644 index 000000000000..65ccf710bdf8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/package-info.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines the {@link com.google.cloud.dataflow.sdk.transforms.windowing.Window} transform + * for dividing the elements in a PCollection into windows, and the + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.Trigger} for controlling when those + * elements are output. + * + *

    {@code Window} logically divides up or groups the elements of a + * {@link com.google.cloud.dataflow.sdk.values.PCollection} into finite windows according to a + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn}. + * The output of {@code Window} contains the same elements as input, but they + * have been logically assigned to windows. The next + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}s, including one + * within composite transforms, will group by the combination of keys and + * windows. + * + *

    Windowing a {@code PCollection} allows chunks of it to be processed + * individually, before the entire {@code PCollection} is available. This is + * especially important for {@code PCollection}s with unbounded size, since the full + * {@code PCollection} is never available at once. + * + *

    For {@code PCollection}s with a bounded size, by default, all data is implicitly in a + * single window, and this replicates conventional batch mode. However, windowing can still be a + * convenient way to express time-sliced algorithms over bounded {@code PCollection}s. + * + *

    As elements are assigned to a window, they are are placed into a pane. When the trigger fires + * all of the elements in the current pane are output. + * + *

    The {@link com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger} will output a + * window when the system watermark passes the end of the window. See + * {@link com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark} for details on the + * watermark. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ActiveWindowSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ActiveWindowSet.java new file mode 100644 index 000000000000..69350cb3eb6e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ActiveWindowSet.java @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; + +import java.util.Collection; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Track which active windows have their state associated with merged-away windows. + * + * When windows are merged we must track which state previously associated with the merged windows + * must now be associated with the result window. Some of that state may be combined eagerly when + * the windows are merged. The rest is combined lazily when the final state is actually + * required when emitting a pane. We keep track of this using an {@link ActiveWindowSet}. + * + *

    An {@link ActiveWindowSet} considers a window to be in one of the following states: + * + *

      + *
    1. NEW: The initial state for a window on an incoming element; we do not yet know + * if it should be merged into an ACTIVE window, or whether it is already present as an + * ACTIVE window, since we have not yet called + * {@link WindowFn#mergeWindows}.
    2. + *
    3. ACTIVE: A window that has state associated with it and has not itself been merged + * away. The window may have one or more state address windows under which its + * non-empty state is stored. A state value for an ACTIVE window must be derived by reading + * the state in all of its state address windows.
    4. + *
    5. EPHEMERAL: A NEW window that has been merged into an ACTIVE window before any state + * has been associated with that window. Thus the window is neither ACTIVE nor MERGED. These + * windows are not persistently represented since if they reappear the merge function should + * again redirect them to an ACTIVE window. EPHEMERAL windows are an optimization for + * the common case of in-order events and {@link Sessions session window} by never associating + * state with windows that are created and immediately merged away.
    6. + *
    7. MERGED: An ACTIVE window has been merged into another ACTIVE window after it had + * state associated with it. The window will thus appear as a state address window for exactly + * one ACTIVE window.
    8. + *
    9. EXPIRED: The window has expired and may have been garbage collected. No new elements + * (even late elements) will ever be assigned to that window. These windows are not explicitly + * represented anywhere; it is expected that the user of {@link ActiveWindowSet} will store + * no state associated with the window.
    10. + *
    + * + *

    + * + *

    If no windows will ever be merged we can use the trivial implementation {@link + * NonMergingActiveWindowSet}. Otherwise, the actual implementation of this data structure is in + * {@link MergingActiveWindowSet}. + * + * @param the type of window being managed + */ +public interface ActiveWindowSet { + /** + * Callback for {@link #merge}. + */ + public interface MergeCallback { + /** + * Called when windows are about to be merged, but before any {@link #onMerge} callback + * has been made. + */ + void prefetchOnMerge(Collection toBeMerged, Collection activeToBeMerged, W mergeResult) + throws Exception; + + /** + * Called when windows are about to be merged, after all {@link #prefetchOnMerge} calls + * have been made, but before the active window set has been updated to reflect the merge. + * + * @param toBeMerged the windows about to be merged. + * @param activeToBeMerged the subset of {@code toBeMerged} corresponding to windows which + * are currently ACTIVE (and about to be merged). The remaining windows have been deemed + * EPHEMERAL, and thus have no state associated with them. + * @param mergeResult the result window, either a member of {@code toBeMerged} or new. + */ + void onMerge(Collection toBeMerged, Collection activeToBeMerged, W mergeResult) + throws Exception; + } + + /** + * Remove EPHEMERAL windows since we only need to know about them while processing new elements. + */ + void removeEphemeralWindows(); + + /** + * Save any state changes needed. + */ + void persist(); + + /** + * Return the ACTIVE window into which {@code window} has been merged. + * Return {@code window} itself if it is ACTIVE. Return null if {@code window} has not + * yet been seen. + */ + @Nullable + W representative(W window); + + /** + * Return (a view of) the set of currently ACTIVE windows. + */ + Set getActiveWindows(); + + /** + * Return {@code true} if {@code window} is ACTIVE. + */ + boolean isActive(W window); + + /** + * If {@code window} is not already known to be ACTIVE, MERGED or EPHEMERAL then add it + * as NEW. All NEW windows will be accounted for as ACTIVE, MERGED or EPHEMERAL by a call + * to {@link #merge}. + */ + void addNew(W window); + + /** + * If {@code window} is not already known to be ACTIVE, MERGED or EPHEMERAL then add it + * as ACTIVE. + */ + void addActive(W window); + + /** + * Remove {@code window} from the set. + */ + void remove(W window); + + /** + * Invoke {@link WindowFn#mergeWindows} on the {@code WindowFn} associated with this window set, + * merging as many of the active windows as possible. {@code mergeCallback} will be invoked for + * each group of windows that are merged. After this no NEW windows will remain, all merge + * result windows will be ACTIVE, and all windows which have been merged away will not be ACTIVE. + */ + void merge(MergeCallback mergeCallback) throws Exception; + + /** + * Signal that all state in {@link #readStateAddresses} for {@code window} has been merged into + * the {@link #writeStateAddress} for {@code window}. + */ + void merged(W window); + + /** + * Return the state address windows for ACTIVE {@code window} from which all state associated + * should be read and merged. + */ + Set readStateAddresses(W window); + + /** + * Return the state address window of ACTIVE {@code window} into which all new state should be + * written. Always one of the results of {@link #readStateAddresses}. + */ + W writeStateAddress(W window); + + /** + * Return the state address window into which all new state should be written after + * ACTIVE windows {@code toBeMerged} have been merged into {@code mergeResult}. + */ + W mergedWriteStateAddress(Collection toBeMerged, W mergeResult); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ApiSurface.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ApiSurface.java new file mode 100644 index 000000000000..7a9c87733b2b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ApiSurface.java @@ -0,0 +1,642 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.common.base.Joiner; +import com.google.common.base.Supplier; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; +import com.google.common.collect.Sets; +import com.google.common.reflect.ClassPath; +import com.google.common.reflect.ClassPath.ClassInfo; +import com.google.common.reflect.Invokable; +import com.google.common.reflect.Parameter; +import com.google.common.reflect.TypeToken; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.lang.reflect.WildcardType; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * Represents the API surface of a package prefix. Used for accessing public classes, + * methods, and the types they reference, to control what dependencies are re-exported. + * + *

    For the purposes of calculating the public API surface, exposure includes any public + * or protected occurrence of: + * + *

      + *
    • superclasses + *
    • interfaces implemented + *
    • actual type arguments to generic types + *
    • array component types + *
    • method return types + *
    • method parameter types + *
    • type variable bounds + *
    • wildcard bounds + *
    + * + *

    Exposure is a transitive property. The resulting map excludes primitives + * and array classes themselves. + * + *

    It is prudent (though not required) to prune prefixes like "java" via the builder + * method {@link #pruningPrefix} to halt the traversal so it does not uselessly catalog references + * that are not interesting. + */ +@SuppressWarnings("rawtypes") +public class ApiSurface { + private static Logger logger = LoggerFactory.getLogger(ApiSurface.class); + + /** + * Returns an empty {@link ApiSurface}. + */ + public static ApiSurface empty() { + logger.debug("Returning an empty ApiSurface"); + return new ApiSurface(Collections.>emptySet(), Collections.emptySet()); + } + + /** + * Returns an {@link ApiSurface} object representing the given package and all subpackages. + */ + public static ApiSurface ofPackage(String packageName) throws IOException { + return ApiSurface.empty().includingPackage(packageName); + } + + /** + * Returns an {@link ApiSurface} object representing just the surface of the given class. + */ + public static ApiSurface ofClass(Class clazz) { + return ApiSurface.empty().includingClass(clazz); + } + + /** + * Returns an {@link ApiSurface} like this one, but also including the named + * package and all of its subpackages. + */ + public ApiSurface includingPackage(String packageName) throws IOException { + ClassPath classPath = ClassPath.from(ClassLoader.getSystemClassLoader()); + + Set> newRootClasses = Sets.newHashSet(); + for (ClassInfo classInfo : classPath.getTopLevelClassesRecursive(packageName)) { + Class clazz = classInfo.load(); + if (exposed(clazz.getModifiers())) { + newRootClasses.add(clazz); + } + } + logger.debug("Including package {} and subpackages: {}", packageName, newRootClasses); + newRootClasses.addAll(rootClasses); + + return new ApiSurface(newRootClasses, patternsToPrune); + } + + /** + * Returns an {@link ApiSurface} like this one, but also including the given class. + */ + public ApiSurface includingClass(Class clazz) { + Set> newRootClasses = Sets.newHashSet(); + logger.debug("Including class {}", clazz); + newRootClasses.add(clazz); + newRootClasses.addAll(rootClasses); + return new ApiSurface(newRootClasses, patternsToPrune); + } + + /** + * Returns an {@link ApiSurface} like this one, but pruning transitive + * references from classes whose full name (including package) begins with the provided prefix. + */ + public ApiSurface pruningPrefix(String prefix) { + return pruningPattern(Pattern.compile(Pattern.quote(prefix) + ".*")); + } + + /** + * Returns an {@link ApiSurface} like this one, but pruning references from the named + * class. + */ + public ApiSurface pruningClassName(String className) { + return pruningPattern(Pattern.compile(Pattern.quote(className))); + } + + /** + * Returns an {@link ApiSurface} like this one, but pruning references from the + * provided class. + */ + public ApiSurface pruningClass(Class clazz) { + return pruningClassName(clazz.getName()); + } + + /** + * Returns an {@link ApiSurface} like this one, but pruning transitive + * references from classes whose full name (including package) begins with the provided prefix. + */ + public ApiSurface pruningPattern(Pattern pattern) { + Set newPatterns = Sets.newHashSet(); + newPatterns.addAll(patternsToPrune); + newPatterns.add(pattern); + return new ApiSurface(rootClasses, newPatterns); + } + + /** + * See {@link #pruningPattern(Pattern)}. + */ + public ApiSurface pruningPattern(String patternString) { + return pruningPattern(Pattern.compile(patternString)); + } + + /** + * Returns all public classes originally belonging to the package + * in the {@link ApiSurface}. + */ + public Set> getRootClasses() { + return rootClasses; + } + + /** + * Returns exposed types in this set, including arrays and primitives as + * specified. + */ + public Set> getExposedClasses() { + return getExposedToExposers().keySet(); + } + + /** + * Returns a path from an exposed class to a root class. There may be many, but this + * gives only one. + * + *

    If there are only cycles, with no path back to a root class, throws + * IllegalStateException. + */ + public List> getAnyExposurePath(Class exposedClass) { + Set> excluded = Sets.newHashSet(); + excluded.add(exposedClass); + List> path = getAnyExposurePath(exposedClass, excluded); + if (path == null) { + throw new IllegalArgumentException( + "Class " + exposedClass + " has no path back to any root class." + + " It should never have been considered exposed."); + } else { + return path; + } + } + + /** + * Returns a path from an exposed class to a root class. There may be many, but this + * gives only one. It will not return a path that crosses the excluded classes. + * + *

    If there are only cycles or paths through the excluded classes, returns null. + * + *

    If the class is not actually in the exposure map, throws IllegalArgumentException + */ + private List> getAnyExposurePath(Class exposedClass, Set> excluded) { + List> exposurePath = Lists.newArrayList(); + exposurePath.add(exposedClass); + + Collection> exposers = getExposedToExposers().get(exposedClass); + if (exposers.isEmpty()) { + throw new IllegalArgumentException("Class " + exposedClass + " is not exposed."); + } + + for (Class exposer : exposers) { + if (excluded.contains(exposer)) { + continue; + } + + // A null exposer means this is already a root class. + if (exposer == null) { + return exposurePath; + } + + List> restOfPath = getAnyExposurePath( + exposer, + Sets.union(excluded, Sets.newHashSet(exposer))); + + if (restOfPath != null) { + exposurePath.addAll(restOfPath); + return exposurePath; + } + } + return null; + } + + //////////////////////////////////////////////////////////////////// + + // Fields initialized upon construction + private final Set> rootClasses; + private final Set patternsToPrune; + + // Fields computed on-demand + private Multimap, Class> exposedToExposers = null; + private Pattern prunedPattern = null; + private Set visited = null; + + private ApiSurface(Set> rootClasses, Set patternsToPrune) { + this.rootClasses = rootClasses; + this.patternsToPrune = patternsToPrune; + } + + /** + * A map from exposed types to place where they are exposed, in the sense of being a part + * of a public-facing API surface. + * + *

    This map is the adjencency list representation of a directed graph, where an edge from type + * {@code T1} to type {@code T2} indicates that {@code T2} directly exposes {@code T1} in its API + * surface. + * + *

    The traversal methods in this class are designed to avoid repeatedly processing types, since + * there will almost always be cyclic references. + */ + private Multimap, Class> getExposedToExposers() { + if (exposedToExposers == null) { + constructExposedToExposers(); + } + return exposedToExposers; + } + + /** + * See {@link #getExposedToExposers}. + */ + private void constructExposedToExposers() { + visited = Sets.newHashSet(); + exposedToExposers = Multimaps.newSetMultimap( + Maps., Collection>>newHashMap(), + new Supplier>>() { + @Override + public Set> get() { + return Sets.newHashSet(); + } + }); + + for (Class clazz : rootClasses) { + addExposedTypes(clazz, null); + } + } + + /** + * A combined {@code Pattern} that implements all the pruning specified. + */ + private Pattern getPrunedPattern() { + if (prunedPattern == null) { + constructPrunedPattern(); + } + return prunedPattern; + } + + /** + * See {@link #getPrunedPattern}. + */ + private void constructPrunedPattern() { + Set prunedPatternStrings = Sets.newHashSet(); + for (Pattern patternToPrune : patternsToPrune) { + prunedPatternStrings.add(patternToPrune.pattern()); + } + prunedPattern = Pattern.compile("(" + Joiner.on(")|(").join(prunedPatternStrings) + ")"); + } + + /** + * Whether a type and all that it references should be pruned from the graph. + */ + private boolean pruned(Type type) { + return pruned(TypeToken.of(type).getRawType()); + } + + /** + * Whether a class and all that it references should be pruned from the graph. + */ + private boolean pruned(Class clazz) { + return clazz.isPrimitive() + || clazz.isArray() + || getPrunedPattern().matcher(clazz.getName()).matches(); + } + + /** + * Whether a type has already beens sufficiently processed. + */ + private boolean done(Type type) { + return visited.contains(type); + } + + private void recordExposure(Class exposed, Class cause) { + exposedToExposers.put(exposed, cause); + } + + private void recordExposure(Type exposed, Class cause) { + exposedToExposers.put(TypeToken.of(exposed).getRawType(), cause); + } + + private void visit(Type type) { + visited.add(type); + } + + /** + * See {@link #addExposedTypes(Type, Class)}. + */ + private void addExposedTypes(TypeToken type, Class cause) { + logger.debug( + "Adding exposed types from {}, which is the type in type token {}", type.getType(), type); + addExposedTypes(type.getType(), cause); + } + + /** + * Adds any references learned by following a link from {@code cause} to {@code type}. + * This will dispatch according to the concrete {@code Type} implementation. See the + * other overloads of {@code addExposedTypes} for their details. + */ + private void addExposedTypes(Type type, Class cause) { + if (type instanceof TypeVariable) { + logger.debug("Adding exposed types from {}, which is a type variable", type); + addExposedTypes((TypeVariable) type, cause); + } else if (type instanceof WildcardType) { + logger.debug("Adding exposed types from {}, which is a wildcard type", type); + addExposedTypes((WildcardType) type, cause); + } else if (type instanceof GenericArrayType) { + logger.debug("Adding exposed types from {}, which is a generic array type", type); + addExposedTypes((GenericArrayType) type, cause); + } else if (type instanceof ParameterizedType) { + logger.debug("Adding exposed types from {}, which is a parameterized type", type); + addExposedTypes((ParameterizedType) type, cause); + } else if (type instanceof Class) { + logger.debug("Adding exposed types from {}, which is a class", type); + addExposedTypes((Class) type, cause); + } else { + throw new IllegalArgumentException("Unknown implementation of Type"); + } + } + + /** + * Adds any types exposed to this set. These will + * come from the (possibly absent) bounds on the + * type variable. + */ + private void addExposedTypes(TypeVariable type, Class cause) { + if (done(type)) { + return; + } + visit(type); + for (Type bound : type.getBounds()) { + logger.debug("Adding exposed types from {}, which is a type bound on {}", bound, type); + addExposedTypes(bound, cause); + } + } + + /** + * Adds any types exposed to this set. These will come from the (possibly absent) bounds on the + * wildcard. + */ + private void addExposedTypes(WildcardType type, Class cause) { + visit(type); + for (Type lowerBound : type.getLowerBounds()) { + logger.debug( + "Adding exposed types from {}, which is a type lower bound on wildcard type {}", + lowerBound, + type); + addExposedTypes(lowerBound, cause); + } + for (Type upperBound : type.getUpperBounds()) { + logger.debug( + "Adding exposed types from {}, which is a type upper bound on wildcard type {}", + upperBound, + type); + addExposedTypes(upperBound, cause); + } + } + + /** + * Adds any types exposed from the given array type. The array type itself is not added. The + * cause of the exposure of the underlying type is considered whatever type exposed the array + * type. + */ + private void addExposedTypes(GenericArrayType type, Class cause) { + if (done(type)) { + return; + } + visit(type); + logger.debug( + "Adding exposed types from {}, which is the component type on generic array type {}", + type.getGenericComponentType(), + type); + addExposedTypes(type.getGenericComponentType(), cause); + } + + /** + * Adds any types exposed to this set. Even if the + * root type is to be pruned, the actual type arguments + * are processed. + */ + private void addExposedTypes(ParameterizedType type, Class cause) { + // Even if the type is already done, this link to it may be new + boolean alreadyDone = done(type); + if (!pruned(type)) { + visit(type); + recordExposure(type, cause); + } + if (alreadyDone) { + return; + } + + // For a parameterized type, pruning does not take place + // here, only for the raw class. + // The type parameters themselves may not be pruned, + // for example with List probably the + // standard List is pruned, but MyApiType is not. + logger.debug( + "Adding exposed types from {}, which is the raw type on parameterized type {}", + type.getRawType(), + type); + addExposedTypes(type.getRawType(), cause); + for (Type typeArg : type.getActualTypeArguments()) { + logger.debug( + "Adding exposed types from {}, which is a type argument on parameterized type {}", + typeArg, + type); + addExposedTypes(typeArg, cause); + } + } + + /** + * Adds a class and all of the types it exposes. The cause + * of the class being exposed is given, and the cause + * of everything within the class is that class itself. + */ + private void addExposedTypes(Class clazz, Class cause) { + if (pruned(clazz)) { + return; + } + // Even if `clazz` has been visited, the link from `cause` may be new + boolean alreadyDone = done(clazz); + visit(clazz); + recordExposure(clazz, cause); + if (alreadyDone || pruned(clazz)) { + return; + } + + TypeToken token = TypeToken.of(clazz); + for (TypeToken superType : token.getTypes()) { + if (!superType.equals(token)) { + logger.debug( + "Adding exposed types from {}, which is a super type token on {}", superType, clazz); + addExposedTypes(superType, clazz); + } + } + for (Class innerClass : clazz.getDeclaredClasses()) { + if (exposed(innerClass.getModifiers())) { + logger.debug( + "Adding exposed types from {}, which is an exposed inner class of {}", + innerClass, + clazz); + addExposedTypes(innerClass, clazz); + } + } + for (Field field : clazz.getDeclaredFields()) { + if (exposed(field.getModifiers())) { + logger.debug("Adding exposed types from {}, which is an exposed field on {}", field, clazz); + addExposedTypes(field, clazz); + } + } + for (Invokable invokable : getExposedInvokables(token)) { + logger.debug( + "Adding exposed types from {}, which is an exposed invokable on {}", invokable, clazz); + addExposedTypes(invokable, clazz); + } + } + + private void addExposedTypes(Invokable invokable, Class cause) { + addExposedTypes(invokable.getReturnType(), cause); + for (Annotation annotation : invokable.getAnnotations()) { + logger.debug( + "Adding exposed types from {}, which is an annotation on invokable {}", + annotation, + invokable); + addExposedTypes(annotation.annotationType(), cause); + } + for (Parameter parameter : invokable.getParameters()) { + logger.debug( + "Adding exposed types from {}, which is a parameter on invokable {}", + parameter, + invokable); + addExposedTypes(parameter, cause); + } + for (TypeToken exceptionType : invokable.getExceptionTypes()) { + logger.debug( + "Adding exposed types from {}, which is an exception type on invokable {}", + exceptionType, + invokable); + addExposedTypes(exceptionType, cause); + } + } + + private void addExposedTypes(Parameter parameter, Class cause) { + logger.debug( + "Adding exposed types from {}, which is the type of parameter {}", + parameter.getType(), + parameter); + addExposedTypes(parameter.getType(), cause); + for (Annotation annotation : parameter.getAnnotations()) { + logger.debug( + "Adding exposed types from {}, which is an annotation on parameter {}", + annotation, + parameter); + addExposedTypes(annotation.annotationType(), cause); + } + } + + private void addExposedTypes(Field field, Class cause) { + addExposedTypes(field.getGenericType(), cause); + for (Annotation annotation : field.getDeclaredAnnotations()) { + logger.debug( + "Adding exposed types from {}, which is an annotation on field {}", annotation, field); + addExposedTypes(annotation.annotationType(), cause); + } + } + + /** + * Returns an {@link Invokable} for each public methods or constructors of a type. + */ + private Set getExposedInvokables(TypeToken type) { + Set invokables = Sets.newHashSet(); + + for (Constructor constructor : type.getRawType().getConstructors()) { + if (0 != (constructor.getModifiers() & (Modifier.PUBLIC | Modifier.PROTECTED))) { + invokables.add(type.constructor(constructor)); + } + } + + for (Method method : type.getRawType().getMethods()) { + if (0 != (method.getModifiers() & (Modifier.PUBLIC | Modifier.PROTECTED))) { + invokables.add(type.method(method)); + } + } + + return invokables; + } + + /** + * Returns true of the given modifier bitmap indicates exposure (public or protected access). + */ + private boolean exposed(int modifiers) { + return 0 != (modifiers & (Modifier.PUBLIC | Modifier.PROTECTED)); + } + + + //////////////////////////////////////////////////////////////////////////// + + public static ApiSurface getSdkApiSurface() throws IOException { + return ApiSurface.ofPackage("com.google.cloud.dataflow") + .pruningPattern("com[.]google[.]cloud[.]dataflow.*Test") + .pruningPattern("com[.]google[.]cloud[.]dataflow.*Benchmark") + .pruningPrefix("com.google.cloud.dataflow.integration") + .pruningPrefix("java") + .pruningPrefix("com.google.api") + .pruningPrefix("com.google.auth") + .pruningPrefix("com.google.bigtable.v1") + .pruningPrefix("com.google.cloud.bigtable.config") + .pruningPrefix("com.google.cloud.bigtable.grpc.Bigtable*Name") + .pruningPrefix("com.google.protobuf") + .pruningPrefix("org.joda.time") + .pruningPrefix("org.apache.avro") + .pruningPrefix("org.junit") + .pruningPrefix("com.fasterxml.jackson.annotation"); + } + + public static void main(String[] args) throws Exception { + List names = Lists.newArrayList(); + for (Class clazz : getSdkApiSurface().getExposedClasses()) { + names.add(clazz.getName()); + } + List sortedNames = Lists.newArrayList(names); + Collections.sort(sortedNames); + + for (String name : sortedNames) { + System.out.println(name); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppEngineEnvironment.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppEngineEnvironment.java new file mode 100644 index 000000000000..c7fe4b4ff245 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppEngineEnvironment.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.lang.reflect.InvocationTargetException; + +/** Stores whether we are running within AppEngine or not. */ +public class AppEngineEnvironment { + /** + * True if running inside of AppEngine, false otherwise. + */ + @Deprecated + public static final boolean IS_APP_ENGINE = isAppEngine(); + + /** + * Attempts to detect whether we are inside of AppEngine. + * + *

    Purposely copied and left private from private code.google.common.util.concurrent.MoreExecutors#isAppEngine. + * + * @return true if we are inside of AppEngine, false otherwise. + */ + static boolean isAppEngine() { + if (System.getProperty("com.google.appengine.runtime.environment") == null) { + return false; + } + try { + // If the current environment is null, we're not inside AppEngine. + return Class.forName("com.google.apphosting.api.ApiProxy") + .getMethod("getCurrentEnvironment") + .invoke(null) != null; + } catch (ClassNotFoundException e) { + // If ApiProxy doesn't exist, we're not on AppEngine at all. + return false; + } catch (InvocationTargetException e) { + // If ApiProxy throws an exception, we're not in a proper AppEngine environment. + return false; + } catch (IllegalAccessException e) { + // If the method isn't accessible, we're not on a supported version of AppEngine; + return false; + } catch (NoSuchMethodException e) { + // If the method doesn't exist, we're not on a supported version of AppEngine; + return false; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppliedCombineFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppliedCombineFn.java new file mode 100644 index 000000000000..512d72def90b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AppliedCombineFn.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.annotations.VisibleForTesting; + +import java.io.Serializable; + +/** + * A {@link KeyedCombineFnWithContext} with a fixed accumulator coder. This is created from a + * specific application of the {@link KeyedCombineFnWithContext}. + * + *

    Because the {@code AccumT} may reference {@code InputT}, the specific {@code Coder} + * may depend on the {@code Coder}. + * + * @param type of keys + * @param type of input values + * @param type of mutable accumulator values + * @param type of output values + */ +public class AppliedCombineFn implements Serializable { + + private final PerKeyCombineFn fn; + private final Coder accumulatorCoder; + + private final Iterable> sideInputViews; + private final KvCoder kvCoder; + private final WindowingStrategy windowingStrategy; + + private AppliedCombineFn(PerKeyCombineFn fn, + Coder accumulatorCoder, Iterable> sideInputViews, + KvCoder kvCoder, WindowingStrategy windowingStrategy) { + this.fn = fn; + this.accumulatorCoder = accumulatorCoder; + this.sideInputViews = sideInputViews; + this.kvCoder = kvCoder; + this.windowingStrategy = windowingStrategy; + } + + public static AppliedCombineFn + withAccumulatorCoder( + PerKeyCombineFn fn, + Coder accumCoder) { + return withAccumulatorCoder(fn, accumCoder, null, null, null); + } + + public static AppliedCombineFn + withAccumulatorCoder( + PerKeyCombineFn fn, + Coder accumCoder, Iterable> sideInputViews, + KvCoder kvCoder, WindowingStrategy windowingStrategy) { + // Casting down the K and InputT is safe because they're only used as inputs. + @SuppressWarnings("unchecked") + PerKeyCombineFn clonedFn = + (PerKeyCombineFn) SerializableUtils.clone(fn); + return create(clonedFn, accumCoder, sideInputViews, kvCoder, windowingStrategy); + } + + @VisibleForTesting + public static AppliedCombineFn + withInputCoder(PerKeyCombineFn fn, + CoderRegistry registry, KvCoder kvCoder) { + return withInputCoder(fn, registry, kvCoder, null, null); + } + + public static AppliedCombineFn + withInputCoder(PerKeyCombineFn fn, + CoderRegistry registry, KvCoder kvCoder, + Iterable> sideInputViews, WindowingStrategy windowingStrategy) { + // Casting down the K and InputT is safe because they're only used as inputs. + @SuppressWarnings("unchecked") + PerKeyCombineFn clonedFn = + (PerKeyCombineFn) SerializableUtils.clone(fn); + try { + Coder accumulatorCoder = clonedFn.getAccumulatorCoder( + registry, kvCoder.getKeyCoder(), kvCoder.getValueCoder()); + return create(clonedFn, accumulatorCoder, sideInputViews, kvCoder, windowingStrategy); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + } + + private static AppliedCombineFn create( + PerKeyCombineFn fn, + Coder accumulatorCoder, Iterable> sideInputViews, + KvCoder kvCoder, WindowingStrategy windowingStrategy) { + return new AppliedCombineFn<>( + fn, accumulatorCoder, sideInputViews, kvCoder, windowingStrategy); + } + + public PerKeyCombineFn getFn() { + return fn; + } + + public Iterable> getSideInputViews() { + return sideInputViews; + } + + public Coder getAccumulatorCoder() { + return accumulatorCoder; + } + + public KvCoder getKvCoder() { + return kvCoder; + } + + public WindowingStrategy getWindowingStrategy() { + return windowingStrategy; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AssignWindowsDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AssignWindowsDoFn.java new file mode 100644 index 000000000000..ca59c5395732 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AssignWindowsDoFn.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * {@link DoFn} that tags elements of a PCollection with windows, according + * to the provided {@link WindowFn}. + * @param Type of elements being windowed + * @param Window type + */ +@SystemDoFnInternal +public class AssignWindowsDoFn extends DoFn { + private WindowFn fn; + + public AssignWindowsDoFn(WindowFn fn) { + this.fn = fn; + } + + @Override + @SuppressWarnings("unchecked") + public void processElement(final ProcessContext c) throws Exception { + Collection windows = + ((WindowFn) fn).assignWindows( + ((WindowFn) fn).new AssignContext() { + @Override + public T element() { + return c.element(); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public Collection windows() { + return c.windowingInternals().windows(); + } + }); + + c.windowingInternals() + .outputWindowedValue(c.element(), c.timestamp(), windows, PaneInfo.NO_FIRING); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptAndTimeBoundedExponentialBackOff.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptAndTimeBoundedExponentialBackOff.java new file mode 100644 index 000000000000..e94d414fc072 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptAndTimeBoundedExponentialBackOff.java @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOff; +import com.google.api.client.util.NanoClock; +import com.google.common.base.Preconditions; + +import java.util.concurrent.TimeUnit; + +/** + * Extension of {@link AttemptBoundedExponentialBackOff} that bounds the total time that the backoff + * is happening as well as the amount of retries. Acts exactly as a AttemptBoundedExponentialBackOff + * unless the time interval has expired since the object was created. At this point, it will always + * return BackOff.STOP. Calling reset() resets both the timer and the number of retry attempts, + * unless a custom ResetPolicy (ResetPolicy.ATTEMPTS or ResetPolicy.TIMER) is passed to the + * constructor. + * + *

    Implementation is not thread-safe. + */ +public class AttemptAndTimeBoundedExponentialBackOff extends AttemptBoundedExponentialBackOff { + private long endTimeMillis; + private long maximumTotalWaitTimeMillis; + private ResetPolicy resetPolicy; + private final NanoClock nanoClock; + // NanoClock.SYSTEM has a max elapsed time of 292 years or 2^63 ns. Here, we choose 2^53 ns as + // a smaller but still huge limit. + private static final long MAX_ELAPSED_TIME_MILLIS = 1L << 53; + + /** + * A ResetPolicy controls the behavior of this BackOff when reset() is called. By default, both + * the number of attempts and the time bound for the BackOff are reset, but an alternative + * ResetPolicy may be set to only reset one of these two. + */ + public static enum ResetPolicy { + ALL, + ATTEMPTS, + TIMER + } + + /** + * Constructs an instance of AttemptAndTimeBoundedExponentialBackoff. + * + * @param maximumNumberOfAttempts The maximum number of attempts it will make. + * @param initialIntervalMillis The original interval to wait between attempts in milliseconds. + * @param maximumTotalWaitTimeMillis The maximum total time that this object will + * allow more attempts in milliseconds. + */ + public AttemptAndTimeBoundedExponentialBackOff( + int maximumNumberOfAttempts, long initialIntervalMillis, long maximumTotalWaitTimeMillis) { + this( + maximumNumberOfAttempts, + initialIntervalMillis, + maximumTotalWaitTimeMillis, + ResetPolicy.ALL, + NanoClock.SYSTEM); + } + + /** + * Constructs an instance of AttemptAndTimeBoundedExponentialBackoff. + * + * @param maximumNumberOfAttempts The maximum number of attempts it will make. + * @param initialIntervalMillis The original interval to wait between attempts in milliseconds. + * @param maximumTotalWaitTimeMillis The maximum total time that this object will + * allow more attempts in milliseconds. + * @param resetPolicy The ResetPolicy specifying the properties of this BackOff that are subject + * to being reset. + */ + public AttemptAndTimeBoundedExponentialBackOff( + int maximumNumberOfAttempts, + long initialIntervalMillis, + long maximumTotalWaitTimeMillis, + ResetPolicy resetPolicy) { + this( + maximumNumberOfAttempts, + initialIntervalMillis, + maximumTotalWaitTimeMillis, + resetPolicy, + NanoClock.SYSTEM); + } + + /** + * Constructs an instance of AttemptAndTimeBoundedExponentialBackoff. + * + * @param maximumNumberOfAttempts The maximum number of attempts it will make. + * @param initialIntervalMillis The original interval to wait between attempts in milliseconds. + * @param maximumTotalWaitTimeMillis The maximum total time that this object will + * allow more attempts in milliseconds. + * @param resetPolicy The ResetPolicy specifying the properties of this BackOff that are subject + * to being reset. + * @param nanoClock clock used to measure the time that has passed. + */ + public AttemptAndTimeBoundedExponentialBackOff( + int maximumNumberOfAttempts, + long initialIntervalMillis, + long maximumTotalWaitTimeMillis, + ResetPolicy resetPolicy, + NanoClock nanoClock) { + super(maximumNumberOfAttempts, initialIntervalMillis); + Preconditions.checkArgument( + maximumTotalWaitTimeMillis > 0, "Maximum total wait time must be greater than zero."); + Preconditions.checkArgument( + maximumTotalWaitTimeMillis < MAX_ELAPSED_TIME_MILLIS, + "Maximum total wait time must be less than " + MAX_ELAPSED_TIME_MILLIS + " milliseconds"); + Preconditions.checkArgument(resetPolicy != null, "resetPolicy may not be null"); + Preconditions.checkArgument(nanoClock != null, "nanoClock may not be null"); + this.maximumTotalWaitTimeMillis = maximumTotalWaitTimeMillis; + this.resetPolicy = resetPolicy; + this.nanoClock = nanoClock; + // Set the end time for this BackOff. Note that we cannot simply call reset() here since the + // resetPolicy may not be set to reset the time bound. + endTimeMillis = getTimeMillis() + maximumTotalWaitTimeMillis; + } + + @Override + public void reset() { + // reset() is called in the constructor of the parent class before resetPolicy and nanoClock are + // set. In this case, we call the parent class's reset() method and return. + if (resetPolicy == null) { + super.reset(); + return; + } + // Reset the number of attempts. + if (resetPolicy == ResetPolicy.ALL || resetPolicy == ResetPolicy.ATTEMPTS) { + super.reset(); + } + // Reset the time bound. + if (resetPolicy == ResetPolicy.ALL || resetPolicy == ResetPolicy.TIMER) { + endTimeMillis = getTimeMillis() + maximumTotalWaitTimeMillis; + } + } + + public void setEndtimeMillis(long endTimeMillis) { + this.endTimeMillis = endTimeMillis; + } + + @Override + public long nextBackOffMillis() { + if (atMaxAttempts()) { + return BackOff.STOP; + } + long backoff = Math.min(super.nextBackOffMillis(), endTimeMillis - getTimeMillis()); + return (backoff > 0 ? backoff : BackOff.STOP); + } + + private long getTimeMillis() { + return TimeUnit.NANOSECONDS.toMillis(nanoClock.nanoTime()); + } + + @Override + public boolean atMaxAttempts() { + return super.atMaxAttempts() || getTimeMillis() >= endTimeMillis; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOff.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOff.java new file mode 100644 index 000000000000..613316ea0e81 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOff.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOff; +import com.google.common.base.Preconditions; + +/** + * Implementation of {@link BackOff} that increases the back off period for each retry attempt + * using a randomization function that grows exponentially. + * + *

    Example: The initial interval is .5 seconds and the maximum number of retries is 10. + * For 10 tries the sequence will be (values in seconds): + * + *

    + * retry#      retry_interval     randomized_interval
    + * 1             0.5                [0.25,   0.75]
    + * 2             0.75               [0.375,  1.125]
    + * 3             1.125              [0.562,  1.687]
    + * 4             1.687              [0.8435, 2.53]
    + * 5             2.53               [1.265,  3.795]
    + * 6             3.795              [1.897,  5.692]
    + * 7             5.692              [2.846,  8.538]
    + * 8             8.538              [4.269, 12.807]
    + * 9            12.807              [6.403, 19.210]
    + * 10           {@link BackOff#STOP}
    + * 
    + * + *

    Implementation is not thread-safe. + */ +public class AttemptBoundedExponentialBackOff implements BackOff { + public static final double DEFAULT_MULTIPLIER = 1.5; + public static final double DEFAULT_RANDOMIZATION_FACTOR = 0.5; + private final int maximumNumberOfAttempts; + private final long initialIntervalMillis; + private int currentAttempt; + + public AttemptBoundedExponentialBackOff(int maximumNumberOfAttempts, long initialIntervalMillis) { + Preconditions.checkArgument(maximumNumberOfAttempts > 0, + "Maximum number of attempts must be greater than zero."); + Preconditions.checkArgument(initialIntervalMillis > 0, + "Initial interval must be greater than zero."); + this.maximumNumberOfAttempts = maximumNumberOfAttempts; + this.initialIntervalMillis = initialIntervalMillis; + reset(); + } + + @Override + public void reset() { + currentAttempt = 1; + } + + @Override + public long nextBackOffMillis() { + if (currentAttempt >= maximumNumberOfAttempts) { + return BackOff.STOP; + } + double currentIntervalMillis = initialIntervalMillis + * Math.pow(DEFAULT_MULTIPLIER, currentAttempt - 1); + double randomOffset = (Math.random() * 2 - 1) + * DEFAULT_RANDOMIZATION_FACTOR * currentIntervalMillis; + currentAttempt += 1; + return Math.round(currentIntervalMillis + randomOffset); + } + + public boolean atMaxAttempts() { + return currentAttempt >= maximumNumberOfAttempts; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AvroUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AvroUtils.java new file mode 100644 index 000000000000..c3a486102e9d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/AvroUtils.java @@ -0,0 +1,345 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import org.apache.avro.Schema; +import org.apache.avro.Schema.Field; +import org.apache.avro.Schema.Type; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.DecoderFactory; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A set of utilities for working with Avro files. + * + *

    These utilities are based on the Avro 1.7.7 specification. + */ +public class AvroUtils { + + /** + * Avro file metadata. + */ + public static class AvroMetadata { + private byte[] syncMarker; + private String codec; + private String schemaString; + + AvroMetadata(byte[] syncMarker, String codec, String schemaString) { + this.syncMarker = syncMarker; + this.codec = codec; + this.schemaString = schemaString; + } + + /** + * The JSON-encoded schema + * string for the file. + */ + public String getSchemaString() { + return schemaString; + } + + /** + * The codec of the + * file. + */ + public String getCodec() { + return codec; + } + + /** + * The 16-byte sync marker for the file. See the documentation for + * Object + * Container File for more information. + */ + public byte[] getSyncMarker() { + return syncMarker; + } + } + + /** + * Reads the {@link AvroMetadata} from the header of an Avro file. + * + *

    This method parses the header of an Avro + * + * Object Container File. + * + * @throws IOException if the file is an invalid format. + */ + public static AvroMetadata readMetadataFromFile(String fileName) throws IOException { + String codec = null; + String schemaString = null; + byte[] syncMarker; + try (InputStream stream = + Channels.newInputStream(IOChannelUtils.getFactory(fileName).open(fileName))) { + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(stream, null); + + // The header of an object container file begins with a four-byte magic number, followed + // by the file metadata (including the schema and codec), encoded as a map. Finally, the + // header ends with the file's 16-byte sync marker. + // See https://avro.apache.org/docs/1.7.7/spec.html#Object+Container+Files for details on + // the encoding of container files. + + // Read the magic number. + byte[] magic = new byte[DataFileConstants.MAGIC.length]; + decoder.readFixed(magic); + if (!Arrays.equals(magic, DataFileConstants.MAGIC)) { + throw new IOException("Missing Avro file signature: " + fileName); + } + + // Read the metadata to find the codec and schema. + ByteBuffer valueBuffer = ByteBuffer.allocate(512); + long numRecords = decoder.readMapStart(); + while (numRecords > 0) { + for (long recordIndex = 0; recordIndex < numRecords; recordIndex++) { + String key = decoder.readString(); + // readBytes() clears the buffer and returns a buffer where: + // - position is the start of the bytes read + // - limit is the end of the bytes read + valueBuffer = decoder.readBytes(valueBuffer); + byte[] bytes = new byte[valueBuffer.remaining()]; + valueBuffer.get(bytes); + if (key.equals(DataFileConstants.CODEC)) { + codec = new String(bytes, "UTF-8"); + } else if (key.equals(DataFileConstants.SCHEMA)) { + schemaString = new String(bytes, "UTF-8"); + } + } + numRecords = decoder.mapNext(); + } + if (codec == null) { + codec = DataFileConstants.NULL_CODEC; + } + + // Finally, read the sync marker. + syncMarker = new byte[DataFileConstants.SYNC_SIZE]; + decoder.readFixed(syncMarker); + } + return new AvroMetadata(syncMarker, codec, schemaString); + } + + /** + * Formats BigQuery seconds-since-epoch into String matching JSON export. Thread-safe and + * immutable. + */ + private static final DateTimeFormatter DATE_AND_SECONDS_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss").withZoneUTC(); + // Package private for BigQueryTableRowIterator to use. + static String formatTimestamp(String timestamp) { + // timestamp is in "seconds since epoch" format, with scientific notation. + // e.g., "1.45206229112345E9" to mean "2016-01-06 06:38:11.123456 UTC". + // Separate into seconds and microseconds. + double timestampDoubleMicros = Double.parseDouble(timestamp) * 1000000; + long timestampMicros = (long) timestampDoubleMicros; + long seconds = timestampMicros / 1000000; + int micros = (int) (timestampMicros % 1000000); + String dayAndTime = DATE_AND_SECONDS_FORMATTER.print(seconds * 1000); + + // No sub-second component. + if (micros == 0) { + return String.format("%s UTC", dayAndTime); + } + + // Sub-second component. + int digits = 6; + int subsecond = micros; + while (subsecond % 10 == 0) { + digits--; + subsecond /= 10; + } + String formatString = String.format("%%0%dd", digits); + String fractionalSeconds = String.format(formatString, subsecond); + return String.format("%s.%s UTC", dayAndTime, fractionalSeconds); + } + + /** + * Utility function to convert from an Avro {@link GenericRecord} to a BigQuery {@link TableRow}. + * + * See + * "Avro format" for more information. + */ + public static TableRow convertGenericRecordToTableRow(GenericRecord record, TableSchema schema) { + return convertGenericRecordToTableRow(record, schema.getFields()); + } + + private static TableRow convertGenericRecordToTableRow( + GenericRecord record, List fields) { + TableRow row = new TableRow(); + for (TableFieldSchema subSchema : fields) { + // Per https://cloud.google.com/bigquery/docs/reference/v2/tables#schema, the name field + // is required, so it may not be null. + Field field = record.getSchema().getField(subSchema.getName()); + Object convertedValue = + getTypedCellValue(field.schema(), subSchema, record.get(field.name())); + if (convertedValue != null) { + // To match the JSON files exported by BigQuery, do not include null values in the output. + row.set(field.name(), convertedValue); + } + } + return row; + } + + @Nullable + private static Object getTypedCellValue(Schema schema, TableFieldSchema fieldSchema, Object v) { + // Per https://cloud.google.com/bigquery/docs/reference/v2/tables#schema, the mode field + // is optional (and so it may be null), but defaults to "NULLABLE". + String mode = firstNonNull(fieldSchema.getMode(), "NULLABLE"); + switch (mode) { + case "REQUIRED": + return convertRequiredField(schema.getType(), fieldSchema, v); + case "REPEATED": + return convertRepeatedField(schema, fieldSchema, v); + case "NULLABLE": + return convertNullableField(schema, fieldSchema, v); + default: + throw new UnsupportedOperationException( + "Parsing a field with BigQuery field schema mode " + fieldSchema.getMode()); + } + } + + private static List convertRepeatedField( + Schema schema, TableFieldSchema fieldSchema, Object v) { + Type arrayType = schema.getType(); + verify( + arrayType == Type.ARRAY, + "BigQuery REPEATED field %s should be Avro ARRAY, not %s", + fieldSchema.getName(), + arrayType); + // REPEATED fields are represented as Avro arrays. + if (v == null) { + // Handle the case of an empty repeated field. + return ImmutableList.of(); + } + @SuppressWarnings("unchecked") + List elements = (List) v; + ImmutableList.Builder values = ImmutableList.builder(); + Type elementType = schema.getElementType().getType(); + for (Object element : elements) { + values.add(convertRequiredField(elementType, fieldSchema, element)); + } + return values.build(); + } + + private static Object convertRequiredField( + Type avroType, TableFieldSchema fieldSchema, Object v) { + // REQUIRED fields are represented as the corresponding Avro types. For example, a BigQuery + // INTEGER type maps to an Avro LONG type. + checkNotNull(v, "REQUIRED field %s should not be null", fieldSchema.getName()); + ImmutableMap fieldMap = + ImmutableMap.builder() + .put("STRING", Type.STRING) + .put("INTEGER", Type.LONG) + .put("FLOAT", Type.DOUBLE) + .put("BOOLEAN", Type.BOOLEAN) + .put("TIMESTAMP", Type.LONG) + .put("RECORD", Type.RECORD) + .build(); + // Per https://cloud.google.com/bigquery/docs/reference/v2/tables#schema, the type field + // is required, so it may not be null. + String bqType = fieldSchema.getType(); + Type expectedAvroType = fieldMap.get(bqType); + verify( + avroType == expectedAvroType, + "Expected Avro schema type %s, not %s, for BigQuery %s field %s", + expectedAvroType, + avroType, + bqType, + fieldSchema.getName()); + switch (fieldSchema.getType()) { + case "STRING": + // Avro will use a CharSequence to represent String objects, but it may not always use + // java.lang.String; for example, it may prefer org.apache.avro.util.Utf8. + verify(v instanceof CharSequence, "Expected CharSequence (String), got %s", v.getClass()); + return v.toString(); + case "INTEGER": + verify(v instanceof Long, "Expected Long, got %s", v.getClass()); + return ((Long) v).toString(); + case "FLOAT": + verify(v instanceof Double, "Expected Double, got %s", v.getClass()); + return v; + case "BOOLEAN": + verify(v instanceof Boolean, "Expected Boolean, got %s", v.getClass()); + return v; + case "TIMESTAMP": + // TIMESTAMP data types are represented as Avro LONG types. They are converted back to + // Strings with variable-precision (up to six digits) to match the JSON files export + // by BigQuery. + verify(v instanceof Long, "Expected Long, got %s", v.getClass()); + Double doubleValue = ((Long) v) / 1000000.0; + return formatTimestamp(doubleValue.toString()); + case "RECORD": + verify(v instanceof GenericRecord, "Expected GenericRecord, got %s", v.getClass()); + return convertGenericRecordToTableRow((GenericRecord) v, fieldSchema.getFields()); + default: + throw new UnsupportedOperationException( + String.format( + "Unexpected BigQuery field schema type %s for field named %s", + fieldSchema.getType(), + fieldSchema.getName())); + } + } + + @Nullable + private static Object convertNullableField( + Schema avroSchema, TableFieldSchema fieldSchema, Object v) { + // NULLABLE fields are represented as an Avro Union of the corresponding type and "null". + verify( + avroSchema.getType() == Type.UNION, + "Expected Avro schema type UNION, not %s, for BigQuery NULLABLE field %s", + avroSchema.getType(), + fieldSchema.getName()); + List unionTypes = avroSchema.getTypes(); + verify( + unionTypes.size() == 2, + "BigQuery NULLABLE field %s should be an Avro UNION of NULL and another type, not %s", + fieldSchema.getName(), + unionTypes); + + if (v == null) { + return null; + } + + Type firstType = unionTypes.get(0).getType(); + if (!firstType.equals(Type.NULL)) { + return convertRequiredField(firstType, fieldSchema, v); + } + return convertRequiredField(unionTypes.get(1).getType(), fieldSchema, v); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BaseExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BaseExecutionContext.java new file mode 100644 index 000000000000..6a0ccf3531bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BaseExecutionContext.java @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Base class for implementations of {@link ExecutionContext}. + * + *

    A concrete subclass should implement {@link #createStepContext} to create the appropriate + * {@link StepContext} implementation. Any {@code StepContext} created will + * be cached for the lifetime of this {@link ExecutionContext}. + * + *

    BaseExecutionContext is generic to allow implementing subclasses to return a concrete subclass + * of {@link StepContext} from {@link #getOrCreateStepContext(String, String, StateSampler)} and + * {@link #getAllStepContexts()} without forcing each subclass to override the method, e.g. + *

    + * @Override
    + * StreamingModeExecutionContext.StepContext getOrCreateStepContext(...) {
    + *   return (StreamingModeExecutionContext.StepContext) super.getOrCreateStepContext(...);
    + * }
    + * 
    + * + *

    When a subclass of {@code BaseExecutionContext} has been downcast, the return types of + * {@link #createStepContext(String, String, StateSampler)}, + * {@link #getOrCreateStepContext(String, String, StateSampler}, and {@link #getAllStepContexts()} + * will be appropriately specialized. + */ +public abstract class BaseExecutionContext + implements ExecutionContext { + + private Map cachedStepContexts = new HashMap<>(); + + /** + * Implementations should override this to create the specific type + * of {@link StepContext} they need. + */ + protected abstract T createStepContext( + String stepName, String transformName, StateSampler stateSampler); + + + /** + * Returns the {@link StepContext} associated with the given step. + */ + @Override + public T getOrCreateStepContext( + String stepName, String transformName, StateSampler stateSampler) { + T context = cachedStepContexts.get(stepName); + if (context == null) { + context = createStepContext(stepName, transformName, stateSampler); + cachedStepContexts.put(stepName, context); + } + return context; + } + + /** + * Returns a collection view of all of the {@link StepContext}s. + */ + @Override + public Collection getAllStepContexts() { + return Collections.unmodifiableCollection(cachedStepContexts.values()); + } + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#output} + * is called. + */ + @Override + public void noteOutput(WindowedValue output) {} + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#sideOutput} + * is called. + */ + @Override + public void noteSideOutput(TupleTag tag, WindowedValue output) {} + + /** + * Base class for implementations of {@link ExecutionContext.StepContext}. + * + *

    To complete a concrete subclass, implement {@link #timerInternals} and + * {@link #stateInternals}. + */ + public abstract static class StepContext implements ExecutionContext.StepContext { + private final ExecutionContext executionContext; + private final String stepName; + private final String transformName; + + public StepContext(ExecutionContext executionContext, String stepName, String transformName) { + this.executionContext = executionContext; + this.stepName = stepName; + this.transformName = transformName; + } + + @Override + public String getStepName() { + return stepName; + } + + @Override + public String getTransformName() { + return transformName; + } + + @Override + public void noteOutput(WindowedValue output) { + executionContext.noteOutput(output); + } + + @Override + public void noteSideOutput(TupleTag tag, WindowedValue output) { + executionContext.noteSideOutput(tag, output); + } + + @Override + public void writePCollectionViewData( + TupleTag tag, + Iterable> data, Coder>> dataCoder, + W window, Coder windowCoder) throws IOException { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public abstract StateInternals stateInternals(); + + @Override + public abstract TimerInternals timerInternals(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BatchTimerInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BatchTimerInternals.java new file mode 100644 index 000000000000..b6a1493239bd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BatchTimerInternals.java @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import org.joda.time.Instant; + +import java.util.HashSet; +import java.util.PriorityQueue; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * TimerInternals that uses priority queues to manage the timers that are ready to fire. + */ +public class BatchTimerInternals implements TimerInternals { + /** Set of timers that are scheduled used for deduplicating timers. */ + private Set existingTimers = new HashSet<>(); + + // Keep these queues separate so we can advance over them separately. + private PriorityQueue watermarkTimers = new PriorityQueue<>(11); + private PriorityQueue processingTimers = new PriorityQueue<>(11); + + private Instant inputWatermarkTime; + private Instant processingTime; + + private PriorityQueue queue(TimeDomain domain) { + return TimeDomain.EVENT_TIME.equals(domain) ? watermarkTimers : processingTimers; + } + + public BatchTimerInternals(Instant processingTime) { + this.processingTime = processingTime; + this.inputWatermarkTime = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + @Override + public void setTimer(TimerData timer) { + if (existingTimers.add(timer)) { + queue(timer.getDomain()).add(timer); + } + } + + @Override + public void deleteTimer(TimerData timer) { + existingTimers.remove(timer); + queue(timer.getDomain()).remove(timer); + } + + @Override + public Instant currentProcessingTime() { + return processingTime; + } + + /** + * {@inheritDoc} + * + * @return {@link BoundedWindow#TIMESTAMP_MAX_VALUE}: in batch mode, upstream processing + * is already complete. + */ + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + public Instant currentInputWatermarkTime() { + return inputWatermarkTime; + } + + @Override + @Nullable + public Instant currentOutputWatermarkTime() { + // The output watermark is always undefined in batch mode. + return null; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("watermarkTimers", watermarkTimers) + .add("processingTimers", processingTimers) + .toString(); + } + + public void advanceInputWatermark(ReduceFnRunner runner, Instant newInputWatermark) + throws Exception { + Preconditions.checkState(!newInputWatermark.isBefore(inputWatermarkTime), + "Cannot move input watermark time backwards from %s to %s", inputWatermarkTime, + newInputWatermark); + inputWatermarkTime = newInputWatermark; + advance(runner, newInputWatermark, TimeDomain.EVENT_TIME); + } + + public void advanceProcessingTime(ReduceFnRunner runner, Instant newProcessingTime) + throws Exception { + Preconditions.checkState(!newProcessingTime.isBefore(processingTime), + "Cannot move processing time backwards from %s to %s", processingTime, newProcessingTime); + processingTime = newProcessingTime; + advance(runner, newProcessingTime, TimeDomain.PROCESSING_TIME); + } + + private void advance(ReduceFnRunner runner, Instant newTime, TimeDomain domain) + throws Exception { + PriorityQueue timers = queue(domain); + boolean shouldFire = false; + + do { + TimerData timer = timers.peek(); + // Timers fire if the new time is ahead of the timer + shouldFire = timer != null && newTime.isAfter(timer.getTimestamp()); + if (shouldFire) { + // Remove before firing, so that if the trigger adds another identical + // timer we don't remove it. + timers.remove(); + runner.onTimer(timer); + } + } while (shouldFire); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserter.java new file mode 100644 index 000000000000..cd5106275646 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserter.java @@ -0,0 +1,434 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.client.util.Sleeper; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableDataInsertAllRequest; +import com.google.api.services.bigquery.model.TableDataInsertAllResponse; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.hadoop.util.ApiErrorExtractor; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.MoreExecutors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +/** + * Inserts rows into BigQuery. + */ +public class BigQueryTableInserter { + private static final Logger LOG = LoggerFactory.getLogger(BigQueryTableInserter.class); + + // Approximate amount of table data to upload per InsertAll request. + private static final long UPLOAD_BATCH_SIZE_BYTES = 64 * 1024; + + // The maximum number of rows to upload per InsertAll request. + private static final long MAX_ROWS_PER_BATCH = 500; + + // The maximum number of times to retry inserting rows into BigQuery. + private static final int MAX_INSERT_ATTEMPTS = 5; + + // The initial backoff after a failure inserting rows into BigQuery. + private static final long INITIAL_INSERT_BACKOFF_INTERVAL_MS = 200L; + + private final Bigquery client; + private final TableReference defaultRef; + private final long maxRowsPerBatch; + + private static final ExecutorService executor = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool(100), 10, TimeUnit.SECONDS); + + /** + * Constructs a new row inserter. + * + * @param client a BigQuery client + */ + public BigQueryTableInserter(Bigquery client) { + this.client = client; + this.defaultRef = null; + this.maxRowsPerBatch = MAX_ROWS_PER_BATCH; + } + + /** + * Constructs a new row inserter. + * + * @param client a BigQuery client + * @param defaultRef identifies the table to insert into + * @deprecated replaced by {@link #BigQueryTableInserter(Bigquery)} + */ + @Deprecated + public BigQueryTableInserter(Bigquery client, TableReference defaultRef) { + this.client = client; + this.defaultRef = defaultRef; + this.maxRowsPerBatch = MAX_ROWS_PER_BATCH; + } + + /** + * Constructs a new row inserter. + * + * @param client a BigQuery client + */ + public BigQueryTableInserter(Bigquery client, int maxRowsPerBatch) { + this.client = client; + this.defaultRef = null; + this.maxRowsPerBatch = maxRowsPerBatch; + } + + /** + * Constructs a new row inserter. + * + * @param client a BigQuery client + * @param defaultRef identifies the default table to insert into + * @deprecated replaced by {@link #BigQueryTableInserter(Bigquery, int)} + */ + @Deprecated + public BigQueryTableInserter(Bigquery client, TableReference defaultRef, int maxRowsPerBatch) { + this.client = client; + this.defaultRef = defaultRef; + this.maxRowsPerBatch = maxRowsPerBatch; + } + + /** + * Insert all rows from the given list. + * + * @deprecated replaced by {@link #insertAll(TableReference, List)} + */ + @Deprecated + public void insertAll(List rowList) throws IOException { + insertAll(defaultRef, rowList, null, null); + } + + /** + * Insert all rows from the given list using specified insertIds if not null. + * + * @deprecated replaced by {@link #insertAll(TableReference, List, List)} + */ + @Deprecated + public void insertAll(List rowList, + @Nullable List insertIdList) throws IOException { + insertAll(defaultRef, rowList, insertIdList, null); + } + + /** + * Insert all rows from the given list. + */ + public void insertAll(TableReference ref, List rowList) throws IOException { + insertAll(ref, rowList, null, null); + } + + /** + * Insert all rows from the given list using specified insertIds if not null. Track count of + * bytes written with the Aggregator. + */ + public void insertAll(TableReference ref, List rowList, + @Nullable List insertIdList, Aggregator byteCountAggregator) + throws IOException { + Preconditions.checkNotNull(ref, "ref"); + if (insertIdList != null && rowList.size() != insertIdList.size()) { + throw new AssertionError("If insertIdList is not null it needs to have at least " + + "as many elements as rowList"); + } + + AttemptBoundedExponentialBackOff backoff = new AttemptBoundedExponentialBackOff( + MAX_INSERT_ATTEMPTS, + INITIAL_INSERT_BACKOFF_INTERVAL_MS); + + List allErrors = new ArrayList<>(); + // These lists contain the rows to publish. Initially the contain the entire list. If there are + // failures, they will contain only the failed rows to be retried. + List rowsToPublish = rowList; + List idsToPublish = insertIdList; + while (true) { + List retryRows = new ArrayList<>(); + List retryIds = (idsToPublish != null) ? new ArrayList() : null; + + int strideIndex = 0; + // Upload in batches. + List rows = new LinkedList<>(); + int dataSize = 0; + + List>> futures = new ArrayList<>(); + List strideIndices = new ArrayList<>(); + + for (int i = 0; i < rowsToPublish.size(); ++i) { + TableRow row = rowsToPublish.get(i); + TableDataInsertAllRequest.Rows out = new TableDataInsertAllRequest.Rows(); + if (idsToPublish != null) { + out.setInsertId(idsToPublish.get(i)); + } + out.setJson(row.getUnknownKeys()); + rows.add(out); + + dataSize += row.toString().length(); + if (dataSize >= UPLOAD_BATCH_SIZE_BYTES || rows.size() >= maxRowsPerBatch || + i == rowsToPublish.size() - 1) { + TableDataInsertAllRequest content = new TableDataInsertAllRequest(); + content.setRows(rows); + + final Bigquery.Tabledata.InsertAll insert = client.tabledata() + .insertAll(ref.getProjectId(), ref.getDatasetId(), ref.getTableId(), + content); + + futures.add( + executor.submit(new Callable>() { + @Override + public List call() throws IOException { + return insert.execute().getInsertErrors(); + } + })); + strideIndices.add(strideIndex); + + if (byteCountAggregator != null) { + byteCountAggregator.addValue(Long.valueOf(dataSize)); + } + dataSize = 0; + strideIndex = i + 1; + rows = new LinkedList<>(); + } + } + + try { + for (int i = 0; i < futures.size(); i++) { + List errors = futures.get(i).get(); + if (errors != null) { + for (TableDataInsertAllResponse.InsertErrors error : errors) { + allErrors.add(error); + if (error.getIndex() == null) { + throw new IOException("Insert failed: " + allErrors); + } + + int errorIndex = error.getIndex().intValue() + strideIndices.get(i); + retryRows.add(rowsToPublish.get(errorIndex)); + if (retryIds != null) { + retryIds.add(idsToPublish.get(errorIndex)); + } + } + } + } + } catch (InterruptedException e) { + throw new IOException("Interrupted while inserting " + rowsToPublish); + } catch (ExecutionException e) { + Throwables.propagate(e.getCause()); + } + + if (!allErrors.isEmpty() && !backoff.atMaxAttempts()) { + try { + Thread.sleep(backoff.nextBackOffMillis()); + } catch (InterruptedException e) { + throw new IOException("Interrupted while waiting before retrying insert of " + retryRows); + } + LOG.info("Retrying failed inserts to BigQuery"); + rowsToPublish = retryRows; + idsToPublish = retryIds; + allErrors.clear(); + } else { + break; + } + } + if (!allErrors.isEmpty()) { + throw new IOException("Insert failed: " + allErrors); + } + } + + /** + * Retrieves or creates the table. + * + *

    The table is checked to conform to insertion requirements as specified + * by WriteDisposition and CreateDisposition. + * + *

    If table truncation is requested (WriteDisposition.WRITE_TRUNCATE), then + * this will re-create the table if necessary to ensure it is empty. + * + *

    If an empty table is required (WriteDisposition.WRITE_EMPTY), then this + * will fail if the table exists and is not empty. + * + *

    When constructing a table, a {@code TableSchema} must be available. If a + * schema is provided, then it will be used. If no schema is provided, but + * an existing table is being cleared (WRITE_TRUNCATE option above), then + * the existing schema will be re-used. If no schema is available, then an + * {@code IOException} is thrown. + */ + public Table getOrCreateTable( + TableReference ref, + WriteDisposition writeDisposition, + CreateDisposition createDisposition, + @Nullable TableSchema schema) throws IOException { + // Check if table already exists. + Bigquery.Tables.Get get = client.tables() + .get(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + Table table = null; + try { + table = get.execute(); + } catch (IOException e) { + ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + if (!errorExtractor.itemNotFound(e) || + createDisposition != CreateDisposition.CREATE_IF_NEEDED) { + // Rethrow. + throw e; + } + } + + // If we want an empty table, and it isn't, then delete it first. + if (table != null) { + if (writeDisposition == WriteDisposition.WRITE_APPEND) { + return table; + } + + boolean empty = isEmpty(ref); + if (empty) { + if (writeDisposition == WriteDisposition.WRITE_TRUNCATE) { + LOG.info("Empty table found, not removing {}", BigQueryIO.toTableSpec(ref)); + } + return table; + + } else if (writeDisposition == WriteDisposition.WRITE_EMPTY) { + throw new IOException("WriteDisposition is WRITE_EMPTY, " + + "but table is not empty"); + } + + // Reuse the existing schema if none was provided. + if (schema == null) { + schema = table.getSchema(); + } + + // Delete table and fall through to re-creating it below. + LOG.info("Deleting table {}", BigQueryIO.toTableSpec(ref)); + Bigquery.Tables.Delete delete = client.tables() + .delete(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + delete.execute(); + } + + if (schema == null) { + throw new IllegalArgumentException( + "Table schema required for new table."); + } + + // Create the table. + return tryCreateTable(ref, schema); + } + + /** + * Checks if a table is empty. + */ + public boolean isEmpty(TableReference ref) throws IOException { + Bigquery.Tabledata.List list = client.tabledata() + .list(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + list.setMaxResults(1L); + TableDataList dataList = list.execute(); + + return dataList.getRows() == null || dataList.getRows().isEmpty(); + } + + /** + * Retry table creation up to 5 minutes (with exponential backoff) when this user is near the + * quota for table creation. This relatively innocuous behavior can happen when BigQueryIO is + * configured with a table spec function to use different tables for each window. + */ + private static final int RETRY_CREATE_TABLE_DURATION_MILLIS = (int) TimeUnit.MINUTES.toMillis(5); + + /** + * Tries to create the BigQuery table. + * If a table with the same name already exists in the dataset, the table + * creation fails, and the function returns null. In such a case, + * the existing table doesn't necessarily have the same schema as specified + * by the parameter. + * + * @param schema Schema of the new BigQuery table. + * @return The newly created BigQuery table information, or null if the table + * with the same name already exists. + * @throws IOException if other error than already existing table occurs. + */ + @Nullable + public Table tryCreateTable(TableReference ref, TableSchema schema) throws IOException { + LOG.info("Trying to create BigQuery table: {}", BigQueryIO.toTableSpec(ref)); + BackOff backoff = + new ExponentialBackOff.Builder() + .setMaxElapsedTimeMillis(RETRY_CREATE_TABLE_DURATION_MILLIS) + .build(); + + Table table = new Table().setTableReference(ref).setSchema(schema); + return tryCreateTable(table, ref.getProjectId(), ref.getDatasetId(), backoff, Sleeper.DEFAULT); + } + + @VisibleForTesting + @Nullable + Table tryCreateTable( + Table table, String projectId, String datasetId, BackOff backoff, Sleeper sleeper) + throws IOException { + boolean retry = false; + while (true) { + try { + return client.tables().insert(projectId, datasetId, table).execute(); + } catch (IOException e) { + ApiErrorExtractor extractor = new ApiErrorExtractor(); + if (extractor.itemAlreadyExists(e)) { + // The table already exists, nothing to return. + return null; + } else if (extractor.rateLimited(e)) { + // The request failed because we hit a temporary quota. Back off and try again. + try { + if (BackOffUtils.next(sleeper, backoff)) { + if (!retry) { + LOG.info( + "Quota limit reached when creating table {}:{}.{}, retrying up to {} minutes", + projectId, + datasetId, + table.getTableReference().getTableId(), + TimeUnit.MILLISECONDS.toSeconds(RETRY_CREATE_TABLE_DURATION_MILLIS) / 60.0); + retry = true; + } + continue; + } + } catch (InterruptedException e1) { + // Restore interrupted state and throw the last failure. + Thread.currentThread().interrupt(); + throw e; + } + } + throw e; + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIterator.java new file mode 100644 index 000000000000..c2c80f79c31d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIterator.java @@ -0,0 +1,469 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.api.client.googleapis.services.AbstractGoogleClientRequest; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.ClassInfo; +import com.google.api.client.util.Data; +import com.google.api.client.util.Sleeper; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.Bigquery.Jobs.Insert; +import com.google.api.services.bigquery.model.Dataset; +import com.google.api.services.bigquery.model.DatasetReference; +import com.google.api.services.bigquery.model.ErrorProto; +import com.google.api.services.bigquery.model.Job; +import com.google.api.services.bigquery.model.JobConfiguration; +import com.google.api.services.bigquery.model.JobConfigurationQuery; +import com.google.api.services.bigquery.model.JobReference; +import com.google.api.services.bigquery.model.JobStatus; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableCell; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Random; + +import javax.annotation.Nullable; + +/** + * Iterates over all rows in a table. + */ +public class BigQueryTableRowIterator implements AutoCloseable { + private static final Logger LOG = LoggerFactory.getLogger(BigQueryTableRowIterator.class); + + @Nullable private TableReference ref; + @Nullable private final String projectId; + @Nullable private TableSchema schema; + private final Bigquery client; + private String pageToken; + private Iterator iteratorOverCurrentBatch; + private TableRow current; + // Set true when the final page is seen from the service. + private boolean lastPage = false; + + // The maximum number of times a BigQuery request will be retried + private static final int MAX_RETRIES = 3; + // Initial wait time for the backoff implementation + private static final Duration INITIAL_BACKOFF_TIME = Duration.standardSeconds(1); + + // After sending a query to BQ service we will be polling the BQ service to check the status with + // following interval to check the status of query execution job + private static final Duration QUERY_COMPLETION_POLL_TIME = Duration.standardSeconds(1); + + private final String query; + // Whether to flatten query results. + private final boolean flattenResults; + // Temporary dataset used to store query results. + private String temporaryDatasetId = null; + // Temporary table used to store query results. + private String temporaryTableId = null; + + private BigQueryTableRowIterator( + @Nullable TableReference ref, @Nullable String query, @Nullable String projectId, + Bigquery client, boolean flattenResults) { + this.ref = ref; + this.query = query; + this.projectId = projectId; + this.client = checkNotNull(client, "client"); + this.flattenResults = flattenResults; + } + + /** + * Constructs a {@code BigQueryTableRowIterator} that reads from the specified table. + */ + public static BigQueryTableRowIterator fromTable(TableReference ref, Bigquery client) { + checkNotNull(ref, "ref"); + checkNotNull(client, "client"); + return new BigQueryTableRowIterator(ref, null, ref.getProjectId(), client, true); + } + + /** + * Constructs a {@code BigQueryTableRowIterator} that reads from the results of executing the + * specified query in the specified project. + */ + public static BigQueryTableRowIterator fromQuery( + String query, String projectId, Bigquery client, @Nullable Boolean flattenResults) { + checkNotNull(query, "query"); + checkNotNull(projectId, "projectId"); + checkNotNull(client, "client"); + return new BigQueryTableRowIterator(null, query, projectId, client, + MoreObjects.firstNonNull(flattenResults, Boolean.TRUE)); + } + + /** + * Opens the table for read. + * @throws IOException on failure + */ + public void open() throws IOException, InterruptedException { + if (query != null) { + ref = executeQueryAndWaitForCompletion(); + } + // Get table schema. + Bigquery.Tables.Get get = + client.tables().get(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + + Table table = + executeWithBackOff( + get, + "Error opening BigQuery table %s of dataset %s : {}", + ref.getTableId(), + ref.getDatasetId()); + schema = table.getSchema(); + } + + public boolean advance() throws IOException, InterruptedException { + while (true) { + if (iteratorOverCurrentBatch != null && iteratorOverCurrentBatch.hasNext()) { + // Embed schema information into the raw row, so that values have an + // associated key. This matches how rows are read when using the + // DataflowPipelineRunner. + current = getTypedTableRow(schema.getFields(), iteratorOverCurrentBatch.next()); + return true; + } + if (lastPage) { + return false; + } + + Bigquery.Tabledata.List list = + client.tabledata().list(ref.getProjectId(), ref.getDatasetId(), ref.getTableId()); + if (pageToken != null) { + list.setPageToken(pageToken); + } + + TableDataList result = + executeWithBackOff( + list, + "Error reading from BigQuery table %s of dataset %s : {}", + ref.getTableId(), + ref.getDatasetId()); + + pageToken = result.getPageToken(); + iteratorOverCurrentBatch = + result.getRows() != null + ? result.getRows().iterator() + : Collections.emptyIterator(); + + // The server may return a page token indefinitely on a zero-length table. + if (pageToken == null || result.getTotalRows() != null && result.getTotalRows() == 0) { + lastPage = true; + } + } + } + + public TableRow getCurrent() { + if (current == null) { + throw new NoSuchElementException(); + } + return current; + } + + /** + * Adjusts a field returned from the BigQuery API to match what we will receive when running + * BigQuery's export-to-GCS and parallel read, which is the efficient parallel implementation + * used for batch jobs executed on the Cloud Dataflow service. + * + *

    The following is the relationship between BigQuery schema and Java types: + * + *

      + *
    • Nulls are {@code null}. + *
    • Repeated fields are {@code List} of objects. + *
    • Record columns are {@link TableRow} objects. + *
    • {@code BOOLEAN} columns are JSON booleans, hence Java {@code Boolean} objects. + *
    • {@code FLOAT} columns are JSON floats, hence Java {@code Double} objects. + *
    • {@code TIMESTAMP} columns are {@code String} objects that are of the format + * {@code yyyy-MM-dd HH:mm:ss[.SSSSSS] UTC}, where the {@code .SSSSSS} has no trailing + * zeros and can be 1 to 6 digits long. + *
    • Every other atomic type is a {@code String}. + *
    + * + *

    Note that integers are encoded as strings to match BigQuery's exported JSON format. + * + *

    Finally, values are stored in the {@link TableRow} as {"field name": value} pairs + * and are not accessible through the {@link TableRow#getF} function. + */ + @Nullable private Object getTypedCellValue(TableFieldSchema fieldSchema, Object v) { + if (Data.isNull(v)) { + return null; + } + + if (Objects.equals(fieldSchema.getMode(), "REPEATED")) { + TableFieldSchema elementSchema = fieldSchema.clone().setMode("REQUIRED"); + @SuppressWarnings("unchecked") + List> rawCells = (List>) v; + ImmutableList.Builder values = ImmutableList.builder(); + for (Map element : rawCells) { + values.add(getTypedCellValue(elementSchema, element.get("v"))); + } + return values.build(); + } + + if (fieldSchema.getType().equals("RECORD")) { + @SuppressWarnings("unchecked") + Map typedV = (Map) v; + return getTypedTableRow(fieldSchema.getFields(), typedV); + } + + if (fieldSchema.getType().equals("FLOAT")) { + return Double.parseDouble((String) v); + } + + if (fieldSchema.getType().equals("BOOLEAN")) { + return Boolean.parseBoolean((String) v); + } + + if (fieldSchema.getType().equals("TIMESTAMP")) { + return AvroUtils.formatTimestamp((String) v); + } + + return v; + } + + /** + * A list of the field names that cannot be used in BigQuery tables processed by Dataflow, + * because they are reserved keywords in {@link TableRow}. + */ + // TODO: This limitation is unfortunate. We need to give users a way to use BigQueryIO that does + // not indirect through our broken use of {@link TableRow}. + // See discussion: https://github.com/GoogleCloudPlatform/DataflowJavaSDK/pull/41 + private static final Collection RESERVED_FIELD_NAMES = + ClassInfo.of(TableRow.class).getNames(); + + /** + * Converts a row returned from the BigQuery JSON API as a {@code Map} into a + * Java {@link TableRow} with nested {@link TableCell TableCells}. The {@code Object} values in + * the cells are converted to Java types according to the provided field schemas. + * + *

    See {@link #getTypedCellValue(TableFieldSchema, Object)} for details on how BigQuery + * types are mapped to Java types. + */ + private TableRow getTypedTableRow(List fields, Map rawRow) { + // If rawRow is a TableRow, use it. If not, create a new one. + TableRow row; + List> cells; + if (rawRow instanceof TableRow) { + // Since rawRow is a TableRow it already has TableCell objects in setF. We do not need to do + // any type conversion, but extract the cells for cell-wise processing below. + row = (TableRow) rawRow; + cells = row.getF(); + // Clear the cells from the row, so that row.getF() will return null. This matches the + // behavior of rows produced by the BigQuery export API used on the service. + row.setF(null); + } else { + row = new TableRow(); + + // Since rawRow is a Map we use Map.get("f") instead of TableRow.getF() to + // get its cells. Similarly, when rawCell is a Map instead of a TableCell, + // we will use Map.get("v") instead of TableCell.getV() get its value. + @SuppressWarnings("unchecked") + List> rawCells = + (List>) rawRow.get("f"); + cells = rawCells; + } + + checkState(cells.size() == fields.size(), + "Expected that the row has the same number of cells %s as fields in the schema %s", + cells.size(), fields.size()); + + // Loop through all the fields in the row, normalizing their types with the TableFieldSchema + // and storing the normalized values by field name in the Map that + // underlies the TableRow. + Iterator> cellIt = cells.iterator(); + Iterator fieldIt = fields.iterator(); + while (cellIt.hasNext()) { + Map cell = cellIt.next(); + TableFieldSchema fieldSchema = fieldIt.next(); + + // Convert the object in this cell to the Java type corresponding to its type in the schema. + Object convertedValue = getTypedCellValue(fieldSchema, cell.get("v")); + + String fieldName = fieldSchema.getName(); + checkArgument(!RESERVED_FIELD_NAMES.contains(fieldName), + "BigQueryIO does not support records with columns named %s", fieldName); + + if (convertedValue == null) { + // BigQuery does not include null values when the export operation (to JSON) is used. + // To match that behavior, BigQueryTableRowiterator, and the DirectPipelineRunner, + // intentionally omits columns with null values. + continue; + } + + row.set(fieldName, convertedValue); + } + return row; + } + + // Create a new BigQuery dataset + private void createDataset(String datasetId) throws IOException, InterruptedException { + Dataset dataset = new Dataset(); + DatasetReference reference = new DatasetReference(); + reference.setProjectId(projectId); + reference.setDatasetId(datasetId); + dataset.setDatasetReference(reference); + + String createDatasetError = + "Error when trying to create the temporary dataset " + datasetId + " in project " + + projectId; + executeWithBackOff( + client.datasets().insert(projectId, dataset), createDatasetError + " :{}"); + } + + // Delete the given table that is available in the given dataset. + private void deleteTable(String datasetId, String tableId) + throws IOException, InterruptedException { + executeWithBackOff( + client.tables().delete(projectId, datasetId, tableId), + "Error when trying to delete the temporary table " + datasetId + " in dataset " + datasetId + + " of project " + projectId + ". Manual deletion may be required. Error message : {}"); + } + + // Delete the given dataset. This will fail if the given dataset has any tables. + private void deleteDataset(String datasetId) throws IOException, InterruptedException { + executeWithBackOff( + client.datasets().delete(projectId, datasetId), + "Error when trying to delete the temporary dataset " + datasetId + " in project " + + projectId + ". Manual deletion may be required. Error message : {}"); + } + + /** + * Executes the specified query and returns a reference to the temporary BigQuery table created + * to hold the results. + * + * @throws IOException if the query fails. + */ + private TableReference executeQueryAndWaitForCompletion() + throws IOException, InterruptedException { + // Create a temporary dataset to store results. + // Starting dataset name with an "_" so that it is hidden. + Random rnd = new Random(System.currentTimeMillis()); + temporaryDatasetId = "_dataflow_temporary_dataset_" + rnd.nextInt(1000000); + temporaryTableId = "dataflow_temporary_table_" + rnd.nextInt(1000000); + + createDataset(temporaryDatasetId); + Job job = new Job(); + JobConfiguration config = new JobConfiguration(); + JobConfigurationQuery queryConfig = new JobConfigurationQuery(); + config.setQuery(queryConfig); + job.setConfiguration(config); + queryConfig.setQuery(query); + queryConfig.setAllowLargeResults(true); + queryConfig.setFlattenResults(flattenResults); + + TableReference destinationTable = new TableReference(); + destinationTable.setProjectId(projectId); + destinationTable.setDatasetId(temporaryDatasetId); + destinationTable.setTableId(temporaryTableId); + queryConfig.setDestinationTable(destinationTable); + + Insert insert = client.jobs().insert(projectId, job); + Job queryJob = executeWithBackOff( + insert, "Error when trying to execute the job for query " + query + " :{}"); + JobReference jobId = queryJob.getJobReference(); + + while (true) { + Job pollJob = executeWithBackOff( + client.jobs().get(projectId, jobId.getJobId()), + "Error when trying to get status of the job for query " + query + " :{}"); + JobStatus status = pollJob.getStatus(); + if (status.getState().equals("DONE")) { + // Job is DONE, but did not necessarily succeed. + ErrorProto error = status.getErrorResult(); + if (error == null) { + return pollJob.getConfiguration().getQuery().getDestinationTable(); + } else { + // There will be no temporary table to delete, so null out the reference. + temporaryTableId = null; + throw new IOException("Executing query " + query + " failed: " + error.getMessage()); + } + } + try { + Thread.sleep(QUERY_COMPLETION_POLL_TIME.getMillis()); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + + // Execute a BQ request with exponential backoff and return the result. + // client - BQ request to be executed + // error - Formatted message to log if when a request fails. Takes exception message as a + // formatter parameter. + public static T executeWithBackOff(AbstractGoogleClientRequest client, String error, + Object... errorArgs) throws IOException, InterruptedException { + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backOff = + new AttemptBoundedExponentialBackOff(MAX_RETRIES, INITIAL_BACKOFF_TIME.getMillis()); + + T result = null; + while (true) { + try { + result = client.execute(); + break; + } catch (IOException e) { + LOG.error(String.format(error, errorArgs), e.getMessage()); + if (!BackOffUtils.next(sleeper, backOff)) { + LOG.error( + String.format(error, errorArgs), "Failing after retrying " + MAX_RETRIES + " times."); + throw e; + } + } + } + + return result; + } + + @Override + public void close() { + // Prevent any further requests. + lastPage = true; + + try { + // Deleting temporary table and dataset that gets generated when executing a query. + if (temporaryDatasetId != null) { + if (temporaryTableId != null) { + deleteTable(temporaryDatasetId, temporaryTableId); + } + deleteDataset(temporaryDatasetId); + } + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BitSetCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BitSetCoder.java new file mode 100644 index 000000000000..f3a039ad6649 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BitSetCoder.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder; +import com.google.cloud.dataflow.sdk.coders.CoderException; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.BitSet; + +/** + * Coder for the BitSet used to track child-trigger finished states. + */ +class BitSetCoder extends AtomicCoder { + + private static final BitSetCoder INSTANCE = new BitSetCoder(); + private transient ByteArrayCoder byteArrayCoder = ByteArrayCoder.of(); + + private BitSetCoder() {} + + public static BitSetCoder of() { + return INSTANCE; + } + + @Override + public void encode(BitSet value, OutputStream outStream, Context context) + throws CoderException, IOException { + byteArrayCoder.encodeAndOwn(value.toByteArray(), outStream, context); + } + + @Override + public BitSet decode(InputStream inStream, Context context) + throws CoderException, IOException { + return BitSet.valueOf(byteArrayCoder.decode(inStream, context)); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "BitSetCoder requires its byteArrayCoder to be deterministic.", + byteArrayCoder); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BufferedElementCountingOutputStream.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BufferedElementCountingOutputStream.java new file mode 100644 index 000000000000..e8e693a996c1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BufferedElementCountingOutputStream.java @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder.Context; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Provides an efficient encoding for {@link Iterable}s containing small values by + * buffering up to {@code bufferSize} bytes of data before prefixing the count. + * Note that each element needs to be encoded in a nested context. See + * {@link Context Coder.Context} for more details. + * + *

    To use this stream: + *

    
    + * BufferedElementCountingOutputStream os = ...
    + * for (Element E : elements) {
    + *   os.markElementStart();
    + *   // write an element to os
    + * }
    + * os.finish();
    + * 
    + * + *

    The resulting output stream is: + *

    + * countA element(0) element(1) ... element(countA - 1)
    + * countB element(0) element(1) ... element(countB - 1)
    + * ...
    + * countX element(0) element(1) ... element(countX - 1)
    + * countY
    + * 
    + * + *

    To read this stream: + *

    
    + * InputStream is = ...
    + * long count;
    + * do {
    + *   count = VarInt.decodeLong(is);
    + *   for (int i = 0; i < count; ++i) {
    + *     // read an element from is
    + *   }
    + * } while(count > 0);
    + * 
    + * + *

    The counts are encoded as variable length longs. See {@link VarInt#encode(long, OutputStream)} + * for more details. The end of the iterable is detected by reading a count of 0. + */ +@NotThreadSafe +public class BufferedElementCountingOutputStream extends OutputStream { + public static final int DEFAULT_BUFFER_SIZE = 64 * 1024; + private final ByteBuffer buffer; + private final OutputStream os; + private boolean finished; + private long count; + + /** + * Creates an output stream which encodes the number of elements output to it in a streaming + * manner. + */ + public BufferedElementCountingOutputStream(OutputStream os) { + this(os, DEFAULT_BUFFER_SIZE); + } + + /** + * Creates an output stream which encodes the number of elements output to it in a streaming + * manner with the given {@code bufferSize}. + */ + BufferedElementCountingOutputStream(OutputStream os, int bufferSize) { + this.buffer = ByteBuffer.allocate(bufferSize); + this.os = os; + this.finished = false; + this.count = 0; + } + + /** + * Finishes the encoding by flushing any buffered data, + * and outputting a final count of 0. + */ + public void finish() throws IOException { + if (finished) { + return; + } + flush(); + // Finish the stream by stating that there are 0 elements that follow. + VarInt.encode(0, os); + finished = true; + } + + /** + * Marks that a new element is being output. This allows this output stream + * to use the buffer if it had previously overflowed marking the start of a new + * block of elements. + */ + public void markElementStart() throws IOException { + if (finished) { + throw new IOException("Stream has been finished. Can not add any more elements."); + } + count++; + } + + @Override + public void write(int b) throws IOException { + if (finished) { + throw new IOException("Stream has been finished. Can not write any more data."); + } + if (count == 0) { + os.write(b); + return; + } + + if (buffer.hasRemaining()) { + buffer.put((byte) b); + } else { + outputBuffer(); + os.write(b); + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (finished) { + throw new IOException("Stream has been finished. Can not write any more data."); + } + if (count == 0) { + os.write(b, off, len); + return; + } + + if (buffer.remaining() >= len) { + buffer.put(b, off, len); + } else { + outputBuffer(); + os.write(b, off, len); + } + } + + @Override + public void flush() throws IOException { + if (finished) { + return; + } + outputBuffer(); + os.flush(); + } + + @Override + public void close() throws IOException { + finish(); + os.close(); + } + + // Output the buffer if it contains any data. + private void outputBuffer() throws IOException { + if (count > 0) { + VarInt.encode(count, os); + // We are using a heap based buffer and not a direct buffer so it is safe to access + // the underlying array. + os.write(buffer.array(), buffer.arrayOffset(), buffer.position()); + buffer.clear(); + // The buffer has been flushed so we must write to the underlying stream until + // we learn of the next element. We reset the count to zero marking that we should + // not use the buffer. + count = 0; + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudKnownType.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudKnownType.java new file mode 100644 index 000000000000..8b41eb83dacf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudKnownType.java @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nullable; + +/** A utility for manipulating well-known cloud types. */ +enum CloudKnownType { + TEXT("http://schema.org/Text", String.class) { + @Override + public T parse(Object value, Class clazz) { + return clazz.cast(value); + } + }, + BOOLEAN("http://schema.org/Boolean", Boolean.class) { + @Override + public T parse(Object value, Class clazz) { + return clazz.cast(value); + } + }, + INTEGER("http://schema.org/Integer", Long.class, Integer.class) { + @Override + public T parse(Object value, Class clazz) { + Object result = null; + if (value.getClass() == clazz) { + result = value; + } else if (clazz == Long.class) { + if (value instanceof Integer) { + result = ((Integer) value).longValue(); + } else if (value instanceof String) { + result = Long.valueOf((String) value); + } + } else if (clazz == Integer.class) { + if (value instanceof Long) { + result = ((Long) value).intValue(); + } else if (value instanceof String) { + result = Integer.valueOf((String) value); + } + } + return clazz.cast(result); + } + }, + FLOAT("http://schema.org/Float", Double.class, Float.class) { + @Override + public T parse(Object value, Class clazz) { + Object result = null; + if (value.getClass() == clazz) { + result = value; + } else if (clazz == Double.class) { + if (value instanceof Float) { + result = ((Float) value).doubleValue(); + } else if (value instanceof String) { + result = Double.valueOf((String) value); + } + } else if (clazz == Float.class) { + if (value instanceof Double) { + result = ((Double) value).floatValue(); + } else if (value instanceof String) { + result = Float.valueOf((String) value); + } + } + return clazz.cast(result); + } + }; + + private final String uri; + private final Class[] classes; + + private CloudKnownType(String uri, Class... classes) { + this.uri = uri; + this.classes = classes; + } + + public String getUri() { + return uri; + } + + public abstract T parse(Object value, Class clazz); + + public Class defaultClass() { + return classes[0]; + } + + private static final Map typesByUri = + Collections.unmodifiableMap(buildTypesByUri()); + + private static Map buildTypesByUri() { + Map result = new HashMap<>(); + for (CloudKnownType ty : CloudKnownType.values()) { + result.put(ty.getUri(), ty); + } + return result; + } + + @Nullable + public static CloudKnownType forUri(@Nullable String uri) { + if (uri == null) { + return null; + } + return typesByUri.get(uri); + } + + private static final Map, CloudKnownType> typesByClass = + Collections.unmodifiableMap(buildTypesByClass()); + + private static Map, CloudKnownType> buildTypesByClass() { + Map, CloudKnownType> result = new HashMap<>(); + for (CloudKnownType ty : CloudKnownType.values()) { + for (Class clazz : ty.classes) { + result.put(clazz, ty); + } + } + return result; + } + + @Nullable + public static CloudKnownType forClass(Class clazz) { + return typesByClass.get(clazz); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudObject.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudObject.java new file mode 100644 index 000000000000..8c704bf6d96c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CloudObject.java @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.api.client.util.Preconditions.checkNotNull; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; + +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A representation of an arbitrary Java object to be instantiated by Dataflow + * workers. + * + *

    Typically, an object to be written by the SDK to the Dataflow service will + * implement a method (typically called {@code asCloudObject()}) that returns a + * {@code CloudObject} to represent the object in the protocol. Once the + * {@code CloudObject} is constructed, the method should explicitly add + * additional properties to be presented during deserialization, representing + * child objects by building additional {@code CloudObject}s. + */ +public final class CloudObject extends GenericJson { + /** + * Constructs a {@code CloudObject} by copying the supplied serialized object + * spec, which must represent an SDK object serialized for transport via the + * Dataflow API. + * + *

    The most common use of this method is during deserialization on the worker, + * where it's used as a binding type during instance construction. + * + * @param spec supplies the serialized form of the object as a nested map + * @throws RuntimeException if the supplied map does not represent an SDK object + */ + public static CloudObject fromSpec(Map spec) { + CloudObject result = new CloudObject(); + result.putAll(spec); + if (result.className == null) { + throw new RuntimeException("Unable to create an SDK object from " + spec + + ": Object class not specified (missing \"" + + PropertyNames.OBJECT_TYPE_NAME + "\" field)"); + } + return result; + } + + /** + * Constructs a {@code CloudObject} to be used for serializing an instance of + * the supplied class for transport via the Dataflow API. The instance + * parameters to be serialized must be supplied explicitly after the + * {@code CloudObject} is created, by using {@link CloudObject#put}. + * + * @param cls the class to use when deserializing the object on the worker + */ + public static CloudObject forClass(Class cls) { + CloudObject result = new CloudObject(); + result.className = checkNotNull(cls).getName(); + return result; + } + + /** + * Constructs a {@code CloudObject} to be used for serializing data to be + * deserialized using the supplied class name the supplied class name for + * transport via the Dataflow API. The instance parameters to be serialized + * must be supplied explicitly after the {@code CloudObject} is created, by + * using {@link CloudObject#put}. + * + * @param className the class to use when deserializing the object on the worker + */ + public static CloudObject forClassName(String className) { + CloudObject result = new CloudObject(); + result.className = checkNotNull(className); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forString(String value) { + CloudObject result = forClassName(CloudKnownType.TEXT.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forBoolean(Boolean value) { + CloudObject result = forClassName(CloudKnownType.BOOLEAN.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forInteger(Long value) { + CloudObject result = forClassName(CloudKnownType.INTEGER.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forInteger(Integer value) { + CloudObject result = forClassName(CloudKnownType.INTEGER.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forFloat(Float value) { + CloudObject result = forClassName(CloudKnownType.FLOAT.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value. + * @param value the scalar value to represent. + */ + public static CloudObject forFloat(Double value) { + CloudObject result = forClassName(CloudKnownType.FLOAT.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + /** + * Constructs a {@code CloudObject} representing the given value of a + * well-known cloud object type. + * @param value the scalar value to represent. + * @throws RuntimeException if the value does not have a + * {@link CloudKnownType} mapping + */ + public static CloudObject forKnownType(Object value) { + @Nullable CloudKnownType ty = CloudKnownType.forClass(value.getClass()); + if (ty == null) { + throw new RuntimeException("Unable to represent value via the Dataflow API: " + value); + } + CloudObject result = forClassName(ty.getUri()); + result.put(PropertyNames.SCALAR_FIELD_NAME, value); + return result; + } + + @Key(PropertyNames.OBJECT_TYPE_NAME) + private String className; + + private CloudObject() {} + + /** + * Gets the name of the Java class that this CloudObject represents. + */ + public String getClassName() { + return className; + } + + @Override + public CloudObject clone() { + return (CloudObject) super.clone(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CoderUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CoderUtils.java new file mode 100644 index 000000000000..771bf09eb346 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CoderUtils.java @@ -0,0 +1,327 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addList; + +import com.google.api.client.util.Base64; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoderBase; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoderBase; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Throwables; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonTypeInfo.As; +import com.fasterxml.jackson.annotation.JsonTypeInfo.Id; +import com.fasterxml.jackson.databind.DatabindContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver; +import com.fasterxml.jackson.databind.jsontype.impl.TypeIdResolverBase; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.type.TypeFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.ref.SoftReference; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.TypeVariable; + +/** + * Utilities for working with Coders. + */ +public final class CoderUtils { + private CoderUtils() {} // Non-instantiable + + /** + * Coder class-name alias for a key-value type. + */ + public static final String KIND_PAIR = "kind:pair"; + + /** + * Coder class-name alias for a stream type. + */ + public static final String KIND_STREAM = "kind:stream"; + + private static ThreadLocal> threadLocalOutputStream + = new ThreadLocal<>(); + + /** + * If true, a call to {@code encodeToByteArray} is already on the call stack. + */ + private static ThreadLocal threadLocalOutputStreamInUse = new ThreadLocal() { + @Override + protected Boolean initialValue() { + return false; + } + }; + + /** + * Encodes the given value using the specified Coder, and returns + * the encoded bytes. + * + *

    This function is not reentrant; it should not be called from methods of the provided + * {@link Coder}. + */ + public static byte[] encodeToByteArray(Coder coder, T value) throws CoderException { + return encodeToByteArray(coder, value, Coder.Context.OUTER); + } + + public static byte[] encodeToByteArray(Coder coder, T value, Coder.Context context) + throws CoderException { + if (threadLocalOutputStreamInUse.get()) { + // encodeToByteArray() is called recursively and the thread local stream is in use, + // allocating a new one. + ByteArrayOutputStream stream = new ExposedByteArrayOutputStream(); + encodeToSafeStream(coder, value, stream, context); + return stream.toByteArray(); + } else { + threadLocalOutputStreamInUse.set(true); + try { + ByteArrayOutputStream stream = getThreadLocalOutputStream(); + encodeToSafeStream(coder, value, stream, context); + return stream.toByteArray(); + } finally { + threadLocalOutputStreamInUse.set(false); + } + } + } + + /** + * Encodes {@code value} to the given {@code stream}, which should be a stream that never throws + * {@code IOException}, such as {@code ByteArrayOutputStream} or + * {@link ExposedByteArrayOutputStream}. + */ + private static void encodeToSafeStream( + Coder coder, T value, OutputStream stream, Coder.Context context) throws CoderException { + try { + coder.encode(value, new UnownedOutputStream(stream), context); + } catch (IOException exn) { + Throwables.propagateIfPossible(exn, CoderException.class); + throw new IllegalArgumentException( + "Forbidden IOException when writing to OutputStream", exn); + } + } + + /** + * Decodes the given bytes using the specified Coder, and returns + * the resulting decoded value. + */ + public static T decodeFromByteArray(Coder coder, byte[] encodedValue) + throws CoderException { + return decodeFromByteArray(coder, encodedValue, Coder.Context.OUTER); + } + + public static T decodeFromByteArray( + Coder coder, byte[] encodedValue, Coder.Context context) throws CoderException { + try (ExposedByteArrayInputStream stream = new ExposedByteArrayInputStream(encodedValue)) { + T result = decodeFromSafeStream(coder, stream, context); + if (stream.available() != 0) { + throw new CoderException( + stream.available() + " unexpected extra bytes after decoding " + result); + } + return result; + } + } + + /** + * Decodes a value from the given {@code stream}, which should be a stream that never throws + * {@code IOException}, such as {@code ByteArrayInputStream} or + * {@link ExposedByteArrayInputStream}. + */ + private static T decodeFromSafeStream( + Coder coder, InputStream stream, Coder.Context context) throws CoderException { + try { + return coder.decode(new UnownedInputStream(stream), context); + } catch (IOException exn) { + Throwables.propagateIfPossible(exn, CoderException.class); + throw new IllegalArgumentException( + "Forbidden IOException when reading from InputStream", exn); + } + } + + private static ByteArrayOutputStream getThreadLocalOutputStream() { + SoftReference refStream = threadLocalOutputStream.get(); + ExposedByteArrayOutputStream stream = refStream == null ? null : refStream.get(); + if (stream == null) { + stream = new ExposedByteArrayOutputStream(); + threadLocalOutputStream.set(new SoftReference<>(stream)); + } + stream.reset(); + return stream; + } + + /** + * Clones the given value by encoding and then decoding it with the specified Coder. + * + *

    This function is not reentrant; it should not be called from methods of the provided + * {@link Coder}. + */ + public static T clone(Coder coder, T value) throws CoderException { + return decodeFromByteArray(coder, encodeToByteArray(coder, value, Coder.Context.OUTER)); + } + + /** + * Encodes the given value using the specified Coder, and returns the Base64 encoding of the + * encoded bytes. + * + * @throws CoderException if there are errors during encoding. + */ + public static String encodeToBase64(Coder coder, T value) + throws CoderException { + byte[] rawValue = encodeToByteArray(coder, value); + return Base64.encodeBase64URLSafeString(rawValue); + } + + /** + * Parses a value from a base64-encoded String using the given coder. + */ + public static T decodeFromBase64(Coder coder, String encodedValue) throws CoderException { + return decodeFromSafeStream( + coder, new ByteArrayInputStream(Base64.decodeBase64(encodedValue)), Coder.Context.OUTER); + } + + /** + * If {@code coderType} is a subclass of {@code Coder} for a specific + * type {@code T}, returns {@code T.class}. + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public static TypeDescriptor getCodedType(TypeDescriptor coderDescriptor) { + ParameterizedType coderType = + (ParameterizedType) coderDescriptor.getSupertype(Coder.class).getType(); + TypeDescriptor codedType = TypeDescriptor.of(coderType.getActualTypeArguments()[0]); + return codedType; + } + + public static CloudObject makeCloudEncoding( + String type, + CloudObject... componentSpecs) { + CloudObject encoding = CloudObject.forClassName(type); + if (componentSpecs.length > 0) { + addList(encoding, PropertyNames.COMPONENT_ENCODINGS, componentSpecs); + } + return encoding; + } + + /** + * A {@link com.fasterxml.jackson.databind.Module} that adds the type + * resolver needed for Coder definitions created by the Dataflow service. + */ + static final class Jackson2Module extends SimpleModule { + /** + * The Coder custom type resolver. + * + *

    This resolver resolves coders. If the Coder ID is a particular + * well-known identifier supplied by the Dataflow service, it's replaced + * with the corresponding class. All other Coder instances are resolved + * by class name, using the package com.google.cloud.dataflow.sdk.coders + * if there are no "."s in the ID. + */ + private static final class Resolver extends TypeIdResolverBase { + @SuppressWarnings("unused") // Used via @JsonTypeIdResolver annotation on Mixin + public Resolver() { + super(TypeFactory.defaultInstance().constructType(Coder.class), + TypeFactory.defaultInstance()); + } + + @Deprecated + @Override + public JavaType typeFromId(String id) { + return typeFromId(null, id); + } + + @Override + public JavaType typeFromId(DatabindContext context, String id) { + Class clazz = getClassForId(id); + if (clazz == KvCoder.class) { + clazz = KvCoderBase.class; + } + if (clazz == MapCoder.class) { + clazz = MapCoderBase.class; + } + @SuppressWarnings("rawtypes") + TypeVariable[] tvs = clazz.getTypeParameters(); + JavaType[] types = new JavaType[tvs.length]; + for (int lupe = 0; lupe < tvs.length; lupe++) { + types[lupe] = TypeFactory.unknownType(); + } + return _typeFactory.constructSimpleType(clazz, types); + } + + private Class getClassForId(String id) { + try { + if (id.contains(".")) { + return Class.forName(id); + } + + if (id.equals(KIND_STREAM)) { + return IterableCoder.class; + } else if (id.equals(KIND_PAIR)) { + return KvCoder.class; + } + + // Otherwise, see if the ID is the name of a class in + // com.google.cloud.dataflow.sdk.coders. We do this via creating + // the class object so that class loaders have a chance to get + // involved -- and since we need the class object anyway. + return Class.forName("com.google.cloud.dataflow.sdk.coders." + id); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Unable to convert coder ID " + id + " to class", e); + } + } + + @Override + public String idFromValueAndType(Object o, Class clazz) { + return clazz.getName(); + } + + @Override + public String idFromValue(Object o) { + return o.getClass().getName(); + } + + @Override + public JsonTypeInfo.Id getMechanism() { + return JsonTypeInfo.Id.CUSTOM; + } + } + + /** + * The mixin class defining how Coders are handled by the deserialization + * {@link ObjectMapper}. + * + *

    This is done via a mixin so that this resolver is only used + * during deserialization requested by the Dataflow SDK. + */ + @JsonTypeIdResolver(Resolver.class) + @JsonTypeInfo(use = Id.CUSTOM, include = As.PROPERTY, property = PropertyNames.OBJECT_TYPE_NAME) + private static final class Mixin {} + + public Jackson2Module() { + super("DataflowCoders"); + setMixInAnnotation(Coder.class, Mixin.class); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java new file mode 100644 index 000000000000..6f2b89b84d01 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.state.StateContext; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +/** + * Factory that produces {@code Combine.Context} based on different inputs. + */ +public class CombineContextFactory { + + private static final Context NULL_CONTEXT = new Context() { + @Override + public PipelineOptions getPipelineOptions() { + throw new IllegalArgumentException("cannot call getPipelineOptions() in a null context"); + } + + @Override + public T sideInput(PCollectionView view) { + throw new IllegalArgumentException("cannot call sideInput() in a null context"); + } + }; + + /** + * Returns a fake {@code Combine.Context} for tests. + */ + public static Context nullContext() { + return NULL_CONTEXT; + } + + /** + * Returns a {@code Combine.Context} that wraps a {@code DoFn.ProcessContext}. + */ + public static Context createFromProcessContext(final DoFn.ProcessContext c) { + return new Context() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public T sideInput(PCollectionView view) { + return c.sideInput(view); + } + }; + } + + /** + * Returns a {@code Combine.Context} that wraps a {@link StateContext}. + */ + public static Context createFromStateContext(final StateContext c) { + return new Context() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public T sideInput(PCollectionView view) { + return c.sideInput(view); + } + }; + } + + /** + * Returns a {@code Combine.Context} from {@code PipelineOptions}, {@code SideInputReader}, + * and the main input window. + */ + public static Context createFromComponents(final PipelineOptions options, + final SideInputReader sideInputReader, final BoundedWindow mainInputWindow) { + return new Context() { + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public T sideInput(PCollectionView view) { + if (!sideInputReader.contains(view)) { + throw new IllegalArgumentException("calling sideInput() with unknown view"); + } + + BoundedWindow sideInputWindow = + view.getWindowingStrategyInternal().getWindowFn().getSideInputWindow(mainInputWindow); + return sideInputReader.get(view, sideInputWindow); + } + }; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java new file mode 100644 index 000000000000..6201e6e7bb1d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java @@ -0,0 +1,97 @@ + +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.state.StateContext; + +import java.io.IOException; +import java.io.NotSerializableException; +import java.io.ObjectOutputStream; + +/** + * Static utility methods that create combine function instances. + */ +public class CombineFnUtil { + /** + * Returns the partial application of the {@link KeyedCombineFnWithContext} to a specific + * context to produce a {@link KeyedCombineFn}. + * + *

    The returned {@link KeyedCombineFn} cannot be serialized. + */ + public static KeyedCombineFn + bindContext( + KeyedCombineFnWithContext combineFn, + StateContext stateContext) { + Context context = CombineContextFactory.createFromStateContext(stateContext); + return new NonSerializableBoundedKeyedCombineFn<>(combineFn, context); + } + + private static class NonSerializableBoundedKeyedCombineFn + extends KeyedCombineFn { + private final KeyedCombineFnWithContext combineFn; + private final Context context; + + private NonSerializableBoundedKeyedCombineFn( + KeyedCombineFnWithContext combineFn, + Context context) { + this.combineFn = combineFn; + this.context = context; + } + @Override + public AccumT createAccumulator(K key) { + return combineFn.createAccumulator(key, context); + } + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value) { + return combineFn.addInput(key, accumulator, value, context); + } + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators) { + return combineFn.mergeAccumulators(key, accumulators, context); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator) { + return combineFn.extractOutput(key, accumulator, context); + } + @Override + public AccumT compact(K key, AccumT accumulator) { + return combineFn.compact(key, accumulator, context); + } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return combineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return combineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + + private void writeObject(@SuppressWarnings("unused") ObjectOutputStream out) + throws IOException { + throw new NotSerializableException( + "Cannot serialize the CombineFn resulting from CombineFnUtil.bindContext."); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CounterAggregator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CounterAggregator.java new file mode 100644 index 000000000000..824825f41fe1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CounterAggregator.java @@ -0,0 +1,96 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterProvider; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; + +/** + * An implementation of the {@code Aggregator} interface that uses a + * {@link Counter} as the underlying representation. Supports {@link CombineFn}s + * from the {@link Sum}, {@link Min} and {@link Max} classes. + * + * @param the type of input values + * @param the type of accumulator values + * @param the type of output value + */ +public class CounterAggregator implements Aggregator { + + private final Counter counter; + private final CombineFn combiner; + + /** + * Constructs a new aggregator with the given name and aggregation logic + * specified in the CombineFn argument. The underlying counter is + * automatically added into the provided CounterSet. + * + *

    If a counter with the same name already exists, it will be reused, as + * long as it has the same type. + */ + public CounterAggregator(String name, CombineFn combiner, + CounterSet.AddCounterMutator addCounterMutator) { + // Safe contravariant cast + this(constructCounter(name, combiner), addCounterMutator, + (CombineFn) combiner); + } + + private CounterAggregator(Counter counter, + CounterSet.AddCounterMutator addCounterMutator, + CombineFn combiner) { + try { + this.counter = addCounterMutator.addCounter(counter); + } catch (IllegalArgumentException ex) { + throw new IllegalArgumentException( + "aggregator's name collides with an existing aggregator " + + "or system-provided counter of an incompatible type"); + } + this.combiner = combiner; + } + + private static Counter constructCounter(String name, + CombineFn combiner) { + if (combiner instanceof CounterProvider) { + @SuppressWarnings("unchecked") + CounterProvider counterProvider = (CounterProvider) combiner; + return counterProvider.getCounter(name); + } else { + throw new IllegalArgumentException("unsupported combiner in Aggregator: " + + combiner.getClass().getName()); + } + } + + @Override + public void addValue(InputT value) { + counter.addValue(value); + } + + @Override + public String getName() { + return counter.getName(); + } + + @Override + public CombineFn getCombineFn() { + return combiner; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CredentialFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CredentialFactory.java new file mode 100644 index 000000000000..4913a1e66ddb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CredentialFactory.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.Credential; + +import java.io.IOException; +import java.security.GeneralSecurityException; + +/** + * Construct an oauth credential to be used by the SDK and the SDK workers. + */ +public interface CredentialFactory { + public Credential getCredential() throws IOException, GeneralSecurityException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Credentials.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Credentials.java new file mode 100644 index 000000000000..671b131554ea --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Credentials.java @@ -0,0 +1,192 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.extensions.java6.auth.oauth2.AbstractPromptReceiver; +import com.google.api.client.extensions.java6.auth.oauth2.AuthorizationCodeInstalledApp; +import com.google.api.client.googleapis.auth.oauth2.GoogleAuthorizationCodeFlow; +import com.google.api.client.googleapis.auth.oauth2.GoogleClientSecrets; +import com.google.api.client.googleapis.auth.oauth2.GoogleCredential; +import com.google.api.client.googleapis.auth.oauth2.GoogleOAuthConstants; +import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.util.store.FileDataStoreFactory; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.common.base.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * Provides support for loading credentials. + */ +public class Credentials { + + private static final Logger LOG = LoggerFactory.getLogger(Credentials.class); + + /** + * OAuth 2.0 scopes used by a local worker (not on GCE). + * The scope cloud-platform provides access to all Cloud Platform resources. + * cloud-platform isn't sufficient yet for talking to datastore so we request + * those resources separately. + * + *

    Note that trusted scope relationships don't apply to OAuth tokens, so for + * services we access directly (GCS) as opposed to through the backend + * (BigQuery, GCE), we need to explicitly request that scope. + */ + private static final List SCOPES = Arrays.asList( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/devstorage.full_control", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/datastore"); + + private static class PromptReceiver extends AbstractPromptReceiver { + @Override + public String getRedirectUri() { + return GoogleOAuthConstants.OOB_REDIRECT_URI; + } + } + + /** + * Initializes OAuth2 credentials. + * + *

    This can use 3 different mechanisms for obtaining a credential: + *

      + *
    1. + * It can fetch the + * + * application default credentials. + *
    2. + *
    3. + * The user can specify a client secrets file and go through the OAuth2 + * webflow. The credential will then be cached in the user's home + * directory for reuse. Provide the property "secrets_file" to use this + * mechanism. + *
    4. + *
    5. + * The user can specify a file containing a service account. + * Provide the properties "service_account_keyfile" and + * "service_account_name" to use this mechanism. + *
    6. + *
    + * The default mechanism is to use the + * + * application default credentials. The other options can be used by providing the + * corresponding properties. + */ + public static Credential getCredential(GcpOptions options) + throws IOException, GeneralSecurityException { + String keyFile = options.getServiceAccountKeyfile(); + String accountName = options.getServiceAccountName(); + + if (keyFile != null && accountName != null) { + try { + return getCredentialFromFile(keyFile, accountName, SCOPES); + } catch (GeneralSecurityException e) { + throw new IOException("Unable to obtain credentials from file", e); + } + } + + if (options.getSecretsFile() != null) { + return getCredentialFromClientSecrets(options, SCOPES); + } + + try { + return GoogleCredential.getApplicationDefault().createScoped(SCOPES); + } catch (IOException e) { + throw new RuntimeException("Unable to get application default credentials. Please see " + + "https://developers.google.com/accounts/docs/application-default-credentials " + + "for details on how to specify credentials. This version of the SDK is " + + "dependent on the gcloud core component version 2015.02.05 or newer to " + + "be able to get credentials from the currently authorized user via gcloud auth.", e); + } + } + + /** + * Loads OAuth2 credential from a local file. + */ + private static Credential getCredentialFromFile( + String keyFile, String accountId, Collection scopes) + throws IOException, GeneralSecurityException { + GoogleCredential credential = new GoogleCredential.Builder() + .setTransport(Transport.getTransport()) + .setJsonFactory(Transport.getJsonFactory()) + .setServiceAccountId(accountId) + .setServiceAccountScopes(scopes) + .setServiceAccountPrivateKeyFromP12File(new File(keyFile)) + .build(); + + LOG.info("Created credential from file {}", keyFile); + return credential; + } + + /** + * Loads OAuth2 credential from client secrets, which may require an + * interactive authorization prompt. + */ + private static Credential getCredentialFromClientSecrets( + GcpOptions options, Collection scopes) + throws IOException, GeneralSecurityException { + String clientSecretsFile = options.getSecretsFile(); + + Preconditions.checkArgument(clientSecretsFile != null); + HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport(); + + JsonFactory jsonFactory = JacksonFactory.getDefaultInstance(); + GoogleClientSecrets clientSecrets; + + try { + clientSecrets = GoogleClientSecrets.load(jsonFactory, + new FileReader(clientSecretsFile)); + } catch (IOException e) { + throw new RuntimeException( + "Could not read the client secrets from file: " + clientSecretsFile, + e); + } + + FileDataStoreFactory dataStoreFactory = + new FileDataStoreFactory(new java.io.File(options.getCredentialDir())); + + GoogleAuthorizationCodeFlow flow = new GoogleAuthorizationCodeFlow.Builder( + httpTransport, jsonFactory, clientSecrets, scopes) + .setDataStoreFactory(dataStoreFactory) + .setTokenServerUrl(new GenericUrl(options.getTokenServerUrl())) + .setAuthorizationServerEncodedUrl(options.getAuthorizationServerEncodedUrl()) + .build(); + + // The credentialId identifies the credential if we're using a persistent + // credential store. + Credential credential = + new AuthorizationCodeInstalledApp(flow, new PromptReceiver()) + .authorize(options.getCredentialId()); + + LOG.info("Got credential from client secret"); + return credential; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowPathValidator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowPathValidator.java new file mode 100644 index 000000000000..cfb120cff35c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowPathValidator.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.common.base.Preconditions; + +import java.io.IOException; + +/** + * GCP implementation of {@link PathValidator}. Only GCS paths are allowed. + */ +public class DataflowPathValidator implements PathValidator { + + private DataflowPipelineOptions dataflowOptions; + + DataflowPathValidator(DataflowPipelineOptions options) { + this.dataflowOptions = options; + } + + public static DataflowPathValidator fromOptions(PipelineOptions options) { + return new DataflowPathValidator(options.as(DataflowPipelineOptions.class)); + } + + /** + * Validates the the input GCS path is accessible and that the path + * is well formed. + */ + @Override + public String validateInputFilePatternSupported(String filepattern) { + GcsPath gcsPath = getGcsPath(filepattern); + Preconditions.checkArgument( + dataflowOptions.getGcsUtil().isGcsPatternSupported(gcsPath.getObject())); + String returnValue = verifyPath(filepattern); + verifyPathIsAccessible(filepattern, "Could not find file %s"); + return returnValue; + } + + /** + * Validates the the output GCS path is accessible and that the path + * is well formed. + */ + @Override + public String validateOutputFilePrefixSupported(String filePrefix) { + String returnValue = verifyPath(filePrefix); + verifyPathIsAccessible(filePrefix, "Output path does not exist or is not writeable: %s"); + return returnValue; + } + + @Override + public String verifyPath(String path) { + GcsPath gcsPath = getGcsPath(path); + Preconditions.checkArgument(gcsPath.isAbsolute(), + "Must provide absolute paths for Dataflow"); + Preconditions.checkArgument(!gcsPath.getObject().contains("//"), + "Dataflow Service does not allow objects with consecutive slashes"); + return gcsPath.toResourceName(); + } + + private void verifyPathIsAccessible(String path, String errorMessage) { + GcsPath gcsPath = getGcsPath(path); + try { + Preconditions.checkArgument(dataflowOptions.getGcsUtil().bucketExists(gcsPath), + errorMessage, path); + } catch (IOException e) { + throw new RuntimeException( + String.format("Unable to verify that GCS bucket gs://%s exists.", gcsPath.getBucket()), + e); + } + } + + private GcsPath getGcsPath(String path) { + try { + return GcsPath.fromUri(path); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(String.format( + "%s expected a valid 'gs://' path but was given '%s'", + dataflowOptions.getRunner().getSimpleName(), path), e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowReleaseInfo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowReleaseInfo.java new file mode 100644 index 000000000000..39b30054f123 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DataflowReleaseInfo.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +/** + * Utilities for working with the Dataflow distribution. + */ +public final class DataflowReleaseInfo extends GenericJson { + private static final Logger LOG = LoggerFactory.getLogger(DataflowReleaseInfo.class); + + private static final String DATAFLOW_PROPERTIES_PATH = + "/com/google/cloud/dataflow/sdk/sdk.properties"; + + private static class LazyInit { + private static final DataflowReleaseInfo INSTANCE = + new DataflowReleaseInfo(DATAFLOW_PROPERTIES_PATH); + } + + /** + * Returns an instance of DataflowReleaseInfo. + */ + public static DataflowReleaseInfo getReleaseInfo() { + return LazyInit.INSTANCE; + } + + @Key private String name = "Google Cloud Dataflow Java SDK"; + @Key private String version = "Unknown"; + + /** Provides the SDK name. */ + public String getName() { + return name; + } + + /** Provides the SDK version. */ + public String getVersion() { + return version; + } + + private DataflowReleaseInfo(String resourcePath) { + Properties properties = new Properties(); + + InputStream in = DataflowReleaseInfo.class.getResourceAsStream( + DATAFLOW_PROPERTIES_PATH); + if (in == null) { + LOG.warn("Dataflow properties resource not found: {}", resourcePath); + return; + } + + try { + properties.load(in); + } catch (IOException e) { + LOG.warn("Error loading Dataflow properties resource: ", e); + } + + for (String name : properties.stringPropertyNames()) { + if (name.equals("name")) { + // We don't allow the properties to override the SDK name. + continue; + } + put(name, properties.getProperty(name)); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectModeExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectModeExecutionContext.java new file mode 100644 index 000000000000..6e970535dbdf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectModeExecutionContext.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.ValueWithMetadata; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.util.state.InMemoryStateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; + +/** + * {@link ExecutionContext} for use in direct mode. + */ +public class DirectModeExecutionContext + extends BaseExecutionContext { + + private Object key; + private List> output = Lists.newArrayList(); + private Map, List>> sideOutputs = Maps.newHashMap(); + + protected DirectModeExecutionContext() {} + + public static DirectModeExecutionContext create() { + return new DirectModeExecutionContext(); + } + + @Override + protected StepContext createStepContext( + String stepName, String transformName, StateSampler stateSampler) { + return new StepContext(this, stepName, transformName); + } + + public Object getKey() { + return key; + } + + public void setKey(Object newKey) { + // The direct mode runner may reorder elements, so we need to keep + // around the state used for each key. + for (ExecutionContext.StepContext stepContext : getAllStepContexts()) { + ((StepContext) stepContext).switchKey(newKey); + } + key = newKey; + } + + @Override + public void noteOutput(WindowedValue outputElem) { + output.add(ValueWithMetadata.of(outputElem).withKey(getKey())); + } + + @Override + public void noteSideOutput(TupleTag tag, WindowedValue outputElem) { + List> output = sideOutputs.get(tag); + if (output == null) { + output = Lists.newArrayList(); + sideOutputs.put(tag, output); + } + output.add(ValueWithMetadata.of(outputElem).withKey(getKey())); + } + + public List> getOutput(@SuppressWarnings("unused") TupleTag tag) { + @SuppressWarnings({"unchecked", "rawtypes"}) // Cast not expressible without rawtypes + List> typedOutput = (List) output; + return typedOutput; + } + + public List> getSideOutput(TupleTag tag) { + if (sideOutputs.containsKey(tag)) { + @SuppressWarnings({"unchecked", "rawtypes"}) // Cast not expressible without rawtypes + List> typedOutput = (List) sideOutputs.get(tag); + return typedOutput; + } else { + return Lists.newArrayList(); + } + } + + /** + * {@link ExecutionContext.StepContext} used in direct mode. + */ + public static class StepContext extends BaseExecutionContext.StepContext { + + /** A map from each key to the state associated with it. */ + private final Map> stateInternals = Maps.newHashMap(); + private InMemoryStateInternals currentStateInternals = null; + + private StepContext(ExecutionContext executionContext, String stepName, String transformName) { + super(executionContext, stepName, transformName); + switchKey(null); + } + + public void switchKey(Object newKey) { + currentStateInternals = stateInternals.get(newKey); + if (currentStateInternals == null) { + currentStateInternals = InMemoryStateInternals.forKey(newKey); + stateInternals.put(newKey, currentStateInternals); + } + } + + @Override + public StateInternals stateInternals() { + return checkNotNull(currentStateInternals); + } + + @Override + public TimerInternals timerInternals() { + throw new UnsupportedOperationException("Direct mode cannot return timerInternals"); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectSideInputReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectSideInputReader.java new file mode 100644 index 000000000000..ee8c922897ac --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DirectSideInputReader.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Predicate; +import com.google.common.collect.Iterables; + +/** + * Basic side input reader wrapping a {@link PTuple} of side input iterables. Encapsulates + * conversion according to the {@link PCollectionView} and projection to a particular + * window. + */ +public class DirectSideInputReader implements SideInputReader { + + private PTuple sideInputValues; + + private DirectSideInputReader(PTuple sideInputValues) { + this.sideInputValues = sideInputValues; + } + + public static DirectSideInputReader of(PTuple sideInputValues) { + return new DirectSideInputReader(sideInputValues); + } + + @Override + public boolean contains(PCollectionView view) { + return sideInputValues.has(view.getTagInternal()); + } + + @Override + public boolean isEmpty() { + return sideInputValues.isEmpty(); + } + + @Override + public T get(PCollectionView view, final BoundedWindow window) { + final TupleTag>> tag = view.getTagInternal(); + if (!sideInputValues.has(tag)) { + throw new IllegalArgumentException("calling getSideInput() with unknown view"); + } + + if (view.getWindowingStrategyInternal().getWindowFn() instanceof GlobalWindows) { + return view.fromIterableInternal(sideInputValues.get(tag)); + } else { + return view.fromIterableInternal( + Iterables.filter(sideInputValues.get(tag), + new Predicate>() { + @Override + public boolean apply(WindowedValue element) { + return element.getWindows().contains(window); + } + })); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnInfo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnInfo.java new file mode 100644 index 000000000000..15a3a471c23c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnInfo.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import java.io.Serializable; + +/** + * Wrapper class holding the necessary information to serialize a DoFn. + * + * @param the type of the (main) input elements of the DoFn + * @param the type of the (main) output elements of the DoFn + */ +public class DoFnInfo implements Serializable { + private final DoFn doFn; + private final WindowingStrategy windowingStrategy; + private final Iterable> sideInputViews; + private final Coder inputCoder; + + public DoFnInfo(DoFn doFn, WindowingStrategy windowingStrategy) { + this.doFn = doFn; + this.windowingStrategy = windowingStrategy; + this.sideInputViews = null; + this.inputCoder = null; + } + + public DoFnInfo(DoFn doFn, WindowingStrategy windowingStrategy, + Iterable> sideInputViews, Coder inputCoder) { + this.doFn = doFn; + this.windowingStrategy = windowingStrategy; + this.sideInputViews = sideInputViews; + this.inputCoder = inputCoder; + } + + public DoFn getDoFn() { + return doFn; + } + + public WindowingStrategy getWindowingStrategy() { + return windowingStrategy; + } + + public Iterable> getSideInputViews() { + return sideInputViews; + } + + public Coder getInputCoder() { + return inputCoder; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunner.java new file mode 100644 index 000000000000..51c3f39584cb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunner.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.ProcessContext; +import com.google.cloud.dataflow.sdk.values.KV; + +/** + * An wrapper interface that represents the execution of a {@link DoFn}. + */ +public interface DoFnRunner { + /** + * Prepares and calls {@link DoFn#startBundle}. + */ + public void startBundle(); + + /** + * Calls {@link DoFn#processElement} with a {@link ProcessContext} containing the current element. + */ + public void processElement(WindowedValue elem); + + /** + * Calls {@link DoFn#finishBundle} and performs additional tasks, such as + * flushing in-memory states. + */ + public void finishBundle(); + + /** + * An internal interface for signaling that a {@link DoFn} requires late data dropping. + */ + public interface ReduceFnExecutor { + /** + * Gets this object as a {@link DoFn}. + * + * Most implementors of this interface are expected to be {@link DoFn} instances, and will + * return themselves. + */ + DoFn, KV> asDoFn(); + + /** + * Returns an aggregator that tracks elements that are dropped due to being late. + */ + Aggregator getDroppedDueToLatenessAggregator(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java new file mode 100644 index 000000000000..04ec59f57dfa --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java @@ -0,0 +1,558 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.DoFnRunners.OutputManager; +import com.google.cloud.dataflow.sdk.util.ExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; + +import org.joda.time.Instant; +import org.joda.time.format.PeriodFormat; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A base implementation of {@link DoFnRunner}. + * + *

    Sub-classes should override {@link #invokeProcessElement}. + */ +public abstract class DoFnRunnerBase implements DoFnRunner { + + /** The DoFn being run. */ + public final DoFn fn; + + /** The context used for running the DoFn. */ + public final DoFnContext context; + + protected DoFnRunnerBase( + PipelineOptions options, + DoFn fn, + SideInputReader sideInputReader, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator, + WindowingStrategy windowingStrategy) { + this.fn = fn; + this.context = new DoFnContext<>( + options, + fn, + sideInputReader, + outputManager, + mainOutputTag, + sideOutputTags, + stepContext, + addCounterMutator, + windowingStrategy == null ? null : windowingStrategy.getWindowFn()); + } + + /** + * An implementation of {@code OutputManager} using simple lists, for testing and in-memory + * contexts such as the {@link DirectPipelineRunner}. + */ + public static class ListOutputManager implements OutputManager { + + private Map, List>> outputLists = Maps.newHashMap(); + + @Override + public void output(TupleTag tag, WindowedValue output) { + @SuppressWarnings({"rawtypes", "unchecked"}) + List> outputList = (List) outputLists.get(tag); + + if (outputList == null) { + outputList = Lists.newArrayList(); + @SuppressWarnings({"rawtypes", "unchecked"}) + List> untypedList = (List) outputList; + outputLists.put(tag, untypedList); + } + + outputList.add(output); + } + + public List> getOutput(TupleTag tag) { + // Safe cast by design, inexpressible in Java without rawtypes + @SuppressWarnings({"rawtypes", "unchecked"}) + List> outputList = (List) outputLists.get(tag); + return (outputList != null) ? outputList : Collections.>emptyList(); + } + } + + @Override + public void startBundle() { + // This can contain user code. Wrap it in case it throws an exception. + try { + fn.startBundle(context); + } catch (Throwable t) { + // Exception in user code. + throw wrapUserCodeException(t); + } + } + + @Override + public void processElement(WindowedValue elem) { + if (elem.getWindows().size() <= 1 + || (!RequiresWindowAccess.class.isAssignableFrom(fn.getClass()) + && context.sideInputReader.isEmpty())) { + invokeProcessElement(elem); + } else { + // We could modify the windowed value (and the processContext) to + // avoid repeated allocations, but this is more straightforward. + for (BoundedWindow window : elem.getWindows()) { + invokeProcessElement(WindowedValue.of( + elem.getValue(), elem.getTimestamp(), window, elem.getPane())); + } + } + } + + /** + * Invokes {@link DoFn#processElement} after certain pre-processings has been done in + * {@link DoFnRunnerBase#processElement}. + */ + protected abstract void invokeProcessElement(WindowedValue elem); + + @Override + public void finishBundle() { + // This can contain user code. Wrap it in case it throws an exception. + try { + fn.finishBundle(context); + } catch (Throwable t) { + // Exception in user code. + throw wrapUserCodeException(t); + } + } + + /** + * A concrete implementation of {@code DoFn.Context} used for running a {@link DoFn}. + * + * @param the type of the DoFn's (main) input elements + * @param the type of the DoFn's (main) output elements + */ + private static class DoFnContext + extends DoFn.Context { + private static final int MAX_SIDE_OUTPUTS = 1000; + + final PipelineOptions options; + final DoFn fn; + final SideInputReader sideInputReader; + final OutputManager outputManager; + final TupleTag mainOutputTag; + final StepContext stepContext; + final CounterSet.AddCounterMutator addCounterMutator; + final WindowFn windowFn; + + /** + * The set of known output tags, some of which may be undeclared, so we can throw an + * exception when it exceeds {@link #MAX_SIDE_OUTPUTS}. + */ + private Set> outputTags; + + public DoFnContext(PipelineOptions options, + DoFn fn, + SideInputReader sideInputReader, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator, + WindowFn windowFn) { + fn.super(); + this.options = options; + this.fn = fn; + this.sideInputReader = sideInputReader; + this.outputManager = outputManager; + this.mainOutputTag = mainOutputTag; + this.outputTags = Sets.newHashSet(); + + outputTags.add(mainOutputTag); + for (TupleTag sideOutputTag : sideOutputTags) { + outputTags.add(sideOutputTag); + } + + this.stepContext = stepContext; + this.addCounterMutator = addCounterMutator; + this.windowFn = windowFn; + super.setupDelegateAggregators(); + } + + ////////////////////////////////////////////////////////////////////////////// + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + WindowedValue makeWindowedValue( + T output, Instant timestamp, Collection windows, PaneInfo pane) { + final Instant inputTimestamp = timestamp; + + if (timestamp == null) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + if (windows == null) { + try { + // The windowFn can never succeed at accessing the element, so its type does not + // matter here + @SuppressWarnings("unchecked") + WindowFn objectWindowFn = (WindowFn) windowFn; + windows = objectWindowFn.assignWindows(objectWindowFn.new AssignContext() { + @Override + public Object element() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input element when none was available"); + } + + @Override + public Instant timestamp() { + if (inputTimestamp == null) { + throw new UnsupportedOperationException( + "WindowFn attempted to access input timestamp when none was available"); + } + return inputTimestamp; + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input windows when none were available"); + } + }); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } + } + + return WindowedValue.of(output, timestamp, windows, pane); + } + + public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { + if (!sideInputReader.contains(view)) { + throw new IllegalArgumentException("calling sideInput() with unknown view"); + } + BoundedWindow sideInputWindow = + view.getWindowingStrategyInternal().getWindowFn().getSideInputWindow(mainInputWindow); + return sideInputReader.get(view, sideInputWindow); + } + + void outputWindowedValue( + OutputT output, + Instant timestamp, + Collection windows, + PaneInfo pane) { + outputWindowedValue(makeWindowedValue(output, timestamp, windows, pane)); + } + + void outputWindowedValue(WindowedValue windowedElem) { + outputManager.output(mainOutputTag, windowedElem); + if (stepContext != null) { + stepContext.noteOutput(windowedElem); + } + } + + protected void sideOutputWindowedValue(TupleTag tag, + T output, + Instant timestamp, + Collection windows, + PaneInfo pane) { + sideOutputWindowedValue(tag, makeWindowedValue(output, timestamp, windows, pane)); + } + + protected void sideOutputWindowedValue(TupleTag tag, WindowedValue windowedElem) { + if (!outputTags.contains(tag)) { + // This tag wasn't declared nor was it seen before during this execution. + // Thus, this must be a new, undeclared and unconsumed output. + // To prevent likely user errors, enforce the limit on the number of side + // outputs. + if (outputTags.size() >= MAX_SIDE_OUTPUTS) { + throw new IllegalArgumentException( + "the number of side outputs has exceeded a limit of " + MAX_SIDE_OUTPUTS); + } + outputTags.add(tag); + } + + outputManager.output(tag, windowedElem); + if (stepContext != null) { + stepContext.noteSideOutput(tag, windowedElem); + } + } + + // Following implementations of output, outputWithTimestamp, and sideOutput + // are only accessible in DoFn.startBundle and DoFn.finishBundle, and will be shadowed by + // ProcessContext's versions in DoFn.processElement. + @Override + public void output(OutputT output) { + outputWindowedValue(output, null, null, PaneInfo.NO_FIRING); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + outputWindowedValue(output, timestamp, null, PaneInfo.NO_FIRING); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + Preconditions.checkNotNull(tag, "TupleTag passed to sideOutput cannot be null"); + sideOutputWindowedValue(tag, output, null, null, PaneInfo.NO_FIRING); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + Preconditions.checkNotNull(tag, "TupleTag passed to sideOutputWithTimestamp cannot be null"); + sideOutputWindowedValue(tag, output, timestamp, null, PaneInfo.NO_FIRING); + } + + private String generateInternalAggregatorName(String userName) { + boolean system = fn.getClass().isAnnotationPresent(SystemDoFnInternal.class); + return (system ? "" : "user-") + stepContext.getStepName() + "-" + userName; + } + + @Override + protected Aggregator createAggregatorInternal( + String name, CombineFn combiner) { + Preconditions.checkNotNull(combiner, + "Combiner passed to createAggregator cannot be null"); + return new CounterAggregator<>(generateInternalAggregatorName(name), + combiner, addCounterMutator); + } + } + + /** + * Returns a new {@code DoFn.ProcessContext} for the given element. + */ + protected DoFn.ProcessContext createProcessContext(WindowedValue elem) { + return new DoFnProcessContext(fn, context, elem); + } + + protected RuntimeException wrapUserCodeException(Throwable t) { + throw UserCodeException.wrapIf(!isSystemDoFn(), t); + } + + private boolean isSystemDoFn() { + return fn.getClass().isAnnotationPresent(SystemDoFnInternal.class); + } + + /** + * A concrete implementation of {@code DoFn.ProcessContext} used for + * running a {@link DoFn} over a single element. + * + * @param the type of the DoFn's (main) input elements + * @param the type of the DoFn's (main) output elements + */ + static class DoFnProcessContext + extends DoFn.ProcessContext { + + + final DoFn fn; + final DoFnContext context; + final WindowedValue windowedValue; + + public DoFnProcessContext(DoFn fn, + DoFnContext context, + WindowedValue windowedValue) { + fn.super(); + this.fn = fn; + this.context = context; + this.windowedValue = windowedValue; + } + + @Override + public PipelineOptions getPipelineOptions() { + return context.getPipelineOptions(); + } + + @Override + public InputT element() { + return windowedValue.getValue(); + } + + @Override + public T sideInput(PCollectionView view) { + Preconditions.checkNotNull(view, "View passed to sideInput cannot be null"); + Iterator windowIter = windows().iterator(); + BoundedWindow window; + if (!windowIter.hasNext()) { + if (context.windowFn instanceof GlobalWindows) { + // TODO: Remove this once GroupByKeyOnly no longer outputs elements + // without windows + window = GlobalWindow.INSTANCE; + } else { + throw new IllegalStateException( + "sideInput called when main input element is not in any windows"); + } + } else { + window = windowIter.next(); + if (windowIter.hasNext()) { + throw new IllegalStateException( + "sideInput called when main input element is in multiple windows"); + } + } + return context.sideInput(view, window); + } + + @Override + public BoundedWindow window() { + if (!(fn instanceof RequiresWindowAccess)) { + throw new UnsupportedOperationException( + "window() is only available in the context of a DoFn marked as RequiresWindow."); + } + return Iterables.getOnlyElement(windows()); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public void output(OutputT output) { + context.outputWindowedValue(windowedValue.withValue(output)); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + checkTimestamp(timestamp); + context.outputWindowedValue(output, timestamp, + windowedValue.getWindows(), windowedValue.getPane()); + } + + void outputWindowedValue( + OutputT output, + Instant timestamp, + Collection windows, + PaneInfo pane) { + context.outputWindowedValue(output, timestamp, windows, pane); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + Preconditions.checkNotNull(tag, "Tag passed to sideOutput cannot be null"); + context.sideOutputWindowedValue(tag, windowedValue.withValue(output)); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + Preconditions.checkNotNull(tag, "Tag passed to sideOutputWithTimestamp cannot be null"); + checkTimestamp(timestamp); + context.sideOutputWindowedValue( + tag, output, timestamp, windowedValue.getWindows(), windowedValue.getPane()); + } + + @Override + public Instant timestamp() { + return windowedValue.getTimestamp(); + } + + public Collection windows() { + return windowedValue.getWindows(); + } + + private void checkTimestamp(Instant timestamp) { + if (timestamp.isBefore(windowedValue.getTimestamp().minus(fn.getAllowedTimestampSkew()))) { + throw new IllegalArgumentException(String.format( + "Cannot output with timestamp %s. Output timestamps must be no earlier than the " + + "timestamp of the current input (%s) minus the allowed skew (%s). See the " + + "DoFn#getAllowedTimestampSkew() Javadoc for details on changing the allowed skew.", + timestamp, windowedValue.getTimestamp(), + PeriodFormat.getDefault().print(fn.getAllowedTimestampSkew().toPeriod()))); + } + } + + @Override + public WindowingInternals windowingInternals() { + return new WindowingInternals() { + @Override + public void outputWindowedValue(OutputT output, Instant timestamp, + Collection windows, PaneInfo pane) { + context.outputWindowedValue(output, timestamp, windows, pane); + } + + @Override + public Collection windows() { + return windowedValue.getWindows(); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public TimerInternals timerInternals() { + return context.stepContext.timerInternals(); + } + + @Override + public void writePCollectionViewData( + TupleTag tag, + Iterable> data, + Coder elemCoder) throws IOException { + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) context.windowFn.windowCoder(); + + context.stepContext.writePCollectionViewData( + tag, data, IterableCoder.of(WindowedValue.getFullCoder(elemCoder, windowCoder)), + window(), windowCoder); + } + + @Override + public StateInternals stateInternals() { + return context.stepContext.stateInternals(); + } + + @Override + public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { + return context.sideInput(view, mainInputWindow); + } + }; + } + + @Override + protected Aggregator + createAggregatorInternal( + String name, CombineFn combiner) { + return context.createAggregatorInternal(name, combiner); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunners.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunners.java new file mode 100644 index 000000000000..64a0968e0fb2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunners.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.DoFnRunner.ReduceFnExecutor; +import com.google.cloud.dataflow.sdk.util.ExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.CounterSet.AddCounterMutator; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.List; + +/** + * Static utility methods that provide {@link DoFnRunner} implementations. + */ +public class DoFnRunners { + /** + * Information about how to create output receivers and output to them. + */ + public interface OutputManager { + /** + * Outputs a single element to the receiver indicated by the given {@link TupleTag}. + */ + public void output(TupleTag tag, WindowedValue output); + } + + /** + * Returns a basic implementation of {@link DoFnRunner} that works for most {@link DoFn DoFns}. + * + *

    It invokes {@link DoFn#processElement} for each input. + */ + public static DoFnRunner simpleRunner( + PipelineOptions options, + DoFn fn, + SideInputReader sideInputReader, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator, + WindowingStrategy windowingStrategy) { + return new SimpleDoFnRunner<>( + options, + fn, + sideInputReader, + outputManager, + mainOutputTag, + sideOutputTags, + stepContext, + addCounterMutator, + windowingStrategy); + } + + /** + * Returns an implementation of {@link DoFnRunner} that handles late data dropping. + * + *

    It drops elements from expired windows before they reach the underlying {@link DoFn}. + */ + public static + DoFnRunner, KV> lateDataDroppingRunner( + PipelineOptions options, + ReduceFnExecutor reduceFnExecutor, + SideInputReader sideInputReader, + OutputManager outputManager, + TupleTag> mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + CounterSet.AddCounterMutator addCounterMutator, + WindowingStrategy windowingStrategy) { + DoFnRunner, KV> simpleDoFnRunner = + simpleRunner( + options, + reduceFnExecutor.asDoFn(), + sideInputReader, + outputManager, + mainOutputTag, + sideOutputTags, + stepContext, + addCounterMutator, + windowingStrategy); + return new LateDataDroppingDoFnRunner<>( + simpleDoFnRunner, + windowingStrategy, + stepContext.timerInternals(), + reduceFnExecutor.getDroppedDueToLatenessAggregator()); + } + + public static DoFnRunner createDefault( + PipelineOptions options, + DoFn doFn, + SideInputReader sideInputReader, + OutputManager outputManager, + TupleTag mainOutputTag, + List> sideOutputTags, + StepContext stepContext, + AddCounterMutator addCounterMutator, + WindowingStrategy windowingStrategy) { + if (doFn instanceof ReduceFnExecutor) { + @SuppressWarnings("rawtypes") + ReduceFnExecutor fn = (ReduceFnExecutor) doFn; + return lateDataDroppingRunner( + options, + fn, + sideInputReader, + outputManager, + (TupleTag) mainOutputTag, + sideOutputTags, + stepContext, + addCounterMutator, + (WindowingStrategy) windowingStrategy); + } + return simpleRunner( + options, + doFn, + sideInputReader, + outputManager, + mainOutputTag, + sideOutputTags, + stepContext, + addCounterMutator, + windowingStrategy); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutableTrigger.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutableTrigger.java new file mode 100644 index 000000000000..22a3762dfc10 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutableTrigger.java @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.common.base.Preconditions; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * A wrapper around a trigger used during execution. While an actual trigger may appear multiple + * times (both in the same trigger expression and in other trigger expressions), the + * {@code ExecutableTrigger} wrapped around them forms a tree (only one occurrence). + * + * @param {@link BoundedWindow} subclass used to represent the windows used. + */ +public class ExecutableTrigger implements Serializable { + + /** Store the index assigned to this trigger. */ + private final int triggerIndex; + private final int firstIndexAfterSubtree; + private final List> subTriggers = new ArrayList<>(); + private final Trigger trigger; + + public static ExecutableTrigger create(Trigger trigger) { + return create(trigger, 0); + } + + private static ExecutableTrigger create( + Trigger trigger, int nextUnusedIndex) { + if (trigger instanceof OnceTrigger) { + return new ExecutableOnceTrigger((OnceTrigger) trigger, nextUnusedIndex); + } else { + return new ExecutableTrigger(trigger, nextUnusedIndex); + } + } + + public static ExecutableTrigger createForOnceTrigger( + OnceTrigger trigger, int nextUnusedIndex) { + return new ExecutableOnceTrigger(trigger, nextUnusedIndex); + } + + private ExecutableTrigger(Trigger trigger, int nextUnusedIndex) { + this.trigger = Preconditions.checkNotNull(trigger, "trigger must not be null"); + this.triggerIndex = nextUnusedIndex++; + + if (trigger.subTriggers() != null) { + for (Trigger subTrigger : trigger.subTriggers()) { + ExecutableTrigger subExecutable = create(subTrigger, nextUnusedIndex); + subTriggers.add(subExecutable); + nextUnusedIndex = subExecutable.firstIndexAfterSubtree; + } + } + firstIndexAfterSubtree = nextUnusedIndex; + } + + public List> subTriggers() { + return subTriggers; + } + + @Override + public String toString() { + return trigger.toString(); + } + + /** + * Return the underlying trigger specification corresponding to this {@code ExecutableTrigger}. + */ + public Trigger getSpec() { + return trigger; + } + + public int getTriggerIndex() { + return triggerIndex; + } + + public final int getFirstIndexAfterSubtree() { + return firstIndexAfterSubtree; + } + + public boolean isCompatible(ExecutableTrigger other) { + return trigger.isCompatible(other.trigger); + } + + public ExecutableTrigger getSubTriggerContaining(int index) { + Preconditions.checkNotNull(subTriggers); + Preconditions.checkState(index > triggerIndex && index < firstIndexAfterSubtree, + "Cannot find sub-trigger containing index not in this tree."); + ExecutableTrigger previous = null; + for (ExecutableTrigger subTrigger : subTriggers) { + if (index < subTrigger.triggerIndex) { + return previous; + } + previous = subTrigger; + } + return previous; + } + + /** + * Invoke the {@link Trigger#onElement} method for this trigger, ensuring that the bits are + * properly updated if the trigger finishes. + */ + public void invokeOnElement(Trigger.OnElementContext c) throws Exception { + trigger.onElement(c.forTrigger(this)); + } + + /** + * Invoke the {@link Trigger#onMerge} method for this trigger, ensuring that the bits are properly + * updated. + */ + public void invokeOnMerge(Trigger.OnMergeContext c) throws Exception { + Trigger.OnMergeContext subContext = c.forTrigger(this); + trigger.onMerge(subContext); + } + + public boolean invokeShouldFire(Trigger.TriggerContext c) throws Exception { + return trigger.shouldFire(c.forTrigger(this)); + } + + public void invokeOnFire(Trigger.TriggerContext c) throws Exception { + trigger.onFire(c.forTrigger(this)); + } + + /** + * Invoke clear for the current this trigger. + */ + public void invokeClear(Trigger.TriggerContext c) throws Exception { + trigger.clear(c.forTrigger(this)); + } + + /** + * {@link ExecutableTrigger} that enforces the fact that the trigger should always FIRE_AND_FINISH + * and never just FIRE. + */ + private static class ExecutableOnceTrigger extends ExecutableTrigger { + + public ExecutableOnceTrigger(OnceTrigger trigger, int nextUnusedIndex) { + super(trigger, nextUnusedIndex); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutionContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutionContext.java new file mode 100644 index 000000000000..cff5b95cf9a9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExecutionContext.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.common.worker.StateSampler; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.io.IOException; +import java.util.Collection; + +/** + * Context for the current execution. This is guaranteed to exist during processing, + * but does not necessarily persist between different batches of work. + */ +public interface ExecutionContext { + /** + * Returns the {@link StepContext} associated with the given step. + */ + StepContext getOrCreateStepContext( + String stepName, String transformName, StateSampler stateSampler); + + /** + * Returns a collection view of all of the {@link StepContext}s. + */ + Collection getAllStepContexts(); + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#output} + * is called. + */ + void noteOutput(WindowedValue output); + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#sideOutput} + * is called. + */ + void noteSideOutput(TupleTag tag, WindowedValue output); + + /** + * Per-step, per-key context used for retrieving state. + */ + public interface StepContext { + + /** + * The name of the step. + */ + String getStepName(); + + /** + * The name of the transform for the step. + */ + String getTransformName(); + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#output} + * is called. + */ + void noteOutput(WindowedValue output); + + /** + * Hook for subclasses to implement that will be called whenever + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#sideOutput} + * is called. + */ + void noteSideOutput(TupleTag tag, WindowedValue output); + + /** + * Writes the given {@code PCollectionView} data to a globally accessible location. + */ + void writePCollectionViewData( + TupleTag tag, + Iterable> data, + Coder>> dataCoder, + W window, + Coder windowCoder) + throws IOException; + + StateInternals stateInternals(); + + TimerInternals timerInternals(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayInputStream.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayInputStream.java new file mode 100644 index 000000000000..dff5fd17a374 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayInputStream.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.ByteArrayInputStream; +import java.io.IOException; + +/** + * {@link ByteArrayInputStream} that allows accessing the entire internal buffer without copying. + */ +public class ExposedByteArrayInputStream extends ByteArrayInputStream{ + + public ExposedByteArrayInputStream(byte[] buf) { + super(buf); + } + + /** Read all remaining bytes. + * @throws IOException */ + public byte[] readAll() throws IOException { + if (pos == 0 && count == buf.length) { + pos = count; + return buf; + } + byte[] ret = new byte[count - pos]; + super.read(ret); + return ret; + } + + @Override + public void close() { + try { + super.close(); + } catch (IOException exn) { + throw new RuntimeException("Unexpected IOException closing ByteArrayInputStream", exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayOutputStream.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayOutputStream.java new file mode 100644 index 000000000000..d8e4d50714b5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayOutputStream.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** + * {@link ByteArrayOutputStream} special cased to treat writes of a single byte-array specially. + * When calling {@link #toByteArray()} after writing only one {@code byte[]} using + * {@link #writeAndOwn(byte[])}, it will return that array directly. + */ +public class ExposedByteArrayOutputStream extends ByteArrayOutputStream { + + private byte[] swappedBuffer; + + /** + * If true, this stream doesn't allow direct access to the passed in byte-array. It behaves just + * like a normal {@link ByteArrayOutputStream}. + * + *

    It is set to true after any write operations other than the first call to + * {@link #writeAndOwn(byte[])}. + */ + private boolean isFallback = false; + + /** + * Fall back to the behavior of a normal {@link ByteArrayOutputStream}. + */ + private void fallback() { + isFallback = true; + if (swappedBuffer != null) { + // swappedBuffer != null means buf is actually provided by the caller of writeAndOwn(), + // while swappedBuffer is the original buffer. + // Recover the buffer and copy the bytes from buf. + byte[] tempBuffer = buf; + count = 0; + buf = swappedBuffer; + super.write(tempBuffer, 0, tempBuffer.length); + swappedBuffer = null; + } + } + + /** + * Write {@code b} to the stream and take the ownership of {@code b}. + * If the stream is empty, {@code b} itself will be used as the content of the stream and + * no content copy will be involved. + *

    Note: After passing any byte array to this method, it must not be modified again. + * + * @throws IOException + */ + public void writeAndOwn(byte[] b) throws IOException { + if (b.length == 0) { + return; + } + if (count == 0) { + // Optimized first-time whole write. + // The original buffer will be swapped to swappedBuffer, while the input b is used as buf. + swappedBuffer = buf; + buf = b; + count = b.length; + } else { + fallback(); + super.write(b); + } + } + + @Override + public void write(byte[] b, int off, int len) { + fallback(); + super.write(b, off, len); + } + + @Override + public void write(int b) { + fallback(); + super.write(b); + } + + @Override + public byte[] toByteArray() { + // Note: count == buf.length is not a correct criteria to "return buf;", because the internal + // buf may be reused after reset(). + if (!isFallback && count > 0) { + return buf; + } else { + return super.toByteArray(); + } + } + + @Override + public void reset() { + if (count == 0) { + return; + } + count = 0; + if (isFallback) { + isFallback = false; + } else { + buf = swappedBuffer; + swappedBuffer = null; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java new file mode 100644 index 000000000000..77d0b830cb92 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.common.base.Predicate; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.file.FileSystems; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.PathMatcher; +import java.nio.file.Paths; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.regex.Matcher; + +/** + * Implements IOChannelFactory for local files. + */ +public class FileIOChannelFactory implements IOChannelFactory { + private static final Logger LOG = LoggerFactory.getLogger(FileIOChannelFactory.class); + + // This implementation only allows for wildcards in the file name. + // The directory portion must exist as-is. + @Override + public Collection match(String spec) throws IOException { + File file = new File(spec); + + File parent = file.getAbsoluteFile().getParentFile(); + if (!parent.exists()) { + throw new IOException("Unable to find parent directory of " + spec); + } + + // Method getAbsolutePath() on Windows platform may return something like + // "c:\temp\file.txt". FileSystem.getPathMatcher() call below will treat + // '\' (backslash) as an escape character, instead of a directory + // separator. Replacing backslash with double-backslash solves the problem. + // We perform the replacement on all platforms, even those that allow + // backslash as a part of the filename, because Globs.toRegexPattern will + // eat one backslash. + String pathToMatch = file.getAbsolutePath().replaceAll(Matcher.quoteReplacement("\\"), + Matcher.quoteReplacement("\\\\")); + + final PathMatcher matcher = FileSystems.getDefault().getPathMatcher("glob:" + pathToMatch); + + Iterable files = com.google.common.io.Files.fileTreeTraverser().preOrderTraversal(parent); + Iterable matchedFiles = Iterables.filter(files, + Predicates.and( + com.google.common.io.Files.isFile(), + new Predicate() { + @Override + public boolean apply(File input) { + return matcher.matches(input.toPath()); + } + })); + + List result = new LinkedList<>(); + for (File match : matchedFiles) { + result.add(match.getPath()); + } + + return result; + } + + @Override + public ReadableByteChannel open(String spec) throws IOException { + LOG.debug("opening file {}", spec); + @SuppressWarnings("resource") // The caller is responsible for closing the channel. + FileInputStream inputStream = new FileInputStream(spec); + // Use this method for creating the channel (rather than new FileChannel) so that we get + // regular FileNotFoundException. Closing the underyling channel will close the inputStream. + return inputStream.getChannel(); + } + + @Override + public WritableByteChannel create(String spec, String mimeType) + throws IOException { + LOG.debug("creating file {}", spec); + File file = new File(spec); + if (file.getAbsoluteFile().getParentFile() != null + && !file.getAbsoluteFile().getParentFile().exists() + && !file.getAbsoluteFile().getParentFile().mkdirs()) { + throw new IOException("Unable to create parent directories for '" + spec + "'"); + } + return Channels.newChannel( + new BufferedOutputStream(new FileOutputStream(file))); + } + + @Override + public long getSizeBytes(String spec) throws IOException { + try { + return Files.size(FileSystems.getDefault().getPath(spec)); + } catch (NoSuchFileException e) { + throw new FileNotFoundException(e.getReason()); + } + } + + @Override + public boolean isReadSeekEfficient(String spec) throws IOException { + return true; + } + + @Override + public String resolve(String path, String other) throws IOException { + return Paths.get(path).resolve(other).toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggers.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggers.java new file mode 100644 index 000000000000..e75be23eee41 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggers.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +/** + * A mutable set which tracks whether any particular {@link ExecutableTrigger} is + * finished. + */ +public interface FinishedTriggers { + /** + * Returns {@code true} if the trigger is finished. + */ + public boolean isFinished(ExecutableTrigger trigger); + + /** + * Sets the fact that the trigger is finished. + */ + public void setFinished(ExecutableTrigger trigger, boolean value); + + /** + * Sets the trigger and all of its subtriggers to unfinished. + */ + public void clearRecursively(ExecutableTrigger trigger); + + /** + * Create an independent copy of this mutable {@link FinishedTriggers}. + */ + public FinishedTriggers copy(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersBitSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersBitSet.java new file mode 100644 index 000000000000..09f7af7a95f1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersBitSet.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.util.BitSet; + +/** + * A {@link FinishedTriggers} implementation based on an underlying {@link BitSet}. + */ +public class FinishedTriggersBitSet implements FinishedTriggers { + + private final BitSet bitSet; + + private FinishedTriggersBitSet(BitSet bitSet) { + this.bitSet = bitSet; + } + + public static FinishedTriggersBitSet emptyWithCapacity(int capacity) { + return new FinishedTriggersBitSet(new BitSet(capacity)); + } + + public static FinishedTriggersBitSet fromBitSet(BitSet bitSet) { + return new FinishedTriggersBitSet(bitSet); + } + + /** + * Returns the underlying {@link BitSet} for this {@link FinishedTriggersBitSet}. + */ + public BitSet getBitSet() { + return bitSet; + } + + @Override + public boolean isFinished(ExecutableTrigger trigger) { + return bitSet.get(trigger.getTriggerIndex()); + } + + @Override + public void setFinished(ExecutableTrigger trigger, boolean value) { + bitSet.set(trigger.getTriggerIndex(), value); + } + + @Override + public void clearRecursively(ExecutableTrigger trigger) { + bitSet.clear(trigger.getTriggerIndex(), trigger.getFirstIndexAfterSubtree()); + } + + @Override + public FinishedTriggersBitSet copy() { + return new FinishedTriggersBitSet((BitSet) bitSet.clone()); + } +} + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersSet.java new file mode 100644 index 000000000000..6da673d28c08 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersSet.java @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.common.collect.Sets; + +import java.util.Set; + +/** + * An implementation of {@link FinishedTriggers} atop a user-provided mutable {@link Set}. + */ +public class FinishedTriggersSet implements FinishedTriggers { + + private final Set> finishedTriggers; + + private FinishedTriggersSet(Set> finishedTriggers) { + this.finishedTriggers = finishedTriggers; + } + + public static FinishedTriggersSet fromSet(Set> finishedTriggers) { + return new FinishedTriggersSet(finishedTriggers); + } + + /** + * Returns a mutable {@link Set} of the underlying triggers that are finished. + */ + public Set> getFinishedTriggers() { + return finishedTriggers; + } + + @Override + public boolean isFinished(ExecutableTrigger trigger) { + return finishedTriggers.contains(trigger); + } + + @Override + public void setFinished(ExecutableTrigger trigger, boolean value) { + if (value) { + finishedTriggers.add(trigger); + } else { + finishedTriggers.remove(trigger); + } + } + + @Override + public void clearRecursively(ExecutableTrigger trigger) { + finishedTriggers.remove(trigger); + for (ExecutableTrigger subTrigger : trigger.subTriggers()) { + clearRecursively(subTrigger); + } + } + + @Override + public FinishedTriggersSet copy() { + return fromSet(Sets.newHashSet(finishedTriggers)); + } + +} + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcpCredentialFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcpCredentialFactory.java new file mode 100644 index 000000000000..8b6f495e9e8b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcpCredentialFactory.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.io.IOException; +import java.security.GeneralSecurityException; + +/** + * Construct an oauth credential to be used by the SDK and the SDK workers. + * Returns a GCP credential. + */ +public class GcpCredentialFactory implements CredentialFactory { + private GcpOptions options; + + private GcpCredentialFactory(GcpOptions options) { + this.options = options; + } + + public static GcpCredentialFactory fromOptions(PipelineOptions options) { + return new GcpCredentialFactory(options.as(GcpOptions.class)); + } + + @Override + public Credential getCredential() throws IOException, GeneralSecurityException { + return Credentials.getCredential(options); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactory.java new file mode 100644 index 000000000000..ce933f563aac --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactory.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; + +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * Implements IOChannelFactory for GCS. + */ +public class GcsIOChannelFactory implements IOChannelFactory { + + private final GcsOptions options; + + public GcsIOChannelFactory(GcsOptions options) { + this.options = options; + } + + @Override + public Collection match(String spec) throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + List matched = util.expand(path); + + List specs = new LinkedList<>(); + for (GcsPath match : matched) { + specs.add(match.toString()); + } + + return specs; + } + + @Override + public ReadableByteChannel open(String spec) throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + return util.open(path); + } + + @Override + public WritableByteChannel create(String spec, String mimeType) + throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + return util.create(path, mimeType); + } + + @Override + public long getSizeBytes(String spec) throws IOException { + GcsPath path = GcsPath.fromUri(spec); + GcsUtil util = options.getGcsUtil(); + return util.fileSize(path); + } + + @Override + public boolean isReadSeekEfficient(String spec) throws IOException { + // TODO It is incorrect to return true here for files with content encoding set to gzip. + return true; + } + + @Override + public String resolve(String path, String other) throws IOException { + return GcsPath.fromUri(path).resolve(other).toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsStager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsStager.java new file mode 100644 index 000000000000..4219bc4269fb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsStager.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineDebugOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; + +import java.util.List; + +/** + * Utility class for staging files to GCS. + */ +public class GcsStager implements Stager { + private DataflowPipelineOptions options; + + private GcsStager(DataflowPipelineOptions options) { + this.options = options; + } + + public static GcsStager fromOptions(PipelineOptions options) { + return new GcsStager(options.as(DataflowPipelineOptions.class)); + } + + @Override + public List stageFiles() { + Preconditions.checkNotNull(options.getStagingLocation()); + List filesToStage = options.getFilesToStage(); + String windmillBinary = + options.as(DataflowPipelineDebugOptions.class).getOverrideWindmillBinary(); + if (windmillBinary != null) { + filesToStage.add("windmill_main=" + windmillBinary); + } + return PackageUtil.stageClasspathElements( + options.getFilesToStage(), options.getStagingLocation()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsUtil.java new file mode 100644 index 000000000000..8fd258f6d23b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GcsUtil.java @@ -0,0 +1,406 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.Sleeper; +import com.google.api.services.storage.Storage; +import com.google.api.services.storage.model.Objects; +import com.google.api.services.storage.model.StorageObject; +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.hadoop.gcsio.GoogleCloudStorageReadChannel; +import com.google.cloud.hadoop.gcsio.GoogleCloudStorageWriteChannel; +import com.google.cloud.hadoop.gcsio.ObjectWriteConditions; +import com.google.cloud.hadoop.util.ApiErrorExtractor; +import com.google.cloud.hadoop.util.AsyncWriteChannelOptions; +import com.google.cloud.hadoop.util.ClientRequestHelper; +import com.google.cloud.hadoop.util.ResilientOperation; +import com.google.cloud.hadoop.util.RetryDeterminer; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.channels.SeekableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * Provides operations on GCS. + */ +public class GcsUtil { + /** + * This is a {@link DefaultValueFactory} able to create a {@link GcsUtil} using + * any transport flags specified on the {@link PipelineOptions}. + */ + public static class GcsUtilFactory implements DefaultValueFactory { + /** + * Returns an instance of {@link GcsUtil} based on the + * {@link PipelineOptions}. + * + *

    If no instance has previously been created, one is created and the value + * stored in {@code options}. + */ + @Override + public GcsUtil create(PipelineOptions options) { + LOG.debug("Creating new GcsUtil"); + GcsOptions gcsOptions = options.as(GcsOptions.class); + + return new GcsUtil(Transport.newStorageClient(gcsOptions).build(), + gcsOptions.getExecutorService(), gcsOptions.getGcsUploadBufferSizeBytes()); + } + } + + private static final Logger LOG = LoggerFactory.getLogger(GcsUtil.class); + + /** Maximum number of items to retrieve per Objects.List request. */ + private static final long MAX_LIST_ITEMS_PER_CALL = 1024; + + /** Matches a glob containing a wildcard, capturing the portion before the first wildcard. */ + private static final Pattern GLOB_PREFIX = Pattern.compile("(?[^\\[*?]*)[\\[*?].*"); + + private static final String RECURSIVE_WILDCARD = "[*]{2}"; + + /** + * A {@link Pattern} for globs with a recursive wildcard. + */ + private static final Pattern RECURSIVE_GCS_PATTERN = + Pattern.compile(".*" + RECURSIVE_WILDCARD + ".*"); + + ///////////////////////////////////////////////////////////////////////////// + + /** Client for the GCS API. */ + private Storage storageClient; + /** Buffer size for GCS uploads (in bytes). */ + @Nullable private final Integer uploadBufferSizeBytes; + + // Helper delegate for turning IOExceptions from API calls into higher-level semantics. + private final ApiErrorExtractor errorExtractor = new ApiErrorExtractor(); + + // Exposed for testing. + final ExecutorService executorService; + + /** + * Returns true if the given GCS pattern is supported otherwise fails with an + * exception. + */ + public boolean isGcsPatternSupported(String gcsPattern) { + if (RECURSIVE_GCS_PATTERN.matcher(gcsPattern).matches()) { + throw new IllegalArgumentException("Unsupported wildcard usage in \"" + gcsPattern + "\": " + + " recursive wildcards are not supported."); + } + + return true; + } + + private GcsUtil( + Storage storageClient, ExecutorService executorService, + @Nullable Integer uploadBufferSizeBytes) { + this.storageClient = storageClient; + this.uploadBufferSizeBytes = uploadBufferSizeBytes; + this.executorService = executorService; + } + + // Use this only for testing purposes. + protected void setStorageClient(Storage storageClient) { + this.storageClient = storageClient; + } + + /** + * Expands a pattern into matched paths. The pattern path may contain globs, which are expanded + * in the result. For patterns that only match a single object, we ensure that the object + * exists. + */ + public List expand(GcsPath gcsPattern) throws IOException { + Preconditions.checkArgument(isGcsPatternSupported(gcsPattern.getObject())); + Matcher m = GLOB_PREFIX.matcher(gcsPattern.getObject()); + Pattern p = null; + String prefix = null; + if (!m.matches()) { + // Not a glob. + Storage.Objects.Get getObject = storageClient.objects().get( + gcsPattern.getBucket(), gcsPattern.getObject()); + try { + // Use a get request to fetch the metadata of the object, + // the request has strong global consistency. + ResilientOperation.retry( + ResilientOperation.getGoogleRequestCallable(getObject), + new AttemptBoundedExponentialBackOff(3, 200), + RetryDeterminer.SOCKET_ERRORS, + IOException.class); + return ImmutableList.of(gcsPattern); + } catch (IOException | InterruptedException e) { + if (e instanceof IOException && errorExtractor.itemNotFound((IOException) e)) { + // If the path was not found, return an empty list. + return ImmutableList.of(); + } + throw new IOException("Unable to match files for pattern " + gcsPattern, e); + } + } else { + // Part before the first wildcard character. + prefix = m.group("PREFIX"); + p = Pattern.compile(globToRegexp(gcsPattern.getObject())); + } + + LOG.debug("matching files in bucket {}, prefix {} against pattern {}", gcsPattern.getBucket(), + prefix, p.toString()); + + // List all objects that start with the prefix (including objects in sub-directories). + Storage.Objects.List listObject = storageClient.objects().list(gcsPattern.getBucket()); + listObject.setMaxResults(MAX_LIST_ITEMS_PER_CALL); + listObject.setPrefix(prefix); + + String pageToken = null; + List results = new LinkedList<>(); + do { + if (pageToken != null) { + listObject.setPageToken(pageToken); + } + + Objects objects; + try { + objects = ResilientOperation.retry( + ResilientOperation.getGoogleRequestCallable(listObject), + new AttemptBoundedExponentialBackOff(3, 200), + RetryDeterminer.SOCKET_ERRORS, + IOException.class); + } catch (Exception e) { + throw new IOException("Unable to match files in bucket " + gcsPattern.getBucket() + + ", prefix " + prefix + " against pattern " + p.toString(), e); + } + //Objects objects = listObject.execute(); + Preconditions.checkNotNull(objects); + + if (objects.getItems() == null) { + break; + } + + // Filter objects based on the regex. + for (StorageObject o : objects.getItems()) { + String name = o.getName(); + // Skip directories, which end with a slash. + if (p.matcher(name).matches() && !name.endsWith("/")) { + LOG.debug("Matched object: {}", name); + results.add(GcsPath.fromObject(o)); + } + } + + pageToken = objects.getNextPageToken(); + } while (pageToken != null); + + return results; + } + + @VisibleForTesting + @Nullable + Integer getUploadBufferSizeBytes() { + return uploadBufferSizeBytes; + } + + /** + * Returns the file size from GCS or throws {@link FileNotFoundException} + * if the resource does not exist. + */ + public long fileSize(GcsPath path) throws IOException { + return fileSize(path, new AttemptBoundedExponentialBackOff(4, 200), Sleeper.DEFAULT); + } + + /** + * Returns the file size from GCS or throws {@link FileNotFoundException} + * if the resource does not exist. + */ + @VisibleForTesting + long fileSize(GcsPath path, BackOff backoff, Sleeper sleeper) throws IOException { + Storage.Objects.Get getObject = + storageClient.objects().get(path.getBucket(), path.getObject()); + try { + StorageObject object = ResilientOperation.retry( + ResilientOperation.getGoogleRequestCallable(getObject), + backoff, + RetryDeterminer.SOCKET_ERRORS, + IOException.class, + sleeper); + return object.getSize().longValue(); + } catch (Exception e) { + if (e instanceof IOException && errorExtractor.itemNotFound((IOException) e)) { + throw new FileNotFoundException(path.toString()); + } + throw new IOException("Unable to get file size", e); + } + } + + /** + * Opens an object in GCS. + * + *

    Returns a SeekableByteChannel that provides access to data in the bucket. + * + * @param path the GCS filename to read from + * @return a SeekableByteChannel that can read the object data + * @throws IOException + */ + public SeekableByteChannel open(GcsPath path) + throws IOException { + return new GoogleCloudStorageReadChannel(storageClient, path.getBucket(), + path.getObject(), errorExtractor, + new ClientRequestHelper()); + } + + /** + * Creates an object in GCS. + * + *

    Returns a WritableByteChannel that can be used to write data to the + * object. + * + * @param path the GCS file to write to + * @param type the type of object, eg "text/plain". + * @return a Callable object that encloses the operation. + * @throws IOException + */ + public WritableByteChannel create(GcsPath path, + String type) throws IOException { + GoogleCloudStorageWriteChannel channel = new GoogleCloudStorageWriteChannel( + executorService, + storageClient, + new ClientRequestHelper(), + path.getBucket(), + path.getObject(), + AsyncWriteChannelOptions.newBuilder().build(), + new ObjectWriteConditions(), + Collections.emptyMap(), + type); + if (uploadBufferSizeBytes != null) { + channel.setUploadBufferSize(uploadBufferSizeBytes); + } + channel.initialize(); + return channel; + } + + /** + * Returns whether the GCS bucket exists. If the bucket exists, it must + * be accessible otherwise the permissions exception will be propagated. + */ + public boolean bucketExists(GcsPath path) throws IOException { + return bucketExists(path, new AttemptBoundedExponentialBackOff(4, 200), Sleeper.DEFAULT); + } + + /** + * Returns whether the GCS bucket exists. This will return false if the bucket + * is inaccessible due to permissions. + */ + @VisibleForTesting + boolean bucketExists(GcsPath path, BackOff backoff, Sleeper sleeper) throws IOException { + Storage.Buckets.Get getBucket = + storageClient.buckets().get(path.getBucket()); + + try { + ResilientOperation.retry( + ResilientOperation.getGoogleRequestCallable(getBucket), + backoff, + new RetryDeterminer() { + @Override + public boolean shouldRetry(IOException e) { + if (errorExtractor.itemNotFound(e) || errorExtractor.accessDenied(e)) { + return false; + } + return RetryDeterminer.SOCKET_ERRORS.shouldRetry(e); + } + }, + IOException.class, + sleeper); + return true; + } catch (GoogleJsonResponseException e) { + if (errorExtractor.itemNotFound(e) || errorExtractor.accessDenied(e)) { + return false; + } + throw e; + } catch (InterruptedException e) { + throw new IOException( + String.format("Error while attempting to verify existence of bucket gs://%s", + path.getBucket()), e); + } + } + + /** + * Expands glob expressions to regular expressions. + * + * @param globExp the glob expression to expand + * @return a string with the regular expression this glob expands to + */ + static String globToRegexp(String globExp) { + StringBuilder dst = new StringBuilder(); + char[] src = globExp.toCharArray(); + int i = 0; + while (i < src.length) { + char c = src[i++]; + switch (c) { + case '*': + dst.append("[^/]*"); + break; + case '?': + dst.append("[^/]"); + break; + case '.': + case '+': + case '{': + case '}': + case '(': + case ')': + case '|': + case '^': + case '$': + // These need to be escaped in regular expressions + dst.append('\\').append(c); + break; + case '\\': + i = doubleSlashes(dst, src, i); + break; + default: + dst.append(c); + break; + } + } + return dst.toString(); + } + + private static int doubleSlashes(StringBuilder dst, char[] src, int i) { + // Emit the next character without special interpretation + dst.append('\\'); + if ((i - 1) != src.length) { + dst.append(src[i]); + i++; + } else { + // A backslash at the very end is treated like an escaped backslash + dst.append('\\'); + } + return i; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java new file mode 100644 index 000000000000..f6246d16414d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.DoFnRunner.ReduceFnExecutor; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.KV; + +/** + * A general {@link GroupAlsoByWindowsDoFn}. This delegates all of the logic to the + * {@link ReduceFnRunner}. + */ +@SystemDoFnInternal +public class GroupAlsoByWindowViaWindowSetDoFn< + K, InputT, OutputT, W extends BoundedWindow, RinT extends KeyedWorkItem> + extends DoFn> implements ReduceFnExecutor { + + public static + DoFn, KV> create( + WindowingStrategy strategy, SystemReduceFn reduceFn) { + return new GroupAlsoByWindowViaWindowSetDoFn<>(strategy, reduceFn); + } + + protected final Aggregator droppedDueToClosedWindow = + createAggregator( + GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER, new Sum.SumLongFn()); + protected final Aggregator droppedDueToLateness = + createAggregator(GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_LATENESS_COUNTER, new Sum.SumLongFn()); + + private final WindowingStrategy windowingStrategy; + private SystemReduceFn reduceFn; + + private GroupAlsoByWindowViaWindowSetDoFn( + WindowingStrategy windowingStrategy, + SystemReduceFn reduceFn) { + @SuppressWarnings("unchecked") + WindowingStrategy noWildcard = (WindowingStrategy) windowingStrategy; + this.windowingStrategy = noWildcard; + this.reduceFn = reduceFn; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + KeyedWorkItem element = c.element(); + + K key = c.element().key(); + TimerInternals timerInternals = c.windowingInternals().timerInternals(); + + // It is the responsibility of the user of GroupAlsoByWindowsViaWindowSet to only + // provide a WindowingInternals instance with the appropriate key type for StateInternals. + @SuppressWarnings("unchecked") + StateInternals stateInternals = (StateInternals) c.windowingInternals().stateInternals(); + + ReduceFnRunner reduceFnRunner = + new ReduceFnRunner<>( + key, + windowingStrategy, + stateInternals, + timerInternals, + c.windowingInternals(), + droppedDueToClosedWindow, + reduceFn, + c.getPipelineOptions()); + + for (TimerData timer : element.timersIterable()) { + reduceFnRunner.onTimer(timer); + } + reduceFnRunner.processElements(element.elementsIterable()); + reduceFnRunner.persist(); + } + + @Override + public DoFn, KV> asDoFn() { + // Safe contravariant cast + @SuppressWarnings("unchecked") + DoFn, KV> asFn = + (DoFn, KV>) this; + return asFn; + } + + @Override + public Aggregator getDroppedDueToLatenessAggregator() { + return droppedDueToLateness; + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFn.java new file mode 100644 index 000000000000..175921d43497 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsDoFn.java @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.KV; + +/** + * DoFn that merges windows and groups elements in those windows, optionally + * combining values. + * + * @param key type + * @param input value element type + * @param output value element type + * @param window type + */ +@SystemDoFnInternal +public abstract class GroupAlsoByWindowsDoFn + extends DoFn>>, KV> { + public static final String DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER = "DroppedDueToClosedWindow"; + public static final String DROPPED_DUE_TO_LATENESS_COUNTER = "DroppedDueToLateness"; + + protected final Aggregator droppedDueToClosedWindow = + createAggregator(DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER, new Sum.SumLongFn()); + protected final Aggregator droppedDueToLateness = + createAggregator(DROPPED_DUE_TO_LATENESS_COUNTER, new Sum.SumLongFn()); + + /** + * Create the default {@link GroupAlsoByWindowsDoFn}, which uses window sets to implement the + * grouping. + * + * @param windowingStrategy The window function and trigger to use for grouping + * @param inputCoder the input coder to use + */ + public static GroupAlsoByWindowsDoFn, W> + createDefault(WindowingStrategy windowingStrategy, Coder inputCoder) { + return new GroupAlsoByWindowsViaOutputBufferDoFn<>( + windowingStrategy, SystemReduceFn.buffering(inputCoder)); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java new file mode 100644 index 000000000000..d394e81a0edf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.Iterables; + +import org.joda.time.Instant; + +import java.util.List; + +/** + * The default batch {@link GroupAlsoByWindowsDoFn} implementation, if no specialized "fast path" + * implementation is applicable. + */ +@SystemDoFnInternal +public class GroupAlsoByWindowsViaOutputBufferDoFn + extends GroupAlsoByWindowsDoFn { + + private final WindowingStrategy strategy; + private SystemReduceFn reduceFn; + + public GroupAlsoByWindowsViaOutputBufferDoFn( + WindowingStrategy windowingStrategy, + SystemReduceFn reduceFn) { + this.strategy = windowingStrategy; + this.reduceFn = reduceFn; + } + + @Override + public void processElement( + DoFn>>, KV>.ProcessContext c) + throws Exception { + K key = c.element().getKey(); + // Used with Batch, we know that all the data is available for this key. We can't use the + // timer manager from the context because it doesn't exist. So we create one and emulate the + // watermark, knowing that we have all data and it is in timestamp order. + BatchTimerInternals timerInternals = new BatchTimerInternals(Instant.now()); + + // It is the responsibility of the user of GroupAlsoByWindowsViaOutputBufferDoFn to only + // provide a WindowingInternals instance with the appropriate key type for StateInternals. + @SuppressWarnings("unchecked") + StateInternals stateInternals = (StateInternals) c.windowingInternals().stateInternals(); + + ReduceFnRunner reduceFnRunner = + new ReduceFnRunner( + key, + strategy, + stateInternals, + timerInternals, + c.windowingInternals(), + droppedDueToClosedWindow, + reduceFn, + c.getPipelineOptions()); + + Iterable>> chunks = + Iterables.partition(c.element().getValue(), 1000); + for (Iterable> chunk : chunks) { + // Process the chunk of elements. + reduceFnRunner.processElements(chunk); + + // Then, since elements are sorted by their timestamp, advance the input watermark + // to the first element, and fire any timers that may have been scheduled. + timerInternals.advanceInputWatermark(reduceFnRunner, chunk.iterator().next().getTimestamp()); + + // Fire any processing timers that need to fire + timerInternals.advanceProcessingTime(reduceFnRunner, Instant.now()); + + // Leave the output watermark undefined. Since there's no late data in batch mode + // there's really no need to track it as we do for streaming. + } + + // Finish any pending windows by advancing the input watermark to infinity. + timerInternals.advanceInputWatermark(reduceFnRunner, BoundedWindow.TIMESTAMP_MAX_VALUE); + + // Finally, advance the processing time to infinity to fire any timers. + timerInternals.advanceProcessingTime(reduceFnRunner, BoundedWindow.TIMESTAMP_MAX_VALUE); + + reduceFnRunner.persist(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelFactory.java new file mode 100644 index 000000000000..f7d0b9a27e1e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelFactory.java @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Collection; + +/** + * Defines a factory for working with read and write channels. + * + *

    Channels provide an abstract API for IO operations. + * + *

    See Java NIO Channels + */ +public interface IOChannelFactory { + + /** + * Matches a specification, which may contain globs, against available + * resources. + * + *

    Glob handling is dependent on the implementation. Implementations should + * all support globs in the final component of a path (eg /foo/bar/*.txt), + * however they are not required to support globs in the directory paths. + * + *

    The list of resources returned are required to exist and not represent abstract + * resources such as symlinks and directories. + */ + Collection match(String spec) throws IOException; + + /** + * Returns a read channel for the given specification. + * + *

    The specification is not expanded; it is used verbatim. + * + *

    If seeking is supported, then this returns a + * {@link java.nio.channels.SeekableByteChannel}. + */ + ReadableByteChannel open(String spec) throws IOException; + + /** + * Returns a write channel for the given specification. + * + *

    The specification is not expanded; is it used verbatim. + */ + WritableByteChannel create(String spec, String mimeType) throws IOException; + + /** + * Returns the size in bytes for the given specification. + * + *

    The specification is not expanded; it is used verbatim. + * + *

    {@link FileNotFoundException} will be thrown if the resource does not exist. + */ + long getSizeBytes(String spec) throws IOException; + + /** + * Returns {@code true} if the channel created when invoking method {@link #open} for the given + * file specification is guaranteed to be of type {@link java.nio.channels.SeekableByteChannel + * SeekableByteChannel} and if seeking into positions of the channel is recommended. Returns + * {@code false} if the channel returned is not a {@code SeekableByteChannel}. May return + * {@code false} even if the channel returned is a {@code SeekableByteChannel}, if seeking is not + * efficient for the given file specification. + * + *

    Only efficiently seekable files can be split into offset ranges. + * + *

    The specification is not expanded; it is used verbatim. + */ + boolean isReadSeekEfficient(String spec) throws IOException; + + /** + * Resolve the given {@code other} against the {@code path}. + * + *

    If the {@code other} parameter is an absolute path then this method trivially returns + * other. If {@code other} is an empty path then this method trivially returns the given + * {@code path}. Otherwise this method considers the given {@code path} to be a directory and + * resolves the {@code other} path against this path. In the simplest case, the {@code other} + * path does not have a root component, in which case this method joins the {@code other} path + * to the given {@code path} and returns a resulting path that ends with the {@code other} path. + * Where the {@code other} path has a root component then resolution is highly implementation + * dependent and therefore unspecified. + */ + public String resolve(String path, String other) throws IOException; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelUtils.java new file mode 100644 index 000000000000..cbf420ec6bdd --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IOChannelUtils.java @@ -0,0 +1,204 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.channels.WritableByteChannel; +import java.text.DecimalFormat; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Provides utilities for creating read and write channels. + */ +public class IOChannelUtils { + // TODO: add registration mechanism for adding new schemas. + private static final Map FACTORY_MAP = + Collections.synchronizedMap(new HashMap()); + + // Pattern that matches shard placeholders within a shard template. + private static final Pattern SHARD_FORMAT_RE = Pattern.compile("(S+|N+)"); + + /** + * Associates a scheme with an {@link IOChannelFactory}. + * + *

    The given factory is used to construct read and write channels when + * a URI is provided with the given scheme. + * + *

    For example, when reading from "gs://bucket/path", the scheme "gs" is + * used to lookup the appropriate factory. + */ + public static void setIOFactory(String scheme, IOChannelFactory factory) { + FACTORY_MAP.put(scheme, factory); + } + + /** + * Registers standard factories globally. This requires {@link PipelineOptions} + * to provide, e.g., credentials for GCS. + */ + public static void registerStandardIOFactories(PipelineOptions options) { + setIOFactory("gs", new GcsIOChannelFactory(options.as(GcsOptions.class))); + } + + /** + * Creates a write channel for the given filename. + */ + public static WritableByteChannel create(String filename, String mimeType) + throws IOException { + return getFactory(filename).create(filename, mimeType); + } + + /** + * Creates a write channel for the given file components. + * + *

    If numShards is specified, then a ShardingWritableByteChannel is + * returned. + * + *

    Shard numbers are 0 based, meaning they start with 0 and end at the + * number of shards - 1. + */ + public static WritableByteChannel create(String prefix, String shardTemplate, + String suffix, int numShards, String mimeType) throws IOException { + if (numShards == 1) { + return create(constructName(prefix, shardTemplate, suffix, 0, 1), + mimeType); + } + + // It is the callers responsibility to close this channel. + @SuppressWarnings("resource") + ShardingWritableByteChannel shardingChannel = + new ShardingWritableByteChannel(); + + Set outputNames = new HashSet<>(); + for (int i = 0; i < numShards; i++) { + String outputName = + constructName(prefix, shardTemplate, suffix, i, numShards); + if (!outputNames.add(outputName)) { + throw new IllegalArgumentException( + "Shard name collision detected for: " + outputName); + } + WritableByteChannel channel = create(outputName, mimeType); + shardingChannel.addChannel(channel); + } + + return shardingChannel; + } + + /** + * Returns the size in bytes for the given specification. + * + *

    The specification is not expanded; it is used verbatim. + * + *

    {@link FileNotFoundException} will be thrown if the resource does not exist. + */ + public static long getSizeBytes(String spec) throws IOException { + return getFactory(spec).getSizeBytes(spec); + } + + /** + * Constructs a fully qualified name from components. + * + *

    The name is built from a prefix, shard template (with shard numbers + * applied), and a suffix. All components are required, but may be empty + * strings. + * + *

    Within a shard template, repeating sequences of the letters "S" or "N" + * are replaced with the shard number, or number of shards respectively. The + * numbers are formatted with leading zeros to match the length of the + * repeated sequence of letters. + * + *

    For example, if prefix = "output", shardTemplate = "-SSS-of-NNN", and + * suffix = ".txt", with shardNum = 1 and numShards = 100, the following is + * produced: "output-001-of-100.txt". + */ + public static String constructName(String prefix, + String shardTemplate, String suffix, int shardNum, int numShards) { + // Matcher API works with StringBuffer, rather than StringBuilder. + StringBuffer sb = new StringBuffer(); + sb.append(prefix); + + Matcher m = SHARD_FORMAT_RE.matcher(shardTemplate); + while (m.find()) { + boolean isShardNum = (m.group(1).charAt(0) == 'S'); + + char[] zeros = new char[m.end() - m.start()]; + Arrays.fill(zeros, '0'); + DecimalFormat df = new DecimalFormat(String.valueOf(zeros)); + String formatted = df.format(isShardNum + ? shardNum + : numShards); + m.appendReplacement(sb, formatted); + } + m.appendTail(sb); + + sb.append(suffix); + return sb.toString(); + } + + private static final Pattern URI_SCHEME_PATTERN = Pattern.compile( + "(?[a-zA-Z][-a-zA-Z0-9+.]*)://.*"); + + /** + * Returns the IOChannelFactory associated with an input specification. + */ + public static IOChannelFactory getFactory(String spec) throws IOException { + // The spec is almost, but not quite, a URI. In particular, + // the reserved characters '[', ']', and '?' have meanings that differ + // from their use in the URI spec. ('*' is not reserved). + // Here, we just need the scheme, which is so circumscribed as to be + // very easy to extract with a regex. + Matcher matcher = URI_SCHEME_PATTERN.matcher(spec); + + if (!matcher.matches()) { + return new FileIOChannelFactory(); + } + + String scheme = matcher.group("scheme"); + IOChannelFactory ioFactory = FACTORY_MAP.get(scheme); + if (ioFactory != null) { + return ioFactory; + } + + throw new IOException("Unable to find handler for " + spec); + } + + /** + * Resolve the given {@code other} against the {@code path}. + * + *

    If the {@code other} parameter is an absolute path then this method trivially returns + * other. If {@code other} is an empty path then this method trivially returns the given + * {@code path}. Otherwise this method considers the given {@code path} to be a directory and + * resolves the {@code other} path against this path. In the simplest case, the {@code other} + * path does not have a root component, in which case this method joins the {@code other} path + * to the given {@code path} and returns a resulting path that ends with the {@code other} path. + * Where the {@code other} path has a root component then resolution is highly implementation + * dependent and therefore unspecified. + */ + public static String resolve(String path, String other) throws IOException { + return getFactory(path).resolve(path, other); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IllegalMutationException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IllegalMutationException.java new file mode 100644 index 000000000000..dbe249eeab5b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IllegalMutationException.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * Thrown when a value appears to have been mutated, but that mutation is forbidden. + */ +public class IllegalMutationException extends RuntimeException { + private Object savedValue; + private Object newValue; + + public IllegalMutationException(String message, Object savedValue, Object newValue) { + super(message); + this.savedValue = savedValue; + this.newValue = newValue; + } + + public IllegalMutationException( + String message, Object savedValue, Object newValue, Throwable cause) { + super(message, cause); + this.savedValue = savedValue; + this.newValue = newValue; + } + + /** + * The original value, before the illegal mutation. + */ + public Object getSavedValue() { + return savedValue; + } + + /** + * The value after the illegal mutation. + */ + public Object getNewValue() { + return newValue; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/InstanceBuilder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/InstanceBuilder.java new file mode 100644 index 000000000000..99442d045365 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/InstanceBuilder.java @@ -0,0 +1,269 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.LinkedList; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Utility for creating objects dynamically. + * + * @param type type of object returned by this instance builder + */ +public class InstanceBuilder { + + /** + * Create an InstanceBuilder for the given type. + * + *

    The specified type is the type returned by {@link #build}, which is + * typically the common base type or interface of the instance being + * constructed. + */ + public static InstanceBuilder ofType(Class type) { + return new InstanceBuilder<>(type); + } + + /** + * Create an InstanceBuilder for the given type. + * + *

    The specified type is the type returned by {@link #build}, which is + * typically the common base type or interface for the instance to be + * constructed. + * + *

    The TypeDescriptor argument allows specification of generic types. For example, + * a {@code List} return type can be specified as + * {@code ofType(new TypeDescriptor>(){})}. + */ + public static InstanceBuilder ofType(TypeDescriptor token) { + @SuppressWarnings("unchecked") + Class type = (Class) token.getRawType(); + return new InstanceBuilder<>(type); + } + + /** + * Sets the class name to be constructed. + * + *

    If the name is a simple name (ie {@link Class#getSimpleName()}), then + * the package of the return type is added as a prefix. + * + *

    The default class is the return type, specified in {@link #ofType}. + * + *

    Modifies and returns the {@code InstanceBuilder} for chaining. + * + * @throws ClassNotFoundException if no class can be found by the given name + */ + public InstanceBuilder fromClassName(String name) + throws ClassNotFoundException { + Preconditions.checkArgument(factoryClass == null, + "Class name may only be specified once"); + if (name.indexOf('.') == -1) { + name = type.getPackage().getName() + "." + name; + } + + try { + factoryClass = Class.forName(name); + } catch (ClassNotFoundException e) { + throw new ClassNotFoundException( + String.format("Could not find class: %s", name), e); + } + return this; + } + + /** + * Sets the factory class to use for instance construction. + * + *

    Modifies and returns the {@code InstanceBuilder} for chaining. + */ + public InstanceBuilder fromClass(Class factoryClass) { + this.factoryClass = factoryClass; + return this; + } + + /** + * Sets the name of the factory method used to construct the instance. + * + *

    The default, if no factory method was specified, is to look for a class + * constructor. + * + *

    Modifies and returns the {@code InstanceBuilder} for chaining. + */ + public InstanceBuilder fromFactoryMethod(String methodName) { + Preconditions.checkArgument(this.methodName == null, + "Factory method name may only be specified once"); + this.methodName = methodName; + return this; + } + + /** + * Adds an argument to be passed to the factory method. + * + *

    The argument type is used to lookup the factory method. This type may be + * a supertype of the argument value's class. + * + *

    Modifies and returns the {@code InstanceBuilder} for chaining. + * + * @param the argument type + */ + public InstanceBuilder withArg(Class argType, ArgT value) { + parameterTypes.add(argType); + arguments.add(value); + return this; + } + + /** + * Creates the instance by calling the factory method with the given + * arguments. + * + *

    Defaults

    + *
      + *
    • factory class: defaults to the output type class, overridden + * via {@link #fromClassName(String)}. + *
    • factory method: defaults to using a constructor on the factory + * class, overridden via {@link #fromFactoryMethod(String)}. + *
    + * + * @throws RuntimeException if the method does not exist, on type mismatch, + * or if the method cannot be made accessible. + */ + public T build() { + if (factoryClass == null) { + factoryClass = type; + } + + Class[] types = parameterTypes + .toArray(new Class[parameterTypes.size()]); + + // TODO: cache results, to speed repeated type lookups? + if (methodName != null) { + return buildFromMethod(types); + } else { + return buildFromConstructor(types); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Type of object to construct. + */ + private final Class type; + + /** + * Types of parameters for Method lookup. + * + * @see Class#getDeclaredMethod(String, Class[]) + */ + private final List> parameterTypes = new LinkedList<>(); + + /** + * Arguments to factory method {@link Method#invoke(Object, Object...)}. + */ + private final List arguments = new LinkedList<>(); + + /** + * Name of factory method, or null to invoke the constructor. + */ + @Nullable private String methodName; + + /** + * Factory class, or null to instantiate {@code type}. + */ + @Nullable private Class factoryClass; + + private InstanceBuilder(Class type) { + this.type = type; + } + + private T buildFromMethod(Class[] types) { + Preconditions.checkState(factoryClass != null); + Preconditions.checkState(methodName != null); + + try { + Method method = factoryClass.getDeclaredMethod(methodName, types); + + Preconditions.checkState(Modifier.isStatic(method.getModifiers()), + "Factory method must be a static method for " + + factoryClass.getName() + "#" + method.getName() + ); + + Preconditions.checkState(type.isAssignableFrom(method.getReturnType()), + "Return type for " + factoryClass.getName() + "#" + method.getName() + + " must be assignable to " + type.getSimpleName()); + + if (!method.isAccessible()) { + method.setAccessible(true); + } + + Object[] args = arguments.toArray(new Object[arguments.size()]); + return type.cast(method.invoke(null, args)); + + } catch (NoSuchMethodException e) { + throw new RuntimeException( + String.format("Unable to find factory method %s#%s(%s)", + factoryClass.getSimpleName(), + methodName, + Joiner.on(", ").join(types))); + + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException( + String.format("Failed to construct instance from factory method %s#%s(%s)", + factoryClass.getSimpleName(), + methodName, + Joiner.on(", ").join(types)), + e); + } + } + + private T buildFromConstructor(Class[] types) { + Preconditions.checkState(factoryClass != null); + + try { + Constructor constructor = factoryClass.getDeclaredConstructor(types); + + Preconditions.checkState(type.isAssignableFrom(factoryClass), + "Instance type " + factoryClass.getName() + + " must be assignable to " + type.getSimpleName()); + + if (!constructor.isAccessible()) { + constructor.setAccessible(true); + } + + Object[] args = arguments.toArray(new Object[arguments.size()]); + return type.cast(constructor.newInstance(args)); + + } catch (NoSuchMethodException e) { + throw new RuntimeException("Unable to find constructor for " + + factoryClass.getName()); + + } catch (InvocationTargetException | + InstantiationException | + IllegalAccessException e) { + throw new RuntimeException("Failed to construct instance from " + + "constructor " + factoryClass.getName(), e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IntervalBoundedExponentialBackOff.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IntervalBoundedExponentialBackOff.java new file mode 100644 index 000000000000..4406ee5c52b4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/IntervalBoundedExponentialBackOff.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOff; +import com.google.common.base.Preconditions; + +/** + * Implementation of {@link BackOff} that increases the back off period for each retry attempt + * using a randomization function that grows exponentially. + * + *

    Example: The initial interval is .5 seconds and the maximum interval is 60 secs. + * For 14 tries the sequence will be (values in seconds): + * + *

    + * retry#      retry_interval     randomized_interval
    + * 1             0.5                [0.25,   0.75]
    + * 2             0.75               [0.375,  1.125]
    + * 3             1.125              [0.562,  1.687]
    + * 4             1.687              [0.8435, 2.53]
    + * 5             2.53               [1.265,  3.795]
    + * 6             3.795              [1.897,  5.692]
    + * 7             5.692              [2.846,  8.538]
    + * 8             8.538              [4.269, 12.807]
    + * 9            12.807              [6.403, 19.210]
    + * 10           28.832              [14.416, 43.248]
    + * 11           43.248              [21.624, 64.873]
    + * 12           60.0                [30.0, 90.0]
    + * 13           60.0                [30.0, 90.0]
    + * 14           60.0                [30.0, 90.0]
    + * 
    + * + *

    Implementation is not thread-safe. + */ +public class IntervalBoundedExponentialBackOff implements BackOff { + public static final double DEFAULT_MULTIPLIER = 1.5; + public static final double DEFAULT_RANDOMIZATION_FACTOR = 0.5; + private final long maximumIntervalMillis; + private final long initialIntervalMillis; + private int currentAttempt; + + public IntervalBoundedExponentialBackOff(int maximumIntervalMillis, long initialIntervalMillis) { + Preconditions.checkArgument( + maximumIntervalMillis > 0, "Maximum interval must be greater than zero."); + Preconditions.checkArgument( + initialIntervalMillis > 0, "Initial interval must be greater than zero."); + this.maximumIntervalMillis = maximumIntervalMillis; + this.initialIntervalMillis = initialIntervalMillis; + reset(); + } + + @Override + public void reset() { + currentAttempt = 1; + } + + @Override + public long nextBackOffMillis() { + double currentIntervalMillis = + Math.min( + initialIntervalMillis * Math.pow(DEFAULT_MULTIPLIER, currentAttempt - 1), + maximumIntervalMillis); + double randomOffset = + (Math.random() * 2 - 1) * DEFAULT_RANDOMIZATION_FACTOR * currentIntervalMillis; + currentAttempt += 1; + return Math.round(currentIntervalMillis + randomOffset); + } + + public boolean atMaxInterval() { + return initialIntervalMillis * Math.pow(DEFAULT_MULTIPLIER, currentAttempt - 1) + >= maximumIntervalMillis; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItem.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItem.java new file mode 100644 index 000000000000..355f0bbc476e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItem.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; + +/** + * Interface that contains all the timers and elements associated with a specific work item. + * + * @param the key type + * @param the element type + */ +public interface KeyedWorkItem { + /** + * Returns the key. + */ + K key(); + + /** + * Returns an iterable containing the timers. + */ + Iterable timersIterable(); + + /** + * Returns an iterable containing the elements. + */ + Iterable> elementsIterable(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItemCoder.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItemCoder.java new file mode 100644 index 000000000000..398e82a8d688 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItemCoder.java @@ -0,0 +1,120 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerDataCoder; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.common.collect.ImmutableList; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * A {@link Coder} for {@link KeyedWorkItem KeyedWorkItems}. + */ +public class KeyedWorkItemCoder extends StandardCoder> { + /** + * Create a new {@link KeyedWorkItemCoder} with the provided key coder, element coder, and window + * coder. + */ + public static KeyedWorkItemCoder of( + Coder keyCoder, Coder elemCoder, Coder windowCoder) { + return new KeyedWorkItemCoder<>(keyCoder, elemCoder, windowCoder); + } + + @JsonCreator + public static KeyedWorkItemCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) List> components) { + checkArgument(components.size() == 3, "Expecting 3 components, got %s", components.size()); + @SuppressWarnings("unchecked") + Coder keyCoder = (Coder) components.get(0); + @SuppressWarnings("unchecked") + Coder elemCoder = (Coder) components.get(1); + @SuppressWarnings("unchecked") + Coder windowCoder = (Coder) components.get(2); + return new KeyedWorkItemCoder<>(keyCoder, elemCoder, windowCoder); + } + + private final Coder keyCoder; + private final Coder elemCoder; + private final Coder windowCoder; + private final Coder> timersCoder; + private final Coder>> elemsCoder; + + private KeyedWorkItemCoder( + Coder keyCoder, Coder elemCoder, Coder windowCoder) { + this.keyCoder = keyCoder; + this.elemCoder = elemCoder; + this.windowCoder = windowCoder; + this.timersCoder = IterableCoder.of(TimerDataCoder.of(windowCoder)); + this.elemsCoder = IterableCoder.of(FullWindowedValueCoder.of(elemCoder, windowCoder)); + } + + @Override + public void encode(KeyedWorkItem value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + keyCoder.encode(value.key(), outStream, nestedContext); + timersCoder.encode(value.timersIterable(), outStream, nestedContext); + elemsCoder.encode(value.elementsIterable(), outStream, nestedContext); + } + + @Override + public KeyedWorkItem decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + Coder.Context nestedContext = context.nested(); + K key = keyCoder.decode(inStream, nestedContext); + Iterable timers = timersCoder.decode(inStream, nestedContext); + Iterable> elems = elemsCoder.decode(inStream, nestedContext); + return KeyedWorkItems.workItem(key, timers, elems); + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(keyCoder, elemCoder, windowCoder); + } + + @Override + public void verifyDeterministic() throws Coder.NonDeterministicException { + keyCoder.verifyDeterministic(); + timersCoder.verifyDeterministic(); + elemsCoder.verifyDeterministic(); + } + + /** + * {@inheritDoc}. + * + * {@link KeyedWorkItemCoder} is not consistent with equals as it can return a + * {@link KeyedWorkItem} of a type different from the originally encoded type. + */ + @Override + public boolean consistentWithEquals() { + return false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItems.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItems.java new file mode 100644 index 000000000000..734bd2c537e2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItems.java @@ -0,0 +1,120 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.common.base.MoreObjects; +import com.google.common.collect.Iterables; + +import java.util.Collections; +import java.util.Objects; + +/** + * Static utility methods that provide {@link KeyedWorkItem} implementations. + */ +public class KeyedWorkItems { + /** + * Returns an implementation of {@link KeyedWorkItem} that wraps around an elements iterable. + * + * @param the key type + * @param the element type + */ + public static KeyedWorkItem elementsWorkItem( + K key, Iterable> elementsIterable) { + return new ComposedKeyedWorkItem<>(key, Collections.emptyList(), elementsIterable); + } + + /** + * Returns an implementation of {@link KeyedWorkItem} that wraps around an timers iterable. + * + * @param the key type + * @param the element type + */ + public static KeyedWorkItem timersWorkItem( + K key, Iterable timersIterable) { + return new ComposedKeyedWorkItem<>( + key, timersIterable, Collections.>emptyList()); + } + + /** + * Returns an implementation of {@link KeyedWorkItem} that wraps around + * an timers iterable and an elements iterable. + * + * @param the key type + * @param the element type + */ + public static KeyedWorkItem workItem( + K key, Iterable timersIterable, Iterable> elementsIterable) { + return new ComposedKeyedWorkItem<>(key, timersIterable, elementsIterable); + } + + /** + * A {@link KeyedWorkItem} composed of an underlying key, {@link TimerData} iterable, and element + * iterable. + */ + public static class ComposedKeyedWorkItem implements KeyedWorkItem { + private final K key; + private final Iterable timers; + private final Iterable> elements; + + private ComposedKeyedWorkItem( + K key, Iterable timers, Iterable> elements) { + this.key = key; + this.timers = timers; + this.elements = elements; + } + + @Override + public K key() { + return key; + } + + @Override + public Iterable timersIterable() { + return timers; + } + + @Override + public Iterable> elementsIterable() { + return elements; + } + + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof ComposedKeyedWorkItem)) { + return false; + } + KeyedWorkItem that = (KeyedWorkItem) other; + return Objects.equals(this.key, that.key()) + && Iterables.elementsEqual(this.timersIterable(), that.timersIterable()) + && Iterables.elementsEqual(this.elementsIterable(), that.elementsIterable()); + } + + @Override + public int hashCode() { + return Objects.hash(key, timers, elements); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(ComposedKeyedWorkItem.class) + .add("key", key) + .add("elements", elements) + .add("timers", timers) + .toString(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/LateDataDroppingDoFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/LateDataDroppingDoFnRunner.java new file mode 100644 index 000000000000..31927ab8823b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/LateDataDroppingDoFnRunner.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Iterables; + +import org.joda.time.Instant; + +/** + * A customized {@link DoFnRunner} that handles late data dropping for + * a {@link KeyedWorkItem} input {@link DoFn}. + * + *

    It expands windows before checking data lateness. + * + *

    {@link KeyedWorkItem KeyedWorkItems} are always in empty windows. + * + * @param key type + * @param input value element type + * @param output value element type + * @param window type + */ +public class LateDataDroppingDoFnRunner + implements DoFnRunner, KV> { + private final DoFnRunner, KV> doFnRunner; + private final LateDataFilter lateDataFilter; + + public LateDataDroppingDoFnRunner( + DoFnRunner, KV> doFnRunner, + WindowingStrategy windowingStrategy, + TimerInternals timerInternals, + Aggregator droppedDueToLateness) { + this.doFnRunner = doFnRunner; + lateDataFilter = new LateDataFilter(windowingStrategy, timerInternals, droppedDueToLateness); + } + + @Override + public void startBundle() { + doFnRunner.startBundle(); + } + + @Override + public void processElement(WindowedValue> elem) { + Iterable> nonLateElements = lateDataFilter.filter( + elem.getValue().key(), elem.getValue().elementsIterable()); + KeyedWorkItem keyedWorkItem = KeyedWorkItems.workItem( + elem.getValue().key(), elem.getValue().timersIterable(), nonLateElements); + doFnRunner.processElement(elem.withValue(keyedWorkItem)); + } + + @Override + public void finishBundle() { + doFnRunner.finishBundle(); + } + + /** + * It filters late data in a {@link KeyedWorkItem}. + */ + @VisibleForTesting + static class LateDataFilter { + private final WindowingStrategy windowingStrategy; + private final TimerInternals timerInternals; + private final Aggregator droppedDueToLateness; + + public LateDataFilter( + WindowingStrategy windowingStrategy, + TimerInternals timerInternals, + Aggregator droppedDueToLateness) { + this.windowingStrategy = windowingStrategy; + this.timerInternals = timerInternals; + this.droppedDueToLateness = droppedDueToLateness; + } + + /** + * Returns an {@code Iterable>} that only contains + * non-late input elements. + */ + public Iterable> filter( + final K key, Iterable> elements) { + Iterable>> windowsExpandedElements = Iterables.transform( + elements, + new Function, Iterable>>() { + @Override + public Iterable> apply(final WindowedValue input) { + return Iterables.transform( + input.getWindows(), + new Function>() { + @Override + public WindowedValue apply(BoundedWindow window) { + return WindowedValue.of( + input.getValue(), input.getTimestamp(), window, input.getPane()); + } + }); + }}); + + Iterable> nonLateElements = Iterables.filter( + Iterables.concat(windowsExpandedElements), + new Predicate>() { + @Override + public boolean apply(WindowedValue input) { + BoundedWindow window = Iterables.getOnlyElement(input.getWindows()); + if (canDropDueToExpiredWindow(window)) { + // The element is too late for this window. + droppedDueToLateness.addValue(1L); + WindowTracing.debug( + "ReduceFnRunner.processElement: Dropping element at {} for key:{}; window:{} " + + "since too far behind inputWatermark:{}; outputWatermark:{}", + input.getTimestamp(), key, window, timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + return false; + } else { + return true; + } + } + }); + return nonLateElements; + } + + /** Is {@code window} expired w.r.t. the garbage collection watermark? */ + private boolean canDropDueToExpiredWindow(BoundedWindow window) { + Instant inputWM = timerInternals.currentInputWatermarkTime(); + return inputWM != null + && window.maxTimestamp().plus(windowingStrategy.getAllowedLateness()).isBefore(inputWM); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MapAggregatorValues.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MapAggregatorValues.java new file mode 100644 index 000000000000..a4d8ffd168e5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MapAggregatorValues.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.common.base.MoreObjects; + +import java.util.Map; + +/** + * An {@link AggregatorValues} implementation that is backed by an in-memory map. + * + * @param the output type of the {@link Aggregator} + */ +public class MapAggregatorValues extends AggregatorValues { + private final Map stepValues; + + public MapAggregatorValues(Map stepValues) { + this.stepValues = stepValues; + } + + @Override + public Map getValuesAtSteps() { + return stepValues; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(MapAggregatorValues.class) + .add("stepValues", stepValues) + .toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MergingActiveWindowSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MergingActiveWindowSet.java new file mode 100644 index 000000000000..95e378d9f4c7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MergingActiveWindowSet.java @@ -0,0 +1,543 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.SetCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.util.state.ValueState; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * An {@link ActiveWindowSet} for merging {@link WindowFn} implementations. + * + *

    The underlying notion of {@link MergingActiveWindowSet} is that of representing equivalence + * classes of merged windows as a mapping from the merged "super-window" to a set of + * state address windows in which some state has been persisted. The mapping need not + * contain EPHEMERAL windows, because they are created and merged without any persistent state. + * Each window must be a state address window for at most one window, so the mapping is + * invertible. + * + *

    The states of a non-expired window are treated as follows: + * + *

      + *
    • NEW: a NEW has an empty set of associated state address windows.
    • + *
    • ACTIVE: an ACTIVE window will be associated with some nonempty set of state + * address windows. If the window has not merged, this will necessarily be the singleton set + * containing just itself, but it is not required that an ACTIVE window be amongst its + * state address windows.
    • + *
    • MERGED: a MERGED window will be in the set of associated windows for some + * other window - that window is retrieved via {@link #representative} (this reverse + * association is implemented in O(1) time).
    • + *
    • EPHEMERAL: EPHEMERAL windows are not persisted but are tracked transiently; + * an EPHEMERAL window must be registered with this {@link ActiveWindowSet} by a call + * to {@link #recordMerge} prior to any request for a {@link #representative}.
    • + *
    + * + *

    To illustrate why an ACTIVE window need not be amongst its own state address windows, + * consider two active windows W1 and W2 that are merged to form W12. Further writes may be + * applied to either of W1 or W2, since a read of W12 implies reading both of W12 and merging + * their results. Hence W12 need not have state directly associated with it. + */ +public class MergingActiveWindowSet implements ActiveWindowSet { + private final WindowFn windowFn; + + @Nullable + private Map> activeWindowToStateAddressWindows; + + /** + * As above, but only for EPHEMERAL windows. Does not need to be persisted. + */ + private final Map> activeWindowToEphemeralWindows; + + /** + * A map from window to the ACTIVE window it has been merged into. + * + *

    Does not need to be persisted. + * + *

      + *
    • Key window may be ACTIVE, MERGED or EPHEMERAL. + *
    • ACTIVE windows map to themselves. + *
    • If W1 maps to W2 then W2 is in {@link #activeWindowToStateAddressWindows}. + *
    • If W1 = W2 then W1 is ACTIVE. If W1 is in the state address window set for W2 then W1 is + * MERGED. Otherwise W1 is EPHEMERAL. + *
    + */ + @Nullable + private Map windowToActiveWindow; + + /** + * Deep clone of {@link #activeWindowToStateAddressWindows} as of last commit. + * + *

    Used to avoid writing to state if no changes have been made during the work unit. + */ + @Nullable + private Map> originalActiveWindowToStateAddressWindows; + + /** + * Handle representing our state in the backend. + */ + private final ValueState>> valueState; + + public MergingActiveWindowSet(WindowFn windowFn, StateInternals state) { + this.windowFn = windowFn; + + StateTag>>> mergeTreeAddr = + StateTags.makeSystemTagInternal(StateTags.value( + "tree", MapCoder.of(windowFn.windowCoder(), SetCoder.of(windowFn.windowCoder())))); + valueState = state.state(StateNamespaces.global(), mergeTreeAddr); + // Little use trying to prefetch this state since the ReduceFnRunner is stymied until it is + // available. + activeWindowToStateAddressWindows = emptyIfNull(valueState.read()); + activeWindowToEphemeralWindows = new HashMap<>(); + originalActiveWindowToStateAddressWindows = deepCopy(activeWindowToStateAddressWindows); + windowToActiveWindow = invert(activeWindowToStateAddressWindows); + } + + @Override + public void removeEphemeralWindows() { + for (Map.Entry> entry : activeWindowToEphemeralWindows.entrySet()) { + for (W ephemeral : entry.getValue()) { + windowToActiveWindow.remove(ephemeral); + } + } + activeWindowToEphemeralWindows.clear(); + } + + @Override + public void persist() { + if (activeWindowToStateAddressWindows.isEmpty()) { + // Force all persistent state to disappear. + valueState.clear(); + return; + } + if (activeWindowToStateAddressWindows.equals(originalActiveWindowToStateAddressWindows)) { + // No change. + return; + } + // All NEW windows must have been accounted for. + for (Map.Entry> entry : activeWindowToStateAddressWindows.entrySet()) { + Preconditions.checkState( + !entry.getValue().isEmpty(), "Cannot persist NEW window %s", entry.getKey()); + } + // Should be no EPHEMERAL windows. + Preconditions.checkState( + activeWindowToEphemeralWindows.isEmpty(), "Unexpected EPHEMERAL windows before persist"); + + valueState.write(activeWindowToStateAddressWindows); + // No need to update originalActiveWindowToStateAddressWindows since this object is about to + // become garbage. + } + + @Override + @Nullable + public W representative(W window) { + return windowToActiveWindow.get(window); + } + + @Override + public Set getActiveWindows() { + return activeWindowToStateAddressWindows.keySet(); + } + + @Override + public boolean isActive(W window) { + return activeWindowToStateAddressWindows.containsKey(window); + } + + @Override + public void addNew(W window) { + if (!windowToActiveWindow.containsKey(window)) { + activeWindowToStateAddressWindows.put(window, new LinkedHashSet()); + } + } + + @Override + public void addActive(W window) { + if (!windowToActiveWindow.containsKey(window)) { + Set stateAddressWindows = new LinkedHashSet<>(); + stateAddressWindows.add(window); + activeWindowToStateAddressWindows.put(window, stateAddressWindows); + windowToActiveWindow.put(window, window); + } + } + + @Override + public void remove(W window) { + for (W stateAddressWindow : activeWindowToStateAddressWindows.get(window)) { + windowToActiveWindow.remove(stateAddressWindow); + } + activeWindowToStateAddressWindows.remove(window); + Set ephemeralWindows = activeWindowToEphemeralWindows.get(window); + if (ephemeralWindows != null) { + for (W ephemeralWindow : ephemeralWindows) { + windowToActiveWindow.remove(ephemeralWindow); + } + activeWindowToEphemeralWindows.remove(window); + } + windowToActiveWindow.remove(window); + } + + private class MergeContextImpl extends WindowFn.MergeContext { + private MergeCallback mergeCallback; + private final List> allToBeMerged; + private final List> allActiveToBeMerged; + private final List allMergeResults; + private final Set seen; + + public MergeContextImpl(MergeCallback mergeCallback) { + windowFn.super(); + this.mergeCallback = mergeCallback; + allToBeMerged = new ArrayList<>(); + allActiveToBeMerged = new ArrayList<>(); + allMergeResults = new ArrayList<>(); + seen = new HashSet<>(); + } + + @Override + public Collection windows() { + return activeWindowToStateAddressWindows.keySet(); + } + + @Override + public void merge(Collection toBeMerged, W mergeResult) throws Exception { + // The arguments have come from userland. + Preconditions.checkNotNull(toBeMerged); + Preconditions.checkNotNull(mergeResult); + List copyOfToBeMerged = new ArrayList<>(toBeMerged.size()); + List activeToBeMerged = new ArrayList<>(toBeMerged.size()); + boolean includesMergeResult = false; + for (W window : toBeMerged) { + Preconditions.checkNotNull(window); + Preconditions.checkState( + isActive(window), "Expecting merge window %s to be active", window); + if (window.equals(mergeResult)) { + includesMergeResult = true; + } + boolean notDup = seen.add(window); + Preconditions.checkState( + notDup, "Expecting merge window %s to appear in at most one merge set", window); + copyOfToBeMerged.add(window); + if (!activeWindowToStateAddressWindows.get(window).isEmpty()) { + activeToBeMerged.add(window); + } + } + if (!includesMergeResult) { + Preconditions.checkState( + !isActive(mergeResult), "Expecting result window %s to be new", mergeResult); + } + allToBeMerged.add(copyOfToBeMerged); + allActiveToBeMerged.add(activeToBeMerged); + allMergeResults.add(mergeResult); + } + + public void recordMerges() throws Exception { + for (int i = 0; i < allToBeMerged.size(); i++) { + mergeCallback.prefetchOnMerge( + allToBeMerged.get(i), allActiveToBeMerged.get(i), allMergeResults.get(i)); + } + for (int i = 0; i < allToBeMerged.size(); i++) { + mergeCallback.onMerge( + allToBeMerged.get(i), allActiveToBeMerged.get(i), allMergeResults.get(i)); + recordMerge(allToBeMerged.get(i), allMergeResults.get(i)); + } + allToBeMerged.clear(); + allActiveToBeMerged.clear(); + allMergeResults.clear(); + seen.clear(); + } + } + + @Override + public void merge(MergeCallback mergeCallback) throws Exception { + MergeContextImpl context = new MergeContextImpl(mergeCallback); + + // See what the window function does with the NEW and already ACTIVE windows. + // Entering userland. + windowFn.mergeWindows(context); + + // Actually do the merging and invoke the callbacks. + context.recordMerges(); + + // Any remaining NEW windows should become implicitly ACTIVE. + for (Map.Entry> entry : activeWindowToStateAddressWindows.entrySet()) { + if (entry.getValue().isEmpty()) { + // This window was NEW but since it survived merging must now become ACTIVE. + W window = entry.getKey(); + entry.getValue().add(window); + windowToActiveWindow.put(window, window); + } + } + } + + /** + * A {@link WindowFn#mergeWindows} call has determined that {@code toBeMerged} (which must + * all be ACTIVE}) should be considered equivalent to {@code activeWindow} (which is either a + * member of {@code toBeMerged} or is a new window). Make the corresponding change in + * the active window set. + */ + private void recordMerge(Collection toBeMerged, W mergeResult) throws Exception { + Set newStateAddressWindows = new LinkedHashSet<>(); + Set existingStateAddressWindows = activeWindowToStateAddressWindows.get(mergeResult); + if (existingStateAddressWindows != null) { + // Preserve all the existing state address windows for mergeResult. + newStateAddressWindows.addAll(existingStateAddressWindows); + } + + Set newEphemeralWindows = new HashSet<>(); + Set existingEphemeralWindows = activeWindowToEphemeralWindows.get(mergeResult); + if (existingEphemeralWindows != null) { + // Preserve all the existing EPHEMERAL windows for meregResult. + newEphemeralWindows.addAll(existingEphemeralWindows); + } + + for (W other : toBeMerged) { + Set otherStateAddressWindows = activeWindowToStateAddressWindows.get(other); + Preconditions.checkState(otherStateAddressWindows != null, "Window %s is not ACTIVE", other); + + for (W otherStateAddressWindow : otherStateAddressWindows) { + // Since otherTarget equiv other AND other equiv mergeResult + // THEN otherTarget equiv mergeResult. + newStateAddressWindows.add(otherStateAddressWindow); + windowToActiveWindow.put(otherStateAddressWindow, mergeResult); + } + activeWindowToStateAddressWindows.remove(other); + + Set otherEphemeralWindows = activeWindowToEphemeralWindows.get(other); + if (otherEphemeralWindows != null) { + for (W otherEphemeral : otherEphemeralWindows) { + // Since otherEphemeral equiv other AND other equiv mergeResult + // THEN otherEphemeral equiv mergeResult. + newEphemeralWindows.add(otherEphemeral); + windowToActiveWindow.put(otherEphemeral, mergeResult); + } + } + activeWindowToEphemeralWindows.remove(other); + + // Now other equiv mergeResult. + if (otherStateAddressWindows.contains(other)) { + // Other was ACTIVE and is now known to be MERGED. + } else if (otherStateAddressWindows.isEmpty()) { + // Other was NEW thus has no state. It is now EPHEMERAL. + newEphemeralWindows.add(other); + } else if (other.equals(mergeResult)) { + // Other was ACTIVE, was never used to store elements, but is still ACTIVE. + // Leave it as active. + } else { + // Other was ACTIVE, was never used to store element, as is no longer considered ACTIVE. + // It is now EPHEMERAL. + newEphemeralWindows.add(other); + } + windowToActiveWindow.put(other, mergeResult); + } + + if (newStateAddressWindows.isEmpty()) { + // If stateAddressWindows is empty then toBeMerged must have only contained EPHEMERAL windows. + // Promote mergeResult to be active now. + newStateAddressWindows.add(mergeResult); + } + windowToActiveWindow.put(mergeResult, mergeResult); + + activeWindowToStateAddressWindows.put(mergeResult, newStateAddressWindows); + if (!newEphemeralWindows.isEmpty()) { + activeWindowToEphemeralWindows.put(mergeResult, newEphemeralWindows); + } + + merged(mergeResult); + } + + @Override + public void merged(W window) { + Set stateAddressWindows = activeWindowToStateAddressWindows.get(window); + Preconditions.checkState(stateAddressWindows != null, "Window %s is not ACTIVE", window); + W first = Iterables.getFirst(stateAddressWindows, null); + stateAddressWindows.clear(); + stateAddressWindows.add(first); + } + + /** + * Return the state address windows for ACTIVE {@code window} from which all state associated + * should + * be read and merged. + */ + @Override + public Set readStateAddresses(W window) { + Set stateAddressWindows = activeWindowToStateAddressWindows.get(window); + Preconditions.checkState(stateAddressWindows != null, "Window %s is not ACTIVE", window); + return stateAddressWindows; + } + + /** + * Return the state address window of ACTIVE {@code window} into which all new state should be + * written. + */ + @Override + public W writeStateAddress(W window) { + Set stateAddressWindows = activeWindowToStateAddressWindows.get(window); + Preconditions.checkState(stateAddressWindows != null, "Window %s is not ACTIVE", window); + W result = Iterables.getFirst(stateAddressWindows, null); + Preconditions.checkState(result != null, "Window %s is still NEW", window); + return result; + } + + @Override + public W mergedWriteStateAddress(Collection toBeMerged, W mergeResult) { + Set stateAddressWindows = activeWindowToStateAddressWindows.get(mergeResult); + if (stateAddressWindows != null && !stateAddressWindows.isEmpty()) { + return Iterables.getFirst(stateAddressWindows, null); + } + for (W mergedWindow : toBeMerged) { + stateAddressWindows = activeWindowToStateAddressWindows.get(mergedWindow); + if (stateAddressWindows != null && !stateAddressWindows.isEmpty()) { + return Iterables.getFirst(stateAddressWindows, null); + } + } + return mergeResult; + } + + @VisibleForTesting + public void checkInvariants() { + Set knownStateAddressWindows = new HashSet<>(); + for (Map.Entry> entry : activeWindowToStateAddressWindows.entrySet()) { + W active = entry.getKey(); + Preconditions.checkState(!entry.getValue().isEmpty(), + "Unexpected empty state address window set for ACTIVE window %s", active); + for (W stateAddressWindow : entry.getValue()) { + Preconditions.checkState(knownStateAddressWindows.add(stateAddressWindow), + "%s is in more than one state address window set", stateAddressWindow); + Preconditions.checkState(active.equals(windowToActiveWindow.get(stateAddressWindow)), + "%s should have %s as its ACTIVE window", stateAddressWindow, active); + } + } + for (Map.Entry> entry : activeWindowToEphemeralWindows.entrySet()) { + W active = entry.getKey(); + Preconditions.checkState(activeWindowToStateAddressWindows.containsKey(active), + "%s must be ACTIVE window", active); + Preconditions.checkState( + !entry.getValue().isEmpty(), "Unexpected empty EPHEMERAL set for %s", active); + for (W ephemeralWindow : entry.getValue()) { + Preconditions.checkState(knownStateAddressWindows.add(ephemeralWindow), + "%s is EPHEMERAL/state address of more than one ACTIVE window", ephemeralWindow); + Preconditions.checkState(active.equals(windowToActiveWindow.get(ephemeralWindow)), + "%s should have %s as its ACTIVE window", ephemeralWindow, active); + } + } + for (Map.Entry entry : windowToActiveWindow.entrySet()) { + Preconditions.checkState(activeWindowToStateAddressWindows.containsKey(entry.getValue()), + "%s should be ACTIVE since representative for %s", entry.getValue(), entry.getKey()); + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("MergingActiveWindowSet {\n"); + for (Map.Entry> entry : activeWindowToStateAddressWindows.entrySet()) { + W active = entry.getKey(); + Set stateAddressWindows = entry.getValue(); + if (stateAddressWindows.isEmpty()) { + sb.append(" NEW "); + sb.append(active); + sb.append('\n'); + } else { + sb.append(" ACTIVE "); + sb.append(active); + sb.append(":\n"); + for (W stateAddressWindow : stateAddressWindows) { + if (stateAddressWindow.equals(active)) { + sb.append(" ACTIVE "); + } else { + sb.append(" MERGED "); + } + sb.append(stateAddressWindow); + sb.append("\n"); + W active2 = windowToActiveWindow.get(stateAddressWindow); + Preconditions.checkState(active2.equals(active)); + } + Set ephemeralWindows = activeWindowToEphemeralWindows.get(active); + if (ephemeralWindows != null) { + for (W ephemeralWindow : ephemeralWindows) { + sb.append(" EPHEMERAL "); + sb.append(ephemeralWindow); + sb.append('\n'); + } + } + } + } + sb.append("}"); + return sb.toString(); + } + + // ====================================================================== + + /** + * Replace null {@code multimap} with empty map, and replace null entries in {@code multimap} with + * empty sets. + */ + private static Map> emptyIfNull(@Nullable Map> multimap) { + if (multimap == null) { + return new HashMap<>(); + } else { + for (Map.Entry> entry : multimap.entrySet()) { + if (entry.getValue() == null) { + entry.setValue(new LinkedHashSet()); + } + } + return multimap; + } + } + + /** Return a deep copy of {@code multimap}. */ + private static Map> deepCopy(Map> multimap) { + Map> newMultimap = new HashMap<>(); + for (Map.Entry> entry : multimap.entrySet()) { + newMultimap.put(entry.getKey(), new LinkedHashSet(entry.getValue())); + } + return newMultimap; + } + + /** Return inversion of {@code multimap}, which must be invertible. */ + private static Map invert(Map> multimap) { + Map result = new HashMap<>(); + for (Map.Entry> entry : multimap.entrySet()) { + W active = entry.getKey(); + for (W target : entry.getValue()) { + W previous = result.put(target, active); + Preconditions.checkState(previous == null, + "Window %s has both %s and %s as representatives", target, previous, active); + } + } + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MimeTypes.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MimeTypes.java new file mode 100644 index 000000000000..489d1832a1c9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MimeTypes.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** Constants representing various mime types. */ +public class MimeTypes { + public static final String TEXT = "text/plain"; + public static final String BINARY = "application/octet-stream"; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MonitoringUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MonitoringUtil.java new file mode 100644 index 000000000000..d45018798d74 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MonitoringUtil.java @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.Dataflow.Projects.Jobs.Messages; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.ListJobMessagesResponse; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A helper class for monitoring jobs submitted to the service. + */ +public final class MonitoringUtil { + + private static final String GCLOUD_DATAFLOW_PREFIX = "gcloud alpha dataflow"; + private static final String ENDPOINT_OVERRIDE_ENV_VAR = + "CLOUDSDK_API_ENDPOINT_OVERRIDES_DATAFLOW"; + + private static final Map DATAFLOW_STATE_TO_JOB_STATE = + ImmutableMap + .builder() + .put("JOB_STATE_UNKNOWN", State.UNKNOWN) + .put("JOB_STATE_STOPPED", State.STOPPED) + .put("JOB_STATE_RUNNING", State.RUNNING) + .put("JOB_STATE_DONE", State.DONE) + .put("JOB_STATE_FAILED", State.FAILED) + .put("JOB_STATE_CANCELLED", State.CANCELLED) + .put("JOB_STATE_UPDATED", State.UPDATED) + .build(); + + private String projectId; + private Messages messagesClient; + + /** + * An interface that can be used for defining callbacks to receive a list + * of JobMessages containing monitoring information. + */ + public interface JobMessagesHandler { + /** Process the rows. */ + void process(List messages); + } + + /** A handler that prints monitoring messages to a stream. */ + public static class PrintHandler implements JobMessagesHandler { + private PrintStream out; + + /** + * Construct the handler. + * + * @param stream The stream to write the messages to. + */ + public PrintHandler(PrintStream stream) { + out = stream; + } + + @Override + public void process(List messages) { + for (JobMessage message : messages) { + if (message.getMessageText() == null || message.getMessageText().isEmpty()) { + continue; + } + String importanceString = null; + if (message.getMessageImportance() == null) { + continue; + } else if (message.getMessageImportance().equals("JOB_MESSAGE_ERROR")) { + importanceString = "Error: "; + } else if (message.getMessageImportance().equals("JOB_MESSAGE_WARNING")) { + importanceString = "Warning: "; + } else if (message.getMessageImportance().equals("JOB_MESSAGE_BASIC")) { + importanceString = "Basic: "; + } else if (message.getMessageImportance().equals("JOB_MESSAGE_DETAILED")) { + importanceString = "Detail: "; + } else { + // TODO: Remove filtering here once getJobMessages supports minimum + // importance. + continue; + } + @Nullable Instant time = TimeUtil.fromCloudTime(message.getTime()); + if (time == null) { + out.print("UNKNOWN TIMESTAMP: "); + } else { + out.print(time + ": "); + } + if (importanceString != null) { + out.print(importanceString); + } + out.println(message.getMessageText()); + } + out.flush(); + } + } + + /** Construct a helper for monitoring. */ + public MonitoringUtil(String projectId, Dataflow dataflow) { + this(projectId, dataflow.projects().jobs().messages()); + } + + // @VisibleForTesting + MonitoringUtil(String projectId, Messages messagesClient) { + this.projectId = projectId; + this.messagesClient = messagesClient; + } + + /** + * Comparator for sorting rows in increasing order based on timestamp. + */ + public static class TimeStampComparator implements Comparator { + @Override + public int compare(JobMessage o1, JobMessage o2) { + @Nullable Instant t1 = fromCloudTime(o1.getTime()); + if (t1 == null) { + return -1; + } + @Nullable Instant t2 = fromCloudTime(o2.getTime()); + if (t2 == null) { + return 1; + } + return t1.compareTo(t2); + } + } + + /** + * Return job messages sorted in ascending order by timestamp. + * @param jobId The id of the job to get the messages for. + * @param startTimestampMs Return only those messages with a + * timestamp greater than this value. + * @return collection of messages + * @throws IOException + */ + public ArrayList getJobMessages( + String jobId, long startTimestampMs) throws IOException { + // TODO: Allow filtering messages by importance + Instant startTimestamp = new Instant(startTimestampMs); + ArrayList allMessages = new ArrayList<>(); + String pageToken = null; + while (true) { + Messages.List listRequest = messagesClient.list(projectId, jobId); + if (pageToken != null) { + listRequest.setPageToken(pageToken); + } + ListJobMessagesResponse response = listRequest.execute(); + + if (response == null || response.getJobMessages() == null) { + return allMessages; + } + + for (JobMessage m : response.getJobMessages()) { + @Nullable Instant timestamp = fromCloudTime(m.getTime()); + if (timestamp == null) { + continue; + } + if (timestamp.isAfter(startTimestamp)) { + allMessages.add(m); + } + } + + if (response.getNextPageToken() == null) { + break; + } else { + pageToken = response.getNextPageToken(); + } + } + + Collections.sort(allMessages, new TimeStampComparator()); + return allMessages; + } + + public static String getJobMonitoringPageURL(String projectName, String jobId) { + try { + // Project name is allowed in place of the project id: the user will be redirected to a URL + // that has the project name replaced with project id. + return String.format( + "https://console.developers.google.com/project/%s/dataflow/job/%s", + URLEncoder.encode(projectName, "UTF-8"), + URLEncoder.encode(jobId, "UTF-8")); + } catch (UnsupportedEncodingException e) { + // Should never happen. + throw new AssertionError("UTF-8 encoding is not supported by the environment", e); + } + } + + public static String getGcloudCancelCommand(DataflowPipelineOptions options, String jobId) { + + // If using a different Dataflow API than default, prefix command with an API override. + String dataflowApiOverridePrefix = ""; + String apiUrl = options.getDataflowClient().getBaseUrl(); + if (!apiUrl.equals(Dataflow.DEFAULT_BASE_URL)) { + dataflowApiOverridePrefix = String.format("%s=%s ", ENDPOINT_OVERRIDE_ENV_VAR, apiUrl); + } + + // Assemble cancel command from optional prefix and project/job parameters. + return String.format("%s%s jobs --project=%s cancel %s", + dataflowApiOverridePrefix, GCLOUD_DATAFLOW_PREFIX, options.getProject(), jobId); + } + + public static State toState(String stateName) { + return MoreObjects.firstNonNull(DATAFLOW_STATE_TO_JOB_STATE.get(stateName), + State.UNKNOWN); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MutationDetector.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MutationDetector.java new file mode 100644 index 000000000000..51e65ab878cb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MutationDetector.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * An object for detecting illegal mutations. + * + *

    The {@link AutoCloseable} aspect of this interface allows use in a try-with-resources + * style, where the implementing class may choose to perform a final mutation check upon + * {@link #close()}. + */ +public interface MutationDetector extends AutoCloseable { + /** + * @throws IllegalMutationException if illegal mutations are detected. + */ + void verifyUnmodified(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MutationDetectors.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MutationDetectors.java new file mode 100644 index 000000000000..412e3eb72520 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MutationDetectors.java @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.common.base.Throwables; + +import java.util.Arrays; +import java.util.Objects; + +/** + * Static methods for creating and working with {@link MutationDetector}. + */ +public class MutationDetectors { + + private MutationDetectors() {} + + /** + * Creates a new {@code MutationDetector} for the provided {@code value} that uses the provided + * {@link Coder} to perform deep copies and comparisons by serializing and deserializing values. + * + *

    It is permissible for {@code value} to be {@code null}. Since {@code null} is immutable, + * the mutation check will always succeed. + */ + public static MutationDetector forValueWithCoder(T value, Coder coder) + throws CoderException { + if (value == null) { + return noopMutationDetector(); + } else { + return new CodedValueMutationDetector<>(value, coder); + } + } + + /** + * Creates a new {@code MutationDetector} that always succeeds. + * + *

    This is useful, for example, for providing a very efficient mutation detector for a value + * which is already immutable by design. + */ + public static MutationDetector noopMutationDetector() { + return new NoopMutationDetector(); + } + + /** + * A {@link MutationDetector} for {@code null}, which is immutable. + */ + private static class NoopMutationDetector implements MutationDetector { + + @Override + public void verifyUnmodified() { } + + @Override + public void close() { } + } + + /** + * Given a value of type {@code T} and a {@link Coder} for that type, provides facilities to save + * check that the value has not changed. + * + * @param the type of values checked for mutation + */ + private static class CodedValueMutationDetector implements MutationDetector { + + private final Coder coder; + + /** + * A saved pointer to an in-memory value provided upon construction, which we will check for + * forbidden mutations. + */ + private final T possiblyModifiedObject; + + /** + * A saved encoded copy of the same value as {@link #possiblyModifiedObject}. Naturally, it + * will not change if {@link #possiblyModifiedObject} is mutated. + */ + private final byte[] encodedOriginalObject; + + /** + * The object decoded from {@link #encodedOriginalObject}. It will be used during every call to + * {@link #verifyUnmodified}, which could be called many times throughout the lifetime of this + * {@link CodedValueMutationDetector}. + */ + private final T clonedOriginalObject; + + /** + * Create a mutation detector for the provided {@code value}, using the provided {@link Coder} + * for cloning and checking serialized forms for equality. + */ + public CodedValueMutationDetector(T value, Coder coder) throws CoderException { + this.coder = coder; + this.possiblyModifiedObject = value; + this.encodedOriginalObject = CoderUtils.encodeToByteArray(coder, value); + this.clonedOriginalObject = CoderUtils.decodeFromByteArray(coder, encodedOriginalObject); + } + + @Override + public void verifyUnmodified() { + try { + verifyUnmodifiedThrowingCheckedExceptions(); + } catch (CoderException exn) { + Throwables.propagate(exn); + } + } + + private void verifyUnmodifiedThrowingCheckedExceptions() throws CoderException { + // If either object believes they are equal, we trust that and short-circuit deeper checks. + if (Objects.equals(possiblyModifiedObject, clonedOriginalObject) + || Objects.equals(clonedOriginalObject, possiblyModifiedObject)) { + return; + } + + // Since retainedObject is in general an instance of a subclass of T, when it is cloned to + // clonedObject using a Coder, the two will generally be equivalent viewed as a T, but in + // general neither retainedObject.equals(clonedObject) nor clonedObject.equals(retainedObject) + // will hold. + // + // For example, CoderUtils.clone(IterableCoder, IterableSubclass) will + // produce an ArrayList with the same contents as the IterableSubclass, but the + // latter will quite reasonably not consider itself equivalent to an ArrayList (and vice + // versa). + // + // To enable a reasonable comparison, we clone retainedObject again here, converting it to + // the same sort of T that the Coder output when it created clonedObject. + T clonedPossiblyModifiedObject = CoderUtils.clone(coder, possiblyModifiedObject); + + // If deepEquals() then we trust the equals implementation. + // This deliberately allows fields to escape this check. + if (Objects.deepEquals(clonedPossiblyModifiedObject, clonedOriginalObject)) { + return; + } + + // If not deepEquals(), the class may just have a poor equals() implementation. + // So we next try checking their serialized forms. We re-serialize instead of checking + // encodedObject, because the Coder may treat it differently. + // + // For example, an unbounded Iterable will be encoded in an unbounded way, but decoded into an + // ArrayList, which will then be re-encoded in a bounded format. So we really do need to + // encode-decode-encode retainedObject. + if (Arrays.equals( + CoderUtils.encodeToByteArray(coder, clonedOriginalObject), + CoderUtils.encodeToByteArray(coder, clonedPossiblyModifiedObject))) { + return; + } + + // If we got here, then they are not deepEquals() and do not have deepEquals() encodings. + // Even if there is some conceptual sense in which the objects are equivalent, it has not + // been adequately expressed in code. + illegalMutation(clonedOriginalObject, clonedPossiblyModifiedObject); + } + + private void illegalMutation(T previousValue, T newValue) throws CoderException { + throw new IllegalMutationException( + String.format("Value %s mutated illegally, new value was %s." + + " Encoding was %s, now %s.", + previousValue, newValue, + CoderUtils.encodeToBase64(coder, previousValue), + CoderUtils.encodeToBase64(coder, newValue)), + previousValue, newValue); + } + + @Override + public void close() { + verifyUnmodified(); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NonEmptyPanes.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NonEmptyPanes.java new file mode 100644 index 000000000000..1270f014230f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NonEmptyPanes.java @@ -0,0 +1,148 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.util.state.AccumulatorCombiningState; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateMerging; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; + +/** + * Tracks which windows have non-empty panes. Specifically, which windows have new elements since + * their last triggering. + * + * @param The kind of windows being tracked. + */ +public abstract class NonEmptyPanes { + + static NonEmptyPanes create( + WindowingStrategy strategy, ReduceFn reduceFn) { + if (strategy.getMode() == AccumulationMode.DISCARDING_FIRED_PANES) { + return new DiscardingModeNonEmptyPanes<>(reduceFn); + } else { + return new GeneralNonEmptyPanes<>(); + } + } + + /** + * Record that some content has been added to the window in {@code context}, and therefore the + * current pane is not empty. + */ + public abstract void recordContent(StateAccessor context); + + /** + * Record that the given pane is empty. + */ + public abstract void clearPane(StateAccessor state); + + /** + * Return true if the current pane for the window in {@code context} is empty. + */ + public abstract ReadableState isEmpty(StateAccessor context); + + /** + * Prefetch in preparation for merging. + */ + public abstract void prefetchOnMerge(MergingStateAccessor state); + + /** + * Eagerly merge backing state. + */ + public abstract void onMerge(MergingStateAccessor context); + + /** + * An implementation of {@code NonEmptyPanes} optimized for use with discarding mode. Uses the + * presence of data in the accumulation buffer to record non-empty panes. + */ + private static class DiscardingModeNonEmptyPanes + extends NonEmptyPanes { + + private ReduceFn reduceFn; + + private DiscardingModeNonEmptyPanes(ReduceFn reduceFn) { + this.reduceFn = reduceFn; + } + + @Override + public ReadableState isEmpty(StateAccessor state) { + return reduceFn.isEmpty(state); + } + + @Override + public void recordContent(StateAccessor state) { + // Nothing to do -- the reduceFn is tracking contents + } + + @Override + public void clearPane(StateAccessor state) { + // Nothing to do -- the reduceFn is tracking contents + } + + @Override + public void prefetchOnMerge(MergingStateAccessor state) { + // Nothing to do -- the reduceFn is tracking contents + } + + @Override + public void onMerge(MergingStateAccessor context) { + // Nothing to do -- the reduceFn is tracking contents + } + } + + /** + * An implementation of {@code NonEmptyPanes} for general use. + */ + private static class GeneralNonEmptyPanes + extends NonEmptyPanes { + + private static final StateTag> + PANE_ADDITIONS_TAG = + StateTags.makeSystemTagInternal(StateTags.combiningValueFromInputInternal( + "count", VarLongCoder.of(), new Sum.SumLongFn())); + + @Override + public void recordContent(StateAccessor state) { + state.access(PANE_ADDITIONS_TAG).add(1L); + } + + @Override + public void clearPane(StateAccessor state) { + state.access(PANE_ADDITIONS_TAG).clear(); + } + + @Override + public ReadableState isEmpty(StateAccessor state) { + return state.access(PANE_ADDITIONS_TAG).isEmpty(); + } + + @Override + public void prefetchOnMerge(MergingStateAccessor state) { + StateMerging.prefetchCombiningValues(state, PANE_ADDITIONS_TAG); + } + + @Override + public void onMerge(MergingStateAccessor context) { + StateMerging.mergeCombiningValues(context, PANE_ADDITIONS_TAG); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NonMergingActiveWindowSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NonMergingActiveWindowSet.java new file mode 100644 index 000000000000..cb7f9b06e7c3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NonMergingActiveWindowSet.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.Set; + +/** + * Implementation of {@link ActiveWindowSet} used with {@link WindowFn WindowFns} that don't support + * merging. + * + * @param the types of windows being managed + */ +public class NonMergingActiveWindowSet implements ActiveWindowSet { + @Override + public void removeEphemeralWindows() {} + + @Override + public void persist() {} + + @Override + public W representative(W window) { + // Always represented by itself. + return window; + } + + @Override + public Set getActiveWindows() { + // Only supported when merging. + throw new java.lang.UnsupportedOperationException(); + } + + @Override + public boolean isActive(W window) { + // Windows should never disappear, since we don't support merging. + return true; + } + + @Override + public void addNew(W window) {} + + @Override + public void addActive(W window) {} + + @Override + public void remove(W window) {} + + @Override + public void merge(MergeCallback mergeCallback) throws Exception {} + + @Override + public void merged(W window) {} + + @Override + public Set readStateAddresses(W window) { + return ImmutableSet.of(window); + } + + @Override + public W writeStateAddress(W window) { + return window; + } + + @Override + public W mergedWriteStateAddress(Collection toBeMerged, W mergeResult) { + return mergeResult; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NoopCredentialFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NoopCredentialFactory.java new file mode 100644 index 000000000000..9ef4c2eb09b7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NoopCredentialFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.io.IOException; +import java.security.GeneralSecurityException; + +/** + * Construct an oauth credential to be used by the SDK and the SDK workers. + * Always returns a null Credential object. + */ +public class NoopCredentialFactory implements CredentialFactory { + public static NoopCredentialFactory fromOptions(PipelineOptions options) { + return new NoopCredentialFactory(); + } + + @Override + public Credential getCredential() throws IOException, GeneralSecurityException { + return null; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NoopPathValidator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NoopPathValidator.java new file mode 100644 index 000000000000..00abbb146ff7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NoopPathValidator.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +/** + * Noop implementation of {@link PathValidator}. All paths are allowed and returned unchanged. + */ +public class NoopPathValidator implements PathValidator { + + private NoopPathValidator() { + } + + public static PathValidator fromOptions( + @SuppressWarnings("unused") PipelineOptions options) { + return new NoopPathValidator(); + } + + @Override + public String validateInputFilePatternSupported(String filepattern) { + return filepattern; + } + + @Override + public String validateOutputFilePrefixSupported(String filePrefix) { + return filePrefix; + } + + @Override + public String verifyPath(String path) { + return path; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NullSideInputReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NullSideInputReader.java new file mode 100644 index 000000000000..0fc264606f88 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/NullSideInputReader.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.collect.Sets; + +import java.util.Collections; +import java.util.Set; + +/** + * A {@link SideInputReader} representing a well-defined set of views, but not storing + * any values for them. Used to check if a side input is present when the data itself + * comes from elsewhere. + */ +public class NullSideInputReader implements SideInputReader { + + private Set> views; + + public static NullSideInputReader empty() { + return new NullSideInputReader(Collections.>emptySet()); + } + + public static NullSideInputReader of(Iterable> views) { + return new NullSideInputReader(views); + } + + private NullSideInputReader(Iterable> views) { + this.views = Sets.newHashSet(views); + } + + @Override + public T get(PCollectionView view, BoundedWindow window) { + throw new IllegalArgumentException("cannot call NullSideInputReader.get()"); + } + + @Override + public boolean isEmpty() { + return views.isEmpty(); + } + + @Override + public boolean contains(PCollectionView view) { + return views.contains(view); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/OutputReference.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/OutputReference.java new file mode 100644 index 000000000000..096c9965284f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/OutputReference.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.api.client.util.Preconditions.checkNotNull; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Key; + +/** + * A representation used by {@link com.google.api.services.dataflow.model.Step}s + * to reference the output of other {@code Step}s. + */ +public final class OutputReference extends GenericJson { + @Key("@type") + public final String type = "OutputReference"; + + @Key("step_name") + private final String stepName; + + @Key("output_name") + private final String outputName; + + public OutputReference(String stepName, String outputName) { + this.stepName = checkNotNull(stepName); + this.outputName = checkNotNull(outputName); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PCollectionViewWindow.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PCollectionViewWindow.java new file mode 100644 index 000000000000..7cf636eb63c4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PCollectionViewWindow.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import java.util.Objects; + +/** + * A pair of a {@link PCollectionView} and a {@link BoundedWindow}, which can + * be thought of as window "of" the view. This is a value class for use e.g. + * as a compound cache key. + * + * @param the type of the underlying PCollectionView + */ +public final class PCollectionViewWindow { + + private final PCollectionView view; + private final BoundedWindow window; + + private PCollectionViewWindow(PCollectionView view, BoundedWindow window) { + this.view = view; + this.window = window; + } + + public static PCollectionViewWindow of(PCollectionView view, BoundedWindow window) { + return new PCollectionViewWindow<>(view, window); + } + + public PCollectionView getView() { + return view; + } + + public BoundedWindow getWindow() { + return window; + } + + @Override + public boolean equals(Object otherObject) { + if (!(otherObject instanceof PCollectionViewWindow)) { + return false; + } + @SuppressWarnings("unchecked") + PCollectionViewWindow other = (PCollectionViewWindow) otherObject; + return getView().equals(other.getView()) && getWindow().equals(other.getWindow()); + } + + @Override + public int hashCode() { + return Objects.hash(getView(), getWindow()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PCollectionViews.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PCollectionViews.java new file mode 100644 index 000000000000..7e735473c3a2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PCollectionViews.java @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.InvalidWindows; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValueBase; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Function; +import com.google.common.base.MoreObjects; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * Implementations of {@link PCollectionView} shared across the SDK. + * + *

    For internal use only, subject to change. + */ +public class PCollectionViews { + + /** + * Returns a {@code PCollectionView} capable of processing elements encoded using the provided + * {@link Coder} and windowed using the provided * {@link WindowingStrategy}. + * + *

    If {@code hasDefault} is {@code true}, then the view will take on the value + * {@code defaultValue} for any empty windows. + */ + public static PCollectionView singletonView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + boolean hasDefault, + T defaultValue, + Coder valueCoder) { + return new SingletonPCollectionView<>( + pipeline, windowingStrategy, hasDefault, defaultValue, valueCoder); + } + + /** + * Returns a {@code PCollectionView>} capable of processing elements encoded using the + * provided {@link Coder} and windowed using the provided {@link WindowingStrategy}. + */ + public static PCollectionView> iterableView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder valueCoder) { + return new IterablePCollectionView<>(pipeline, windowingStrategy, valueCoder); + } + + /** + * Returns a {@code PCollectionView>} capable of processing elements encoded using the + * provided {@link Coder} and windowed using the provided {@link WindowingStrategy}. + */ + public static PCollectionView> listView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder valueCoder) { + return new ListPCollectionView<>(pipeline, windowingStrategy, valueCoder); + } + + /** + * Returns a {@code PCollectionView>} capable of processing elements encoded using the + * provided {@link Coder} and windowed using the provided {@link WindowingStrategy}. + */ + public static PCollectionView> mapView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder> valueCoder) { + + return new MapPCollectionView(pipeline, windowingStrategy, valueCoder); + } + + /** + * Returns a {@code PCollectionView>>} capable of processing elements encoded + * using the provided {@link Coder} and windowed using the provided {@link WindowingStrategy}. + */ + public static PCollectionView>> multimapView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder> valueCoder) { + return new MultimapPCollectionView(pipeline, windowingStrategy, valueCoder); + } + + /** + * Implementation of conversion of singleton {@code Iterable>} to {@code T}. + * + *

    For internal use only. + * + *

    Instantiate via {@link PCollectionViews#singletonView}. + */ + public static class SingletonPCollectionView + extends PCollectionViewBase { + @Nullable private byte[] encodedDefaultValue; + @Nullable private transient T defaultValue; + @Nullable private Coder valueCoder; + private boolean hasDefault; + + private SingletonPCollectionView( + Pipeline pipeline, WindowingStrategy windowingStrategy, + boolean hasDefault, T defaultValue, Coder valueCoder) { + super(pipeline, windowingStrategy, valueCoder); + this.hasDefault = hasDefault; + this.defaultValue = defaultValue; + this.valueCoder = valueCoder; + if (hasDefault) { + try { + this.encodedDefaultValue = CoderUtils.encodeToByteArray(valueCoder, defaultValue); + } catch (IOException e) { + throw new RuntimeException("Unexpected IOException: ", e); + } + } + } + + /** + * Returns the default value that was specified. + * + *

    For internal use only. + * + * @throws NoSuchElementException if no default was specified. + */ + public T getDefaultValue() { + if (!hasDefault) { + throw new NoSuchElementException("Empty PCollection accessed as a singleton view."); + } + // Lazily decode the default value once + synchronized (this) { + if (encodedDefaultValue != null) { + try { + defaultValue = CoderUtils.decodeFromByteArray(valueCoder, encodedDefaultValue); + encodedDefaultValue = null; + } catch (IOException e) { + throw new RuntimeException("Unexpected IOException: ", e); + } + } + } + return defaultValue; + } + + @Override + protected T fromElements(Iterable> contents) { + try { + return Iterables.getOnlyElement(contents).getValue(); + } catch (NoSuchElementException exc) { + return getDefaultValue(); + } catch (IllegalArgumentException exc) { + throw new IllegalArgumentException( + "PCollection with more than one element " + + "accessed as a singleton view."); + } + } + } + + /** + * Implementation of conversion {@code Iterable>} to {@code Iterable}. + * + *

    For internal use only. + * + *

    Instantiate via {@link PCollectionViews#iterableView}. + */ + public static class IterablePCollectionView + extends PCollectionViewBase, W> { + private IterablePCollectionView( + Pipeline pipeline, WindowingStrategy windowingStrategy, Coder valueCoder) { + super(pipeline, windowingStrategy, valueCoder); + } + + @Override + protected Iterable fromElements(Iterable> contents) { + return Iterables.unmodifiableIterable( + Iterables.transform(contents, new Function, T>() { + @SuppressWarnings("unchecked") + @Override + public T apply(WindowedValue input) { + return input.getValue(); + } + })); + } + } + + /** + * Implementation of conversion {@code Iterable>} to {@code List}. + * + *

    For internal use only. + * + *

    Instantiate via {@link PCollectionViews#listView}. + */ + public static class ListPCollectionView + extends PCollectionViewBase, W> { + private ListPCollectionView( + Pipeline pipeline, WindowingStrategy windowingStrategy, Coder valueCoder) { + super(pipeline, windowingStrategy, valueCoder); + } + + @Override + protected List fromElements(Iterable> contents) { + return ImmutableList.copyOf( + Iterables.transform(contents, new Function, T>() { + @SuppressWarnings("unchecked") + @Override + public T apply(WindowedValue input) { + return input.getValue(); + } + })); + } + } + + /** + * Implementation of conversion {@code Iterable>>} + * to {@code Map>}. + * + *

    For internal use only. + */ + public static class MultimapPCollectionView + extends PCollectionViewBase, Map>, W> { + private MultimapPCollectionView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder> valueCoder) { + super(pipeline, windowingStrategy, valueCoder); + } + + @Override + protected Map> fromElements(Iterable>> elements) { + Multimap multimap = HashMultimap.create(); + for (WindowedValue> elem : elements) { + KV kv = elem.getValue(); + multimap.put(kv.getKey(), kv.getValue()); + } + // Safe covariant cast that Java cannot express without rawtypes, even with unchecked casts + @SuppressWarnings({"unchecked", "rawtypes"}) + Map> resultMap = (Map) multimap.asMap(); + return Collections.unmodifiableMap(resultMap); + } + } + + /** + * Implementation of conversion {@code Iterable>} with + * one value per key to {@code Map}. + * + *

    For internal use only. + */ + public static class MapPCollectionView + extends PCollectionViewBase, Map, W> { + private MapPCollectionView( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder> valueCoder) { + super(pipeline, windowingStrategy, valueCoder); + } + + /** + * Input iterable must actually be {@code Iterable>>}. + */ + @Override + protected Map fromElements(Iterable>> elements) { + Map map = new HashMap<>(); + for (WindowedValue> elem : elements) { + KV kv = elem.getValue(); + if (map.containsKey(kv.getKey())) { + throw new IllegalArgumentException("Duplicate values for " + kv.getKey()); + } + map.put(kv.getKey(), kv.getValue()); + } + return Collections.unmodifiableMap(map); + } + } + + /** + * A base class for {@link PCollectionView} implementations, with additional type parameters + * that are not visible at pipeline assembly time when the view is used as a side input. + */ + private abstract static class PCollectionViewBase + extends PValueBase + implements PCollectionView { + /** A unique tag for the view, typed according to the elements underlying the view. */ + private TupleTag>> tag; + + /** The windowing strategy for the PCollection underlying the view. */ + private WindowingStrategy windowingStrategy; + + /** The coder for the elements underlying the view. */ + private Coder>> coder; + + /** + * Implement this to complete the implementation. It is a conversion function from + * all of the elements of the underlying {@link PCollection} to the value of the view. + */ + protected abstract ViewT fromElements(Iterable> elements); + + /** + * Call this constructor to initialize the fields for which this base class provides + * boilerplate accessors. + */ + protected PCollectionViewBase( + Pipeline pipeline, + TupleTag>> tag, + WindowingStrategy windowingStrategy, + Coder valueCoder) { + super(pipeline); + if (windowingStrategy.getWindowFn() instanceof InvalidWindows) { + throw new IllegalArgumentException("WindowFn of PCollectionView cannot be InvalidWindows"); + } + this.tag = tag; + this.windowingStrategy = windowingStrategy; + this.coder = + IterableCoder.of(WindowedValue.getFullCoder( + valueCoder, windowingStrategy.getWindowFn().windowCoder())); + } + + /** + * Call this constructor to initialize the fields for which this base class provides + * boilerplate accessors, with an auto-generated tag. + */ + protected PCollectionViewBase( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + Coder valueCoder) { + this(pipeline, new TupleTag>>(), windowingStrategy, valueCoder); + } + + /** + * For serialization only. Do not use directly. Subclasses should call from their own + * protected no-argument constructor. + */ + @SuppressWarnings("unused") // used for serialization + protected PCollectionViewBase() { + super(); + } + + @Override + public ViewT fromIterableInternal(Iterable> elements) { + // Safe cast: it is required that the rest of the SDK maintain the invariant + // that a PCollectionView is only provided an iterable for the elements of an + // appropriately typed PCollection. + @SuppressWarnings({"rawtypes", "unchecked"}) + Iterable> typedElements = (Iterable) elements; + return fromElements(typedElements); + } + + /** + * Returns a unique {@link TupleTag} identifying this {@link PCollectionView}. + * + *

    For internal use only by runner implementors. + */ + @Override + public TupleTag>> getTagInternal() { + // Safe cast: It is required that the rest of the SDK maintain the invariant that + // this tag is only used to access the contents of an appropriately typed underlying + // PCollection + @SuppressWarnings({"rawtypes", "unchecked"}) + TupleTag>> untypedTag = (TupleTag) tag; + return untypedTag; + } + + /** + * Returns the {@link WindowingStrategy} of this {@link PCollectionView}, which should + * be that of the underlying {@link PCollection}. + * + *

    For internal use only by runner implementors. + */ + @Override + public WindowingStrategy getWindowingStrategyInternal() { + return windowingStrategy; + } + + @Override + public Coder>> getCoderInternal() { + // Safe cast: It is required that the rest of the SDK only use this untyped coder + // for the elements of an appropriately typed underlying PCollection. + @SuppressWarnings({"rawtypes", "unchecked"}) + Coder>> untypedCoder = (Coder) coder; + return untypedCoder; + } + + @Override + public int hashCode() { + return Objects.hash(tag); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PCollectionView) || other == null) { + return false; + } + @SuppressWarnings("unchecked") + PCollectionView otherView = (PCollectionView) other; + return tag.equals(otherView.getTagInternal()); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("tag", tag).toString(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PTuple.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PTuple.java new file mode 100644 index 000000000000..5b87b5cf8832 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PTuple.java @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A {@code PTuple} is an immutable tuple of + * heterogeneously-typed values, "keyed" by {@link TupleTag}s. + * + *

    PTuples can be created and accessed like follows: + *

     {@code
    + * String v1 = ...;
    + * Integer v2 = ...;
    + * Iterable v3 = ...;
    + *
    + * // Create TupleTags for each of the values to put in the
    + * // PTuple (the type of the TupleTag enables tracking the
    + * // static type of each of the values in the PTuple):
    + * TupleTag tag1 = new TupleTag<>();
    + * TupleTag tag2 = new TupleTag<>();
    + * TupleTag> tag3 = new TupleTag<>();
    + *
    + * // Create a PTuple with three values:
    + * PTuple povs =
    + *     PTuple.of(tag1, v1)
    + *         .and(tag2, v2)
    + *         .and(tag3, v3);
    + *
    + * // Create an empty PTuple:
    + * Pipeline p = ...;
    + * PTuple povs2 = PTuple.empty(p);
    + *
    + * // Get values out of a PTuple, using the same tags
    + * // that were used to put them in:
    + * Integer vX = povs.get(tag2);
    + * String vY = povs.get(tag1);
    + * Iterable vZ = povs.get(tag3);
    + *
    + * // Get a map of all values in a PTuple:
    + * Map, ?> allVs = povs.getAll();
    + * } 
    + */ +public class PTuple { + /** + * Returns an empty PTuple. + * + *

    Longer PTuples can be created by calling + * {@link #and} on the result. + */ + public static PTuple empty() { + return new PTuple(); + } + + /** + * Returns a singleton PTuple containing the given + * value keyed by the given TupleTag. + * + *

    Longer PTuples can be created by calling + * {@link #and} on the result. + */ + public static PTuple of(TupleTag tag, V value) { + return empty().and(tag, value); + } + + /** + * Returns a new PTuple that has all the values and + * tags of this PTuple plus the given value and tag. + * + *

    The given TupleTag should not already be mapped to a + * value in this PTuple. + */ + public PTuple and(TupleTag tag, V value) { + Map, Object> newMap = new LinkedHashMap, Object>(); + newMap.putAll(valueMap); + newMap.put(tag, value); + return new PTuple(newMap); + } + + /** + * Returns whether this PTuple contains a value with + * the given tag. + */ + public boolean has(TupleTag tag) { + return valueMap.containsKey(tag); + } + + /** + * Returns true if this {@code PTuple} is empty. + */ + public boolean isEmpty() { + return valueMap.isEmpty(); + } + + /** + * Returns the value with the given tag in this + * PTuple. Throws IllegalArgumentException if there is no + * such value, i.e., {@code !has(tag)}. + */ + public V get(TupleTag tag) { + if (!has(tag)) { + throw new IllegalArgumentException( + "TupleTag not found in this PTuple"); + } + @SuppressWarnings("unchecked") + V value = (V) valueMap.get(tag); + return value; + } + + /** + * Returns an immutable Map from TupleTag to corresponding + * value, for all the members of this PTuple. + */ + public Map, ?> getAll() { + return valueMap; + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + private final Map, ?> valueMap; + + @SuppressWarnings("rawtypes") + private PTuple() { + this(new LinkedHashMap()); + } + + private PTuple(Map, ?> valueMap) { + this.valueMap = Collections.unmodifiableMap(valueMap); + } + + /** + * Returns a PTuple with each of the given tags mapping + * to the corresponding value. + * + *

    For internal use only. + */ + public static PTuple ofInternal(Map, ?> valueMap) { + return new PTuple(valueMap); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PackageUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PackageUtil.java new file mode 100644 index 000000000000..8b2d56f13320 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PackageUtil.java @@ -0,0 +1,327 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.BackOffUtils; +import com.google.api.client.util.Sleeper; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.cloud.hadoop.util.ApiErrorExtractor; +import com.google.common.hash.Funnels; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; +import com.google.common.io.CountingOutputStream; +import com.google.common.io.Files; + +import com.fasterxml.jackson.core.Base64Variants; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** Helper routines for packages. */ +public class PackageUtil { + private static final Logger LOG = LoggerFactory.getLogger(PackageUtil.class); + /** + * A reasonable upper bound on the number of jars required to launch a Dataflow job. + */ + public static final int SANE_CLASSPATH_SIZE = 1000; + /** + * The initial interval to use between package staging attempts. + */ + private static final long INITIAL_BACKOFF_INTERVAL_MS = 5000L; + /** + * The maximum number of attempts when staging a file. + */ + private static final int MAX_ATTEMPTS = 5; + + /** + * Translates exceptions from API calls. + */ + private static final ApiErrorExtractor ERROR_EXTRACTOR = new ApiErrorExtractor(); + + /** + * Creates a DataflowPackage containing information about how a classpath element should be + * staged, including the staging destination as well as its size and hash. + * + * @param classpathElement The local path for the classpath element. + * @param stagingPath The base location for staged classpath elements. + * @param overridePackageName If non-null, use the given value as the package name + * instead of generating one automatically. + * @return The package. + */ + @Deprecated + public static DataflowPackage createPackage(File classpathElement, + String stagingPath, String overridePackageName) { + return createPackageAttributes(classpathElement, stagingPath, overridePackageName) + .getDataflowPackage(); + } + + /** + * Compute and cache the attributes of a classpath element that we will need to stage it. + * + * @param classpathElement the file or directory to be staged. + * @param stagingPath The base location for staged classpath elements. + * @param overridePackageName If non-null, use the given value as the package name + * instead of generating one automatically. + * @return a {@link PackageAttributes} that containing metadata about the object to be staged. + */ + static PackageAttributes createPackageAttributes(File classpathElement, + String stagingPath, String overridePackageName) { + try { + boolean directory = classpathElement.isDirectory(); + + // Compute size and hash in one pass over file or directory. + Hasher hasher = Hashing.md5().newHasher(); + OutputStream hashStream = Funnels.asOutputStream(hasher); + CountingOutputStream countingOutputStream = new CountingOutputStream(hashStream); + + if (!directory) { + // Files are staged as-is. + Files.asByteSource(classpathElement).copyTo(countingOutputStream); + } else { + // Directories are recursively zipped. + ZipFiles.zipDirectory(classpathElement, countingOutputStream); + } + + long size = countingOutputStream.getCount(); + String hash = Base64Variants.MODIFIED_FOR_URL.encode(hasher.hash().asBytes()); + + // Create the DataflowPackage with staging name and location. + String uniqueName = getUniqueContentName(classpathElement, hash); + String resourcePath = IOChannelUtils.resolve(stagingPath, uniqueName); + DataflowPackage target = new DataflowPackage(); + target.setName(overridePackageName != null ? overridePackageName : uniqueName); + target.setLocation(resourcePath); + + return new PackageAttributes(size, hash, directory, target); + } catch (IOException e) { + throw new RuntimeException("Package setup failure for " + classpathElement, e); + } + } + + /** + * Transfers the classpath elements to the staging location. + * + * @param classpathElements The elements to stage. + * @param stagingPath The base location to stage the elements to. + * @return A list of cloud workflow packages, each representing a classpath element. + */ + public static List stageClasspathElements( + Collection classpathElements, String stagingPath) { + return stageClasspathElements(classpathElements, stagingPath, Sleeper.DEFAULT); + } + + // Visible for testing. + static List stageClasspathElements( + Collection classpathElements, String stagingPath, + Sleeper retrySleeper) { + LOG.info("Uploading {} files from PipelineOptions.filesToStage to staging location to " + + "prepare for execution.", classpathElements.size()); + + if (classpathElements.size() > SANE_CLASSPATH_SIZE) { + LOG.warn("Your classpath contains {} elements, which Google Cloud Dataflow automatically " + + "copies to all workers. Having this many entries on your classpath may be indicative " + + "of an issue in your pipeline. You may want to consider trimming the classpath to " + + "necessary dependencies only, using --filesToStage pipeline option to override " + + "what files are being staged, or bundling several dependencies into one.", + classpathElements.size()); + } + + ArrayList packages = new ArrayList<>(); + + if (stagingPath == null) { + throw new IllegalArgumentException( + "Can't stage classpath elements on because no staging location has been provided"); + } + + int numUploaded = 0; + int numCached = 0; + for (String classpathElement : classpathElements) { + String packageName = null; + if (classpathElement.contains("=")) { + String[] components = classpathElement.split("=", 2); + packageName = components[0]; + classpathElement = components[1]; + } + + File file = new File(classpathElement); + if (!file.exists()) { + LOG.warn("Skipping non-existent classpath element {} that was specified.", + classpathElement); + continue; + } + + PackageAttributes attributes = createPackageAttributes(file, stagingPath, packageName); + + DataflowPackage workflowPackage = attributes.getDataflowPackage(); + packages.add(workflowPackage); + String target = workflowPackage.getLocation(); + + // TODO: Should we attempt to detect the Mime type rather than + // always using MimeTypes.BINARY? + try { + try { + long remoteLength = IOChannelUtils.getSizeBytes(target); + if (remoteLength == attributes.getSize()) { + LOG.debug("Skipping classpath element already staged: {} at {}", + classpathElement, target); + numCached++; + continue; + } + } catch (FileNotFoundException expected) { + // If the file doesn't exist, it means we need to upload it. + } + + // Upload file, retrying on failure. + AttemptBoundedExponentialBackOff backoff = new AttemptBoundedExponentialBackOff( + MAX_ATTEMPTS, + INITIAL_BACKOFF_INTERVAL_MS); + while (true) { + try { + LOG.debug("Uploading classpath element {} to {}", classpathElement, target); + try (WritableByteChannel writer = IOChannelUtils.create(target, MimeTypes.BINARY)) { + copyContent(classpathElement, writer); + } + numUploaded++; + break; + } catch (IOException e) { + if (ERROR_EXTRACTOR.accessDenied(e)) { + String errorMessage = String.format( + "Uploaded failed due to permissions error, will NOT retry staging " + + "of classpath %s. Please verify credentials are valid and that you have " + + "write access to %s. Stale credentials can be resolved by executing " + + "'gcloud auth login'.", classpathElement, target); + LOG.error(errorMessage); + throw new IOException(errorMessage, e); + } else if (!backoff.atMaxAttempts()) { + LOG.warn("Upload attempt failed, sleeping before retrying staging of classpath: {}", + classpathElement, e); + BackOffUtils.next(retrySleeper, backoff); + } else { + // Rethrow last error, to be included as a cause in the catch below. + LOG.error("Upload failed, will NOT retry staging of classpath: {}", + classpathElement, e); + throw e; + } + } + } + } catch (Exception e) { + throw new RuntimeException("Could not stage classpath element: " + classpathElement, e); + } + } + + LOG.info("Uploading PipelineOptions.filesToStage complete: {} files newly uploaded, " + + "{} files cached", + numUploaded, numCached); + + return packages; + } + + /** + * Returns a unique name for a file with a given content hash. + * + *

    Directory paths are removed. Example: + *

    +   * dir="a/b/c/d", contentHash="f000" => d-f000.jar
    +   * file="a/b/c/d.txt", contentHash="f000" => d-f000.txt
    +   * file="a/b/c/d", contentHash="f000" => d-f000
    +   * 
    + */ + static String getUniqueContentName(File classpathElement, String contentHash) { + String fileName = Files.getNameWithoutExtension(classpathElement.getAbsolutePath()); + String fileExtension = Files.getFileExtension(classpathElement.getAbsolutePath()); + if (classpathElement.isDirectory()) { + return fileName + "-" + contentHash + ".jar"; + } else if (fileExtension.isEmpty()) { + return fileName + "-" + contentHash; + } + return fileName + "-" + contentHash + "." + fileExtension; + } + + /** + * Copies the contents of the classpathElement to the output channel. + * + *

    If the classpathElement is a directory, a Zip stream is constructed on the fly, + * otherwise the file contents are copied as-is. + * + *

    The output channel is not closed. + */ + private static void copyContent(String classpathElement, WritableByteChannel outputChannel) + throws IOException { + final File classpathElementFile = new File(classpathElement); + if (classpathElementFile.isDirectory()) { + ZipFiles.zipDirectory(classpathElementFile, Channels.newOutputStream(outputChannel)); + } else { + Files.asByteSource(classpathElementFile).copyTo(Channels.newOutputStream(outputChannel)); + } + } + /** + * Holds the metadata necessary to stage a file or confirm that a staged file has not changed. + */ + static class PackageAttributes { + private final boolean directory; + private final long size; + private final String hash; + private DataflowPackage dataflowPackage; + + public PackageAttributes(long size, String hash, boolean directory, + DataflowPackage dataflowPackage) { + this.size = size; + this.hash = Objects.requireNonNull(hash, "hash"); + this.directory = directory; + this.dataflowPackage = Objects.requireNonNull(dataflowPackage, "dataflowPackage"); + } + + /** + * @return the dataflowPackage + */ + public DataflowPackage getDataflowPackage() { + return dataflowPackage; + } + + /** + * @return the directory + */ + public boolean isDirectory() { + return directory; + } + + /** + * @return the size + */ + public long getSize() { + return size; + } + + /** + * @return the hash + */ + public String getHash() { + return hash; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PaneInfoTracker.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PaneInfoTracker.java new file mode 100644 index 000000000000..38499c2e2b88 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PaneInfoTracker.java @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.util.state.ValueState; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import org.joda.time.Instant; + +/** + * Determine the timing and other properties of a new pane for a given computation, key and window. + * Incorporates any previous pane, whether the pane has been produced because an + * on-time {@link AfterWatermark} trigger firing, and the relation between the element's timestamp + * and the current output watermark. + */ +public class PaneInfoTracker { + private TimerInternals timerInternals; + + public PaneInfoTracker(TimerInternals timerInternals) { + this.timerInternals = timerInternals; + } + + @VisibleForTesting + static final StateTag> PANE_INFO_TAG = + StateTags.makeSystemTagInternal(StateTags.value("pane", PaneInfoCoder.INSTANCE)); + + public void clear(StateAccessor state) { + state.access(PANE_INFO_TAG).clear(); + } + + /** + * Return a ({@link ReadableState} for) the pane info appropriate for {@code context}. The pane + * info includes the timing for the pane, who's calculation is quite subtle. + * + * @param isEndOfWindow should be {@code true} only if the pane is being emitted + * because an end-of-window timer has fired and the trigger agreed we should fire. + * @param isFinal should be {@code true} only if the triggering machinery can guarantee + * no further firings for the + */ + public ReadableState getNextPaneInfo(ReduceFn.Context context, + final boolean isEndOfWindow, final boolean isFinal) { + final Object key = context.key(); + final ReadableState previousPaneFuture = + context.state().access(PaneInfoTracker.PANE_INFO_TAG); + final Instant windowMaxTimestamp = context.window().maxTimestamp(); + + return new ReadableState() { + @Override + public ReadableState readLater() { + previousPaneFuture.readLater(); + return this; + } + + @Override + public PaneInfo read() { + PaneInfo previousPane = previousPaneFuture.read(); + return describePane(key, windowMaxTimestamp, previousPane, isEndOfWindow, isFinal); + } + }; + } + + public void storeCurrentPaneInfo(ReduceFn.Context context, PaneInfo currentPane) { + context.state().access(PANE_INFO_TAG).write(currentPane); + } + + private PaneInfo describePane(Object key, Instant windowMaxTimestamp, PaneInfo previousPane, + boolean isEndOfWindow, boolean isFinal) { + boolean isFirst = previousPane == null; + Timing previousTiming = isFirst ? null : previousPane.getTiming(); + long index = isFirst ? 0 : previousPane.getIndex() + 1; + long nonSpeculativeIndex = isFirst ? 0 : previousPane.getNonSpeculativeIndex() + 1; + Instant outputWM = timerInternals.currentOutputWatermarkTime(); + Instant inputWM = timerInternals.currentInputWatermarkTime(); + + // True if it is not possible to assign the element representing this pane a timestamp + // which will make an ON_TIME pane for any following computation. + // Ie true if the element's latest possible timestamp is before the current output watermark. + boolean isLateForOutput = outputWM != null && windowMaxTimestamp.isBefore(outputWM); + + // True if all emitted panes (if any) were EARLY panes. + // Once the ON_TIME pane has fired, all following panes must be considered LATE even + // if the output watermark is behind the end of the window. + boolean onlyEarlyPanesSoFar = previousTiming == null || previousTiming == Timing.EARLY; + + Timing timing; + if (isLateForOutput || !onlyEarlyPanesSoFar) { + // The output watermark has already passed the end of this window, or we have already + // emitted a non-EARLY pane. Irrespective of how this pane was triggered we must + // consider this pane LATE. + timing = Timing.LATE; + } else if (isEndOfWindow) { + // This is the unique ON_TIME firing for the window. + timing = Timing.ON_TIME; + } else { + // All other cases are EARLY. + timing = Timing.EARLY; + nonSpeculativeIndex = -1; + } + + WindowTracing.debug( + "describePane: {} pane (prev was {}) for key:{}; windowMaxTimestamp:{}; " + + "inputWatermark:{}; outputWatermark:{}; isEndOfWindow:{}; isLateForOutput:{}", + timing, previousTiming, key, windowMaxTimestamp, inputWM, outputWM, isEndOfWindow, + isLateForOutput); + + if (previousPane != null) { + // Timing transitions should follow EARLY* ON_TIME? LATE* + switch (previousTiming) { + case EARLY: + Preconditions.checkState( + timing == Timing.EARLY || timing == Timing.ON_TIME || timing == Timing.LATE, + "EARLY cannot transition to %s", timing); + break; + case ON_TIME: + Preconditions.checkState( + timing == Timing.LATE, "ON_TIME cannot transition to %s", timing); + break; + case LATE: + Preconditions.checkState(timing == Timing.LATE, "LATE cannot transtion to %s", timing); + break; + case UNKNOWN: + break; + } + Preconditions.checkState(!previousPane.isLast(), "Last pane was not last after all."); + } + + return PaneInfo.createPane(isFirst, isFinal, timing, index, nonSpeculativeIndex); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PathValidator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PathValidator.java new file mode 100644 index 000000000000..658de2a78f9e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PathValidator.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * Interface for controlling validation of paths. + */ +public interface PathValidator { + /** + * Validate that a file pattern is conforming. + * + * @param filepattern The file pattern to verify. + * @return The post-validation filepattern. + */ + public String validateInputFilePatternSupported(String filepattern); + + /** + * Validate that an output file prefix is conforming. + * + * @param filePrefix the file prefix to verify. + * @return The post-validation filePrefix. + */ + public String validateOutputFilePrefixSupported(String filePrefix); + + /** + * Validate that a path is a valid path and that the path + * is accessible. + * + * @param path The path to verify. + * @return The post-validation path. + */ + public String verifyPath(String path); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PerKeyCombineFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PerKeyCombineFnRunner.java new file mode 100644 index 000000000000..b5f328f014b1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PerKeyCombineFnRunner.java @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; + +import java.io.Serializable; +import java.util.Collection; + +/** + * An interface that runs a {@link PerKeyCombineFn} with unified APIs. + * + *

    Different keyed combine functions have their own implementations. + * For example, the implementation can skip allocating {@code Combine.Context}, + * if the keyed combine function doesn't use it. + */ +public interface PerKeyCombineFnRunner extends Serializable { + /** + * Returns the {@link PerKeyCombineFn} it holds. + * + *

    It can be a {@code KeyedCombineFn} or a {@code KeyedCombineFnWithContext}. + */ + public PerKeyCombineFn fn(); + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Forwards the call to a {@link PerKeyCombineFn} to create the accumulator in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public AccumT createAccumulator(K key, DoFn.ProcessContext c); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to add the input in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public AccumT addInput(K key, AccumT accumulator, InputT input, DoFn.ProcessContext c); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to merge accumulators in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public AccumT mergeAccumulators( + K key, Iterable accumulators, DoFn.ProcessContext c); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to extract the output in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public OutputT extractOutput(K key, AccumT accumulator, DoFn.ProcessContext c); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to compact the accumulator in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public AccumT compact(K key, AccumT accumulator, DoFn.ProcessContext c); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to combine the inputs and extract output + * in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public OutputT apply(K key, Iterable inputs, DoFn.ProcessContext c); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to add all inputs in a {@link DoFn}. + * + *

    It constructs a {@code CombineWithContext.Context} from {@code DoFn.ProcessContext} + * if it is required. + */ + public AccumT addInputs(K key, Iterable inputs, DoFn.ProcessContext c); + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Forwards the call to a {@link PerKeyCombineFn} to create the accumulator. + * + *

    It constructs a {@code CombineWithContext.Context} from + * {@link PipelineOptions} and {@link SideInputReader} if it is required. + */ + public AccumT createAccumulator(K key, PipelineOptions options, + SideInputReader sideInputReader, Collection windows); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to add the input. + * + *

    It constructs a {@code CombineWithContext.Context} from + * {@link PipelineOptions} and {@link SideInputReader} if it is required. + */ + public AccumT addInput(K key, AccumT accumulator, InputT value, PipelineOptions options, + SideInputReader sideInputReader, Collection windows); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to merge accumulators. + * + *

    It constructs a {@code CombineWithContext.Context} from + * {@link PipelineOptions} and {@link SideInputReader} if it is required. + */ + public AccumT mergeAccumulators(K key, Iterable accumulators, PipelineOptions options, + SideInputReader sideInputReader, Collection windows); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to extract the output. + * + *

    It constructs a {@code CombineWithContext.Context} from + * {@link PipelineOptions} and {@link SideInputReader} if it is required. + */ + public OutputT extractOutput(K key, AccumT accumulator, PipelineOptions options, + SideInputReader sideInputReader, Collection windows); + + /** + * Forwards the call to a {@link PerKeyCombineFn} to compact the accumulator. + * + *

    It constructs a {@code CombineWithContext.Context} from + * {@link PipelineOptions} and {@link SideInputReader} if it is required. + */ + public AccumT compact(K key, AccumT accumulator, PipelineOptions options, + SideInputReader sideInputReader, Collection windows); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PerKeyCombineFnRunners.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PerKeyCombineFnRunners.java new file mode 100644 index 000000000000..6606c5451f7e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PerKeyCombineFnRunners.java @@ -0,0 +1,257 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.RequiresContextInternal; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.common.collect.Iterables; + +import java.util.Collection; + +/** + * Static utility methods that provide {@link PerKeyCombineFnRunner} implementations + * for different keyed combine functions. + */ +public class PerKeyCombineFnRunners { + /** + * Returns a {@link PerKeyCombineFnRunner} from a {@link PerKeyCombineFn}. + */ + public static PerKeyCombineFnRunner + create(PerKeyCombineFn perKeyCombineFn) { + if (perKeyCombineFn instanceof RequiresContextInternal) { + return new KeyedCombineFnWithContextRunner<>( + (KeyedCombineFnWithContext) perKeyCombineFn); + } else { + return new KeyedCombineFnRunner<>( + (KeyedCombineFn) perKeyCombineFn); + } + } + + /** + * An implementation of {@link PerKeyCombineFnRunner} with {@link KeyedCombineFn}. + * + * It forwards functions calls to the {@link KeyedCombineFn}. + */ + private static class KeyedCombineFnRunner + implements PerKeyCombineFnRunner { + private final KeyedCombineFn keyedCombineFn; + + private KeyedCombineFnRunner( + KeyedCombineFn keyedCombineFn) { + this.keyedCombineFn = keyedCombineFn; + } + + @Override + public KeyedCombineFn fn() { + return keyedCombineFn; + } + + @Override + public AccumT createAccumulator(K key, DoFn.ProcessContext c) { + return keyedCombineFn.createAccumulator(key); + } + + @Override + public AccumT addInput( + K key, AccumT accumulator, InputT input, DoFn.ProcessContext c) { + return keyedCombineFn.addInput(key, accumulator, input); + } + + @Override + public AccumT mergeAccumulators( + K key, Iterable accumulators, DoFn.ProcessContext c) { + return keyedCombineFn.mergeAccumulators(key, accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, DoFn.ProcessContext c) { + return keyedCombineFn.extractOutput(key, accumulator); + } + + @Override + public AccumT compact(K key, AccumT accumulator, DoFn.ProcessContext c) { + return keyedCombineFn.compact(key, accumulator); + } + + @Override + public OutputT apply(K key, Iterable inputs, DoFn.ProcessContext c) { + return keyedCombineFn.apply(key, inputs); + } + + @Override + public AccumT addInputs(K key, Iterable inputs, DoFn.ProcessContext c) { + AccumT accum = keyedCombineFn.createAccumulator(key); + for (InputT input : inputs) { + accum = keyedCombineFn.addInput(key, accum, input); + } + return accum; + } + + @Override + public String toString() { + return keyedCombineFn.toString(); + } + + @Override + public AccumT createAccumulator(K key, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFn.createAccumulator(key); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT input, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFn.addInput(key, accumulator, input); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFn.mergeAccumulators(key, accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFn.extractOutput(key, accumulator); + } + + @Override + public AccumT compact(K key, AccumT accumulator, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFn.compact(key, accumulator); + } + } + + /** + * An implementation of {@link PerKeyCombineFnRunner} with {@link KeyedCombineFnWithContext}. + * + * It forwards functions calls to the {@link KeyedCombineFnWithContext}. + */ + private static class KeyedCombineFnWithContextRunner + implements PerKeyCombineFnRunner { + private final KeyedCombineFnWithContext keyedCombineFnWithContext; + + private KeyedCombineFnWithContextRunner( + KeyedCombineFnWithContext keyedCombineFnWithContext) { + this.keyedCombineFnWithContext = keyedCombineFnWithContext; + } + + @Override + public KeyedCombineFnWithContext fn() { + return keyedCombineFnWithContext; + } + + @Override + public AccumT createAccumulator(K key, DoFn.ProcessContext c) { + return keyedCombineFnWithContext.createAccumulator(key, + CombineContextFactory.createFromProcessContext(c)); + } + + @Override + public AccumT addInput( + K key, AccumT accumulator, InputT value, DoFn.ProcessContext c) { + return keyedCombineFnWithContext.addInput(key, accumulator, value, + CombineContextFactory.createFromProcessContext(c)); + } + + @Override + public AccumT mergeAccumulators( + K key, Iterable accumulators, DoFn.ProcessContext c) { + return keyedCombineFnWithContext.mergeAccumulators( + key, accumulators, CombineContextFactory.createFromProcessContext(c)); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, DoFn.ProcessContext c) { + return keyedCombineFnWithContext.extractOutput(key, accumulator, + CombineContextFactory.createFromProcessContext(c)); + } + + @Override + public AccumT compact(K key, AccumT accumulator, DoFn.ProcessContext c) { + return keyedCombineFnWithContext.compact(key, accumulator, + CombineContextFactory.createFromProcessContext(c)); + } + + @Override + public OutputT apply(K key, Iterable inputs, DoFn.ProcessContext c) { + return keyedCombineFnWithContext.apply(key, inputs, + CombineContextFactory.createFromProcessContext(c)); + } + + @Override + public AccumT addInputs(K key, Iterable inputs, DoFn.ProcessContext c) { + CombineWithContext.Context combineContext = CombineContextFactory.createFromProcessContext(c); + AccumT accum = keyedCombineFnWithContext.createAccumulator(key, combineContext); + for (InputT input : inputs) { + accum = keyedCombineFnWithContext.addInput(key, accum, input, combineContext); + } + return accum; + } + + @Override + public String toString() { + return keyedCombineFnWithContext.toString(); + } + + @Override + public AccumT createAccumulator(K key, PipelineOptions options, SideInputReader sideInputReader, + Collection windows) { + return keyedCombineFnWithContext.createAccumulator(key, + CombineContextFactory.createFromComponents( + options, sideInputReader, Iterables.getOnlyElement(windows))); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT input, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFnWithContext.addInput(key, accumulator, input, + CombineContextFactory.createFromComponents( + options, sideInputReader, Iterables.getOnlyElement(windows))); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFnWithContext.mergeAccumulators(key, accumulators, + CombineContextFactory.createFromComponents( + options, sideInputReader, Iterables.getOnlyElement(windows))); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFnWithContext.extractOutput(key, accumulator, + CombineContextFactory.createFromComponents( + options, sideInputReader, Iterables.getOnlyElement(windows))); + } + + @Override + public AccumT compact(K key, AccumT accumulator, PipelineOptions options, + SideInputReader sideInputReader, Collection windows) { + return keyedCombineFnWithContext.compact(key, accumulator, + CombineContextFactory.createFromComponents( + options, sideInputReader, Iterables.getOnlyElement(windows))); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java new file mode 100644 index 000000000000..5611fabe28a7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * Constant property names used by the SDK in CloudWorkflow specifications. + */ +public class PropertyNames { + public static final String ALLOWED_ENCODINGS = "allowed_encodings"; + public static final String APPEND_TRAILING_NEWLINES = "append_trailing_newlines"; + public static final String BIGQUERY_CREATE_DISPOSITION = "create_disposition"; + public static final String BIGQUERY_DATASET = "dataset"; + public static final String BIGQUERY_PROJECT = "project"; + public static final String BIGQUERY_SCHEMA = "schema"; + public static final String BIGQUERY_TABLE = "table"; + public static final String BIGQUERY_QUERY = "bigquery_query"; + public static final String BIGQUERY_FLATTEN_RESULTS = "bigquery_flatten_results"; + public static final String BIGQUERY_WRITE_DISPOSITION = "write_disposition"; + public static final String BIGQUERY_EXPORT_FORMAT = "bigquery_export_format"; + public static final String BIGQUERY_EXPORT_SCHEMA = "bigquery_export_schema"; + public static final String CO_GBK_RESULT_SCHEMA = "co_gbk_result_schema"; + public static final String COMBINE_FN = "combine_fn"; + public static final String COMPONENT_ENCODINGS = "component_encodings"; + public static final String COMPRESSION_TYPE = "compression_type"; + public static final String CUSTOM_SOURCE_FORMAT = "custom_source"; + public static final String CONCAT_SOURCE_SOURCES = "sources"; + public static final String CONCAT_SOURCE_BASE_SPECS = "base_specs"; + public static final String SOURCE_STEP_INPUT = "custom_source_step_input"; + public static final String SOURCE_SPEC = "spec"; + public static final String SOURCE_METADATA = "metadata"; + public static final String SOURCE_DOES_NOT_NEED_SPLITTING = "does_not_need_splitting"; + public static final String SOURCE_PRODUCES_SORTED_KEYS = "produces_sorted_keys"; + public static final String SOURCE_IS_INFINITE = "is_infinite"; + public static final String SOURCE_ESTIMATED_SIZE_BYTES = "estimated_size_bytes"; + public static final String ELEMENT = "element"; + public static final String ELEMENTS = "elements"; + public static final String ENCODING = "encoding"; + public static final String ENCODING_ID = "encoding_id"; + public static final String END_INDEX = "end_index"; + public static final String END_OFFSET = "end_offset"; + public static final String END_SHUFFLE_POSITION = "end_shuffle_position"; + public static final String ENVIRONMENT_VERSION_JOB_TYPE_KEY = "job_type"; + public static final String ENVIRONMENT_VERSION_MAJOR_KEY = "major"; + public static final String FILENAME = "filename"; + public static final String FILENAME_PREFIX = "filename_prefix"; + public static final String FILENAME_SUFFIX = "filename_suffix"; + public static final String FILEPATTERN = "filepattern"; + public static final String FOOTER = "footer"; + public static final String FORMAT = "format"; + public static final String HEADER = "header"; + public static final String INPUTS = "inputs"; + public static final String INPUT_CODER = "input_coder"; + public static final String IS_GENERATED = "is_generated"; + public static final String IS_PAIR_LIKE = "is_pair_like"; + public static final String IS_STREAM_LIKE = "is_stream_like"; + public static final String IS_WRAPPER = "is_wrapper"; + public static final String DISALLOW_COMBINER_LIFTING = "disallow_combiner_lifting"; + public static final String NON_PARALLEL_INPUTS = "non_parallel_inputs"; + public static final String NUM_SHARD_CODERS = "num_shard_coders"; + public static final String NUM_METADATA_SHARD_CODERS = "num_metadata_shard_coders"; + public static final String NUM_SHARDS = "num_shards"; + public static final String OBJECT_TYPE_NAME = "@type"; + public static final String OUTPUT = "output"; + public static final String OUTPUT_INFO = "output_info"; + public static final String OUTPUT_NAME = "output_name"; + public static final String PARALLEL_INPUT = "parallel_input"; + public static final String PHASE = "phase"; + public static final String PUBSUB_ID_LABEL = "pubsub_id_label"; + public static final String PUBSUB_SUBSCRIPTION = "pubsub_subscription"; + public static final String PUBSUB_TIMESTAMP_LABEL = "pubsub_timestamp_label"; + public static final String PUBSUB_TOPIC = "pubsub_topic"; + public static final String SCALAR_FIELD_NAME = "value"; + public static final String SERIALIZED_FN = "serialized_fn"; + public static final String SHARD_NAME_TEMPLATE = "shard_template"; + public static final String SHUFFLE_KIND = "shuffle_kind"; + public static final String SHUFFLE_READER_CONFIG = "shuffle_reader_config"; + public static final String SHUFFLE_WRITER_CONFIG = "shuffle_writer_config"; + public static final String SORT_VALUES = "sort_values"; + public static final String START_INDEX = "start_index"; + public static final String START_OFFSET = "start_offset"; + public static final String START_SHUFFLE_POSITION = "start_shuffle_position"; + public static final String STRIP_TRAILING_NEWLINES = "strip_trailing_newlines"; + public static final String TUPLE_TAGS = "tuple_tags"; + public static final String USE_INDEXED_FORMAT = "use_indexed_format"; + public static final String USER_FN = "user_fn"; + public static final String USER_NAME = "user_name"; + public static final String USES_KEYED_STATE = "uses_keyed_state"; + public static final String VALIDATE_SINK = "validate_sink"; + public static final String VALIDATE_SOURCE = "validate_source"; + public static final String VALUE = "value"; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RandomAccessData.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RandomAccessData.java new file mode 100644 index 000000000000..6c96c8e7033b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RandomAccessData.java @@ -0,0 +1,352 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.common.base.MoreObjects; +import com.google.common.io.ByteStreams; +import com.google.common.primitives.UnsignedBytes; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Comparator; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * An elastic-sized byte array which allows you to manipulate it as a stream, or access + * it directly. This allows for a quick succession of moving bytes from an {@link InputStream} + * to this wrapper to be used as an {@link OutputStream} and vice versa. This wrapper + * also provides random access to bytes stored within. This wrapper allows users to finely + * control the number of byte copies that occur. + * + * Anything stored within the in-memory buffer from offset {@link #size()} is considered temporary + * unused storage. + */ +@NotThreadSafe +public class RandomAccessData { + /** + * A {@link Coder} which encodes the valid parts of this stream. + * This follows the same encoding scheme as {@link ByteArrayCoder}. + * This coder is deterministic and consistent with equals. + * + * This coder does not support encoding positive infinity. + */ + public static class RandomAccessDataCoder extends AtomicCoder { + private static final RandomAccessDataCoder INSTANCE = new RandomAccessDataCoder(); + + @JsonCreator + public static RandomAccessDataCoder of() { + return INSTANCE; + } + + @Override + public void encode(RandomAccessData value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + if (value == POSITIVE_INFINITY) { + throw new CoderException("Positive infinity can not be encoded."); + } + if (!context.isWholeStream) { + VarInt.encode(value.size, outStream); + } + value.writeTo(outStream, 0, value.size); + } + + @Override + public RandomAccessData decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + RandomAccessData rval = new RandomAccessData(); + if (!context.isWholeStream) { + int length = VarInt.decodeInt(inStream); + rval.readFrom(inStream, 0, length); + } else { + ByteStreams.copy(inStream, rval.asOutputStream()); + } + return rval; + } + + @Override + public boolean consistentWithEquals() { + return true; + } + + @Override + public boolean isRegisterByteSizeObserverCheap( + RandomAccessData value, Coder.Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(RandomAccessData value, Coder.Context context) + throws Exception { + if (value == null) { + throw new CoderException("cannot encode a null in memory stream"); + } + long size = 0; + if (!context.isWholeStream) { + size += VarInt.getLength(value.size); + } + return size + value.size; + } + } + + public static final UnsignedLexicographicalComparator UNSIGNED_LEXICOGRAPHICAL_COMPARATOR = + new UnsignedLexicographicalComparator(); + + /** + * A {@link Comparator} that compares two byte arrays lexicographically. It compares + * values as a list of unsigned bytes. The first pair of values that follow any common prefix, + * or when one array is a prefix of the other, treats the shorter array as the lesser. + * For example, [] < [0x01] < [0x01, 0x7F] < [0x01, 0x80] < [0x02] < POSITIVE INFINITY. + * + *

    Note that a token type of positive infinity is supported and is greater than + * all other {@link RandomAccessData}. + */ + public static final class UnsignedLexicographicalComparator + implements Comparator { + // Do not instantiate + private UnsignedLexicographicalComparator() { + } + + @Override + public int compare(RandomAccessData o1, RandomAccessData o2) { + return compare(o1, o2, 0 /* start from the beginning */); + } + + /** + * Compare the two sets of bytes starting at the given offset. + */ + public int compare(RandomAccessData o1, RandomAccessData o2, int startOffset) { + if (o1 == o2) { + return 0; + } + if (o1 == POSITIVE_INFINITY) { + return 1; + } + if (o2 == POSITIVE_INFINITY) { + return -1; + } + + int minBytesLen = Math.min(o1.size, o2.size); + for (int i = startOffset; i < minBytesLen; i++) { + // unsigned comparison + int b1 = o1.buffer[i] & 0xFF; + int b2 = o2.buffer[i] & 0xFF; + if (b1 == b2) { + continue; + } + // Return the stream with the smaller byte as the smaller value. + return b1 - b2; + } + // If one is a prefix of the other, return the shorter one as the smaller one. + // If both lengths are equal, then both streams are equal. + return o1.size - o2.size; + } + + /** + * Compute the length of the common prefix of the two provided sets of bytes. + */ + public int commonPrefixLength(RandomAccessData o1, RandomAccessData o2) { + int minBytesLen = Math.min(o1.size, o2.size); + for (int i = 0; i < minBytesLen; i++) { + // unsigned comparison + int b1 = o1.buffer[i] & 0xFF; + int b2 = o2.buffer[i] & 0xFF; + if (b1 != b2) { + return i; + } + } + return minBytesLen; + } + } + + /** A token type representing positive infinity. */ + static final RandomAccessData POSITIVE_INFINITY = new RandomAccessData(0); + + /** + * Returns a RandomAccessData that is the smallest value of same length which + * is strictly greater than this. Note that if this is empty or is all 0xFF then + * a token value of positive infinity is returned. + * + * The {@link UnsignedLexicographicalComparator} supports comparing {@link RandomAccessData} + * with support for positive infinitiy. + */ + public RandomAccessData increment() throws IOException { + RandomAccessData copy = copy(); + for (int i = copy.size - 1; i >= 0; --i) { + if (copy.buffer[i] != UnsignedBytes.MAX_VALUE) { + copy.buffer[i] = UnsignedBytes.checkedCast(UnsignedBytes.toInt(copy.buffer[i]) + 1); + return copy; + } + } + return POSITIVE_INFINITY; + } + + private static final int DEFAULT_INITIAL_BUFFER_SIZE = 128; + + /** Constructs a RandomAccessData with a default buffer size. */ + public RandomAccessData() { + this(DEFAULT_INITIAL_BUFFER_SIZE); + } + + /** Constructs a RandomAccessData with the initial buffer. */ + public RandomAccessData(byte[] initialBuffer) { + checkNotNull(initialBuffer); + this.buffer = initialBuffer; + this.size = initialBuffer.length; + } + + /** Constructs a RandomAccessData with the given buffer size. */ + public RandomAccessData(int initialBufferSize) { + checkArgument(initialBufferSize >= 0, "Expected initial buffer size to be greater than zero."); + this.buffer = new byte[initialBufferSize]; + } + + private byte[] buffer; + private int size; + + /** Returns the backing array. */ + public byte[] array() { + return buffer; + } + + /** Returns the number of bytes in the backing array that are valid. */ + public int size() { + return size; + } + + /** Resets the end of the stream to the specified position. */ + public void resetTo(int position) { + ensureCapacity(position); + size = position; + } + + private final OutputStream outputStream = new OutputStream() { + @Override + public void write(int b) throws IOException { + ensureCapacity(size + 1); + buffer[size] = (byte) b; + size += 1; + } + + @Override + public void write(byte[] b, int offset, int length) throws IOException { + ensureCapacity(size + length); + System.arraycopy(b, offset, buffer, size, length); + size += length; + } + }; + + /** + * Returns an output stream which writes to the backing buffer from the current position. + * Note that the internal buffer will grow as required to accomodate all data written. + */ + public OutputStream asOutputStream() { + return outputStream; + } + + /** + * Returns an {@link InputStream} wrapper which supplies the portion of this backing byte buffer + * starting at {@code offset} and up to {@code length} bytes. Note that the returned + * {@link InputStream} is only a wrapper and any modifications to the underlying + * {@link RandomAccessData} will be visible by the {@link InputStream}. + */ + public InputStream asInputStream(final int offset, final int length) { + return new ByteArrayInputStream(buffer, offset, length); + } + + /** + * Writes {@code length} bytes starting at {@code offset} from the backing data store to the + * specified output stream. + */ + public void writeTo(OutputStream out, int offset, int length) throws IOException { + out.write(buffer, offset, length); + } + + /** + * Reads {@code length} bytes from the specified input stream writing them into the backing + * data store starting at {@code offset}. + * + *

    Note that the in memory stream will be grown to ensure there is enough capacity. + */ + public void readFrom(InputStream inStream, int offset, int length) throws IOException { + ensureCapacity(offset + length); + ByteStreams.readFully(inStream, buffer, offset, length); + size = offset + length; + } + + /** Returns a copy of this RandomAccessData. */ + public RandomAccessData copy() throws IOException { + RandomAccessData copy = new RandomAccessData(size); + writeTo(copy.asOutputStream(), 0, size); + return copy; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof RandomAccessData)) { + return false; + } + return UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare(this, (RandomAccessData) other) == 0; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < size; ++i) { + result = 31 * result + buffer[i]; + } + + return result; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("buffer", Arrays.copyOf(buffer, size)) + .add("size", size) + .toString(); + } + + private void ensureCapacity(int minCapacity) { + // If we have enough space, don't grow the buffer. + if (minCapacity <= buffer.length) { + return; + } + + // Try to double the size of the buffer, if thats not enough, just use the new capacity. + // Note that we use Math.min(long, long) to not cause overflow on the multiplication. + int newCapacity = (int) Math.min(Integer.MAX_VALUE, buffer.length * 2L); + if (newCapacity < minCapacity) { + newCapacity = minCapacity; + } + buffer = Arrays.copyOf(buffer, newCapacity); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFn.java new file mode 100644 index 000000000000..c5ef2ea12613 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFn.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; + +import org.joda.time.Instant; + +import java.io.Serializable; + +/** + * Specification for processing to happen after elements have been grouped by key. + * + * @param The type of key being processed. + * @param The type of input values associated with the key. + * @param The output type that will be produced for each key. + * @param The type of windows this operates on. + */ +public abstract class ReduceFn + implements Serializable { + + /** Information accessible to all the processing methods in this {@code ReduceFn}. */ + public abstract class Context { + /** Return the key that is being processed. */ + public abstract K key(); + + /** The window that is being processed. */ + public abstract W window(); + + /** Access the current {@link WindowingStrategy}. */ + public abstract WindowingStrategy windowingStrategy(); + + /** Return the interface for accessing state. */ + public abstract StateAccessor state(); + + /** Return the interface for accessing timers. */ + public abstract Timers timers(); + } + + /** Information accessible within {@link #processValue}. */ + public abstract class ProcessValueContext extends Context { + /** Return the actual value being processed. */ + public abstract InputT value(); + + /** Return the timestamp associated with the value. */ + public abstract Instant timestamp(); + } + + /** Information accessible within {@link #onMerge}. */ + public abstract class OnMergeContext extends Context { + /** Return the interface for accessing state. */ + @Override + public abstract MergingStateAccessor state(); + } + + /** Information accessible within {@link #onTrigger}. */ + public abstract class OnTriggerContext extends Context { + /** Returns the {@link PaneInfo} for the trigger firing being processed. */ + public abstract PaneInfo paneInfo(); + + /** Output the given value in the current window. */ + public abstract void output(OutputT value); + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Called for each value of type {@code InputT} associated with the current key. + */ + public abstract void processValue(ProcessValueContext c) throws Exception; + + /** + * Called when windows are merged. + */ + public abstract void onMerge(OnMergeContext context) throws Exception; + + /** + * Called when triggers fire. + * + *

    Implementations of {@link ReduceFn} should call {@link OnTriggerContext#output} to emit + * any results that should be included in the pane produced by this trigger firing. + */ + public abstract void onTrigger(OnTriggerContext context) throws Exception; + + /** + * Called before {@link #onMerge} is invoked to provide an opportunity to prefetch any needed + * state. + * + * @param c Context to use prefetch from. + */ + public void prefetchOnMerge(MergingStateAccessor c) throws Exception {} + + /** + * Called before {@link #onTrigger} is invoked to provide an opportunity to prefetch any needed + * state. + * + * @param context Context to use prefetch from. + */ + public void prefetchOnTrigger(StateAccessor context) {} + + /** + * Called to clear any persisted state that the {@link ReduceFn} may be holding. This will be + * called when the windowing is closing and will receive no future interactions. + */ + public abstract void clearState(Context context) throws Exception; + + /** + * Returns true if the there is no buffered state. + */ + public abstract ReadableState isEmpty(StateAccessor context); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java new file mode 100644 index 000000000000..bdbaf1098e3a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java @@ -0,0 +1,495 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.State; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateContext; +import com.google.cloud.dataflow.sdk.util.state.StateContexts; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces.WindowNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; + +import org.joda.time.Instant; + +import java.util.Collection; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Factory for creating instances of the various {@link ReduceFn} contexts. + */ +class ReduceFnContextFactory { + public interface OnTriggerCallbacks { + void output(OutputT toOutput); + } + + private final K key; + private final ReduceFn reduceFn; + private final WindowingStrategy windowingStrategy; + private final StateInternals stateInternals; + private final ActiveWindowSet activeWindows; + private final TimerInternals timerInternals; + private final WindowingInternals windowingInternals; + private final PipelineOptions options; + + ReduceFnContextFactory(K key, ReduceFn reduceFn, + WindowingStrategy windowingStrategy, StateInternals stateInternals, + ActiveWindowSet activeWindows, TimerInternals timerInternals, + WindowingInternals windowingInternals, PipelineOptions options) { + this.key = key; + this.reduceFn = reduceFn; + this.windowingStrategy = windowingStrategy; + this.stateInternals = stateInternals; + this.activeWindows = activeWindows; + this.timerInternals = timerInternals; + this.windowingInternals = windowingInternals; + this.options = options; + } + + /** Where should we look for state associated with a given window? */ + public static enum StateStyle { + /** All state is associated with the window itself. */ + DIRECT, + /** State is associated with the 'state address' windows tracked by the active window set. */ + RENAMED + } + + private StateAccessorImpl stateAccessor(W window, StateStyle style) { + return new StateAccessorImpl( + activeWindows, windowingStrategy.getWindowFn().windowCoder(), + stateInternals, StateContexts.createFromComponents(options, windowingInternals, window), + style); + } + + public ReduceFn.Context base(W window, StateStyle style) { + return new ContextImpl(stateAccessor(window, style)); + } + + public ReduceFn.ProcessValueContext forValue( + W window, InputT value, Instant timestamp, StateStyle style) { + return new ProcessValueContextImpl(stateAccessor(window, style), value, timestamp); + } + + public ReduceFn.OnTriggerContext forTrigger(W window, + ReadableState pane, StateStyle style, OnTriggerCallbacks callbacks) { + return new OnTriggerContextImpl(stateAccessor(window, style), pane, callbacks); + } + + public ReduceFn.OnMergeContext forMerge( + Collection activeToBeMerged, W mergeResult, StateStyle style) { + return new OnMergeContextImpl( + new MergingStateAccessorImpl(activeWindows, + windowingStrategy.getWindowFn().windowCoder(), + stateInternals, style, activeToBeMerged, mergeResult)); + } + + public ReduceFn.OnMergeContext forPremerge(W window) { + return new OnPremergeContextImpl(new PremergingStateAccessorImpl( + activeWindows, windowingStrategy.getWindowFn().windowCoder(), stateInternals, window)); + } + + private class TimersImpl implements Timers { + private final StateNamespace namespace; + + public TimersImpl(StateNamespace namespace) { + Preconditions.checkArgument(namespace instanceof WindowNamespace); + this.namespace = namespace; + } + + @Override + public void setTimer(Instant timestamp, TimeDomain timeDomain) { + timerInternals.setTimer(TimerData.of(namespace, timestamp, timeDomain)); + } + + @Override + public void deleteTimer(Instant timestamp, TimeDomain timeDomain) { + timerInternals.deleteTimer(TimerData.of(namespace, timestamp, timeDomain)); + } + + @Override + public Instant currentProcessingTime() { + return timerInternals.currentProcessingTime(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return timerInternals.currentSynchronizedProcessingTime(); + } + + @Override + @Nullable + public Instant currentEventTime() { + return timerInternals.currentInputWatermarkTime(); + } + } + + // ====================================================================== + // StateAccessors + // ====================================================================== + static class StateAccessorImpl implements StateAccessor { + + + protected final ActiveWindowSet activeWindows; + protected final StateContext context; + protected final StateNamespace windowNamespace; + protected final Coder windowCoder; + protected final StateInternals stateInternals; + protected final StateStyle style; + + public StateAccessorImpl(ActiveWindowSet activeWindows, Coder windowCoder, + StateInternals stateInternals, StateContext context, StateStyle style) { + + this.activeWindows = activeWindows; + this.windowCoder = windowCoder; + this.stateInternals = stateInternals; + this.context = checkNotNull(context); + this.windowNamespace = namespaceFor(context.window()); + this.style = style; + } + + protected StateNamespace namespaceFor(W window) { + return StateNamespaces.window(windowCoder, window); + } + + protected StateNamespace windowNamespace() { + return windowNamespace; + } + + W window() { + return context.window(); + } + + StateNamespace namespace() { + return windowNamespace(); + } + + @Override + public StateT access(StateTag address) { + switch (style) { + case DIRECT: + return stateInternals.state(windowNamespace(), address, context); + case RENAMED: + return stateInternals.state( + namespaceFor(activeWindows.writeStateAddress(context.window())), address, context); + } + throw new RuntimeException(); // cases are exhaustive. + } + } + + static class MergingStateAccessorImpl + extends StateAccessorImpl implements MergingStateAccessor { + private final Collection activeToBeMerged; + + public MergingStateAccessorImpl(ActiveWindowSet activeWindows, Coder windowCoder, + StateInternals stateInternals, StateStyle style, Collection activeToBeMerged, + W mergeResult) { + super(activeWindows, windowCoder, stateInternals, + StateContexts.windowOnly(mergeResult), style); + this.activeToBeMerged = activeToBeMerged; + } + + @Override + public StateT access(StateTag address) { + switch (style) { + case DIRECT: + return stateInternals.state(windowNamespace(), address, context); + case RENAMED: + return stateInternals.state( + namespaceFor(activeWindows.mergedWriteStateAddress( + activeToBeMerged, context.window())), + address, + context); + } + throw new RuntimeException(); // cases are exhaustive. + } + + @Override + public Map accessInEachMergingWindow( + StateTag address) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (W mergingWindow : activeToBeMerged) { + StateNamespace namespace = null; + switch (style) { + case DIRECT: + namespace = namespaceFor(mergingWindow); + break; + case RENAMED: + namespace = namespaceFor(activeWindows.writeStateAddress(mergingWindow)); + break; + } + Preconditions.checkNotNull(namespace); // cases are exhaustive. + builder.put(mergingWindow, stateInternals.state(namespace, address, context)); + } + return builder.build(); + } + } + + static class PremergingStateAccessorImpl + extends StateAccessorImpl implements MergingStateAccessor { + public PremergingStateAccessorImpl(ActiveWindowSet activeWindows, Coder windowCoder, + StateInternals stateInternals, W window) { + super(activeWindows, windowCoder, stateInternals, + StateContexts.windowOnly(window), StateStyle.RENAMED); + } + + Collection mergingWindows() { + return activeWindows.readStateAddresses(context.window()); + } + + @Override + public Map accessInEachMergingWindow( + StateTag address) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (W stateAddressWindow : activeWindows.readStateAddresses(context.window())) { + StateT stateForWindow = + stateInternals.state(namespaceFor(stateAddressWindow), address, context); + builder.put(stateAddressWindow, stateForWindow); + } + return builder.build(); + } + } + + // ====================================================================== + // Contexts + // ====================================================================== + + private class ContextImpl extends ReduceFn.Context { + private final StateAccessorImpl state; + private final TimersImpl timers; + + private ContextImpl(StateAccessorImpl state) { + reduceFn.super(); + this.state = state; + this.timers = new TimersImpl(state.namespace()); + } + + @Override + public K key() { + return key; + } + + @Override + public W window() { + return state.window(); + } + + @Override + public WindowingStrategy windowingStrategy() { + return windowingStrategy; + } + + @Override + public StateAccessor state() { + return state; + } + + @Override + public Timers timers() { + return timers; + } + } + + private class ProcessValueContextImpl + extends ReduceFn.ProcessValueContext { + private final InputT value; + private final Instant timestamp; + private final StateAccessorImpl state; + private final TimersImpl timers; + + private ProcessValueContextImpl(StateAccessorImpl state, + InputT value, Instant timestamp) { + reduceFn.super(); + this.state = state; + this.value = value; + this.timestamp = timestamp; + this.timers = new TimersImpl(state.namespace()); + } + + @Override + public K key() { + return key; + } + + @Override + public W window() { + return state.window(); + } + + @Override + public WindowingStrategy windowingStrategy() { + return windowingStrategy; + } + + @Override + public StateAccessor state() { + return state; + } + + @Override + public InputT value() { + return value; + } + + @Override + public Instant timestamp() { + return timestamp; + } + + @Override + public Timers timers() { + return timers; + } + } + + private class OnTriggerContextImpl extends ReduceFn.OnTriggerContext { + private final StateAccessorImpl state; + private final ReadableState pane; + private final OnTriggerCallbacks callbacks; + private final TimersImpl timers; + + private OnTriggerContextImpl(StateAccessorImpl state, ReadableState pane, + OnTriggerCallbacks callbacks) { + reduceFn.super(); + this.state = state; + this.pane = pane; + this.callbacks = callbacks; + this.timers = new TimersImpl(state.namespace()); + } + + @Override + public K key() { + return key; + } + + @Override + public W window() { + return state.window(); + } + + @Override + public WindowingStrategy windowingStrategy() { + return windowingStrategy; + } + + @Override + public StateAccessor state() { + return state; + } + + @Override + public PaneInfo paneInfo() { + return pane.read(); + } + + @Override + public void output(OutputT value) { + callbacks.output(value); + } + + @Override + public Timers timers() { + return timers; + } + } + + private class OnMergeContextImpl extends ReduceFn.OnMergeContext { + private final MergingStateAccessorImpl state; + private final TimersImpl timers; + + private OnMergeContextImpl(MergingStateAccessorImpl state) { + reduceFn.super(); + this.state = state; + this.timers = new TimersImpl(state.namespace()); + } + + @Override + public K key() { + return key; + } + + @Override + public WindowingStrategy windowingStrategy() { + return windowingStrategy; + } + + @Override + public MergingStateAccessor state() { + return state; + } + + @Override + public W window() { + return state.window(); + } + + @Override + public Timers timers() { + return timers; + } + } + + private class OnPremergeContextImpl extends ReduceFn.OnMergeContext { + private final PremergingStateAccessorImpl state; + private final TimersImpl timers; + + private OnPremergeContextImpl(PremergingStateAccessorImpl state) { + reduceFn.super(); + this.state = state; + this.timers = new TimersImpl(state.namespace()); + } + + @Override + public K key() { + return key; + } + + @Override + public WindowingStrategy windowingStrategy() { + return windowingStrategy; + } + + @Override + public MergingStateAccessor state() { + return state; + } + + @Override + public W window() { + return state.window(); + } + + @Override + public Timers timers() { + return timers; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java new file mode 100644 index 000000000000..fe5c4742103e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java @@ -0,0 +1,846 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey.GroupByKeyOnly; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.ReduceFnContextFactory.OnTriggerCallbacks; +import com.google.cloud.dataflow.sdk.util.ReduceFnContextFactory.StateStyle; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces.WindowNamespace; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Manages the execution of a {@link ReduceFn} after a {@link GroupByKeyOnly} has partitioned the + * {@link PCollection} by key. + * + *

    The {@link #onTrigger} relies on a {@link TriggerRunner} to manage the execution of + * the triggering logic. The {@code ReduceFnRunner}s responsibilities are: + * + *

      + *
    • Tracking the windows that are active (have buffered data) as elements arrive and + * triggers are fired. + *
    • Holding the watermark based on the timestamps of elements in a pane and releasing it + * when the trigger fires. + *
    • Calling the appropriate callbacks on {@link ReduceFn} based on trigger execution, timer + * firings, etc, and providing appropriate contexts to the {@link ReduceFn} for actions + * such as output. + *
    • Scheduling garbage collection of state associated with a specific window, and making that + * happen when the appropriate timer fires. + *
    + * + * @param The type of key being processed. + * @param The type of values associated with the key. + * @param The output type that will be produced for each key. + * @param The type of windows this operates on. + */ +public class ReduceFnRunner { + + /** + * The {@link ReduceFnRunner} depends on most aspects of the {@link WindowingStrategy}. + * + *
      + *
    • It runs the trigger from the {@link WindowingStrategy}.
    • + *
    • It merges windows according to the {@link WindowingStrategy}.
    • + *
    • It chooses how to track active windows and clear out expired windows + * according to the {@link WindowingStrategy}, based on the allowed lateness and + * whether windows can merge.
    • + *
    • It decides whether to emit empty final panes according to whether the + * {@link WindowingStrategy} requires it.
    • + *
    • It uses discarding or accumulation mode according to the {@link WindowingStrategy}.
    • + *
    + */ + private final WindowingStrategy windowingStrategy; + + private final OutputWindowedValue> outputter; + + private final StateInternals stateInternals; + + private final Aggregator droppedDueToClosedWindow; + + private final K key; + + private final OnMergeCallback onMergeCallback = new OnMergeCallback(); + + /** + * Track which windows are still active and which 'state address' windows contain state + * for a merged window. + * + *
      + *
    • State: Global map for all active windows for this computation and key. + *
    • Lifetime: Cleared when no active windows need to be tracked. A window lives within + * the active window set until its trigger is closed or the window is garbage collected. + *
    + */ + private final ActiveWindowSet activeWindows; + + /** + * Always a {@link SystemReduceFn}. + * + *
      + *
    • State: A bag of accumulated values, or the intermediate result of a combiner. + *
    • State style: RENAMED + *
    • Merging: Concatenate or otherwise combine the state from each merged window. + *
    • Lifetime: Cleared when a pane fires if DISCARDING_FIRED_PANES. Otherwise cleared + * when trigger is finished or when the window is garbage collected. + *
    + */ + private final ReduceFn reduceFn; + + /** + * Manage the setting and firing of timer events. + * + *
      + *
    • Merging: End-of-window and garbage collection timers are cancelled when windows are + * merged away. Timers created by triggers are never garbage collected and are left to + * fire and be ignored. + *
    • Lifetime: Timers automatically disappear after they fire. + *
    + */ + private final TimerInternals timerInternals; + + /** + * Manage the execution and state for triggers. + * + *
      + *
    • State: Tracks which sub-triggers have finished, and any additional state needed to + * determine when the trigger should fire. + *
    • State style: DIRECT + *
    • Merging: Finished bits are explicitly managed. Other state is eagerly merged as + * needed. + *
    • Lifetime: Most trigger state is cleared when the final pane is emitted. However + * the finished bits are left behind and must be cleared when the window is + * garbage collected. + *
    + */ + private final TriggerRunner triggerRunner; + + /** + * Store the output watermark holds for each window. + * + *
      + *
    • State: Bag of hold timestamps. + *
    • State style: RENAMED + *
    • Merging: Depending on {@link OutputTimeFn}, may need to be recalculated on merging. + * When a pane fires it may be necessary to add (back) an end-of-window or garbage collection + * hold. + *
    • Lifetime: Cleared when a pane fires or when the window is garbage collected. + *
    + */ + private final WatermarkHold watermarkHold; + + private final ReduceFnContextFactory contextFactory; + + /** + * Store the previously emitted pane (if any) for each window. + * + *
      + *
    • State: The previous {@link PaneInfo} passed to the user's {@link DoFn#processElement}, + * if any. + *
    • Style style: DIRECT + *
    • Merging: Always keyed by actual window, so does not depend on {@link #activeWindows}. + * Cleared when window is merged away. + *
    • Lifetime: Cleared when trigger is closed or window is garbage collected. + *
    + */ + private final PaneInfoTracker paneInfoTracker; + + /** + * Store whether we've seen any elements for a window since the last pane was emitted. + * + *
      + *
    • State: Unless DISCARDING_FIRED_PANES, a count of number of elements added so far. + *
    • State style: RENAMED. + *
    • Merging: Counts are summed when windows are merged. + *
    • Lifetime: Cleared when pane fires or window is garbage collected. + *
    + */ + private final NonEmptyPanes nonEmptyPanes; + + public ReduceFnRunner( + K key, + WindowingStrategy windowingStrategy, + StateInternals stateInternals, + TimerInternals timerInternals, + WindowingInternals> windowingInternals, + Aggregator droppedDueToClosedWindow, + ReduceFn reduceFn, + PipelineOptions options) { + this.key = key; + this.timerInternals = timerInternals; + this.paneInfoTracker = new PaneInfoTracker(timerInternals); + this.stateInternals = stateInternals; + this.outputter = new OutputViaWindowingInternals<>(windowingInternals); + this.droppedDueToClosedWindow = droppedDueToClosedWindow; + this.reduceFn = reduceFn; + + @SuppressWarnings("unchecked") + WindowingStrategy objectWindowingStrategy = + (WindowingStrategy) windowingStrategy; + this.windowingStrategy = objectWindowingStrategy; + + this.nonEmptyPanes = NonEmptyPanes.create(this.windowingStrategy, this.reduceFn); + + // Note this may incur I/O to load persisted window set data. + this.activeWindows = createActiveWindowSet(); + + this.contextFactory = + new ReduceFnContextFactory(key, reduceFn, this.windowingStrategy, + stateInternals, this.activeWindows, timerInternals, windowingInternals, options); + + this.watermarkHold = new WatermarkHold<>(timerInternals, windowingStrategy); + this.triggerRunner = + new TriggerRunner<>( + windowingStrategy.getTrigger(), + new TriggerContextFactory<>(windowingStrategy, stateInternals, activeWindows)); + } + + private ActiveWindowSet createActiveWindowSet() { + return windowingStrategy.getWindowFn().isNonMerging() + ? new NonMergingActiveWindowSet() + : new MergingActiveWindowSet(windowingStrategy.getWindowFn(), stateInternals); + } + + @VisibleForTesting + boolean isFinished(W window) { + return triggerRunner.isClosed(contextFactory.base(window, StateStyle.DIRECT).state()); + } + + /** + * Incorporate {@code values} into the underlying reduce function, and manage holds, timers, + * triggers, and window merging. + * + *

    The general strategy is: + *

      + *
    1. Use {@link WindowedValue#getWindows} (itself determined using + * {@link WindowFn#assignWindows}) to determine which windows each element belongs to. Some + * of those windows will already have state associated with them. The rest are considered + * NEW. + *
    2. Use {@link WindowFn#mergeWindows} to attempt to merge currently ACTIVE and NEW windows. + * Each NEW window will become either ACTIVE, MERGED, or EPHEMERAL. (See {@link + * ActiveWindowSet} for definitions of these terms.) + *
    3. If at all possible, eagerly substitute EPHEMERAL windows with their ACTIVE state address + * windows before any state is associated with the EPHEMERAL window. In the common case that + * windows for new elements are merged into existing ACTIVE windows then no additional + * storage or merging overhead will be incurred. + *
    4. Otherwise, keep track of the state address windows for ACTIVE windows so that their + * states can be merged on-demand when a pane fires. + *
    5. Process the element for each of the windows it's windows have been merged into according + * to {@link ActiveWindowSet}. Processing may require running triggers, setting timers, + * setting holds, and invoking {@link ReduceFn#onTrigger}. + *
    + */ + public void processElements(Iterable> values) throws Exception { + // If an incoming element introduces a new window, attempt to merge it into an existing + // window eagerly. The outcome is stored in the ActiveWindowSet. + collectAndMergeWindows(values); + + Set windowsToConsider = new HashSet<>(); + + // Process each element, using the updated activeWindows determined by collectAndMergeWindows. + for (WindowedValue value : values) { + windowsToConsider.addAll(processElement(value)); + } + + // Trigger output from any window for which the trigger is ready + for (W mergedWindow : windowsToConsider) { + ReduceFn.Context directContext = + contextFactory.base(mergedWindow, StateStyle.DIRECT); + ReduceFn.Context renamedContext = + contextFactory.base(mergedWindow, StateStyle.RENAMED); + triggerRunner.prefetchShouldFire(mergedWindow, directContext.state()); + emitIfAppropriate(directContext, renamedContext, false/* isEndOfWindow */); + } + + // We're all done with merging and emitting elements so can compress the activeWindow state. + activeWindows.removeEphemeralWindows(); + } + + public void persist() { + activeWindows.persist(); + } + + /** + * Extract the windows associated with the values, and invoke merge. + */ + private void collectAndMergeWindows(Iterable> values) throws Exception { + // No-op if no merging can take place + if (windowingStrategy.getWindowFn().isNonMerging()) { + return; + } + + // Collect the windows from all elements (except those which are too late) and + // make sure they are already in the active window set or are added as NEW windows. + for (WindowedValue value : values) { + for (BoundedWindow untypedWindow : value.getWindows()) { + @SuppressWarnings("unchecked") + W window = (W) untypedWindow; + + ReduceFn.Context directContext = + contextFactory.base(window, StateStyle.DIRECT); + if (triggerRunner.isClosed(directContext.state())) { + // This window has already been closed. + // We will update the counter for this in the corresponding processElement call. + continue; + } + + if (activeWindows.isActive(window)) { + Set stateAddressWindows = activeWindows.readStateAddresses(window); + if (stateAddressWindows.size() > 1) { + // This is a legacy window who's state has not been eagerly merged. + // Do that now. + ReduceFn.OnMergeContext premergeContext = + contextFactory.forPremerge(window); + reduceFn.onMerge(premergeContext); + watermarkHold.onMerge(premergeContext); + activeWindows.merged(window); + } + } + + // Add this window as NEW if we've not yet seen it. + activeWindows.addNew(window); + } + } + + // Merge all of the active windows and retain a mapping from source windows to result windows. + mergeActiveWindows(); + } + + private class OnMergeCallback implements ActiveWindowSet.MergeCallback { + /** + * Called from the active window set to indicate {@code toBeMerged} (of which only + * {@code activeToBeMerged} are ACTIVE and thus have state associated with them) will later + * be merged into {@code mergeResult}. + */ + @Override + public void prefetchOnMerge( + Collection toBeMerged, Collection activeToBeMerged, W mergeResult) throws Exception { + ReduceFn.OnMergeContext directMergeContext = + contextFactory.forMerge(activeToBeMerged, mergeResult, StateStyle.DIRECT); + ReduceFn.OnMergeContext renamedMergeContext = + contextFactory.forMerge(activeToBeMerged, mergeResult, StateStyle.RENAMED); + + // Prefetch various state. + triggerRunner.prefetchForMerge(mergeResult, activeToBeMerged, directMergeContext.state()); + reduceFn.prefetchOnMerge(renamedMergeContext.state()); + watermarkHold.prefetchOnMerge(renamedMergeContext.state()); + nonEmptyPanes.prefetchOnMerge(renamedMergeContext.state()); + } + + /** + * Called from the active window set to indicate {@code toBeMerged} (of which only + * {@code activeToBeMerged} are ACTIVE and thus have state associated with them) are about + * to be merged into {@code mergeResult}. + */ + @Override + public void onMerge(Collection toBeMerged, Collection activeToBeMerged, W mergeResult) + throws Exception { + // At this point activeWindows has NOT incorporated the results of the merge. + ReduceFn.OnMergeContext directMergeContext = + contextFactory.forMerge(activeToBeMerged, mergeResult, StateStyle.DIRECT); + ReduceFn.OnMergeContext renamedMergeContext = + contextFactory.forMerge(activeToBeMerged, mergeResult, StateStyle.RENAMED); + + // Run the reduceFn to perform any needed merging. + reduceFn.onMerge(renamedMergeContext); + + // Merge the watermark holds. + watermarkHold.onMerge(renamedMergeContext); + + // Merge non-empty pane state. + nonEmptyPanes.onMerge(renamedMergeContext.state()); + + // Have the trigger merge state as needed + triggerRunner.onMerge( + directMergeContext.window(), directMergeContext.timers(), directMergeContext.state()); + + for (W active : activeToBeMerged) { + if (active.equals(mergeResult)) { + // Not merged away. + continue; + } + // Cleanup flavor A: Currently ACTIVE window is about to become MERGED. + // Clear any state not already cleared by the onMerge calls above. + WindowTracing.debug("ReduceFnRunner.onMerge: Merging {} into {}", active, mergeResult); + ReduceFn.Context directClearContext = + contextFactory.base(active, StateStyle.DIRECT); + // No need for the end-of-window or garbage collection timers. + // We will establish a new end-of-window or garbage collection timer for the mergeResult + // window in processElement below. There must be at least one element for the mergeResult + // window since a new element with a new window must have triggered this onMerge. + cancelEndOfWindowAndGarbageCollectionTimers(directClearContext); + // We no longer care about any previous panes of merged away windows. The + // merge result window gets to start fresh if it is new. + paneInfoTracker.clear(directClearContext.state()); + } + } + } + + private void mergeActiveWindows() throws Exception { + activeWindows.merge(onMergeCallback); + } + + /** + * Process an element. + * @param value the value being processed + * + * @return the set of windows in which the element was actually processed + */ + private Collection processElement(WindowedValue value) throws Exception { + // Redirect element windows to the ACTIVE windows they have been merged into. + // The compressed representation (value, {window1, window2, ...}) actually represents + // distinct elements (value, window1), (value, window2), ... + // so if window1 and window2 merge, the resulting window will contain both copies + // of the value. + Collection windows = new ArrayList<>(); + for (BoundedWindow untypedWindow : value.getWindows()) { + @SuppressWarnings("unchecked") + W window = (W) untypedWindow; + W active = activeWindows.representative(window); + Preconditions.checkState(active != null, "Window %s should have been added", window); + windows.add(active); + } + + // Prefetch in each of the windows if we're going to need to process triggers + for (W window : windows) { + ReduceFn.ProcessValueContext directContext = contextFactory.forValue( + window, value.getValue(), value.getTimestamp(), StateStyle.DIRECT); + triggerRunner.prefetchForValue(window, directContext.state()); + } + + // Process the element for each (representative) window it belongs to. + for (W window : windows) { + ReduceFn.ProcessValueContext directContext = contextFactory.forValue( + window, value.getValue(), value.getTimestamp(), StateStyle.DIRECT); + ReduceFn.ProcessValueContext renamedContext = contextFactory.forValue( + window, value.getValue(), value.getTimestamp(), StateStyle.RENAMED); + + // Check to see if the triggerRunner thinks the window is closed. If so, drop that window. + if (triggerRunner.isClosed(directContext.state())) { + droppedDueToClosedWindow.addValue(1L); + WindowTracing.debug( + "ReduceFnRunner.processElement: Dropping element at {} for key:{}; window:{} " + + "since window is no longer active at inputWatermark:{}; outputWatermark:{}", + value.getTimestamp(), key, window, timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + continue; + } + + nonEmptyPanes.recordContent(renamedContext.state()); + + // Make sure we've scheduled the end-of-window or garbage collection timer for this window. + Instant timer = scheduleEndOfWindowOrGarbageCollectionTimer(directContext); + + // Hold back progress of the output watermark until we have processed the pane this + // element will be included within. If the element is too late for that, place a hold at + // the end-of-window or garbage collection time to allow empty panes to contribute elements + // which won't be dropped due to lateness by a following computation (assuming the following + // computation uses the same allowed lateness value...) + @Nullable Instant hold = watermarkHold.addHolds(renamedContext); + + if (hold != null) { + // Assert that holds have a proximate timer. + boolean holdInWindow = !hold.isAfter(window.maxTimestamp()); + boolean timerInWindow = !timer.isAfter(window.maxTimestamp()); + Preconditions.checkState( + holdInWindow == timerInWindow, + "set a hold at %s, a timer at %s, which disagree as to whether they are in window %s", + hold, + timer, + directContext.window()); + } + + // Execute the reduceFn, which will buffer the value as appropriate + reduceFn.processValue(renamedContext); + + // Run the trigger to update its state + triggerRunner.processValue( + directContext.window(), + directContext.timestamp(), + directContext.timers(), + directContext.state()); + } + + return windows; + } + + /** + * Called when an end-of-window, garbage collection, or trigger-specific timer fires. + */ + public void onTimer(TimerData timer) throws Exception { + // Which window is the timer for? + Preconditions.checkArgument(timer.getNamespace() instanceof WindowNamespace, + "Expected timer to be in WindowNamespace, but was in %s", timer.getNamespace()); + @SuppressWarnings("unchecked") + WindowNamespace windowNamespace = (WindowNamespace) timer.getNamespace(); + W window = windowNamespace.getWindow(); + ReduceFn.Context directContext = + contextFactory.base(window, StateStyle.DIRECT); + ReduceFn.Context renamedContext = + contextFactory.base(window, StateStyle.RENAMED); + + // Has this window had its trigger finish? + // - The trigger may implement isClosed as constant false. + // - If the window function does not support windowing then all windows will be considered + // active. + // So we must combine the above. + boolean windowIsActive = + activeWindows.isActive(window) && !triggerRunner.isClosed(directContext.state()); + + if (!windowIsActive) { + WindowTracing.debug( + "ReduceFnRunner.onTimer: Note that timer {} is for non-ACTIVE window {}", timer, window); + } + + // If this is an end-of-window timer then: + // 1. We need to set a GC timer + // 2. We need to let the PaneInfoTracker know that we are transitioning from early to late, + // and possibly emitting an on-time pane. + boolean isEndOfWindow = + TimeDomain.EVENT_TIME == timer.getDomain() + && timer.getTimestamp().equals(window.maxTimestamp()); + + // If this is a garbage collection timer then we should trigger and garbage collect the window. + Instant cleanupTime = window.maxTimestamp().plus(windowingStrategy.getAllowedLateness()); + boolean isGarbageCollection = + TimeDomain.EVENT_TIME == timer.getDomain() && timer.getTimestamp().equals(cleanupTime); + + if (isGarbageCollection) { + WindowTracing.debug( + "ReduceFnRunner.onTimer: Cleaning up for key:{}; window:{} at {} with " + + "inputWatermark:{}; outputWatermark:{}", + key, window, timer.getTimestamp(), timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + + if (windowIsActive) { + // We need to call onTrigger to emit the final pane if required. + // The final pane *may* be ON_TIME if no prior ON_TIME pane has been emitted, + // and the watermark has passed the end of the window. + onTrigger(directContext, renamedContext, isEndOfWindow, true/* isFinished */); + } + + // Cleanup flavor B: Clear all the remaining state for this window since we'll never + // see elements for it again. + clearAllState(directContext, renamedContext, windowIsActive); + } else { + WindowTracing.debug( + "ReduceFnRunner.onTimer: Triggering for key:{}; window:{} at {} with " + + "inputWatermark:{}; outputWatermark:{}", + key, window, timer.getTimestamp(), timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + if (windowIsActive) { + emitIfAppropriate(directContext, renamedContext, isEndOfWindow); + } + + if (isEndOfWindow) { + // Since we are processing an on-time firing we should schedule the garbage collection + // timer. (If getAllowedLateness is zero then the timer event will be considered a + // cleanup event and handled by the above). + // Note we must do this even if the trigger is finished so that we are sure to cleanup + // any final trigger tombstones. + Preconditions.checkState( + windowingStrategy.getAllowedLateness().isLongerThan(Duration.ZERO), + "Unexpected zero getAllowedLateness"); + WindowTracing.debug( + "ReduceFnRunner.onTimer: Scheduling cleanup timer for key:{}; window:{} at {} with " + + "inputWatermark:{}; outputWatermark:{}", + key, directContext.window(), cleanupTime, timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + directContext.timers().setTimer(cleanupTime, TimeDomain.EVENT_TIME); + } + } + } + + /** + * Clear all the state associated with {@code context}'s window. + * Should only be invoked if we know all future elements for this window will be considered + * beyond allowed lateness. + * This is a superset of the clearing done by {@link #emitIfAppropriate} below since: + *
      + *
    1. We can clear the trigger state tombstone since we'll never need to ask about it again. + *
    2. We can clear any remaining garbage collection hold. + *
    + */ + private void clearAllState( + ReduceFn.Context directContext, + ReduceFn.Context renamedContext, + boolean windowIsActive) + throws Exception { + if (windowIsActive) { + // Since window is still active the trigger has not closed. + reduceFn.clearState(renamedContext); + watermarkHold.clearHolds(renamedContext); + nonEmptyPanes.clearPane(renamedContext.state()); + triggerRunner.clearState( + directContext.window(), directContext.timers(), directContext.state()); + } else { + // Needed only for backwards compatibility over UPDATE. + // Clear any end-of-window or garbage collection holds keyed by the current window. + // Only needed if: + // - We have merging windows. + // - We are DISCARDING_FIRED_PANES. + // - A pane has fired. + // - But the trigger is not (yet) closed. + if (windowingStrategy.getMode() == AccumulationMode.DISCARDING_FIRED_PANES + && !windowingStrategy.getWindowFn().isNonMerging()) { + watermarkHold.clearHolds(directContext); + } + } + paneInfoTracker.clear(directContext.state()); + activeWindows.remove(directContext.window()); + // We'll never need to test for the trigger being closed again. + triggerRunner.clearFinished(directContext.state()); + } + + /** Should the reduce function state be cleared? */ + private boolean shouldDiscardAfterFiring(boolean isFinished) { + if (isFinished) { + // This is the last firing for trigger. + return true; + } + if (windowingStrategy.getMode() == AccumulationMode.DISCARDING_FIRED_PANES) { + // Nothing should be accumulated between panes. + return true; + } + return false; + } + + /** + * Possibly emit a pane if a trigger is ready to fire or timers require it, and cleanup state. + */ + private void emitIfAppropriate(ReduceFn.Context directContext, + ReduceFn.Context renamedContext, boolean isEndOfWindow) + throws Exception { + if (!triggerRunner.shouldFire( + directContext.window(), directContext.timers(), directContext.state())) { + // Ignore unless trigger is ready to fire + return; + } + + // Inform the trigger of the transition to see if it is finished + triggerRunner.onFire(directContext.window(), directContext.timers(), directContext.state()); + boolean isFinished = triggerRunner.isClosed(directContext.state()); + + // Will be able to clear all element state after triggering? + boolean shouldDiscard = shouldDiscardAfterFiring(isFinished); + + // Run onTrigger to produce the actual pane contents. + // As a side effect it will clear all element holds, but not necessarily any + // end-of-window or garbage collection holds. + onTrigger(directContext, renamedContext, isEndOfWindow, isFinished); + + // Now that we've triggered, the pane is empty. + nonEmptyPanes.clearPane(renamedContext.state()); + + // Cleanup buffered data if appropriate + if (shouldDiscard) { + // Cleanup flavor C: The user does not want any buffered data to persist between panes. + reduceFn.clearState(renamedContext); + } + + if (isFinished) { + // Cleanup flavor D: If trigger is closed we will ignore all new incoming elements. + // Clear state not otherwise cleared by onTrigger and clearPane above. + // Remember the trigger is, indeed, closed until the window is garbage collected. + triggerRunner.clearState( + directContext.window(), directContext.timers(), directContext.state()); + paneInfoTracker.clear(directContext.state()); + activeWindows.remove(directContext.window()); + } + } + + /** + * Do we need to emit a pane? + */ + private boolean needToEmit( + boolean isEmpty, boolean isEndOfWindow, boolean isFinished, PaneInfo.Timing timing) { + if (!isEmpty) { + // The pane has elements. + return true; + } + if (isEndOfWindow && timing == Timing.ON_TIME) { + // This is the unique ON_TIME pane. + return true; + } + if (isFinished && windowingStrategy.getClosingBehavior() == ClosingBehavior.FIRE_ALWAYS) { + // This is known to be the final pane, and the user has requested it even when empty. + return true; + } + return false; + } + + /** + * Run the {@link ReduceFn#onTrigger} method and produce any necessary output. + */ + private void onTrigger( + final ReduceFn.Context directContext, + ReduceFn.Context renamedContext, + boolean isEndOfWindow, + boolean isFinished) + throws Exception { + // Prefetch necessary states + ReadableState outputTimestampFuture = + watermarkHold.extractAndRelease(renamedContext, isFinished).readLater(); + ReadableState paneFuture = + paneInfoTracker.getNextPaneInfo(directContext, isEndOfWindow, isFinished).readLater(); + ReadableState isEmptyFuture = + nonEmptyPanes.isEmpty(renamedContext.state()).readLater(); + + reduceFn.prefetchOnTrigger(directContext.state()); + triggerRunner.prefetchOnFire(directContext.window(), directContext.state()); + + // Calculate the pane info. + final PaneInfo pane = paneFuture.read(); + // Extract the window hold, and as a side effect clear it. + final Instant outputTimestamp = outputTimestampFuture.read(); + + // Only emit a pane if it has data or empty panes are observable. + if (needToEmit(isEmptyFuture.read(), isEndOfWindow, isFinished, pane.getTiming())) { + // Run reduceFn.onTrigger method. + final List windows = Collections.singletonList(directContext.window()); + ReduceFn.OnTriggerContext renamedTriggerContext = + contextFactory.forTrigger(directContext.window(), paneFuture, StateStyle.RENAMED, + new OnTriggerCallbacks() { + @Override + public void output(OutputT toOutput) { + // We're going to output panes, so commit the (now used) PaneInfo. + // TODO: This is unnecessary if the trigger isFinished since the saved + // state will be immediately deleted. + paneInfoTracker.storeCurrentPaneInfo(directContext, pane); + + // Output the actual value. + outputter.outputWindowedValue( + KV.of(key, toOutput), outputTimestamp, windows, pane); + } + }); + + reduceFn.onTrigger(renamedTriggerContext); + } + } + + /** + * Make sure we'll eventually have a timer fire which will tell us to garbage collect + * the window state. For efficiency we may need to do this in two steps rather + * than one. Return the time at which the timer will fire. + * + *
      + *
    • If allowedLateness is zero then we'll garbage collect at the end of the window. + * For simplicity we'll set our own timer for this situation even though an + * {@link AfterWatermark} trigger may have also set an end-of-window timer. + * ({@code setTimer} is idempotent.) + *
    • If allowedLateness is non-zero then we could just always set a timer for the garbage + * collection time. However if the windows are large (eg hourly) and the allowedLateness is small + * (eg seconds) then we'll end up with nearly twice the number of timers in-flight. So we + * instead set an end-of-window timer and then roll that forward to a garbage collection timer + * when it fires. We use the input watermark to distinguish those cases. + *
    + */ + private Instant scheduleEndOfWindowOrGarbageCollectionTimer( + ReduceFn.Context directContext) { + Instant inputWM = timerInternals.currentInputWatermarkTime(); + Instant endOfWindow = directContext.window().maxTimestamp(); + Instant fireTime; + String which; + if (inputWM != null && endOfWindow.isBefore(inputWM)) { + fireTime = endOfWindow.plus(windowingStrategy.getAllowedLateness()); + which = "garbage collection"; + } else { + fireTime = endOfWindow; + which = "end-of-window"; + } + WindowTracing.trace( + "ReduceFnRunner.scheduleEndOfWindowOrGarbageCollectionTimer: Scheduling {} timer at {} for " + + "key:{}; window:{} where inputWatermark:{}; outputWatermark:{}", + which, + fireTime, + key, + directContext.window(), + inputWM, + timerInternals.currentOutputWatermarkTime()); + directContext.timers().setTimer(fireTime, TimeDomain.EVENT_TIME); + return fireTime; + } + + private void cancelEndOfWindowAndGarbageCollectionTimers(ReduceFn.Context context) { + WindowTracing.debug( + "ReduceFnRunner.cancelEndOfWindowAndGarbageCollectionTimers: Deleting timers for " + + "key:{}; window:{} where inputWatermark:{}; outputWatermark:{}", + key, context.window(), timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + Instant timer = context.window().maxTimestamp(); + context.timers().deleteTimer(timer, TimeDomain.EVENT_TIME); + if (windowingStrategy.getAllowedLateness().isLongerThan(Duration.ZERO)) { + timer = timer.plus(windowingStrategy.getAllowedLateness()); + context.timers().deleteTimer(timer, TimeDomain.EVENT_TIME); + } + } + + /** + * An object that can output a value with all of its windowing information. This is a deliberately + * restricted subinterface of {@link WindowingInternals} to express how it is used here. + */ + private interface OutputWindowedValue { + void outputWindowedValue(OutputT output, Instant timestamp, + Collection windows, PaneInfo pane); + } + + private static class OutputViaWindowingInternals + implements OutputWindowedValue { + + private final WindowingInternals windowingInternals; + + public OutputViaWindowingInternals(WindowingInternals windowingInternals) { + this.windowingInternals = windowingInternals; + } + + @Override + public void outputWindowedValue( + OutputT output, + Instant timestamp, + Collection windows, + PaneInfo pane) { + windowingInternals.outputWindowedValue(output, timestamp, windows, pane); + } + + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReifyTimestampAndWindowsDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReifyTimestampAndWindowsDoFn.java new file mode 100644 index 000000000000..88a1c15eb70c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReifyTimestampAndWindowsDoFn.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.values.KV; + +/** + * DoFn that makes timestamps and window assignments explicit in the value part of each key/value + * pair. + * + * @param the type of the keys of the input and output {@code PCollection}s + * @param the type of the values of the input {@code PCollection} + */ +@SystemDoFnInternal +public class ReifyTimestampAndWindowsDoFn + extends DoFn, KV>> { + @Override + public void processElement(ProcessContext c) + throws Exception { + KV kv = c.element(); + K key = kv.getKey(); + V value = kv.getValue(); + c.output(KV.of( + key, + WindowedValue.of( + value, + c.timestamp(), + c.windowingInternals().windows(), + c.pane()))); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Reshuffle.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Reshuffle.java new file mode 100644 index 000000000000..367db2dc5bec --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Reshuffle.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.NonMergingWindowFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * A {@link PTransform} that returns a {@link PCollection} equivalent to its input but operationally + * provides some of the side effects of a {@link GroupByKey}, in particular preventing fusion of + * the surrounding transforms, checkpointing and deduplication by id (see + * {@link ValueWithRecordId}). + * + *

    Performs a {@link GroupByKey} so that the data is key-partitioned. Configures the + * {@link WindowingStrategy} so that no data is dropped, but doesn't affect the need for + * the user to specify allowed lateness and accumulation mode before a user-inserted GroupByKey. + * + * @param The type of key being reshuffled on. + * @param The type of value being reshuffled. + */ +public class Reshuffle extends PTransform>, PCollection>> { + + private Reshuffle() { + } + + public static Reshuffle of() { + return new Reshuffle(); + } + + @Override + public PCollection> apply(PCollection> input) { + WindowingStrategy originalStrategy = input.getWindowingStrategy(); + // If the input has already had its windows merged, then the GBK that performed the merge + // will have set originalStrategy.getWindowFn() to InvalidWindows, causing the GBK contained + // here to fail. Instead, we install a valid WindowFn that leaves all windows unchanged. + Window.Bound> rewindow = Window + .>into(new PassThroughWindowFn<>(originalStrategy.getWindowFn())) + .triggering(new ReshuffleTrigger<>()) + .discardingFiredPanes() + .withAllowedLateness(Duration.millis(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis())); + + return input.apply(rewindow) + .apply(GroupByKey.create()) + // Set the windowing strategy directly, so that it doesn't get counted as the user having + // set allowed lateness. + .setWindowingStrategyInternal(originalStrategy) + .apply(ParDo.named("ExpandIterable").of( + new DoFn>, KV>() { + @Override + public void processElement(ProcessContext c) { + K key = c.element().getKey(); + for (V value : c.element().getValue()) { + c.output(KV.of(key, value)); + } + } + })); + } + + /** + * A {@link WindowFn} that leaves all associations between elements and windows unchanged. + * + *

    In order to implement all the abstract methods of {@link WindowFn}, this requires the + * prior {@link WindowFn}, to which all auxiliary functionality is delegated. + */ + private static class PassThroughWindowFn extends NonMergingWindowFn { + + /** The WindowFn prior to this. Used for its windowCoder, etc. */ + private final WindowFn priorWindowFn; + + public PassThroughWindowFn(WindowFn priorWindowFn) { + // Safe because it is only used privately here. + // At every point where a window is returned or accepted, it has been provided + // by priorWindowFn, so it is of the type expected. + @SuppressWarnings("unchecked") + WindowFn internalWindowFn = (WindowFn) priorWindowFn; + this.priorWindowFn = internalWindowFn; + } + + @Override + public Collection assignWindows(WindowFn.AssignContext c) + throws Exception { + // The windows are provided by priorWindowFn, which also provides the coder for them + @SuppressWarnings("unchecked") + Collection priorWindows = (Collection) c.windows(); + return priorWindows; + } + + @Override + public boolean isCompatible(WindowFn other) { + throw new UnsupportedOperationException( + String.format("%s.isCompatible() should never be called." + + " It is a private implementation detail of Reshuffle." + + " This message indicates a bug in the Dataflow SDK.", + getClass().getCanonicalName())); + } + + @Override + public Coder windowCoder() { + // Safe because priorWindowFn provides the windows also. + // The Coder is _not_ actually a coder for an arbitrary BoundedWindow. + return priorWindowFn.windowCoder(); + } + + @Override + public BoundedWindow getSideInputWindow(BoundedWindow window) { + throw new UnsupportedOperationException( + String.format("%s.getSideInputWindow() should never be called." + + " It is a private implementation detail of Reshuffle." + + " This message indicates a bug in the Dataflow SDK.", + getClass().getCanonicalName())); + } + + @Override + public Instant getOutputTime(Instant inputTimestamp, BoundedWindow window) { + return inputTimestamp; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReshuffleTrigger.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReshuffleTrigger.java new file mode 100644 index 000000000000..248f00589d93 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReshuffleTrigger.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; + +import org.joda.time.Instant; + +import java.util.List; + +/** + * The trigger used with {@link Reshuffle} which triggers on every element + * and never buffers state. + * + * @param The kind of window that is being reshuffled. + */ +public class ReshuffleTrigger extends Trigger { + + ReshuffleTrigger() { + super(null); + } + + @Override + public void onElement(Trigger.OnElementContext c) { } + + @Override + public void onMerge(Trigger.OnMergeContext c) { } + + @Override + protected Trigger getContinuationTrigger(List> continuationTriggers) { + return this; + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(W window) { + throw new UnsupportedOperationException( + "ReshuffleTrigger should not be used outside of Reshuffle"); + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return true; + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializer.java new file mode 100644 index 000000000000..756dce0a9985 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializer.java @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.http.HttpBackOffIOExceptionHandler; +import com.google.api.client.http.HttpBackOffUnsuccessfulResponseHandler; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpResponseInterceptor; +import com.google.api.client.http.HttpUnsuccessfulResponseHandler; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Implements a request initializer that adds retry handlers to all + * HttpRequests. + * + *

    This allows chaining through to another HttpRequestInitializer, since + * clients have exactly one HttpRequestInitializer, and Credential is also + * a required HttpRequestInitializer. + * + *

    Also can take a HttpResponseInterceptor to be applied to the responses. + */ +public class RetryHttpRequestInitializer implements HttpRequestInitializer { + + private static final Logger LOG = LoggerFactory.getLogger(RetryHttpRequestInitializer.class); + + /** + * Http response codes that should be silently ignored. + */ + private static final Set DEFAULT_IGNORED_RESPONSE_CODES = new HashSet<>( + Arrays.asList(307 /* Redirect, handled by the client library */, + 308 /* Resume Incomplete, handled by the client library */)); + + /** + * Http response timeout to use for hanging gets. + */ + private static final int HANGING_GET_TIMEOUT_SEC = 80; + + private static class LoggingHttpBackOffIOExceptionHandler + extends HttpBackOffIOExceptionHandler { + public LoggingHttpBackOffIOExceptionHandler(BackOff backOff) { + super(backOff); + } + + @Override + public boolean handleIOException(HttpRequest request, boolean supportsRetry) + throws IOException { + boolean willRetry = super.handleIOException(request, supportsRetry); + if (willRetry) { + LOG.debug("Request failed with IOException, will retry: {}", request.getUrl()); + } else { + LOG.warn("Request failed with IOException, will NOT retry: {}", request.getUrl()); + } + return willRetry; + } + } + + private static class LoggingHttpBackoffUnsuccessfulResponseHandler + implements HttpUnsuccessfulResponseHandler { + private final HttpBackOffUnsuccessfulResponseHandler handler; + private final Set ignoredResponseCodes; + + public LoggingHttpBackoffUnsuccessfulResponseHandler(BackOff backoff, + Sleeper sleeper, Set ignoredResponseCodes) { + this.ignoredResponseCodes = ignoredResponseCodes; + handler = new HttpBackOffUnsuccessfulResponseHandler(backoff); + handler.setSleeper(sleeper); + handler.setBackOffRequired( + new HttpBackOffUnsuccessfulResponseHandler.BackOffRequired() { + @Override + public boolean isRequired(HttpResponse response) { + int statusCode = response.getStatusCode(); + return (statusCode / 100 == 5) || // 5xx: server error + statusCode == 429; // 429: Too many requests + } + }); + } + + @Override + public boolean handleResponse(HttpRequest request, HttpResponse response, + boolean supportsRetry) throws IOException { + boolean retry = handler.handleResponse(request, response, supportsRetry); + if (retry) { + LOG.debug("Request failed with code {} will retry: {}", + response.getStatusCode(), request.getUrl()); + + } else if (!ignoredResponseCodes.contains(response.getStatusCode())) { + LOG.warn("Request failed with code {}, will NOT retry: {}", + response.getStatusCode(), request.getUrl()); + } + + return retry; + } + } + + @Deprecated + private final HttpRequestInitializer chained; + + private final HttpResponseInterceptor responseInterceptor; // response Interceptor to use + + private final NanoClock nanoClock; // used for testing + + private final Sleeper sleeper; // used for testing + + private Set ignoredResponseCodes = new HashSet<>(DEFAULT_IGNORED_RESPONSE_CODES); + + public RetryHttpRequestInitializer() { + this(Collections.emptyList()); + } + + /** + * @param chained a downstream HttpRequestInitializer, which will also be + * applied to HttpRequest initialization. May be null. + * + * @deprecated use {@link #RetryHttpRequestInitializer}. + */ + @Deprecated + public RetryHttpRequestInitializer(@Nullable HttpRequestInitializer chained) { + this(chained, Collections.emptyList()); + } + + /** + * @param additionalIgnoredResponseCodes a list of HTTP status codes that should not be logged. + */ + public RetryHttpRequestInitializer(Collection additionalIgnoredResponseCodes) { + this(additionalIgnoredResponseCodes, null); + } + + + /** + * @param chained a downstream HttpRequestInitializer, which will also be + * applied to HttpRequest initialization. May be null. + * @param additionalIgnoredResponseCodes a list of HTTP status codes that should not be logged. + * + * @deprecated use {@link #RetryHttpRequestInitializer(Collection)}. + */ + @Deprecated + public RetryHttpRequestInitializer(@Nullable HttpRequestInitializer chained, + Collection additionalIgnoredResponseCodes) { + this(chained, additionalIgnoredResponseCodes, null); + } + + /** + * @param additionalIgnoredResponseCodes a list of HTTP status codes that should not be logged. + * @param responseInterceptor HttpResponseInterceptor to be applied on all requests. May be null. + */ + public RetryHttpRequestInitializer( + Collection additionalIgnoredResponseCodes, + @Nullable HttpResponseInterceptor responseInterceptor) { + this(null, NanoClock.SYSTEM, Sleeper.DEFAULT, additionalIgnoredResponseCodes, + responseInterceptor); + } + + /** + * @param chained a downstream HttpRequestInitializer, which will also be applied to HttpRequest + * initialization. May be null. + * @param additionalIgnoredResponseCodes a list of HTTP status codes that should not be logged. + * @param responseInterceptor HttpResponseInterceptor to be applied on all requests. May be null. + * + * @deprecated use {@link #RetryHttpRequestInitializer(Collection, HttpResponseInterceptor)}. + */ + @Deprecated + public RetryHttpRequestInitializer( + @Nullable HttpRequestInitializer chained, + Collection additionalIgnoredResponseCodes, + @Nullable HttpResponseInterceptor responseInterceptor) { + this(chained, NanoClock.SYSTEM, Sleeper.DEFAULT, additionalIgnoredResponseCodes, + responseInterceptor); + } + + /** + * Visible for testing. + * + * @param chained a downstream HttpRequestInitializer, which will also be + * applied to HttpRequest initialization. May be null. + * @param nanoClock used as a timing source for knowing how much time has elapsed. + * @param sleeper used to sleep between retries. + * @param additionalIgnoredResponseCodes a list of HTTP status codes that should not be logged. + */ + RetryHttpRequestInitializer(@Nullable HttpRequestInitializer chained, + NanoClock nanoClock, Sleeper sleeper, Collection additionalIgnoredResponseCodes, + HttpResponseInterceptor responseInterceptor) { + this.chained = chained; + this.nanoClock = nanoClock; + this.sleeper = sleeper; + this.ignoredResponseCodes.addAll(additionalIgnoredResponseCodes); + this.responseInterceptor = responseInterceptor; + } + + @Override + public void initialize(HttpRequest request) throws IOException { + if (chained != null) { + chained.initialize(request); + } + + // Set a timeout for hanging-gets. + // TODO: Do this exclusively for work requests. + request.setReadTimeout(HANGING_GET_TIMEOUT_SEC * 1000); + + // Back off on retryable http errors. + request.setUnsuccessfulResponseHandler( + // A back-off multiplier of 2 raises the maximum request retrying time + // to approximately 5 minutes (keeping other back-off parameters to + // their default values). + new LoggingHttpBackoffUnsuccessfulResponseHandler( + new ExponentialBackOff.Builder().setNanoClock(nanoClock) + .setMultiplier(2).build(), + sleeper, ignoredResponseCodes)); + + // Retry immediately on IOExceptions. + LoggingHttpBackOffIOExceptionHandler loggingBackoffHandler = + new LoggingHttpBackOffIOExceptionHandler(BackOff.ZERO_BACKOFF); + request.setIOExceptionHandler(loggingBackoffHandler); + + // Set response initializer + if (responseInterceptor != null) { + request.setResponseInterceptor(responseInterceptor); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SerializableUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SerializableUtils.java new file mode 100644 index 000000000000..cacba0ea1704 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SerializableUtils.java @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.decodeFromByteArray; +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.common.base.Preconditions; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Arrays; + +/** + * Utilities for working with Serializables. + */ +public class SerializableUtils { + /** + * Serializes the argument into an array of bytes, and returns it. + * + * @throws IllegalArgumentException if there are errors when serializing + */ + public static byte[] serializeToByteArray(Serializable value) { + try { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(buffer)) { + oos.writeObject(value); + } + return buffer.toByteArray(); + } catch (IOException exn) { + throw new IllegalArgumentException( + "unable to serialize " + value, + exn); + } + } + + /** + * Deserializes an object from the given array of bytes, e.g., as + * serialized using {@link #serializeToByteArray}, and returns it. + * + * @throws IllegalArgumentException if there are errors when + * deserializing, using the provided description to identify what + * was being deserialized + */ + public static Object deserializeFromByteArray(byte[] encodedValue, + String description) { + try { + try (ObjectInputStream ois = new ObjectInputStream( + new ByteArrayInputStream(encodedValue))) { + return ois.readObject(); + } + } catch (IOException | ClassNotFoundException exn) { + throw new IllegalArgumentException( + "unable to deserialize " + description, + exn); + } + } + + public static T ensureSerializable(T value) { + @SuppressWarnings("unchecked") + T copy = (T) deserializeFromByteArray(serializeToByteArray(value), + value.toString()); + return copy; + } + + public static T clone(T value) { + @SuppressWarnings("unchecked") + T copy = (T) deserializeFromByteArray(serializeToByteArray(value), + value.toString()); + return copy; + } + + /** + * Serializes a Coder and verifies that it can be correctly deserialized. + * + *

    Throws a RuntimeException if serialized Coder cannot be deserialized, or + * if the deserialized instance is not equal to the original. + * + * @return the serialized Coder, as a {@link CloudObject} + */ + public static CloudObject ensureSerializable(Coder coder) { + // Make sure that Coders are java serializable as well since + // they are regularly captured within DoFn's. + Coder copy = (Coder) ensureSerializable((Serializable) coder); + + CloudObject cloudObject = copy.asCloudObject(); + + Coder decoded; + try { + decoded = Serializer.deserialize(cloudObject, Coder.class); + } catch (RuntimeException e) { + throw new RuntimeException( + String.format("Unable to deserialize Coder: %s. " + + "Check that a suitable constructor is defined. " + + "See Coder for details.", coder), e + ); + } + Preconditions.checkState(coder.equals(decoded), + String.format("Coder not equal to original after serialization, " + + "indicating that the Coder may not implement serialization " + + "correctly. Before: %s, after: %s, cloud encoding: %s", + coder, decoded, cloudObject)); + + return cloudObject; + } + + /** + * Serializes an arbitrary T with the given {@code Coder} and verifies + * that it can be correctly deserialized. + */ + public static T ensureSerializableByCoder( + Coder coder, T value, String errorContext) { + byte[] encodedValue; + try { + encodedValue = encodeToByteArray(coder, value); + } catch (CoderException exn) { + // TODO: Put in better element printing: + // truncate if too long. + throw new IllegalArgumentException( + errorContext + ": unable to encode value " + + value + " using " + coder, + exn); + } + try { + return decodeFromByteArray(coder, encodedValue); + } catch (CoderException exn) { + // TODO: Put in better encoded byte array printing: + // use printable chars with escapes instead of codes, and + // truncate if too long. + throw new IllegalArgumentException( + errorContext + ": unable to decode " + Arrays.toString(encodedValue) + + ", encoding of value " + value + ", using " + coder, + exn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Serializer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Serializer.java new file mode 100644 index 000000000000..6a8a337ab14f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Serializer.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Utility for converting objects between Java and Cloud representations. + */ +public final class Serializer { + // Delay initialization of statics until the first call to Serializer. + private static class SingletonHelper { + static final ObjectMapper OBJECT_MAPPER = createObjectMapper(); + static final ObjectMapper TREE_MAPPER = createTreeMapper(); + + /** + * Creates the object mapper that will be used for serializing Google API + * client maps into Jackson trees. + */ + private static ObjectMapper createTreeMapper() { + return new ObjectMapper(); + } + + /** + * Creates the object mapper that will be used for deserializing Jackson + * trees into objects. + */ + private static ObjectMapper createObjectMapper() { + ObjectMapper m = new ObjectMapper(); + // Ignore properties that are not used by the object. + m.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); + + // For parameters of type Object, use the @type property to determine the + // class to instantiate. + // + // TODO: It would be ideal to do this for all non-final classes. The + // problem with using DefaultTyping.NON_FINAL is that it insists on having + // type information in the JSON for classes with useful default + // implementations, such as List. Ideally, we'd combine these defaults + // with available type information if that information's present. + m.enableDefaultTypingAsProperty( + ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT, + PropertyNames.OBJECT_TYPE_NAME); + + m.registerModule(new CoderUtils.Jackson2Module()); + + return m; + } + } + + /** + * Deserializes an object from a Dataflow structured encoding (represented in + * Java as a map). + * + *

    The standard Dataflow SDK object serialization protocol is based on JSON. + * Data is typically encoded as a JSON object whose fields represent the + * object's data. + * + *

    The actual deserialization is performed by Jackson, which can deserialize + * public fields, use JavaBean setters, or use injection annotations to + * indicate how to construct the object. The {@link ObjectMapper} used is + * configured to use the "@type" field as the name of the class to instantiate + * (supporting polymorphic types), and may be further configured by + * annotations or via {@link ObjectMapper#registerModule}. + * + * @see + * Jackson Data-Binding + * @see + * Jackson-Annotations + * @param serialized the object in untyped decoded form (i.e. a nested {@link Map}) + * @param clazz the expected object class + */ + public static T deserialize(Map serialized, Class clazz) { + try { + return SingletonHelper.OBJECT_MAPPER.treeToValue( + SingletonHelper.TREE_MAPPER.valueToTree( + deserializeCloudKnownTypes(serialized)), + clazz); + } catch (JsonProcessingException e) { + throw new RuntimeException( + "Unable to deserialize class " + clazz, e); + } + } + + /** + * Recursively walks the supplied map, looking for well-known cloud type + * information (keyed as {@link PropertyNames#OBJECT_TYPE_NAME}, matching a + * URI value from the {@link CloudKnownType} enum. Upon finding this type + * information, it converts it into the correspondingly typed Java value. + */ + @SuppressWarnings("unchecked") + private static Object deserializeCloudKnownTypes(Object src) { + if (src instanceof Map) { + Map srcMap = (Map) src; + @Nullable Object value = srcMap.get(PropertyNames.SCALAR_FIELD_NAME); + @Nullable CloudKnownType type = + CloudKnownType.forUri((String) srcMap.get(PropertyNames.OBJECT_TYPE_NAME)); + if (type != null && value != null) { + // It's a value of a well-known cloud type; let the known type handler + // handle the translation. + Object result = type.parse(value, type.defaultClass()); + return result; + } + // Otherwise, it's just an ordinary map. + Map dest = new HashMap<>(srcMap.size()); + for (Map.Entry entry : srcMap.entrySet()) { + dest.put(entry.getKey(), deserializeCloudKnownTypes(entry.getValue())); + } + return dest; + } + if (src instanceof List) { + List srcList = (List) src; + List dest = new ArrayList<>(srcList.size()); + for (Object obj : srcList) { + dest.add(deserializeCloudKnownTypes(obj)); + } + return dest; + } + // Neither a Map nor a List; no translation needed. + return src; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ShardingWritableByteChannel.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ShardingWritableByteChannel.java new file mode 100644 index 000000000000..54794ef04a3d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ShardingWritableByteChannel.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; + +/** + * Implements a WritableByteChannel that may contain multiple output shards. + * + *

    This provides {@link #writeToShard}, which takes a shard number for + * writing to a particular shard. + * + *

    The channel is considered open if all downstream channels are open, and + * closes all downstream channels when closed. + */ +public class ShardingWritableByteChannel implements WritableByteChannel { + + /** + * Special shard number that causes a write to all shards. + */ + public static final int ALL_SHARDS = -2; + + + private final ArrayList writers = new ArrayList<>(); + + /** + * Returns the number of output shards. + */ + public int getNumShards() { + return writers.size(); + } + + /** + * Adds another shard output channel. + */ + public void addChannel(WritableByteChannel writer) { + writers.add(writer); + } + + /** + * Returns the WritableByteChannel associated with the given shard number. + */ + public WritableByteChannel getChannel(int shardNum) { + return writers.get(shardNum); + } + + /** + * Writes the buffer to the given shard. + * + *

    This does not change the current output shard. + * + * @return The total number of bytes written. If the shard number is + * {@link #ALL_SHARDS}, then the total is the sum of each individual shard + * write. + */ + public int writeToShard(int shardNum, ByteBuffer src) throws IOException { + if (shardNum >= 0) { + return writers.get(shardNum).write(src); + } + + switch (shardNum) { + case ALL_SHARDS: + int size = 0; + for (WritableByteChannel writer : writers) { + size += writer.write(src); + } + return size; + + default: + throw new IllegalArgumentException("Illegal shard number: " + shardNum); + } + } + + /** + * Writes a buffer to all shards. + * + *

    Same as calling {@code writeToShard(ALL_SHARDS, buf)}. + */ + @Override + public int write(ByteBuffer src) throws IOException { + return writeToShard(ALL_SHARDS, src); + } + + @Override + public boolean isOpen() { + for (WritableByteChannel writer : writers) { + if (!writer.isOpen()) { + return false; + } + } + + return true; + } + + @Override + public void close() throws IOException { + for (WritableByteChannel writer : writers) { + writer.close(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SideInputReader.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SideInputReader.java new file mode 100644 index 000000000000..37873f3136a2 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SideInputReader.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import javax.annotation.Nullable; + +/** + * The interface to objects that provide side inputs. Particular implementations + * may read a side input directly or use appropriate sorts of caching, etc. + */ +public interface SideInputReader { + /** + * Returns the value of the given {@link PCollectionView} for the given {@link BoundedWindow}. + * + *

    It is valid for a side input to be {@code null}. It is not valid for this to + * return {@code null} for any other reason. + */ + @Nullable + T get(PCollectionView view, BoundedWindow window); + + /** + * Returns true if the given {@link PCollectionView} is valid for this reader. + */ + boolean contains(PCollectionView view); + + /** + * Returns true if there are no side inputs in this reader. + */ + boolean isEmpty(); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SimpleDoFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SimpleDoFnRunner.java new file mode 100644 index 000000000000..15a5e518341f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SimpleDoFnRunner.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.DoFnRunners.OutputManager; +import com.google.cloud.dataflow.sdk.util.ExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.util.common.CounterSet.AddCounterMutator; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import java.util.List; + +/** + * Runs a {@link DoFn} by constructing the appropriate contexts and passing them in. + * + * @param the type of the DoFn's (main) input elements + * @param the type of the DoFn's (main) output elements + */ +public class SimpleDoFnRunner extends DoFnRunnerBase{ + + protected SimpleDoFnRunner(PipelineOptions options, DoFn fn, + SideInputReader sideInputReader, + OutputManager outputManager, + TupleTag mainOutputTag, List> sideOutputTags, StepContext stepContext, + AddCounterMutator addCounterMutator, WindowingStrategy windowingStrategy) { + super(options, fn, sideInputReader, outputManager, mainOutputTag, sideOutputTags, stepContext, + addCounterMutator, windowingStrategy); + } + + @Override + protected void invokeProcessElement(WindowedValue elem) { + final DoFn.ProcessContext processContext = createProcessContext(elem); + // This can contain user code. Wrap it in case it throws an exception. + try { + fn.processElement(processContext); + } catch (Exception ex) { + throw wrapUserCodeException(ex); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Stager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Stager.java new file mode 100644 index 000000000000..04fd599ab513 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Stager.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.services.dataflow.model.DataflowPackage; + +import java.util.List; + +/** + * Interface for staging files needed for running a Dataflow pipeline. + */ +public interface Stager { + /* Stage files and return a list of packages. */ + public List stageFiles(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StreamUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StreamUtils.java new file mode 100644 index 000000000000..268eb7fe4e9c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StreamUtils.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.lang.ref.SoftReference; + +/** + * Utility functions for stream operations. + */ +public class StreamUtils { + + private StreamUtils() { + } + + private static final int BUF_SIZE = 8192; + + private static ThreadLocal> threadLocalBuffer = new ThreadLocal<>(); + + /** + * Efficient converting stream to bytes. + */ + public static byte[] getBytes(InputStream stream) throws IOException { + if (stream instanceof ExposedByteArrayInputStream) { + // Fast path for the exposed version. + return ((ExposedByteArrayInputStream) stream).readAll(); + } else if (stream instanceof ByteArrayInputStream) { + // Fast path for ByteArrayInputStream. + byte[] ret = new byte[stream.available()]; + stream.read(ret); + return ret; + } + // Falls back to normal stream copying. + SoftReference refBuffer = threadLocalBuffer.get(); + byte[] buffer = refBuffer == null ? null : refBuffer.get(); + if (buffer == null) { + buffer = new byte[BUF_SIZE]; + threadLocalBuffer.set(new SoftReference(buffer)); + } + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + while (true) { + int r = stream.read(buffer); + if (r == -1) { + break; + } + outStream.write(buffer, 0, r); + } + return outStream.toByteArray(); + } + +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StringUtils.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StringUtils.java new file mode 100644 index 000000000000..3a18336d0e1c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/StringUtils.java @@ -0,0 +1,242 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Utilities for working with JSON and other human-readable string formats. + */ +public class StringUtils { + /** + * Converts the given array of bytes into a legal JSON string. + * + *

    Uses a simple strategy of converting each byte to a single char, + * except for non-printable chars, non-ASCII chars, and '%', '\', + * and '"', which are encoded as three chars in '%xx' format, where + * 'xx' is the hexadecimal encoding of the byte. + */ + public static String byteArrayToJsonString(byte[] bytes) { + StringBuilder sb = new StringBuilder(bytes.length * 2); + for (byte b : bytes) { + if (b >= 32 && b < 127) { + // A printable ascii character. + char c = (char) b; + if (c != '%' && c != '\\' && c != '\"') { + // Not an escape prefix or special character, either. + // Send through unchanged. + sb.append(c); + continue; + } + } + // Send through escaped. Use '%xx' format. + sb.append(String.format("%%%02x", b)); + } + return sb.toString(); + } + + /** + * Converts the given string, encoded using {@link #byteArrayToJsonString}, + * into a byte array. + * + * @throws IllegalArgumentException if the argument string is not legal + */ + public static byte[] jsonStringToByteArray(String string) { + List bytes = new ArrayList<>(); + for (int i = 0; i < string.length(); ) { + char c = string.charAt(i); + Byte b; + if (c == '%') { + // Escaped. Expect '%xx' format. + try { + b = (byte) Integer.parseInt(string.substring(i + 1, i + 3), 16); + } catch (IndexOutOfBoundsException | NumberFormatException exn) { + throw new IllegalArgumentException( + "not in legal encoded format; " + + "substring [" + i + ".." + (i + 2) + "] not in format \"%xx\"", + exn); + } + i += 3; + } else { + // Send through unchanged. + b = (byte) c; + i++; + } + bytes.add(b); + } + byte[] byteArray = new byte[bytes.size()]; + int i = 0; + for (Byte b : bytes) { + byteArray[i++] = b; + } + return byteArray; + } + + private static final String[] STANDARD_NAME_SUFFIXES = + new String[]{"DoFn", "Fn"}; + + /** + * Pattern to match a non-anonymous inner class. + * Eg, matches "Foo$Bar", or even "Foo$1$Bar", but not "Foo$1" or "Foo$1$2". + */ + private static final Pattern NAMED_INNER_CLASS = + Pattern.compile(".+\\$(?[^0-9].*)"); + + private static final String ANONYMOUS_CLASS_REGEX = "\\$[0-9]+\\$"; + + /** + * Returns a simple name for a class. + * + *

    Note: this is non-invertible - the name may be simplified to an + * extent that it cannot be mapped back to the original class. + * + *

    This can be used to generate human-readable names. It + * removes the package and outer classes from the name, + * and removes common suffixes. + * + *

    Examples: + *

      + *
    • {@code some.package.Word.SummaryDoFn} -> "Summary" + *
    • {@code another.package.PairingFn} -> "Pairing" + *
    + * + * @throws IllegalArgumentException if the class is anonymous + */ + public static String approximateSimpleName(Class clazz) { + return approximateSimpleName(clazz, /* dropOuterClassNames */ true); + } + + /** + * Returns a name for a PTransform class. + * + *

    This can be used to generate human-readable transform names. It + * removes the package from the name, and removes common suffixes. + * + *

    It is different than approximateSimpleName: + *

      + *
    • 1. It keeps the outer classes names. + *
    • 2. It removes the common transform inner class: "Bound". + *
    + * + *

    Examples: + *

      + *
    • {@code some.package.Word.Summary} -> "Word.Summary" + *
    • {@code another.package.Pairing.Bound} -> "Pairing" + *
    + */ + public static String approximatePTransformName(Class clazz) { + Preconditions.checkArgument(PTransform.class.isAssignableFrom(clazz)); + return approximateSimpleName(clazz, /* dropOuterClassNames */ false) + .replaceFirst("\\.Bound$", ""); + } + + /** + * Calculate the Levenshtein distance between two strings. + * + *

    The Levenshtein distance between two words is the minimum number of single-character edits + * (i.e. insertions, deletions or substitutions) required to change one string into the other. + */ + public static int getLevenshteinDistance(final String s, final String t) { + Preconditions.checkNotNull(s); + Preconditions.checkNotNull(t); + + // base cases + if (s.equals(t)) { + return 0; + } + if (s.length() == 0) { + return t.length(); + } + if (t.length() == 0) { + return s.length(); + } + + // create two work arrays to store integer distances + final int[] v0 = new int[t.length() + 1]; + final int[] v1 = new int[t.length() + 1]; + + // initialize v0 (the previous row of distances) + // this row is A[0][i]: edit distance for an empty s + // the distance is just the number of characters to delete from t + for (int i = 0; i < v0.length; i++) { + v0[i] = i; + } + + for (int i = 0; i < s.length(); i++) { + // calculate v1 (current row distances) from the previous row v0 + + // first element of v1 is A[i+1][0] + // edit distance is delete (i+1) chars from s to match empty t + v1[0] = i + 1; + + // use formula to fill in the rest of the row + for (int j = 0; j < t.length(); j++) { + int cost = (s.charAt(i) == t.charAt(j)) ? 0 : 1; + v1[j + 1] = Math.min(Math.min(v1[j] + 1, v0[j + 1] + 1), v0[j] + cost); + } + + // copy v1 (current row) to v0 (previous row) for next iteration + System.arraycopy(v1, 0, v0, 0, v0.length); + } + + return v1[t.length()]; + } + + private static String approximateSimpleName(Class clazz, boolean dropOuterClassNames) { + Preconditions.checkArgument(!clazz.isAnonymousClass(), + "Attempted to get simple name of anonymous class"); + + String fullName = clazz.getName(); + String shortName = fullName.substring(fullName.lastIndexOf('.') + 1); + + // Drop common suffixes for each named component. + String[] names = shortName.split("\\$"); + for (int i = 0; i < names.length; i++) { + names[i] = simplifyNameComponent(names[i]); + } + shortName = Joiner.on('$').join(names); + + if (dropOuterClassNames) { + // Simplify inner class name by dropping outer class prefixes. + Matcher m = NAMED_INNER_CLASS.matcher(shortName); + if (m.matches()) { + shortName = m.group("INNER"); + } + } else { + // Dropping anonymous outer classes + shortName = shortName.replaceAll(ANONYMOUS_CLASS_REGEX, "."); + shortName = shortName.replaceAll("\\$", "."); + } + return shortName; + } + + private static String simplifyNameComponent(String name) { + for (String suffix : STANDARD_NAME_SUFFIXES) { + if (name.endsWith(suffix) && name.length() > suffix.length()) { + return name.substring(0, name.length() - suffix.length()); + } + } + return name; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Structs.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Structs.java new file mode 100644 index 000000000000..c621c5564e1c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Structs.java @@ -0,0 +1,384 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.util.Data; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A collection of static methods for manipulating datastructure representations + * transferred via the Dataflow API. + */ +public final class Structs { + private Structs() {} // Non-instantiable + + public static String getString(Map map, String name) throws Exception { + return getValue(map, name, String.class, "a string"); + } + + public static String getString( + Map map, String name, @Nullable String defaultValue) + throws Exception { + return getValue(map, name, String.class, "a string", defaultValue); + } + + public static byte[] getBytes(Map map, String name) throws Exception { + @Nullable byte[] result = getBytes(map, name, null); + if (result == null) { + throw new ParameterNotFoundException(name, map); + } + return result; + } + + @Nullable + public static byte[] getBytes(Map map, String name, @Nullable byte[] defaultValue) + throws Exception { + @Nullable String jsonString = getString(map, name, null); + if (jsonString == null) { + return defaultValue; + } + // TODO: Need to agree on a format for encoding bytes in + // a string that can be sent to the backend, over the cloud + // map task work API. base64 encoding seems pretty common. Switch to it? + return StringUtils.jsonStringToByteArray(jsonString); + } + + public static Boolean getBoolean(Map map, String name) throws Exception { + return getValue(map, name, Boolean.class, "a boolean"); + } + + @Nullable + public static Boolean getBoolean( + Map map, String name, @Nullable Boolean defaultValue) + throws Exception { + return getValue(map, name, Boolean.class, "a boolean", defaultValue); + } + + public static Long getLong(Map map, String name) throws Exception { + return getValue(map, name, Long.class, "a long"); + } + + @Nullable + public static Long getLong(Map map, String name, @Nullable Long defaultValue) + throws Exception { + return getValue(map, name, Long.class, "a long", defaultValue); + } + + public static Integer getInt(Map map, String name) throws Exception { + return getValue(map, name, Integer.class, "an int"); + } + + @Nullable + public static Integer getInt(Map map, String name, @Nullable Integer defaultValue) + throws Exception { + return getValue(map, name, Integer.class, "an int", defaultValue); + } + + @Nullable + public static List getStrings( + Map map, String name, @Nullable List defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "a string or a list"); + } + return defaultValue; + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a list of strings, + // this is an empty list. + return Collections.emptyList(); + } + @Nullable String singletonString = decodeValue(value, String.class); + if (singletonString != null) { + return Collections.singletonList(singletonString); + } + if (!(value instanceof List)) { + throw new IncorrectTypeException(name, map, "a string or a list"); + } + @SuppressWarnings("unchecked") + List elements = (List) value; + List result = new ArrayList<>(elements.size()); + for (Object o : elements) { + @Nullable String s = decodeValue(o, String.class); + if (s == null) { + throw new IncorrectTypeException(name, map, "a list of strings"); + } + result.add(s); + } + return result; + } + + public static Map getObject(Map map, String name) + throws Exception { + @Nullable Map result = getObject(map, name, null); + if (result == null) { + throw new ParameterNotFoundException(name, map); + } + return result; + } + + @Nullable + public static Map getObject( + Map map, String name, @Nullable Map defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "an object"); + } + return defaultValue; + } + return checkObject(value, map, name); + } + + private static Map checkObject( + Object value, Map map, String name) throws Exception { + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as an object, this is an + // empty map. + return Collections.emptyMap(); + } + if (!(value instanceof Map)) { + throw new IncorrectTypeException(name, map, "an object (not a map)"); + } + @SuppressWarnings("unchecked") + Map mapValue = (Map) value; + if (!mapValue.containsKey(PropertyNames.OBJECT_TYPE_NAME)) { + throw new IncorrectTypeException(name, map, + "an object (no \"" + PropertyNames.OBJECT_TYPE_NAME + "\" field)"); + } + return mapValue; + } + + @Nullable + public static List> getListOfMaps(Map map, String name, + @Nullable List> defaultValue) throws Exception { + @Nullable + Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "a list"); + } + return defaultValue; + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a list, + // this is an empty list. + return Collections.>emptyList(); + } + + if (!(value instanceof List)) { + throw new IncorrectTypeException(name, map, "a list"); + } + + List elements = (List) value; + for (Object elem : elements) { + if (!(elem instanceof Map)) { + throw new IncorrectTypeException(name, map, "a list of Map objects"); + } + } + + @SuppressWarnings("unchecked") + List> result = (List>) elements; + return result; + } + + public static Map getDictionary( + Map map, String name) throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + throw new ParameterNotFoundException(name, map); + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a dictionary, this is + // an empty map. + return Collections.emptyMap(); + } + if (!(value instanceof Map)) { + throw new IncorrectTypeException(name, map, "a dictionary"); + } + @SuppressWarnings("unchecked") + Map result = (Map) value; + return result; + } + + @Nullable + public static Map getDictionary( + Map map, String name, @Nullable Map defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, "a dictionary"); + } + return defaultValue; + } + if (Data.isNull(value)) { + // This is a JSON literal null. When represented as a dictionary, this is + // an empty map. + return Collections.emptyMap(); + } + if (!(value instanceof Map)) { + throw new IncorrectTypeException(name, map, "a dictionary"); + } + @SuppressWarnings("unchecked") + Map result = (Map) value; + return result; + } + + // Builder operations. + + public static void addString(Map map, String name, String value) { + addObject(map, name, CloudObject.forString(value)); + } + + public static void addBoolean(Map map, String name, boolean value) { + addObject(map, name, CloudObject.forBoolean(value)); + } + + public static void addLong(Map map, String name, long value) { + addObject(map, name, CloudObject.forInteger(value)); + } + + public static void addObject( + Map map, String name, Map value) { + map.put(name, value); + } + + public static void addNull(Map map, String name) { + map.put(name, Data.nullOf(Object.class)); + } + + public static void addLongs(Map map, String name, long... longs) { + List> elements = new ArrayList<>(longs.length); + for (Long value : longs) { + elements.add(CloudObject.forInteger(value)); + } + map.put(name, elements); + } + + public static void addList( + Map map, String name, List> elements) { + map.put(name, elements); + } + + public static void addStringList(Map map, String name, List elements) { + ArrayList objects = new ArrayList<>(elements.size()); + for (String element : elements) { + objects.add(CloudObject.forString(element)); + } + addList(map, name, objects); + } + + public static > void addList( + Map map, String name, T[] elements) { + map.put(name, Arrays.asList(elements)); + } + + public static void addDictionary( + Map map, String name, Map value) { + map.put(name, value); + } + + public static void addDouble(Map map, String name, Double value) { + addObject(map, name, CloudObject.forFloat(value)); + } + + // Helper methods for a few of the accessor methods. + + private static T getValue(Map map, String name, Class clazz, String type) + throws Exception { + @Nullable T result = getValue(map, name, clazz, type, null); + if (result == null) { + throw new ParameterNotFoundException(name, map); + } + return result; + } + + @Nullable + private static T getValue( + Map map, String name, Class clazz, String type, @Nullable T defaultValue) + throws Exception { + @Nullable Object value = map.get(name); + if (value == null) { + if (map.containsKey(name)) { + throw new IncorrectTypeException(name, map, type); + } + return defaultValue; + } + T result = decodeValue(value, clazz); + if (result == null) { + // The value exists, but can't be decoded. + throw new IncorrectTypeException(name, map, type); + } + return result; + } + + @Nullable + private static T decodeValue(Object value, Class clazz) { + try { + if (value.getClass() == clazz) { + // decodeValue() is only called for final classes; if the class matches, + // it's safe to just return the value, and if it doesn't match, decoding + // is needed. + return clazz.cast(value); + } + if (!(value instanceof Map)) { + return null; + } + @SuppressWarnings("unchecked") + Map map = (Map) value; + @Nullable String typeName = (String) map.get(PropertyNames.OBJECT_TYPE_NAME); + if (typeName == null) { + return null; + } + @Nullable CloudKnownType knownType = CloudKnownType.forUri(typeName); + if (knownType == null) { + return null; + } + @Nullable Object scalar = map.get(PropertyNames.SCALAR_FIELD_NAME); + if (scalar == null) { + return null; + } + return knownType.parse(scalar, clazz); + } catch (ClassCastException e) { + // If any class cast fails during decoding, the value's not decodable. + return null; + } + } + + private static final class ParameterNotFoundException extends Exception { + public ParameterNotFoundException(String name, Map map) { + super("didn't find required parameter " + name + " in " + map); + } + } + + private static final class IncorrectTypeException extends Exception { + public IncorrectTypeException(String name, Map map, String type) { + super("required parameter " + name + " in " + map + " not " + type); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemDoFnInternal.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemDoFnInternal.java new file mode 100644 index 000000000000..3255ede8755a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemDoFnInternal.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to mark {@link DoFn DoFns} as an internal component of the Dataflow SDK. + * + *

    Currently, the only effect of this is to mark any aggregators reported by an annotated + * {@code DoFn} as a system counter (as opposed to a user counter). + * + *

    This is internal to the Dataflow SDK. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface SystemDoFnInternal {} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java new file mode 100644 index 000000000000..16657925e35f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.state.AccumulatorCombiningState; +import com.google.cloud.dataflow.sdk.util.state.BagState; +import com.google.cloud.dataflow.sdk.util.state.CombiningState; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateMerging; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; + +/** + * {@link ReduceFn} implementing the default reduction behaviors of {@link GroupByKey}. + * + * @param The type of key being processed. + * @param The type of values associated with the key. + * @param The output type that will be produced for each key. + * @param The type of windows this operates on. + */ +public abstract class SystemReduceFn + extends ReduceFn { + private static final String BUFFER_NAME = "buf"; + + /** + * Create a factory that produces {@link SystemReduceFn} instances that that buffer all of the + * input values in persistent state and produces an {@code Iterable}. + */ + public static SystemReduceFn, Iterable, W> + buffering(final Coder inputCoder) { + final StateTag> bufferTag = + StateTags.makeSystemTagInternal(StateTags.bag(BUFFER_NAME, inputCoder)); + return new SystemReduceFn, Iterable, W>(bufferTag) { + @Override + public void prefetchOnMerge(MergingStateAccessor state) throws Exception { + StateMerging.prefetchBags(state, bufferTag); + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + StateMerging.mergeBags(c.state(), bufferTag); + } + }; + } + + /** + * Create a factory that produces {@link SystemReduceFn} instances that combine all of the input + * values using a {@link CombineFn}. + */ + public static SystemReduceFn + combining( + final Coder keyCoder, final AppliedCombineFn combineFn) { + final StateTag> bufferTag; + if (combineFn.getFn() instanceof KeyedCombineFnWithContext) { + bufferTag = StateTags.makeSystemTagInternal( + StateTags.keyedCombiningValueWithContext( + BUFFER_NAME, combineFn.getAccumulatorCoder(), + (KeyedCombineFnWithContext) combineFn.getFn())); + + } else { + bufferTag = StateTags.makeSystemTagInternal( + StateTags.keyedCombiningValue( + BUFFER_NAME, combineFn.getAccumulatorCoder(), + (KeyedCombineFn) combineFn.getFn())); + } + return new SystemReduceFn(bufferTag) { + @Override + public void prefetchOnMerge(MergingStateAccessor state) throws Exception { + StateMerging.prefetchCombiningValues(state, bufferTag); + } + + @Override + public void onMerge(OnMergeContext c) throws Exception { + StateMerging.mergeCombiningValues(c.state(), bufferTag); + } + }; + } + + private StateTag> bufferTag; + + public SystemReduceFn( + StateTag> bufferTag) { + this.bufferTag = bufferTag; + } + + @Override + public void processValue(ProcessValueContext c) throws Exception { + c.state().access(bufferTag).add(c.value()); + } + + @Override + public void prefetchOnTrigger(StateAccessor state) { + state.access(bufferTag).readLater(); + } + + @Override + public void onTrigger(OnTriggerContext c) throws Exception { + c.output(c.state().access(bufferTag).read()); + } + + @Override + public void clearState(Context c) throws Exception { + c.state().access(bufferTag).clear(); + } + + @Override + public ReadableState isEmpty(StateAccessor state) { + return state.access(bufferTag).isEmpty(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TestCredential.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TestCredential.java new file mode 100644 index 000000000000..359e15774fd0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TestCredential.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.BearerToken; +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.auth.oauth2.TokenResponse; +import com.google.api.client.testing.http.MockHttpTransport; + +import java.io.IOException; + +/** + * Fake credential, for use in testing. + */ +public class TestCredential extends Credential { + + private final String token; + + public TestCredential() { + this("NULL"); + } + + public TestCredential(String token) { + super(new Builder( + BearerToken.authorizationHeaderAccessMethod()) + .setTransport(new MockHttpTransport())); + this.token = token; + } + + @Override + protected TokenResponse executeRefreshToken() throws IOException { + TokenResponse response = new TokenResponse(); + response.setExpiresInSeconds(5L * 60); + response.setAccessToken(token); + return response; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeDomain.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeDomain.java new file mode 100644 index 000000000000..4ff36f722a79 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeDomain.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +/** + * {@code TimeDomain} specifies whether an operation is based on + * timestamps of elements or current "real-world" time as reported while processing. + */ +public enum TimeDomain { + /** + * The {@code EVENT_TIME} domain corresponds to the timestamps on the elements. Time advances + * on the system watermark advances. + */ + EVENT_TIME, + + /** + * The {@code PROCESSING_TIME} domain corresponds to the current to the current (system) time. + * This is advanced during execution of the Dataflow pipeline. + */ + PROCESSING_TIME, + + /** + * Same as the {@code PROCESSING_TIME} domain, except it won't fire a timer set for time + * {@code T} until all timers from earlier stages set for a time earlier than {@code T} have + * fired. + */ + SYNCHRONIZED_PROCESSING_TIME; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeUtil.java new file mode 100644 index 000000000000..93195a763586 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimeUtil.java @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import org.joda.time.DateTime; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.ReadableDuration; +import org.joda.time.ReadableInstant; +import org.joda.time.chrono.ISOChronology; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * A helper class for converting between Dataflow API and SDK time + * representations. + * + *

    Dataflow API times are strings of the form + * {@code YYYY-MM-dd'T'HH:mm:ss[.nnnn]'Z'}: that is, RFC 3339 + * strings with optional fractional seconds and a 'Z' offset. + * + *

    Dataflow API durations are strings of the form {@code ['-']sssss[.nnnn]'s'}: + * that is, seconds with optional fractional seconds and a literal 's' at the end. + * + *

    In both formats, fractional seconds are either three digits (millisecond + * resolution), six digits (microsecond resolution), or nine digits (nanosecond + * resolution). + */ +public final class TimeUtil { + private TimeUtil() {} // Non-instantiable. + + private static final Pattern DURATION_PATTERN = Pattern.compile("(\\d+)(?:\\.(\\d+))?s"); + private static final Pattern TIME_PATTERN = + Pattern.compile("(\\d{4})-(\\d{2})-(\\d{2})T(\\d{2}):(\\d{2}):(\\d{2})(?:\\.(\\d+))?Z"); + + /** + * Converts a {@link ReadableInstant} into a Dateflow API time value. + */ + public static String toCloudTime(ReadableInstant instant) { + // Note that since Joda objects use millisecond resolution, we always + // produce either no fractional seconds or fractional seconds with + // millisecond resolution. + + // Translate the ReadableInstant to a DateTime with ISOChronology. + DateTime time = new DateTime(instant); + + int millis = time.getMillisOfSecond(); + if (millis == 0) { + return String.format("%04d-%02d-%02dT%02d:%02d:%02dZ", + time.getYear(), + time.getMonthOfYear(), + time.getDayOfMonth(), + time.getHourOfDay(), + time.getMinuteOfHour(), + time.getSecondOfMinute()); + } else { + return String.format("%04d-%02d-%02dT%02d:%02d:%02d.%03dZ", + time.getYear(), + time.getMonthOfYear(), + time.getDayOfMonth(), + time.getHourOfDay(), + time.getMinuteOfHour(), + time.getSecondOfMinute(), + millis); + } + } + + /** + * Converts a time value received via the Dataflow API into the corresponding + * {@link Instant}. + * @return the parsed time, or null if a parse error occurs + */ + @Nullable + public static Instant fromCloudTime(String time) { + Matcher matcher = TIME_PATTERN.matcher(time); + if (!matcher.matches()) { + return null; + } + int year = Integer.valueOf(matcher.group(1)); + int month = Integer.valueOf(matcher.group(2)); + int day = Integer.valueOf(matcher.group(3)); + int hour = Integer.valueOf(matcher.group(4)); + int minute = Integer.valueOf(matcher.group(5)); + int second = Integer.valueOf(matcher.group(6)); + int millis = 0; + + String frac = matcher.group(7); + if (frac != null) { + int fracs = Integer.valueOf(frac); + if (frac.length() == 3) { // millisecond resolution + millis = fracs; + } else if (frac.length() == 6) { // microsecond resolution + millis = fracs / 1000; + } else if (frac.length() == 9) { // nanosecond resolution + millis = fracs / 1000000; + } else { + return null; + } + } + + return new DateTime(year, month, day, hour, minute, second, millis, + ISOChronology.getInstanceUTC()).toInstant(); + } + + /** + * Converts a {@link ReadableDuration} into a Dataflow API duration string. + */ + public static String toCloudDuration(ReadableDuration duration) { + // Note that since Joda objects use millisecond resolution, we always + // produce either no fractional seconds or fractional seconds with + // millisecond resolution. + long millis = duration.getMillis(); + long seconds = millis / 1000; + millis = millis % 1000; + if (millis == 0) { + return String.format("%ds", seconds); + } else { + return String.format("%d.%03ds", seconds, millis); + } + } + + /** + * Converts a Dataflow API duration string into a {@link Duration}. + * @return the parsed duration, or null if a parse error occurs + */ + @Nullable + public static Duration fromCloudDuration(String duration) { + Matcher matcher = DURATION_PATTERN.matcher(duration); + if (!matcher.matches()) { + return null; + } + long millis = Long.valueOf(matcher.group(1)) * 1000; + String frac = matcher.group(2); + if (frac != null) { + long fracs = Long.valueOf(frac); + if (frac.length() == 3) { // millisecond resolution + millis += fracs; + } else if (frac.length() == 6) { // microsecond resolution + millis += fracs / 1000; + } else if (frac.length() == 9) { // nanosecond resolution + millis += fracs / 1000000; + } else { + return null; + } + } + return Duration.millis(millis); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimerInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimerInternals.java new file mode 100644 index 000000000000..c823ed39b167 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TimerInternals.java @@ -0,0 +1,269 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * Encapsulate interaction with time within the execution environment. + * + *

    This class allows setting and deleting timers, and also retrieving an + * estimate of the current time. + */ +public interface TimerInternals { + + /** + * Writes out a timer to be fired when the watermark reaches the given + * timestamp. + * + *

    The combination of {@code namespace}, {@code timestamp} and {@code domain} uniquely + * identify a timer. Multiple timers set for the same parameters can be safely deduplicated. + */ + void setTimer(TimerData timerKey); + + /** + * Deletes the given timer. + */ + void deleteTimer(TimerData timerKey); + + /** + * Returns the current timestamp in the {@link TimeDomain#PROCESSING_TIME} time domain. + */ + Instant currentProcessingTime(); + + /** + * Returns the current timestamp in the {@link TimeDomain#SYNCHRONIZED_PROCESSING_TIME} time + * domain or {@code null} if unknown. + */ + @Nullable + Instant currentSynchronizedProcessingTime(); + + /** + * Return the current, local input watermark timestamp for this computation + * in the {@link TimeDomain#EVENT_TIME} time domain. Return {@code null} if unknown. + * + *

    This value: + *

      + *
    1. Is monotonically increasing. + *
    2. May differ between workers due to network and other delays. + *
    3. Will never be ahead of the global input watermark for this computation. But it + * may be arbitrarily behind the global input watermark. + *
    4. Any element with a timestamp before the local input watermark can be considered + * 'locally late' and be subject to special processing or be dropped entirely. + *
    + * + *

    Note that because the local input watermark can be behind the global input watermark, + * it is possible for an element to be considered locally on-time even though it is + * globally late. + */ + @Nullable + Instant currentInputWatermarkTime(); + + /** + * Return the current, local output watermark timestamp for this computation + * in the {@link TimeDomain#EVENT_TIME} time domain. Return {@code null} if unknown. + * + *

    This value: + *

      + *
    1. Is monotonically increasing. + *
    2. Will never be ahead of {@link #currentInputWatermarkTime} as returned above. + *
    3. May differ between workers due to network and other delays. + *
    4. However will never be behind the global input watermark for any following computation. + *
    + * + *

    In pictures: + *

    +   *  |              |       |       |       |
    +   *  |              |   D   |   C   |   B   |   A
    +   *  |              |       |       |       |
    +   * GIWM     <=    GOWM <= LOWM <= LIWM <= GIWM
    +   * (next stage)
    +   * -------------------------------------------------> event time
    +   * 
    + * where + *
      + *
    • LOWM = local output water mark. + *
    • GOWM = global output water mark. + *
    • GIWM = global input water mark. + *
    • LIWM = local input water mark. + *
    • A = A globally on-time element. + *
    • B = A globally late, but locally on-time element. + *
    • C = A locally late element which may still contribute to the timestamp of a pane. + *
    • D = A locally late element which cannot contribute to the timestamp of a pane. + *
    + * + *

    Note that if a computation emits an element which is not before the current output watermark + * then that element will always appear locally on-time in all following computations. However, + * it is possible for an element emitted before the current output watermark to appear locally + * on-time in a following computation. Thus we must be careful to never assume locally late data + * viewed on the output of a computation remains locally late on the input of a following + * computation. + */ + @Nullable + Instant currentOutputWatermarkTime(); + + /** + * Data about a timer as represented within {@link TimerInternals}. + */ + public static class TimerData implements Comparable { + private final StateNamespace namespace; + private final Instant timestamp; + private final TimeDomain domain; + + private TimerData(StateNamespace namespace, Instant timestamp, TimeDomain domain) { + this.namespace = checkNotNull(namespace); + this.timestamp = checkNotNull(timestamp); + this.domain = checkNotNull(domain); + } + + public StateNamespace getNamespace() { + return namespace; + } + + public Instant getTimestamp() { + return timestamp; + } + + public TimeDomain getDomain() { + return domain; + } + + /** + * Construct the {@code TimerKey} for the given parameters. + */ + public static TimerData of(StateNamespace namespace, Instant timestamp, TimeDomain domain) { + return new TimerData(namespace, timestamp, domain); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof TimerData)) { + return false; + } + + TimerData that = (TimerData) obj; + return Objects.equals(this.domain, that.domain) + && this.timestamp.isEqual(that.timestamp) + && Objects.equals(this.namespace, that.namespace); + } + + @Override + public int hashCode() { + return Objects.hash(domain, timestamp, namespace); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("namespace", namespace) + .add("timestamp", timestamp) + .add("domain", domain) + .toString(); + } + + @Override + public int compareTo(TimerData o) { + return Long.compare(timestamp.getMillis(), o.getTimestamp().getMillis()); + } + } + + /** + * A {@link Coder} for {@link TimerData}. + */ + public class TimerDataCoder extends StandardCoder { + private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of(); + private static final InstantCoder INSTANT_CODER = InstantCoder.of(); + private final Coder windowCoder; + + public static TimerDataCoder of(Coder windowCoder) { + return new TimerDataCoder(windowCoder); + } + + @SuppressWarnings("unchecked") + @JsonCreator + public static TimerDataCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 components, got " + components.size()); + return of((Coder) components.get(0)); + } + + private TimerDataCoder(Coder windowCoder) { + this.windowCoder = windowCoder; + } + + @Override + public void encode(TimerData timer, OutputStream outStream, Context context) + throws CoderException, IOException { + Context nestedContext = context.nested(); + STRING_CODER.encode(timer.namespace.stringKey(), outStream, nestedContext); + INSTANT_CODER.encode(timer.timestamp, outStream, nestedContext); + STRING_CODER.encode(timer.domain.name(), outStream, nestedContext); + } + + @Override + public TimerData decode(InputStream inStream, Context context) + throws CoderException, IOException { + Context nestedContext = context.nested(); + StateNamespace namespace = + StateNamespaces.fromString(STRING_CODER.decode(inStream, nestedContext), windowCoder); + Instant timestamp = INSTANT_CODER.decode(inStream, nestedContext); + TimeDomain domain = TimeDomain.valueOf(STRING_CODER.decode(inStream, nestedContext)); + return TimerData.of(namespace, timestamp, domain); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(windowCoder); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic("window coder must be deterministic", windowCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Timers.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Timers.java new file mode 100644 index 000000000000..7d4b4f2abe57 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Timers.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; + +import org.joda.time.Instant; + +import javax.annotation.Nullable; + +/** + * Interface for interacting with time. + */ +@Experimental(Experimental.Kind.TIMERS) +public interface Timers { + /** + * Sets a timer to fire when the event time watermark, the current processing time, or + * the synchronized processing time watermark surpasses a given timestamp. + * + *

    See {@link TimeDomain} for details on the time domains available. + * + *

    Timers are not guaranteed to fire immediately, but will be delivered at some time + * afterwards. + * + *

    An implementation of {@link Timers} implicitly scopes timers that are set - they may + * be scoped to a key and window, or a key, window, and trigger, etc. + * + * @param timestamp the time at which the timer should be delivered + * @param timeDomain the domain that the {@code timestamp} applies to + */ + public abstract void setTimer(Instant timestamp, TimeDomain timeDomain); + + /** Removes the timer set in this context for the {@code timestmap} and {@code timeDomain}. */ + public abstract void deleteTimer(Instant timestamp, TimeDomain timeDomain); + + /** Returns the current processing time. */ + public abstract Instant currentProcessingTime(); + + /** Returns the current synchronized processing time or {@code null} if unknown. */ + @Nullable + public abstract Instant currentSynchronizedProcessingTime(); + + /** Returns the current event time or {@code null} if unknown. */ + @Nullable + public abstract Instant currentEventTime(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java new file mode 100644 index 000000000000..7735a9e01fcc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java @@ -0,0 +1,196 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.pubsub.Pubsub; +import com.google.api.services.storage.Storage; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineDebugOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.hadoop.util.ChainingHttpRequestInitializer; +import com.google.common.collect.ImmutableList; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.security.GeneralSecurityException; + +/** + * Helpers for cloud communication. + */ +public class Transport { + + private static class SingletonHelper { + /** Global instance of the JSON factory. */ + private static final JsonFactory JSON_FACTORY; + + /** Global instance of the HTTP transport. */ + private static final HttpTransport HTTP_TRANSPORT; + + static { + try { + JSON_FACTORY = JacksonFactory.getDefaultInstance(); + HTTP_TRANSPORT = GoogleNetHttpTransport.newTrustedTransport(); + } catch (GeneralSecurityException | IOException e) { + throw new RuntimeException(e); + } + } + } + + public static HttpTransport getTransport() { + return SingletonHelper.HTTP_TRANSPORT; + } + + public static JsonFactory getJsonFactory() { + return SingletonHelper.JSON_FACTORY; + } + + private static class ApiComponents { + public String rootUrl; + public String servicePath; + + public ApiComponents(String root, String path) { + this.rootUrl = root; + this.servicePath = path; + } + } + + private static ApiComponents apiComponentsFromUrl(String urlString) { + try { + URL url = new URL(urlString); + String rootUrl = url.getProtocol() + "://" + url.getHost() + + (url.getPort() > 0 ? ":" + url.getPort() : ""); + return new ApiComponents(rootUrl, url.getPath()); + } catch (MalformedURLException e) { + throw new RuntimeException("Invalid URL: " + urlString); + } + } + + /** + * Returns a BigQuery client builder. + * + *

    Note: this client's endpoint is not modified by the + * {@link DataflowPipelineDebugOptions#getApiRootUrl()} option. + */ + public static Bigquery.Builder + newBigQueryClient(BigQueryOptions options) { + return new Bigquery.Builder(getTransport(), getJsonFactory(), + chainHttpRequestInitializer( + options.getGcpCredential(), + // Do not log 404. It clutters the output and is possibly even required by the caller. + new RetryHttpRequestInitializer(ImmutableList.of(404)))) + .setApplicationName(options.getAppName()) + .setGoogleClientRequestInitializer(options.getGoogleApiTrace()); + } + + /** + * Returns a Pubsub client builder. + * + *

    Note: this client's endpoint is not modified by the + * {@link DataflowPipelineDebugOptions#getApiRootUrl()} option. + */ + public static Pubsub.Builder + newPubsubClient(DataflowPipelineOptions options) { + return new Pubsub.Builder(getTransport(), getJsonFactory(), + chainHttpRequestInitializer( + options.getGcpCredential(), + // Do not log 404. It clutters the output and is possibly even required by the caller. + new RetryHttpRequestInitializer(ImmutableList.of(404)))) + .setRootUrl(options.getPubsubRootUrl()) + .setApplicationName(options.getAppName()) + .setGoogleClientRequestInitializer(options.getGoogleApiTrace()); + } + + /** + * Returns a Google Cloud Dataflow client builder. + */ + public static Dataflow.Builder newDataflowClient(DataflowPipelineOptions options) { + String servicePath = options.getDataflowEndpoint(); + ApiComponents components; + if (servicePath.contains("://")) { + components = apiComponentsFromUrl(servicePath); + } else { + components = new ApiComponents(options.getApiRootUrl(), servicePath); + } + + return new Dataflow.Builder(getTransport(), + getJsonFactory(), + chainHttpRequestInitializer( + options.getGcpCredential(), + // Do not log 404. It clutters the output and is possibly even required by the caller. + new RetryHttpRequestInitializer(ImmutableList.of(404)))) + .setApplicationName(options.getAppName()) + .setRootUrl(components.rootUrl) + .setServicePath(components.servicePath) + .setGoogleClientRequestInitializer(options.getGoogleApiTrace()); + } + + /** + * Returns a Dataflow client that does not automatically retry failed + * requests. + */ + public static Dataflow.Builder + newRawDataflowClient(DataflowPipelineOptions options) { + return newDataflowClient(options) + .setHttpRequestInitializer(options.getGcpCredential()) + .setGoogleClientRequestInitializer(options.getGoogleApiTrace()); + } + + /** + * Returns a Cloud Storage client builder. + * + *

    Note: this client's endpoint is not modified by the + * {@link DataflowPipelineDebugOptions#getApiRootUrl()} option. + */ + public static Storage.Builder + newStorageClient(GcsOptions options) { + String servicePath = options.getGcsEndpoint(); + Storage.Builder storageBuilder = new Storage.Builder(getTransport(), getJsonFactory(), + chainHttpRequestInitializer( + options.getGcpCredential(), + // Do not log the code 404. Code up the stack will deal with 404's if needed, and + // logging it by default clutters the output during file staging. + new RetryHttpRequestInitializer( + ImmutableList.of(404), new UploadIdResponseInterceptor()))) + .setApplicationName(options.getAppName()) + .setGoogleClientRequestInitializer(options.getGoogleApiTrace()); + if (servicePath != null) { + ApiComponents components = apiComponentsFromUrl(servicePath); + storageBuilder.setRootUrl(components.rootUrl); + storageBuilder.setServicePath(components.servicePath); + } + return storageBuilder; + } + + private static HttpRequestInitializer chainHttpRequestInitializer( + Credential credential, HttpRequestInitializer httpRequestInitializer) { + if (credential == null) { + return httpRequestInitializer; + } else { + return new ChainingHttpRequestInitializer(credential, httpRequestInitializer); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java new file mode 100644 index 000000000000..64ff402a9aec --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java @@ -0,0 +1,522 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.MergingTriggerInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.TriggerInfo; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.State; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.common.base.Predicate; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; + +import org.joda.time.Instant; + +import java.util.Collection; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Factory for creating instances of the various {@link Trigger} contexts. + * + *

    These contexts are highly interdependent and share many fields; it is inadvisable + * to create them via any means other than this factory class. + */ +public class TriggerContextFactory { + + private final WindowingStrategy windowingStrategy; + private StateInternals stateInternals; + // Future triggers may be able to exploit the active window to state address window mapping. + @SuppressWarnings("unused") + private ActiveWindowSet activeWindows; + private final Coder windowCoder; + + public TriggerContextFactory(WindowingStrategy windowingStrategy, + StateInternals stateInternals, ActiveWindowSet activeWindows) { + this.windowingStrategy = windowingStrategy; + this.stateInternals = stateInternals; + this.activeWindows = activeWindows; + this.windowCoder = windowingStrategy.getWindowFn().windowCoder(); + } + + public Trigger.TriggerContext base(W window, Timers timers, + ExecutableTrigger rootTrigger, FinishedTriggers finishedSet) { + return new TriggerContextImpl(window, timers, rootTrigger, finishedSet); + } + + public Trigger.OnElementContext createOnElementContext( + W window, Timers timers, Instant elementTimestamp, + ExecutableTrigger rootTrigger, FinishedTriggers finishedSet) { + return new OnElementContextImpl(window, timers, rootTrigger, finishedSet, elementTimestamp); + } + + public Trigger.OnMergeContext createOnMergeContext(W window, Timers timers, + ExecutableTrigger rootTrigger, FinishedTriggers finishedSet, + Map finishedSets) { + return new OnMergeContextImpl(window, timers, rootTrigger, finishedSet, finishedSets); + } + + public StateAccessor createStateAccessor(W window, ExecutableTrigger trigger) { + return new StateAccessorImpl(window, trigger); + } + + public MergingStateAccessor createMergingStateAccessor( + W mergeResult, Collection mergingWindows, ExecutableTrigger trigger) { + return new MergingStateAccessorImpl(trigger, mergingWindows, mergeResult); + } + + private class TriggerInfoImpl implements Trigger.TriggerInfo { + + protected final ExecutableTrigger trigger; + protected final FinishedTriggers finishedSet; + private final Trigger.TriggerContext context; + + public TriggerInfoImpl(ExecutableTrigger trigger, FinishedTriggers finishedSet, + Trigger.TriggerContext context) { + this.trigger = trigger; + this.finishedSet = finishedSet; + this.context = context; + } + + @Override + public boolean isMerging() { + return !windowingStrategy.getWindowFn().isNonMerging(); + } + + @Override + public Iterable> subTriggers() { + return trigger.subTriggers(); + } + + @Override + public ExecutableTrigger subTrigger(int subtriggerIndex) { + return trigger.subTriggers().get(subtriggerIndex); + } + + @Override + public boolean isFinished() { + return finishedSet.isFinished(trigger); + } + + @Override + public boolean isFinished(int subtriggerIndex) { + return finishedSet.isFinished(subTrigger(subtriggerIndex)); + } + + @Override + public boolean areAllSubtriggersFinished() { + return Iterables.isEmpty(unfinishedSubTriggers()); + } + + @Override + public Iterable> unfinishedSubTriggers() { + return FluentIterable + .from(trigger.subTriggers()) + .filter(new Predicate>() { + @Override + public boolean apply(ExecutableTrigger trigger) { + return !finishedSet.isFinished(trigger); + } + }); + } + + @Override + public ExecutableTrigger firstUnfinishedSubTrigger() { + for (ExecutableTrigger subTrigger : trigger.subTriggers()) { + if (!finishedSet.isFinished(subTrigger)) { + return subTrigger; + } + } + return null; + } + + @Override + public void resetTree() throws Exception { + finishedSet.clearRecursively(trigger); + trigger.invokeClear(context); + } + + @Override + public void setFinished(boolean finished) { + finishedSet.setFinished(trigger, finished); + } + + @Override + public void setFinished(boolean finished, int subTriggerIndex) { + finishedSet.setFinished(subTrigger(subTriggerIndex), finished); + } + } + + private class TriggerTimers implements Timers { + + private final Timers timers; + private final W window; + + public TriggerTimers(W window, Timers timers) { + this.timers = timers; + this.window = window; + } + + @Override + public void setTimer(Instant timestamp, TimeDomain timeDomain) { + timers.setTimer(timestamp, timeDomain); + } + + @Override + public void deleteTimer(Instant timestamp, TimeDomain timeDomain) { + if (timeDomain == TimeDomain.EVENT_TIME + && timestamp.equals(window.maxTimestamp())) { + // Don't allow triggers to unset the at-max-timestamp timer. This is necessary for on-time + // state transitions. + return; + } + timers.deleteTimer(timestamp, timeDomain); + } + + @Override + public Instant currentProcessingTime() { + return timers.currentProcessingTime(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return timers.currentSynchronizedProcessingTime(); + } + + @Override + @Nullable + public Instant currentEventTime() { + return timers.currentEventTime(); + } + } + + private class MergingTriggerInfoImpl + extends TriggerInfoImpl implements Trigger.MergingTriggerInfo { + + private final Map finishedSets; + + public MergingTriggerInfoImpl( + ExecutableTrigger trigger, + FinishedTriggers finishedSet, + Trigger.TriggerContext context, + Map finishedSets) { + super(trigger, finishedSet, context); + this.finishedSets = finishedSets; + } + + @Override + public boolean finishedInAnyMergingWindow() { + for (FinishedTriggers finishedSet : finishedSets.values()) { + if (finishedSet.isFinished(trigger)) { + return true; + } + } + return false; + } + + @Override + public boolean finishedInAllMergingWindows() { + for (FinishedTriggers finishedSet : finishedSets.values()) { + if (!finishedSet.isFinished(trigger)) { + return false; + } + } + return true; + } + + @Override + public Iterable getFinishedMergingWindows() { + return Maps.filterValues(finishedSets, new Predicate() { + @Override + public boolean apply(FinishedTriggers finishedSet) { + return finishedSet.isFinished(trigger); + } + }).keySet(); + } + } + + private class StateAccessorImpl implements StateAccessor { + protected final int triggerIndex; + protected final StateNamespace windowNamespace; + + public StateAccessorImpl( + W window, + ExecutableTrigger trigger) { + this.triggerIndex = trigger.getTriggerIndex(); + this.windowNamespace = namespaceFor(window); + } + + protected StateNamespace namespaceFor(W window) { + return StateNamespaces.windowAndTrigger(windowCoder, window, triggerIndex); + } + + @Override + public StateT access(StateTag address) { + return stateInternals.state(windowNamespace, address); + } + } + + private class MergingStateAccessorImpl extends StateAccessorImpl + implements MergingStateAccessor { + private final Collection activeToBeMerged; + + public MergingStateAccessorImpl(ExecutableTrigger trigger, Collection activeToBeMerged, + W mergeResult) { + super(mergeResult, trigger); + this.activeToBeMerged = activeToBeMerged; + } + + @Override + public StateT access( + StateTag address) { + return stateInternals.state(windowNamespace, address); + } + + @Override + public Map accessInEachMergingWindow( + StateTag address) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (W mergingWindow : activeToBeMerged) { + StateT stateForWindow = stateInternals.state(namespaceFor(mergingWindow), address); + builder.put(mergingWindow, stateForWindow); + } + return builder.build(); + } + } + + private class TriggerContextImpl extends Trigger.TriggerContext { + + private final W window; + private final StateAccessorImpl state; + private final Timers timers; + private final TriggerInfoImpl triggerInfo; + + private TriggerContextImpl( + W window, + Timers timers, + ExecutableTrigger trigger, + FinishedTriggers finishedSet) { + trigger.getSpec().super(); + this.window = window; + this.state = new StateAccessorImpl(window, trigger); + this.timers = new TriggerTimers(window, timers); + this.triggerInfo = new TriggerInfoImpl(trigger, finishedSet, this); + } + + @Override + public Trigger.TriggerContext forTrigger(ExecutableTrigger trigger) { + return new TriggerContextImpl(window, timers, trigger, triggerInfo.finishedSet); + } + + @Override + public TriggerInfo trigger() { + return triggerInfo; + } + + @Override + public StateAccessor state() { + return state; + } + + @Override + public W window() { + return window; + } + + @Override + public void deleteTimer(Instant timestamp, TimeDomain domain) { + timers.deleteTimer(timestamp, domain); + } + + @Override + public Instant currentProcessingTime() { + return timers.currentProcessingTime(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return timers.currentSynchronizedProcessingTime(); + } + + @Override + @Nullable + public Instant currentEventTime() { + return timers.currentEventTime(); + } + } + + private class OnElementContextImpl extends Trigger.OnElementContext { + + private final W window; + private final StateAccessorImpl state; + private final Timers timers; + private final TriggerInfoImpl triggerInfo; + private final Instant eventTimestamp; + + private OnElementContextImpl( + W window, + Timers timers, + ExecutableTrigger trigger, + FinishedTriggers finishedSet, + Instant eventTimestamp) { + trigger.getSpec().super(); + this.window = window; + this.state = new StateAccessorImpl(window, trigger); + this.timers = new TriggerTimers(window, timers); + this.triggerInfo = new TriggerInfoImpl(trigger, finishedSet, this); + this.eventTimestamp = eventTimestamp; + } + + + @Override + public Instant eventTimestamp() { + return eventTimestamp; + } + + @Override + public Trigger.OnElementContext forTrigger(ExecutableTrigger trigger) { + return new OnElementContextImpl( + window, timers, trigger, triggerInfo.finishedSet, eventTimestamp); + } + + @Override + public TriggerInfo trigger() { + return triggerInfo; + } + + @Override + public StateAccessor state() { + return state; + } + + @Override + public W window() { + return window; + } + + @Override + public void setTimer(Instant timestamp, TimeDomain domain) { + timers.setTimer(timestamp, domain); + } + + + @Override + public void deleteTimer(Instant timestamp, TimeDomain domain) { + timers.deleteTimer(timestamp, domain); + } + + @Override + public Instant currentProcessingTime() { + return timers.currentProcessingTime(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return timers.currentSynchronizedProcessingTime(); + } + + @Override + @Nullable + public Instant currentEventTime() { + return timers.currentEventTime(); + } + } + + private class OnMergeContextImpl extends Trigger.OnMergeContext { + private final MergingStateAccessor state; + private final W window; + private final Collection mergingWindows; + private final Timers timers; + private final MergingTriggerInfoImpl triggerInfo; + + private OnMergeContextImpl( + W window, + Timers timers, + ExecutableTrigger trigger, + FinishedTriggers finishedSet, + Map finishedSets) { + trigger.getSpec().super(); + this.mergingWindows = finishedSets.keySet(); + this.window = window; + this.state = new MergingStateAccessorImpl(trigger, mergingWindows, window); + this.timers = new TriggerTimers(window, timers); + this.triggerInfo = new MergingTriggerInfoImpl(trigger, finishedSet, this, finishedSets); + } + + @Override + public Trigger.OnMergeContext forTrigger(ExecutableTrigger trigger) { + return new OnMergeContextImpl( + window, timers, trigger, triggerInfo.finishedSet, triggerInfo.finishedSets); + } + + @Override + public MergingStateAccessor state() { + return state; + } + + @Override + public MergingTriggerInfo trigger() { + return triggerInfo; + } + + @Override + public W window() { + return window; + } + + @Override + public void setTimer(Instant timestamp, TimeDomain domain) { + timers.setTimer(timestamp, domain); + } + + @Override + public void deleteTimer(Instant timestamp, TimeDomain domain) { + timers.setTimer(timestamp, domain); + + } + + @Override + public Instant currentProcessingTime() { + return timers.currentProcessingTime(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return timers.currentSynchronizedProcessingTime(); + } + + @Override + @Nullable + public Instant currentEventTime() { + return timers.currentEventTime(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java new file mode 100644 index 000000000000..dcfd03516b74 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java @@ -0,0 +1,223 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.util.state.ValueState; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; + +import org.joda.time.Instant; + +import java.util.BitSet; +import java.util.Collection; +import java.util.Map; + +/** + * Executes a trigger while managing persistence of information about which subtriggers are + * finished. Subtriggers include all recursive trigger expressions as well as the entire trigger. + * + *

    Specifically, the responsibilities are: + * + *

      + *
    • Invoking the trigger's methods via its {@link ExecutableTrigger} wrapper by + * constructing the appropriate trigger contexts.
    • + *
    • Committing a record of which subtriggers are finished to persistent state.
    • + *
    • Restoring the record of which subtriggers are finished from persistent state.
    • + *
    • Clearing out the persisted finished set when a caller indicates + * (via {#link #clearFinished}) that it is no longer needed.
    • + *
    + * + *

    These responsibilities are intertwined: trigger contexts include mutable information about + * which subtriggers are finished. This class provides the information when building the contexts + * and commits the information when the method of the {@link ExecutableTrigger} returns. + * + * @param The kind of windows being processed. + */ +public class TriggerRunner { + @VisibleForTesting + static final StateTag> FINISHED_BITS_TAG = + StateTags.makeSystemTagInternal(StateTags.value("closed", BitSetCoder.of())); + + private final ExecutableTrigger rootTrigger; + private final TriggerContextFactory contextFactory; + + public TriggerRunner(ExecutableTrigger rootTrigger, TriggerContextFactory contextFactory) { + Preconditions.checkState(rootTrigger.getTriggerIndex() == 0); + this.rootTrigger = rootTrigger; + this.contextFactory = contextFactory; + } + + private FinishedTriggersBitSet readFinishedBits(ValueState state) { + if (!isFinishedSetNeeded()) { + // If no trigger in the tree will ever have finished bits, then we don't need to read them. + // So that the code can be agnostic to that fact, we create a BitSet that is all 0 (not + // finished) for each trigger in the tree. + return FinishedTriggersBitSet.emptyWithCapacity(rootTrigger.getFirstIndexAfterSubtree()); + } + + BitSet bitSet = state.read(); + return bitSet == null + ? FinishedTriggersBitSet.emptyWithCapacity(rootTrigger.getFirstIndexAfterSubtree()) + : FinishedTriggersBitSet.fromBitSet(bitSet); + } + + /** Return true if the trigger is closed in the window corresponding to the specified state. */ + public boolean isClosed(StateAccessor state) { + return readFinishedBits(state.access(FINISHED_BITS_TAG)).isFinished(rootTrigger); + } + + public void prefetchForValue(W window, StateAccessor state) { + if (isFinishedSetNeeded()) { + state.access(FINISHED_BITS_TAG).readLater(); + } + rootTrigger.getSpec().prefetchOnElement( + contextFactory.createStateAccessor(window, rootTrigger)); + } + + public void prefetchOnFire(W window, StateAccessor state) { + if (isFinishedSetNeeded()) { + state.access(FINISHED_BITS_TAG).readLater(); + } + rootTrigger.getSpec().prefetchOnFire(contextFactory.createStateAccessor(window, rootTrigger)); + } + + public void prefetchShouldFire(W window, StateAccessor state) { + if (isFinishedSetNeeded()) { + state.access(FINISHED_BITS_TAG).readLater(); + } + rootTrigger.getSpec().prefetchShouldFire( + contextFactory.createStateAccessor(window, rootTrigger)); + } + + /** + * Run the trigger logic to deal with a new value. + */ + public void processValue(W window, Instant timestamp, Timers timers, StateAccessor state) + throws Exception { + // Clone so that we can detect changes and so that changes here don't pollute merging. + FinishedTriggersBitSet finishedSet = + readFinishedBits(state.access(FINISHED_BITS_TAG)).copy(); + Trigger.OnElementContext triggerContext = contextFactory.createOnElementContext( + window, timers, timestamp, rootTrigger, finishedSet); + rootTrigger.invokeOnElement(triggerContext); + persistFinishedSet(state, finishedSet); + } + + public void prefetchForMerge( + W window, Collection mergingWindows, MergingStateAccessor state) { + if (isFinishedSetNeeded()) { + for (ValueState value : state.accessInEachMergingWindow(FINISHED_BITS_TAG).values()) { + value.readLater(); + } + } + rootTrigger.getSpec().prefetchOnMerge(contextFactory.createMergingStateAccessor( + window, mergingWindows, rootTrigger)); + } + + /** + * Run the trigger merging logic as part of executing the specified merge. + */ + public void onMerge(W window, Timers timers, MergingStateAccessor state) throws Exception { + // Clone so that we can detect changes and so that changes here don't pollute merging. + FinishedTriggersBitSet finishedSet = + readFinishedBits(state.access(FINISHED_BITS_TAG)).copy(); + + // And read the finished bits in each merging window. + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (Map.Entry> entry : + state.accessInEachMergingWindow(FINISHED_BITS_TAG).entrySet()) { + // Don't need to clone these, since the trigger context doesn't allow modification + builder.put(entry.getKey(), readFinishedBits(entry.getValue())); + } + ImmutableMap mergingFinishedSets = builder.build(); + + Trigger.OnMergeContext mergeContext = contextFactory.createOnMergeContext( + window, timers, rootTrigger, finishedSet, mergingFinishedSets); + + // Run the merge from the trigger + rootTrigger.invokeOnMerge(mergeContext); + + persistFinishedSet(state, finishedSet); + + // Clear the finished bits. + clearFinished(state); + } + + public boolean shouldFire(W window, Timers timers, StateAccessor state) throws Exception { + FinishedTriggers finishedSet = readFinishedBits(state.access(FINISHED_BITS_TAG)).copy(); + Trigger.TriggerContext context = contextFactory.base(window, timers, + rootTrigger, finishedSet); + return rootTrigger.invokeShouldFire(context); + } + + public void onFire(W window, Timers timers, StateAccessor state) throws Exception { + FinishedTriggersBitSet finishedSet = + readFinishedBits(state.access(FINISHED_BITS_TAG)).copy(); + Trigger.TriggerContext context = contextFactory.base(window, timers, + rootTrigger, finishedSet); + rootTrigger.invokeOnFire(context); + persistFinishedSet(state, finishedSet); + } + + private void persistFinishedSet( + StateAccessor state, FinishedTriggersBitSet modifiedFinishedSet) { + if (!isFinishedSetNeeded()) { + return; + } + + ValueState finishedSetState = state.access(FINISHED_BITS_TAG); + if (!readFinishedBits(finishedSetState).equals(modifiedFinishedSet)) { + if (modifiedFinishedSet.getBitSet().isEmpty()) { + finishedSetState.clear(); + } else { + finishedSetState.write(modifiedFinishedSet.getBitSet()); + } + } + } + + /** + * Clear finished bits. + */ + public void clearFinished(StateAccessor state) { + if (isFinishedSetNeeded()) { + state.access(FINISHED_BITS_TAG).clear(); + } + } + + /** + * Clear the state used for executing triggers, but leave the finished set to indicate + * the window is closed. + */ + public void clearState(W window, Timers timers, StateAccessor state) throws Exception { + // Don't need to clone, because we'll be clearing the finished bits anyways. + FinishedTriggers finishedSet = readFinishedBits(state.access(FINISHED_BITS_TAG)); + rootTrigger.invokeClear(contextFactory.base(window, timers, rootTrigger, finishedSet)); + } + + private boolean isFinishedSetNeeded() { + // TODO: If we know that no trigger in the tree will ever finish, we don't need to do the + // lookup. Right now, we special case this for the DefaultTrigger. + return !(rootTrigger.getSpec() instanceof DefaultTrigger); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UnownedInputStream.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UnownedInputStream.java new file mode 100644 index 000000000000..3d80230a52cf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UnownedInputStream.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.common.base.MoreObjects; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A {@link OutputStream} wrapper which protects against the user attempting to modify + * the underlying stream by closing it or using mark. + */ +public class UnownedInputStream extends FilterInputStream { + public UnownedInputStream(InputStream delegate) { + super(delegate); + } + + @Override + public void close() throws IOException { + throw new UnsupportedOperationException("Caller does not own the underlying input stream " + + " and should not call close()."); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof UnownedInputStream + && ((UnownedInputStream) obj).in.equals(in); + } + + @Override + public int hashCode() { + return in.hashCode(); + } + + @SuppressWarnings("UnsynchronizedOverridesSynchronized") + @Override + public void mark(int readlimit) { + throw new UnsupportedOperationException("Caller does not own the underlying input stream " + + " and should not call mark()."); + } + + @Override + public boolean markSupported() { + return false; + } + + @SuppressWarnings("UnsynchronizedOverridesSynchronized") + @Override + public void reset() throws IOException { + throw new UnsupportedOperationException("Caller does not own the underlying input stream " + + " and should not call reset()."); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(UnownedInputStream.class).add("in", in).toString(); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UnownedOutputStream.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UnownedOutputStream.java new file mode 100644 index 000000000000..29187a1b9da6 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UnownedOutputStream.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.common.base.MoreObjects; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +/** + * A {@link OutputStream} wrapper which protects against the user attempting to modify + * the underlying stream by closing it. + */ +public class UnownedOutputStream extends FilterOutputStream { + public UnownedOutputStream(OutputStream delegate) { + super(delegate); + } + + @Override + public void close() throws IOException { + throw new UnsupportedOperationException("Caller does not own the underlying output stream " + + " and should not call close()."); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof UnownedOutputStream + && ((UnownedOutputStream) obj).out.equals(out); + } + + @Override + public int hashCode() { + return out.hashCode(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(UnownedOutputStream.class).add("out", out).toString(); + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UploadIdResponseInterceptor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UploadIdResponseInterceptor.java new file mode 100644 index 000000000000..da597e692e63 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UploadIdResponseInterceptor.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpResponseInterceptor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Implements a response intercepter that logs the upload id if the upload + * id header exists and it is the first request (does not have upload_id parameter in the request). + * Only logs if debug level is enabled. + */ +public class UploadIdResponseInterceptor implements HttpResponseInterceptor { + + private static final Logger LOG = LoggerFactory.getLogger(UploadIdResponseInterceptor.class); + private static final String UPLOAD_ID_PARAM = "upload_id"; + private static final String UPLOAD_TYPE_PARAM = "uploadType"; + private static final String UPLOAD_HEADER = "X-GUploader-UploadID"; + + @Override + public void interceptResponse(HttpResponse response) throws IOException { + if (!LOG.isDebugEnabled()) { + return; + } + String uploadId = response.getHeaders().getFirstHeaderStringValue(UPLOAD_HEADER); + if (uploadId == null) { + return; + } + + GenericUrl url = response.getRequest().getUrl(); + // The check for no upload id limits the output to one log line per upload. + // The check for upload type makes sure this is an upload and not a read. + if (url.get(UPLOAD_ID_PARAM) == null && url.get(UPLOAD_TYPE_PARAM) != null) { + LOG.debug( + "Upload ID for url {} on worker {} is {}", + url, + System.getProperty("worker_id"), + uploadId); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UserCodeException.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UserCodeException.java new file mode 100644 index 000000000000..9b9c7a5919f3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/UserCodeException.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.util.Arrays; +import java.util.Objects; + +/** + * An exception that was thrown in user-code. Sets the stack trace + * from the first time execution enters user code down through the + * rest of the user's stack frames until the exception is + * reached. + */ +public class UserCodeException extends RuntimeException { + + public static UserCodeException wrap(Throwable t) { + if (t instanceof UserCodeException) { + return (UserCodeException) t; + } + + return new UserCodeException(t); + } + + public static RuntimeException wrapIf(boolean condition, Throwable t) { + if (condition) { + return wrap(t); + } + + if (t instanceof RuntimeException) { + return (RuntimeException) t; + } + + return new RuntimeException(t); + } + + private UserCodeException(Throwable t) { + super(t); + truncateStackTrace(t); + } + + /** + * Truncates the @{Throwable}'s stack trace to contain only user code, + * removing all frames below. + * + *

    This is to remove infrastructure noise below user code entry point. We do this + * by finding common stack frames between the throwable's captured stack and that + * of the current thread. + */ + private void truncateStackTrace(Throwable t) { + + StackTraceElement[] currentStack = Thread.currentThread().getStackTrace(); + StackTraceElement[] throwableStack = t.getStackTrace(); + + int currentStackSize = currentStack.length; + int throwableStackSize = throwableStack.length; + + int commonFrames = 0; + while (framesEqual(currentStack[currentStackSize - commonFrames - 1], + throwableStack[throwableStackSize - commonFrames - 1])) { + commonFrames++; + if (commonFrames >= Math.min(currentStackSize, throwableStackSize)) { + break; + } + } + + StackTraceElement[] truncatedStack = Arrays.copyOfRange(throwableStack, 0, + throwableStackSize - commonFrames); + t.setStackTrace(truncatedStack); + } + + /** + * Check if two frames are equal; Frames are considered equal if they point to the same method. + */ + private boolean framesEqual(StackTraceElement frame1, StackTraceElement frame2) { + boolean areEqual = Objects.equals(frame1.getClassName(), frame2.getClassName()); + areEqual &= Objects.equals(frame1.getMethodName(), frame2.getMethodName()); + + return areEqual; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ValueWithRecordId.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ValueWithRecordId.java new file mode 100644 index 000000000000..ac1f2ebcce9c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ValueWithRecordId.java @@ -0,0 +1,154 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * Immutable struct containing a value as well as a unique id identifying the value. + * + * @param the underlying value type + */ +public class ValueWithRecordId { + private final ValueT value; + private final byte[] id; + + public ValueWithRecordId(ValueT value, byte[] id) { + this.value = value; + this.id = id; + } + + public ValueT getValue() { + return value; + } + + public byte[] getId() { + return id; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("id", id) + .add("value", value) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ValueWithRecordId)) { + return false; + } + ValueWithRecordId otherRecord = (ValueWithRecordId) other; + return Objects.deepEquals(id, otherRecord.id) + && Objects.deepEquals(value, otherRecord.value); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(id), value); + } + + /** + * A {@link Coder} for {@code ValueWithRecordId}, using a wrapped value {@code Coder}. + */ + public static class ValueWithRecordIdCoder + extends StandardCoder> { + public static ValueWithRecordIdCoder of(Coder valueCoder) { + return new ValueWithRecordIdCoder<>(valueCoder); + } + + @JsonCreator + public static ValueWithRecordIdCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of(components.get(0)); + } + + protected ValueWithRecordIdCoder(Coder valueCoder) { + this.valueCoder = valueCoder; + this.idCoder = ByteArrayCoder.of(); + } + + @Override + public List> getCoderArguments() { + return Arrays.asList(valueCoder); + } + + @Override + public void encode(ValueWithRecordId value, OutputStream outStream, Context context) + throws IOException { + valueCoder.encode(value.value, outStream, context.nested()); + idCoder.encode(value.id, outStream, context); + } + + @Override + public ValueWithRecordId decode(InputStream inStream, Context context) + throws IOException { + return new ValueWithRecordId( + valueCoder.decode(inStream, context.nested()), + idCoder.decode(inStream, context)); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + valueCoder.verifyDeterministic(); + } + + public Coder getValueCoder() { + return valueCoder; + } + + Coder valueCoder; + ByteArrayCoder idCoder; + } + + public static + PTransform>, PCollection> stripIds() { + return ParDo.named("StripIds") + .of( + new DoFn, T>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getValue()); + } + }); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Values.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Values.java new file mode 100644 index 000000000000..d4440e76de6d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Values.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A collection of static methods for manipulating value representations + * transfered via the Dataflow API. + */ +public final class Values { + private Values() {} // Non-instantiable + + public static Boolean asBoolean(Object value) throws ClassCastException { + @Nullable Boolean knownResult = checkKnownValue(CloudKnownType.BOOLEAN, value, Boolean.class); + if (knownResult != null) { + return knownResult; + } + return Boolean.class.cast(value); + } + + public static Double asDouble(Object value) throws ClassCastException { + @Nullable Double knownResult = checkKnownValue(CloudKnownType.FLOAT, value, Double.class); + if (knownResult != null) { + return knownResult; + } + if (value instanceof Double) { + return (Double) value; + } + return ((Float) value).doubleValue(); + } + + public static Long asLong(Object value) throws ClassCastException { + @Nullable Long knownResult = checkKnownValue(CloudKnownType.INTEGER, value, Long.class); + if (knownResult != null) { + return knownResult; + } + if (value instanceof Long) { + return (Long) value; + } + return ((Integer) value).longValue(); + } + + public static String asString(Object value) throws ClassCastException { + @Nullable String knownResult = checkKnownValue(CloudKnownType.TEXT, value, String.class); + if (knownResult != null) { + return knownResult; + } + return String.class.cast(value); + } + + @Nullable + private static T checkKnownValue(CloudKnownType type, Object value, Class clazz) { + if (!(value instanceof Map)) { + return null; + } + Map map = (Map) value; + @Nullable String typeName = (String) map.get(PropertyNames.OBJECT_TYPE_NAME); + if (typeName == null) { + return null; + } + @Nullable CloudKnownType knownType = CloudKnownType.forUri(typeName); + if (knownType == null || knownType != type) { + return null; + } + @Nullable Object scalar = map.get(PropertyNames.SCALAR_FIELD_NAME); + if (scalar == null) { + return null; + } + return knownType.parse(scalar, clazz); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/VarInt.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/VarInt.java new file mode 100644 index 000000000000..af039112eaba --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/VarInt.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * Variable-length encoding for integers. + * + *

    Handles, in a common encoding format, signed bytes, shorts, ints, and longs. + * Takes between 1 and 10 bytes. + * Less efficient than BigEndian{Int,Long} coder for negative or large numbers. + * All negative ints are encoded using 5 bytes, longs take 10 bytes. + */ +public class VarInt { + + private static long convertIntToLongNoSignExtend(int v) { + return v & 0xFFFFFFFFL; + } + + /** + * Encodes the given value onto the stream. + */ + public static void encode(int v, OutputStream stream) throws IOException { + encode(convertIntToLongNoSignExtend(v), stream); + } + + /** + * Encodes the given value onto the stream. + */ + public static void encode(long v, OutputStream stream) throws IOException { + do { + // Encode next 7 bits + terminator bit + long bits = v & 0x7F; + v >>>= 7; + byte b = (byte) (bits | ((v != 0) ? 0x80 : 0)); + stream.write(b); + } while (v != 0); + } + + /** + * Decodes an integer value from the given stream. + */ + public static int decodeInt(InputStream stream) throws IOException { + long r = decodeLong(stream); + if (r < 0 || r >= 1L << 32) { + throw new IOException("varint overflow " + r); + } + return (int) r; + } + + /** + * Decodes a long value from the given stream. + */ + public static long decodeLong(InputStream stream) throws IOException { + long result = 0; + int shift = 0; + int b; + do { + // Get 7 bits from next byte + b = stream.read(); + if (b < 0) { + if (shift == 0) { + throw new EOFException(); + } else { + throw new IOException("varint not terminated"); + } + } + long bits = b & 0x7F; + if (shift >= 64 || (shift == 63 && bits > 1)) { + // Out of range + throw new IOException("varint too long"); + } + result |= bits << shift; + shift += 7; + } while ((b & 0x80) != 0); + return result; + } + + /** + * Returns the length of the encoding of the given value (in bytes). + */ + public static int getLength(int v) { + return getLength(convertIntToLongNoSignExtend(v)); + } + + /** + * Returns the length of the encoding of the given value (in bytes). + */ + public static int getLength(long v) { + int result = 0; + do { + result++; + v >>>= 7; + } while (v != 0); + return result; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WatermarkHold.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WatermarkHold.java new file mode 100644 index 000000000000..d537ddb0e80d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WatermarkHold.java @@ -0,0 +1,450 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; +import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; +import com.google.cloud.dataflow.sdk.util.state.ReadableState; +import com.google.cloud.dataflow.sdk.util.state.StateMerging; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.util.state.WatermarkHoldState; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.Serializable; + +import javax.annotation.Nullable; + +/** + * Implements the logic to hold the output watermark for a computation back + * until it has seen all the elements it needs based on the input watermark for the + * computation. + * + *

    The backend ensures the output watermark can never progress beyond the + * input watermark for a computation. GroupAlsoByWindows computations may add a 'hold' + * to the output watermark in order to prevent it progressing beyond a time within a window. + * The hold will be 'cleared' when the associated pane is emitted. + * + *

    This class is only intended for use by {@link ReduceFnRunner}. The two evolve together and + * will likely break any other uses. + * + * @param The kind of {@link BoundedWindow} the hold is for. + */ +class WatermarkHold implements Serializable { + /** + * Return tag for state containing the output watermark hold + * used for elements. + */ + public static + StateTag> watermarkHoldTagForOutputTimeFn( + OutputTimeFn outputTimeFn) { + return StateTags.>makeSystemTagInternal( + StateTags.watermarkStateInternal("hold", outputTimeFn)); + } + + /** + * Tag for state containing end-of-window and garbage collection output watermark holds. + * (We can't piggy-back on the data hold state since the outputTimeFn may be + * {@link OutputTimeFns#outputAtLatestInputTimestamp()}, in which case every pane will + * would take the end-of-window time as its element time.) + */ + @VisibleForTesting + public static final StateTag> EXTRA_HOLD_TAG = + StateTags.makeSystemTagInternal(StateTags.watermarkStateInternal( + "extra", OutputTimeFns.outputAtEarliestInputTimestamp())); + + private final TimerInternals timerInternals; + private final WindowingStrategy windowingStrategy; + private final StateTag> elementHoldTag; + + public WatermarkHold(TimerInternals timerInternals, WindowingStrategy windowingStrategy) { + this.timerInternals = timerInternals; + this.windowingStrategy = windowingStrategy; + this.elementHoldTag = watermarkHoldTagForOutputTimeFn(windowingStrategy.getOutputTimeFn()); + } + + /** + * Add a hold to prevent the output watermark progressing beyond the (possibly adjusted) timestamp + * of the element in {@code context}. We allow the actual hold time to be shifted later by + * {@link OutputTimeFn#assignOutputTime}, but no further than the end of the window. The hold will + * remain until cleared by {@link #extractAndRelease}. Return the timestamp at which the hold + * was placed, or {@literal null} if no hold was placed. + * + *

    In the following we'll write {@code E} to represent an element's timestamp after passing + * through the window strategy's output time function, {@code IWM} for the local input watermark, + * {@code OWM} for the local output watermark, and {@code GCWM} for the garbage collection + * watermark (which is at {@code IWM - getAllowedLateness}). Time progresses from left to right, + * and we write {@code [ ... ]} to denote a bounded window with implied lower bound. + * + *

    Note that the GCWM will be the same as the IWM if {@code getAllowedLateness} + * is {@code ZERO}. + * + *

    Here are the cases we need to handle. They are conceptually considered in the + * sequence written since if getAllowedLateness is ZERO the GCWM is the same as the IWM. + *

      + *
    1. (Normal) + *
      +   *          |
      +   *      [   | E        ]
      +   *          |
      +   *         IWM
      +   * 
      + * This is, hopefully, the common and happy case. The element is locally on-time and can + * definitely make it to an {@code ON_TIME} pane which we can still set an end-of-window timer + * for. We place an element hold at E, which may contribute to the {@code ON_TIME} pane's + * timestamp (depending on the output time function). Thus the OWM will not proceed past E + * until the next pane fires. + * + *
    2. (Discard - no target window) + *
      +   *                       |                            |
      +   *      [     E        ] |                            |
      +   *                       |                            |
      +   *                     GCWM  <-getAllowedLateness->  IWM
      +   * 
      + * The element is very locally late. The window has been garbage collected, thus there + * is no target pane E could be assigned to. We discard E. + * + *
    3. (Unobservably late) + *
      +   *          |    |
      +   *      [   | E  |     ]
      +   *          |    |
      +   *         OWM  IWM
      +   * 
      + * The element is locally late, however we can still treat this case as for 'Normal' above + * since the IWM has not yet passed the end of the window and the element is ahead of the + * OWM. In effect, we get to 'launder' the locally late element and consider it as locally + * on-time because no downstream computation can observe the difference. + * + *
    4. (Maybe late 1) + *
      +   *          |            |
      +   *      [   | E        ] |
      +   *          |            |
      +   *         OWM          IWM
      +   * 
      + * The end-of-window timer may have already fired for this window, and thus an {@code ON_TIME} + * pane may have already been emitted. However, if timer firings have been delayed then it + * is possible the {@code ON_TIME} pane has not yet been emitted. We can't place an element + * hold since we can't be sure if it will be cleared promptly. Thus this element *may* find + * its way into an {@code ON_TIME} pane, but if so it will *not* contribute to that pane's + * timestamp. We may however set a garbage collection hold if required. + * + *
    5. (Maybe late 2) + *
      +   *               |   |
      +   *      [     E  |   | ]
      +   *               |   |
      +   *              OWM IWM
      +   * 
      + * The end-of-window timer has not yet fired, so this element may still appear in an + * {@code ON_TIME} pane. However the element is too late to contribute to the output + * watermark hold, and thus won't contribute to the pane's timestamp. We can still place an + * end-of-window hold. + * + *
    6. (Maybe late 3) + *
      +   *               |       |
      +   *      [     E  |     ] |
      +   *               |       |
      +   *              OWM     IWM
      +   * 
      + * As for the (Maybe late 2) case, however we don't even know if the end-of-window timer + * has already fired, or it is about to fire. We can place only the garbage collection hold, + * if required. + * + *
    7. (Definitely late) + *
      +   *                       |   |
      +   *      [     E        ] |   |
      +   *                       |   |
      +   *                      OWM IWM
      +   * 
      + * The element is definitely too late to make an {@code ON_TIME} pane. We are too late to + * place an end-of-window hold. We can still place a garbage collection hold if required. + * + *
    + */ + @Nullable + public Instant addHolds(ReduceFn.ProcessValueContext context) { + Instant hold = addElementHold(context); + if (hold == null) { + hold = addEndOfWindowOrGarbageCollectionHolds(context); + } + return hold; + } + + /** + * Return {@code timestamp}, possibly shifted forward in time according to the window + * strategy's output time function. + */ + private Instant shift(Instant timestamp, W window) { + Instant shifted = windowingStrategy.getOutputTimeFn().assignOutputTime(timestamp, window); + if (shifted.isBefore(timestamp)) { + throw new IllegalStateException( + String.format("OutputTimeFn moved element from %s to earlier time %s for window %s", + timestamp, shifted, window)); + } + if (!timestamp.isAfter(window.maxTimestamp()) && shifted.isAfter(window.maxTimestamp())) { + throw new IllegalStateException( + String.format("OutputTimeFn moved element from %s to %s which is beyond end of window %s", + timestamp, shifted, window)); + } + + return shifted; + } + + /** + * Add an element hold if possible. Return instant at which hold was added, or {@literal null} + * if no hold was added. + */ + @Nullable + private Instant addElementHold(ReduceFn.ProcessValueContext context) { + // Give the window function a chance to move the hold timestamp forward to encourage progress. + // (A later hold implies less impediment to the output watermark making progress, which in + // turn encourages end-of-window triggers to fire earlier in following computations.) + Instant elementHold = shift(context.timestamp(), context.window()); + + Instant outputWM = timerInternals.currentOutputWatermarkTime(); + Instant inputWM = timerInternals.currentInputWatermarkTime(); + + // Only add the hold if we can be sure: + // - the backend will be able to respect it + // (ie the hold is at or ahead of the output watermark), AND + // - a timer will be set to clear it by the end of window + // (ie the end of window is at or ahead of the input watermark). + String which; + boolean tooLate; + // TODO: These case labels could be tightened. + // See the case analysis in addHolds above for the motivation. + if (outputWM != null && elementHold.isBefore(outputWM)) { + which = "too late to effect output watermark"; + tooLate = true; + } else if (inputWM != null && context.window().maxTimestamp().isBefore(inputWM)) { + which = "too late for end-of-window timer"; + tooLate = true; + } else { + which = "on time"; + tooLate = false; + context.state().access(elementHoldTag).add(elementHold); + } + WindowTracing.trace( + "WatermarkHold.addHolds: element hold at {} is {} for " + + "key:{}; window:{}; inputWatermark:{}; outputWatermark:{}", + elementHold, which, context.key(), context.window(), inputWM, + outputWM); + + return tooLate ? null : elementHold; + } + + /** + * Add an end-of-window hold or, if too late for that, a garbage collection hold (if required). + * Return the {@link Instant} at which hold was added, or {@literal null} if no hold was added. + * + *

    The end-of-window hold guarantees that an empty {@code ON_TIME} pane can be given + * a timestamp which will not be considered beyond allowed lateness by any downstream computation. + */ + @Nullable + private Instant addEndOfWindowOrGarbageCollectionHolds(ReduceFn.Context context) { + Instant hold = addEndOfWindowHold(context); + if (hold == null) { + hold = addGarbageCollectionHold(context); + } + return hold; + } + + /** + * Add an end-of-window hold. Return the {@link Instant} at which hold was added, + * or {@literal null} if no hold was added. + * + *

    The end-of-window hold guarantees that any empty {@code ON_TIME} pane can be given + * a timestamp which will not be considered beyond allowed lateness by any downstream computation. + */ + @Nullable + private Instant addEndOfWindowHold(ReduceFn.Context context) { + // Only add an end-of-window hold if we can be sure a timer will be set to clear it + // by the end of window (ie the end of window is at or ahead of the input watermark). + Instant outputWM = timerInternals.currentOutputWatermarkTime(); + Instant inputWM = timerInternals.currentInputWatermarkTime(); + String which; + boolean tooLate; + Instant eowHold = context.window().maxTimestamp(); + if (inputWM != null && eowHold.isBefore(inputWM)) { + which = "too late for end-of-window timer"; + tooLate = true; + } else { + which = "on time"; + tooLate = false; + Preconditions.checkState(outputWM == null || !eowHold.isBefore(outputWM), + "End-of-window hold %s cannot be before output watermark %s", eowHold, outputWM); + context.state().access(EXTRA_HOLD_TAG).add(eowHold); + } + WindowTracing.trace( + "WatermarkHold.addEndOfWindowHold: end-of-window hold at {} is {} for " + + "key:{}; window:{}; inputWatermark:{}; outputWatermark:{}", + eowHold, which, context.key(), context.window(), inputWM, + outputWM); + + return tooLate ? null : eowHold; + } + + /** + * Add a garbage collection hold, if required. Return the {@link Instant} at which hold was added, + * or {@literal null} if no hold was added. + * + *

    The garbage collection hold gurantees that any empty final pane can be given + * a timestamp which will not be considered beyond allowed lateness by any downstream + * computation. If we are sure no empty final panes can be emitted then there's no need + * for an additional hold. + */ + @Nullable + private Instant addGarbageCollectionHold(ReduceFn.Context context) { + // Only add a garbage collection hold if we may need to emit an empty pane + // at garbage collection time, and garbage collection time is strictly after the + // end of window. (All non-empty panes will have holds at their output + // time derived from their incoming elements and no additional hold is required.) + if (context.windowingStrategy().getClosingBehavior() == ClosingBehavior.FIRE_ALWAYS + && windowingStrategy.getAllowedLateness().isLongerThan(Duration.ZERO)) { + Instant gcHold = context.window().maxTimestamp().plus(windowingStrategy.getAllowedLateness()); + Instant outputWM = timerInternals.currentOutputWatermarkTime(); + Instant inputWM = timerInternals.currentInputWatermarkTime(); + WindowTracing.trace( + "WatermarkHold.addGarbageCollectionHold: garbage collection at {} hold for " + + "key:{}; window:{}; inputWatermark:{}; outputWatermark:{}", + gcHold, context.key(), context.window(), inputWM, outputWM); + Preconditions.checkState(inputWM == null || !gcHold.isBefore(inputWM), + "Garbage collection hold %s cannot be before input watermark %s", gcHold, inputWM); + context.state().access(EXTRA_HOLD_TAG).add(gcHold); + return gcHold; + } else { + return null; + } + } + + /** + * Prefetch watermark holds in preparation for merging. + */ + public void prefetchOnMerge(MergingStateAccessor state) { + StateMerging.prefetchWatermarks(state, elementHoldTag); + } + + /** + * Updates the watermark hold when windows merge if it is possible the merged value does + * not equal all of the existing holds. For example, if the new window implies a later + * watermark hold, then earlier holds may be released. + */ + public void onMerge(ReduceFn.OnMergeContext context) { + WindowTracing.debug("onMerge: for key:{}; window:{}; inputWatermark:{}; outputWatermark:{}", + context.key(), context.window(), timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + StateMerging.mergeWatermarks(context.state(), elementHoldTag, context.window()); + // If we had a cheap way to determine if we have an element hold then we could + // avoid adding an unnecessary end-of-window or garbage collection hold. + // Simply reading the above merged watermark would impose an additional read for the + // common case that the active window has just one undelying state address window and + // the hold depends on the min of the elemest timestamps. + StateMerging.clear(context.state(), EXTRA_HOLD_TAG); + addEndOfWindowOrGarbageCollectionHolds(context); + } + + /** + * Return (a future for) the earliest hold for {@code context}. Clear all the holds after + * reading, but add/restore an end-of-window or garbage collection hold if required. + * + *

    The returned timestamp is the output timestamp according to the {@link OutputTimeFn} + * from the windowing strategy of this {@link WatermarkHold}, combined across all the non-late + * elements in the current pane. If there is no such value the timestamp is the end + * of the window. + */ + public ReadableState extractAndRelease( + final ReduceFn.Context context, final boolean isFinished) { + WindowTracing.debug( + "extractAndRelease: for key:{}; window:{}; inputWatermark:{}; outputWatermark:{}", + context.key(), context.window(), timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + final WatermarkHoldState elementHoldState = context.state().access(elementHoldTag); + final WatermarkHoldState extraHoldState = context.state().access(EXTRA_HOLD_TAG); + return new ReadableState() { + @Override + public ReadableState readLater() { + elementHoldState.readLater(); + extraHoldState.readLater(); + return this; + } + + @Override + public Instant read() { + // Read both the element and extra holds. + Instant elementHold = elementHoldState.read(); + Instant extraHold = extraHoldState.read(); + Instant hold; + // Find the minimum, accounting for null. + if (elementHold == null) { + hold = extraHold; + } else if (extraHold == null) { + hold = elementHold; + } else if (elementHold.isBefore(extraHold)) { + hold = elementHold; + } else { + hold = extraHold; + } + if (hold == null || hold.isAfter(context.window().maxTimestamp())) { + // If no hold (eg because all elements came in behind the output watermark), or + // the hold was for garbage collection, take the end of window as the result. + WindowTracing.debug( + "WatermarkHold.extractAndRelease.read: clipping from {} to end of window " + + "for key:{}; window:{}", + hold, context.key(), context.window()); + hold = context.window().maxTimestamp(); + } + WindowTracing.debug("WatermarkHold.extractAndRelease.read: clearing for key:{}; window:{}", + context.key(), context.window()); + + // Clear the underlying state to allow the output watermark to progress. + elementHoldState.clear(); + extraHoldState.clear(); + + if (!isFinished) { + // Only need to leave behind an end-of-window or garbage collection hold + // if future elements will be processed. + addEndOfWindowOrGarbageCollectionHolds(context); + } + + return hold; + } + }; + } + + /** + * Clear any remaining holds. + */ + public void clearHolds(ReduceFn.Context context) { + WindowTracing.debug( + "WatermarkHold.clearHolds: For key:{}; window:{}; inputWatermark:{}; outputWatermark:{}", + context.key(), context.window(), timerInternals.currentInputWatermarkTime(), + timerInternals.currentOutputWatermarkTime()); + context.state().access(elementHoldTag).clear(); + context.state().access(EXTRA_HOLD_TAG).clear(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Weighted.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Weighted.java new file mode 100644 index 000000000000..c31ad7f861c4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Weighted.java @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * Interface representing an object that has a weight, in unspecified units. + */ +public interface Weighted { + /** + * Returns the weight of the object. + */ + long getWeight(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WeightedValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WeightedValue.java new file mode 100644 index 000000000000..4a6e84079faa --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WeightedValue.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +/** + * A {@code T} with an accompanying weight. Units are unspecified. + * + * @param the underlying type of object + */ +public final class WeightedValue implements Weighted { + + private final T value; + private final long weight; + + private WeightedValue(T value, long weight) { + this.value = value; + this.weight = weight; + } + + public static WeightedValue of(T value, long weight) { + return new WeightedValue<>(value, weight); + } + + public long getWeight() { + return weight; + } + + public T getValue() { + return value; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowTracing.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowTracing.java new file mode 100644 index 000000000000..6ae2f4206c48 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowTracing.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Logging for window operations. Generally only feasible to enable on hand-picked pipelines. + */ +public final class WindowTracing { + private static final Logger LOG = LoggerFactory.getLogger(WindowTracing.class); + + public static void debug(String format, Object... args) { + LOG.debug(format, args); + } + + @SuppressWarnings("unused") + public static void trace(String format, Object... args) { + LOG.trace(format, args); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowedValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowedValue.java new file mode 100644 index 000000000000..1e944e25133e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowedValue.java @@ -0,0 +1,720 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CollectionCoder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * An immutable triple of value, timestamp, and windows. + * + * @param the type of the value + */ +public abstract class WindowedValue { + + protected final T value; + protected final PaneInfo pane; + + /** + * Returns a {@code WindowedValue} with the given value, timestamp, + * and windows. + */ + public static WindowedValue of( + T value, + Instant timestamp, + Collection windows, + PaneInfo pane) { + Preconditions.checkNotNull(pane); + + if (windows.size() == 0 && BoundedWindow.TIMESTAMP_MIN_VALUE.equals(timestamp)) { + return valueInEmptyWindows(value, pane); + } else if (windows.size() == 1) { + return of(value, timestamp, windows.iterator().next(), pane); + } else { + return new TimestampedValueInMultipleWindows<>(value, timestamp, windows, pane); + } + } + + /** + * Returns a {@code WindowedValue} with the given value, timestamp, and window. + */ + public static WindowedValue of( + T value, + Instant timestamp, + BoundedWindow window, + PaneInfo pane) { + Preconditions.checkNotNull(pane); + + boolean isGlobal = GlobalWindow.INSTANCE.equals(window); + if (isGlobal && BoundedWindow.TIMESTAMP_MIN_VALUE.equals(timestamp)) { + return valueInGlobalWindow(value, pane); + } else if (isGlobal) { + return new TimestampedValueInGlobalWindow<>(value, timestamp, pane); + } else { + return new TimestampedValueInSingleWindow<>(value, timestamp, window, pane); + } + } + + /** + * Returns a {@code WindowedValue} with the given value in the {@link GlobalWindow} using the + * default timestamp and pane. + */ + public static WindowedValue valueInGlobalWindow(T value) { + return new ValueInGlobalWindow<>(value, PaneInfo.NO_FIRING); + } + + /** + * Returns a {@code WindowedValue} with the given value in the {@link GlobalWindow} using the + * default timestamp and the specified pane. + */ + public static WindowedValue valueInGlobalWindow(T value, PaneInfo pane) { + return new ValueInGlobalWindow<>(value, pane); + } + + /** + * Returns a {@code WindowedValue} with the given value and timestamp, + * {@code GlobalWindow} and default pane. + */ + public static WindowedValue timestampedValueInGlobalWindow(T value, Instant timestamp) { + if (BoundedWindow.TIMESTAMP_MIN_VALUE.equals(timestamp)) { + return valueInGlobalWindow(value); + } else { + return new TimestampedValueInGlobalWindow<>(value, timestamp, PaneInfo.NO_FIRING); + } + } + + /** + * Returns a {@code WindowedValue} with the given value in no windows, and the default timestamp + * and pane. + */ + public static WindowedValue valueInEmptyWindows(T value) { + return new ValueInEmptyWindows(value, PaneInfo.NO_FIRING); + } + + /** + * Returns a {@code WindowedValue} with the given value in no windows, and the default timestamp + * and the specified pane. + */ + public static WindowedValue valueInEmptyWindows(T value, PaneInfo pane) { + return new ValueInEmptyWindows(value, pane); + } + + private WindowedValue(T value, PaneInfo pane) { + this.value = value; + this.pane = checkNotNull(pane); + } + + /** + * Returns a new {@code WindowedValue} that is a copy of this one, but with a different value, + * which may have a new type {@code NewT}. + */ + public abstract WindowedValue withValue(NewT value); + + /** + * Returns the value of this {@code WindowedValue}. + */ + public T getValue() { + return value; + } + + /** + * Returns the timestamp of this {@code WindowedValue}. + */ + public abstract Instant getTimestamp(); + + /** + * Returns the windows of this {@code WindowedValue}. + */ + public abstract Collection getWindows(); + + /** + * Returns the pane of this {@code WindowedValue} in its window. + */ + public PaneInfo getPane() { + return pane; + } + + @Override + public abstract boolean equals(Object o); + + @Override + public abstract int hashCode(); + + @Override + public abstract String toString(); + + private static final Collection GLOBAL_WINDOWS = + Collections.singletonList(GlobalWindow.INSTANCE); + + /** + * The abstract superclass of WindowedValue representations where + * timestamp == MIN. + */ + private abstract static class MinTimestampWindowedValue + extends WindowedValue { + public MinTimestampWindowedValue(T value, PaneInfo pane) { + super(value, pane); + } + + @Override + public Instant getTimestamp() { + return BoundedWindow.TIMESTAMP_MIN_VALUE; + } + } + + /** + * The representation of a WindowedValue where timestamp == MIN and + * windows == {GlobalWindow}. + */ + private static class ValueInGlobalWindow + extends MinTimestampWindowedValue { + public ValueInGlobalWindow(T value, PaneInfo pane) { + super(value, pane); + } + + @Override + public WindowedValue withValue(NewT value) { + return new ValueInGlobalWindow<>(value, pane); + } + + @Override + public Collection getWindows() { + return GLOBAL_WINDOWS; + } + + @Override + public boolean equals(Object o) { + if (o instanceof ValueInGlobalWindow) { + ValueInGlobalWindow that = (ValueInGlobalWindow) o; + return Objects.equals(that.pane, this.pane) + && Objects.equals(that.value, this.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(value, pane); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("value", value) + .add("pane", pane) + .toString(); + } + } + + /** + * The representation of a WindowedValue where timestamp == MIN and + * windows == {}. + */ + private static class ValueInEmptyWindows + extends MinTimestampWindowedValue { + public ValueInEmptyWindows(T value, PaneInfo pane) { + super(value, pane); + } + + @Override + public WindowedValue withValue(NewT value) { + return new ValueInEmptyWindows<>(value, pane); + } + + @Override + public Collection getWindows() { + return Collections.emptyList(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof ValueInEmptyWindows) { + ValueInEmptyWindows that = (ValueInEmptyWindows) o; + return Objects.equals(that.pane, this.pane) + && Objects.equals(that.value, this.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(value, pane); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("value", value) + .add("pane", pane) + .toString(); + } + } + + /** + * The abstract superclass of WindowedValue representations where + * timestamp is arbitrary. + */ + private abstract static class TimestampedWindowedValue + extends WindowedValue { + protected final Instant timestamp; + + public TimestampedWindowedValue(T value, + Instant timestamp, + PaneInfo pane) { + super(value, pane); + this.timestamp = checkNotNull(timestamp); + } + + @Override + public Instant getTimestamp() { + return timestamp; + } + } + + /** + * The representation of a WindowedValue where timestamp {@code >} + * MIN and windows == {GlobalWindow}. + */ + private static class TimestampedValueInGlobalWindow + extends TimestampedWindowedValue { + public TimestampedValueInGlobalWindow(T value, + Instant timestamp, + PaneInfo pane) { + super(value, timestamp, pane); + } + + @Override + public WindowedValue withValue(NewT value) { + return new TimestampedValueInGlobalWindow<>(value, timestamp, pane); + } + + @Override + public Collection getWindows() { + return GLOBAL_WINDOWS; + } + + @Override + public boolean equals(Object o) { + if (o instanceof TimestampedValueInGlobalWindow) { + TimestampedValueInGlobalWindow that = + (TimestampedValueInGlobalWindow) o; + return this.timestamp.isEqual(that.timestamp) // don't compare chronology objects + && Objects.equals(that.pane, this.pane) + && Objects.equals(that.value, this.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(value, pane, timestamp.getMillis()); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("value", value) + .add("timestamp", timestamp) + .add("pane", pane) + .toString(); + } + } + + /** + * The representation of a WindowedValue where timestamp is arbitrary and + * windows == a single non-Global window. + */ + private static class TimestampedValueInSingleWindow + extends TimestampedWindowedValue { + private final BoundedWindow window; + + public TimestampedValueInSingleWindow(T value, + Instant timestamp, + BoundedWindow window, + PaneInfo pane) { + super(value, timestamp, pane); + this.window = checkNotNull(window); + } + + @Override + public WindowedValue withValue(NewT value) { + return new TimestampedValueInSingleWindow<>(value, timestamp, window, pane); + } + + @Override + public Collection getWindows() { + return Collections.singletonList(window); + } + + @Override + public boolean equals(Object o) { + if (o instanceof TimestampedValueInSingleWindow) { + TimestampedValueInSingleWindow that = + (TimestampedValueInSingleWindow) o; + return Objects.equals(that.value, this.value) + && this.timestamp.isEqual(that.timestamp) // don't compare chronology objects + && Objects.equals(that.pane, this.pane) + && Objects.equals(that.window, this.window); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(value, timestamp.getMillis(), pane, window); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("value", value) + .add("timestamp", timestamp) + .add("window", window) + .add("pane", pane) + .toString(); + } + } + + /** + * The representation of a WindowedValue, excluding the special + * cases captured above. + */ + private static class TimestampedValueInMultipleWindows + extends TimestampedWindowedValue { + private Collection windows; + + public TimestampedValueInMultipleWindows( + T value, + Instant timestamp, + Collection windows, + PaneInfo pane) { + super(value, timestamp, pane); + this.windows = checkNotNull(windows); + } + + @Override + public WindowedValue withValue(NewT value) { + return new TimestampedValueInMultipleWindows<>(value, timestamp, windows, pane); + } + + @Override + public Collection getWindows() { + return windows; + } + + @Override + public boolean equals(Object o) { + if (o instanceof TimestampedValueInMultipleWindows) { + TimestampedValueInMultipleWindows that = + (TimestampedValueInMultipleWindows) o; + if (this.timestamp.isEqual(that.timestamp) // don't compare chronology objects + && Objects.equals(that.value, this.value) + && Objects.equals(that.pane, this.pane)) { + ensureWindowsAreASet(); + that.ensureWindowsAreASet(); + return that.windows.equals(this.windows); + } + } + return false; + } + + @Override + public int hashCode() { + ensureWindowsAreASet(); + return Objects.hash(value, timestamp.getMillis(), pane, windows); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("value", value) + .add("timestamp", timestamp) + .add("windows", windows) + .add("pane", pane) + .toString(); + } + + private void ensureWindowsAreASet() { + if (!(windows instanceof Set)) { + windows = new LinkedHashSet<>(windows); + } + } + } + + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns the {@code Coder} to use for a {@code WindowedValue}, + * using the given valueCoder and windowCoder. + */ + public static FullWindowedValueCoder getFullCoder( + Coder valueCoder, + Coder windowCoder) { + return FullWindowedValueCoder.of(valueCoder, windowCoder); + } + + /** + * Returns the {@code ValueOnlyCoder} from the given valueCoder. + */ + public static ValueOnlyWindowedValueCoder getValueOnlyCoder(Coder valueCoder) { + return ValueOnlyWindowedValueCoder.of(valueCoder); + } + + /** + * Abstract class for {@code WindowedValue} coder. + */ + public abstract static class WindowedValueCoder + extends StandardCoder> { + final Coder valueCoder; + + WindowedValueCoder(Coder valueCoder) { + this.valueCoder = checkNotNull(valueCoder); + } + + /** + * Returns the value coder. + */ + public Coder getValueCoder() { + return valueCoder; + } + + /** + * Returns a new {@code WindowedValueCoder} that is a copy of this one, + * but with a different value coder. + */ + public abstract WindowedValueCoder withValueCoder(Coder valueCoder); + } + + /** + * Coder for {@code WindowedValue}. + */ + public static class FullWindowedValueCoder extends WindowedValueCoder { + private final Coder windowCoder; + // Precompute and cache the coder for a list of windows. + private final Coder> windowsCoder; + + public static FullWindowedValueCoder of( + Coder valueCoder, + Coder windowCoder) { + return new FullWindowedValueCoder<>(valueCoder, windowCoder); + } + + @JsonCreator + public static FullWindowedValueCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + checkArgument(components.size() == 2, + "Expecting 2 components, got " + components.size()); + @SuppressWarnings("unchecked") + Coder window = (Coder) components.get(1); + return of(components.get(0), window); + } + + FullWindowedValueCoder(Coder valueCoder, + Coder windowCoder) { + super(valueCoder); + this.windowCoder = checkNotNull(windowCoder); + // It's not possible to statically type-check correct use of the + // windowCoder (we have to ensure externally that we only get + // windows of the class handled by windowCoder), so type + // windowsCoder in a way that makes encode() and decode() work + // right, and cast the window type away here. + @SuppressWarnings({"unchecked", "rawtypes"}) + Coder> collectionCoder = + (Coder) CollectionCoder.of(this.windowCoder); + this.windowsCoder = collectionCoder; + } + + public Coder getWindowCoder() { + return windowCoder; + } + + public Coder> getWindowsCoder() { + return windowsCoder; + } + + @Override + public WindowedValueCoder withValueCoder(Coder valueCoder) { + return new FullWindowedValueCoder<>(valueCoder, windowCoder); + } + + @Override + public void encode(WindowedValue windowedElem, + OutputStream outStream, + Context context) + throws CoderException, IOException { + Context nestedContext = context.nested(); + valueCoder.encode(windowedElem.getValue(), outStream, nestedContext); + InstantCoder.of().encode( + windowedElem.getTimestamp(), outStream, nestedContext); + windowsCoder.encode(windowedElem.getWindows(), outStream, nestedContext); + PaneInfoCoder.INSTANCE.encode(windowedElem.getPane(), outStream, context); + } + + @Override + public WindowedValue decode(InputStream inStream, Context context) + throws CoderException, IOException { + Context nestedContext = context.nested(); + T value = valueCoder.decode(inStream, nestedContext); + Instant timestamp = InstantCoder.of().decode(inStream, nestedContext); + Collection windows = + windowsCoder.decode(inStream, nestedContext); + PaneInfo pane = PaneInfoCoder.INSTANCE.decode(inStream, nestedContext); + return WindowedValue.of(value, timestamp, windows, pane); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "FullWindowedValueCoder requires a deterministic valueCoder", + valueCoder); + verifyDeterministic( + "FullWindowedValueCoder requires a deterministic windowCoder", + windowCoder); + } + + @Override + public void registerByteSizeObserver(WindowedValue value, + ElementByteSizeObserver observer, + Context context) throws Exception { + valueCoder.registerByteSizeObserver(value.getValue(), observer, context); + InstantCoder.of().registerByteSizeObserver(value.getTimestamp(), observer, context); + windowsCoder.registerByteSizeObserver(value.getWindows(), observer, context); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_WRAPPER, true); + return result; + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return Arrays.>asList(valueCoder, windowCoder); + } + } + + /** + * Coder for {@code WindowedValue}. + * + *

    A {@code ValueOnlyWindowedValueCoder} only encodes and decodes the value. It drops + * timestamp and windows for encoding, and uses defaults timestamp, and windows for decoding. + */ + public static class ValueOnlyWindowedValueCoder extends WindowedValueCoder { + public static ValueOnlyWindowedValueCoder of( + Coder valueCoder) { + return new ValueOnlyWindowedValueCoder<>(valueCoder); + } + + @JsonCreator + public static ValueOnlyWindowedValueCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + checkArgument(components.size() == 1, "Expecting 1 component, got " + components.size()); + return of(components.get(0)); + } + + ValueOnlyWindowedValueCoder(Coder valueCoder) { + super(valueCoder); + } + + @Override + public WindowedValueCoder withValueCoder(Coder valueCoder) { + return new ValueOnlyWindowedValueCoder<>(valueCoder); + } + + @Override + public void encode(WindowedValue windowedElem, OutputStream outStream, Context context) + throws CoderException, IOException { + valueCoder.encode(windowedElem.getValue(), outStream, context); + } + + @Override + public WindowedValue decode(InputStream inStream, Context context) + throws CoderException, IOException { + T value = valueCoder.decode(inStream, context); + return WindowedValue.valueInGlobalWindow(value); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "ValueOnlyWindowedValueCoder requires a deterministic valueCoder", + valueCoder); + } + + @Override + public void registerByteSizeObserver( + WindowedValue value, ElementByteSizeObserver observer, Context context) + throws Exception { + valueCoder.registerByteSizeObserver(value.getValue(), observer, context); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + addBoolean(result, PropertyNames.IS_WRAPPER, true); + return result; + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(valueCoder); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java new file mode 100644 index 000000000000..12fcd532a275 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.Collection; + +/** + * Interface that may be required by some (internal) {@code DoFn}s to implement windowing. It should + * not be necessary for general user code to interact with this at all. + * + *

    This interface should be provided by runner implementors to support windowing on their runner. + * + * @param input type + * @param output type + */ +public interface WindowingInternals { + + /** + * Unsupported state internals. The key type is unknown. It is up to the user to use the + * correct type of key. + */ + StateInternals stateInternals(); + + /** + * Output the value at the specified timestamp in the listed windows. + */ + void outputWindowedValue(OutputT output, Instant timestamp, + Collection windows, PaneInfo pane); + + /** + * Return the timer manager provided by the underlying system, or null if Timers need + * to be emulated. + */ + TimerInternals timerInternals(); + + /** + * Access the windows the element is being processed in without "exploding" it. + */ + Collection windows(); + + /** + * Access the pane of the current window(s). + */ + PaneInfo pane(); + + /** + * Write the given {@link PCollectionView} data to a location accessible by other workers. + */ + void writePCollectionViewData( + TupleTag tag, + Iterable> data, + Coder elemCoder) throws IOException; + + /** + * Return the value of the side input for the window of a main input element. + */ + T sideInput(PCollectionView view, BoundedWindow mainInputWindow); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingStrategy.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingStrategy.java new file mode 100644 index 000000000000..c167b8c0cdff --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingStrategy.java @@ -0,0 +1,268 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.common.base.MoreObjects; + +import org.joda.time.Duration; + +import java.io.Serializable; +import java.util.Objects; + +/** + * A {@code WindowingStrategy} describes the windowing behavior for a specific collection of values. + * It has both a {@link WindowFn} describing how elements are assigned to windows and a + * {@link Trigger} that controls when output is produced for each window. + * + * @param type of elements being windowed + * @param {@link BoundedWindow} subclass used to represent the + * windows used by this {@code WindowingStrategy} + */ +public class WindowingStrategy implements Serializable { + + /** + * The accumulation modes that can be used with windowing. + */ + public enum AccumulationMode { + DISCARDING_FIRED_PANES, + ACCUMULATING_FIRED_PANES; + } + + private static final Duration DEFAULT_ALLOWED_LATENESS = Duration.ZERO; + private static final WindowingStrategy DEFAULT = of(new GlobalWindows()); + + private final WindowFn windowFn; + private final OutputTimeFn outputTimeFn; + private final ExecutableTrigger trigger; + private final AccumulationMode mode; + private final Duration allowedLateness; + private final ClosingBehavior closingBehavior; + private final boolean triggerSpecified; + private final boolean modeSpecified; + private final boolean allowedLatenessSpecified; + private final boolean outputTimeFnSpecified; + + private WindowingStrategy( + WindowFn windowFn, + ExecutableTrigger trigger, boolean triggerSpecified, + AccumulationMode mode, boolean modeSpecified, + Duration allowedLateness, boolean allowedLatenessSpecified, + OutputTimeFn outputTimeFn, boolean outputTimeFnSpecified, + ClosingBehavior closingBehavior) { + this.windowFn = windowFn; + this.trigger = trigger; + this.triggerSpecified = triggerSpecified; + this.mode = mode; + this.modeSpecified = modeSpecified; + this.allowedLateness = allowedLateness; + this.allowedLatenessSpecified = allowedLatenessSpecified; + this.closingBehavior = closingBehavior; + this.outputTimeFn = outputTimeFn; + this.outputTimeFnSpecified = outputTimeFnSpecified; + } + + /** + * Return a fully specified, default windowing strategy. + */ + public static WindowingStrategy globalDefault() { + return DEFAULT; + } + + public static WindowingStrategy of( + WindowFn windowFn) { + return new WindowingStrategy<>(windowFn, + ExecutableTrigger.create(DefaultTrigger.of()), false, + AccumulationMode.DISCARDING_FIRED_PANES, false, + DEFAULT_ALLOWED_LATENESS, false, + windowFn.getOutputTimeFn(), false, + ClosingBehavior.FIRE_IF_NON_EMPTY); + } + + public WindowFn getWindowFn() { + return windowFn; + } + + public ExecutableTrigger getTrigger() { + return trigger; + } + + public boolean isTriggerSpecified() { + return triggerSpecified; + } + + public Duration getAllowedLateness() { + return allowedLateness; + } + + public boolean isAllowedLatenessSpecified() { + return allowedLatenessSpecified; + } + + public AccumulationMode getMode() { + return mode; + } + + public boolean isModeSpecified() { + return modeSpecified; + } + + public ClosingBehavior getClosingBehavior() { + return closingBehavior; + } + + public OutputTimeFn getOutputTimeFn() { + return outputTimeFn; + } + + public boolean isOutputTimeFnSpecified() { + return outputTimeFnSpecified; + } + + /** + * Returns a {@link WindowingStrategy} identical to {@code this} but with the trigger set to + * {@code wildcardTrigger}. + */ + public WindowingStrategy withTrigger(Trigger wildcardTrigger) { + @SuppressWarnings("unchecked") + Trigger typedTrigger = (Trigger) wildcardTrigger; + return new WindowingStrategy( + windowFn, + ExecutableTrigger.create(typedTrigger), true, + mode, modeSpecified, + allowedLateness, allowedLatenessSpecified, + outputTimeFn, outputTimeFnSpecified, + closingBehavior); + } + + /** + * Returns a {@link WindowingStrategy} identical to {@code this} but with the accumulation mode + * set to {@code mode}. + */ + public WindowingStrategy withMode(AccumulationMode mode) { + return new WindowingStrategy( + windowFn, + trigger, triggerSpecified, + mode, true, + allowedLateness, allowedLatenessSpecified, + outputTimeFn, outputTimeFnSpecified, + closingBehavior); + } + + /** + * Returns a {@link WindowingStrategy} identical to {@code this} but with the window function + * set to {@code wildcardWindowFn}. + */ + public WindowingStrategy withWindowFn(WindowFn wildcardWindowFn) { + @SuppressWarnings("unchecked") + WindowFn typedWindowFn = (WindowFn) wildcardWindowFn; + + // The onus of type correctness falls on the callee. + @SuppressWarnings("unchecked") + OutputTimeFn newOutputTimeFn = (OutputTimeFn) + (outputTimeFnSpecified ? outputTimeFn : typedWindowFn.getOutputTimeFn()); + + return new WindowingStrategy( + typedWindowFn, + trigger, triggerSpecified, + mode, modeSpecified, + allowedLateness, allowedLatenessSpecified, + newOutputTimeFn, outputTimeFnSpecified, + closingBehavior); + } + + /** + * Returns a {@link WindowingStrategy} identical to {@code this} but with the allowed lateness + * set to {@code allowedLateness}. + */ + public WindowingStrategy withAllowedLateness(Duration allowedLateness) { + return new WindowingStrategy( + windowFn, + trigger, triggerSpecified, + mode, modeSpecified, + allowedLateness, true, + outputTimeFn, outputTimeFnSpecified, + closingBehavior); + } + + public WindowingStrategy withClosingBehavior(ClosingBehavior closingBehavior) { + return new WindowingStrategy( + windowFn, + trigger, triggerSpecified, + mode, modeSpecified, + allowedLateness, allowedLatenessSpecified, + outputTimeFn, outputTimeFnSpecified, + closingBehavior); + } + + @Experimental(Experimental.Kind.OUTPUT_TIME) + public WindowingStrategy withOutputTimeFn(OutputTimeFn outputTimeFn) { + + @SuppressWarnings("unchecked") + OutputTimeFn typedOutputTimeFn = (OutputTimeFn) outputTimeFn; + + return new WindowingStrategy( + windowFn, + trigger, triggerSpecified, + mode, modeSpecified, + allowedLateness, allowedLatenessSpecified, + typedOutputTimeFn, true, + closingBehavior); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("windowFn", windowFn) + .add("allowedLateness", allowedLateness) + .add("trigger", trigger) + .add("accumulationMode", mode) + .add("outputTimeFn", outputTimeFn) + .toString(); + } + + @Override + public boolean equals(Object object) { + if (!(object instanceof WindowingStrategy)) { + return false; + } + WindowingStrategy other = (WindowingStrategy) object; + return + isTriggerSpecified() == other.isTriggerSpecified() + && isAllowedLatenessSpecified() == other.isAllowedLatenessSpecified() + && isModeSpecified() == other.isModeSpecified() + && getMode().equals(other.getMode()) + && getAllowedLateness().equals(other.getAllowedLateness()) + && getClosingBehavior().equals(other.getClosingBehavior()) + && getTrigger().equals(other.getTrigger()) + && getWindowFn().equals(other.getWindowFn()); + } + + @Override + public int hashCode() { + return Objects.hash(triggerSpecified, allowedLatenessSpecified, modeSpecified, + windowFn, trigger, mode, allowedLateness, closingBehavior); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ZipFiles.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ZipFiles.java new file mode 100644 index 000000000000..773b65fb9822 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ZipFiles.java @@ -0,0 +1,294 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.FluentIterable; +import com.google.common.collect.Iterators; +import com.google.common.io.ByteSource; +import com.google.common.io.CharSource; +import com.google.common.io.Closer; +import com.google.common.io.Files; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Iterator; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; +import java.util.zip.ZipOutputStream; + +/** + * Functions for zipping a directory (including a subdirectory) into a ZIP-file + * or unzipping it again. + */ +public final class ZipFiles { + private ZipFiles() {} + + /** + * Returns a new {@link ByteSource} for reading the contents of the given + * entry in the given zip file. + */ + static ByteSource asByteSource(ZipFile file, ZipEntry entry) { + return new ZipEntryByteSource(file, entry); + } + + /** + * Returns a new {@link CharSource} for reading the contents of the given + * entry in the given zip file as text using the given charset. + */ + static CharSource asCharSource( + ZipFile file, ZipEntry entry, Charset charset) { + return asByteSource(file, entry).asCharSource(charset); + } + + private static final class ZipEntryByteSource extends ByteSource { + + private final ZipFile file; + private final ZipEntry entry; + + ZipEntryByteSource(ZipFile file, ZipEntry entry) { + this.file = checkNotNull(file); + this.entry = checkNotNull(entry); + } + + @Override + public InputStream openStream() throws IOException { + return file.getInputStream(entry); + } + + // TODO: implement size() to try calling entry.getSize()? + + @Override + public String toString() { + return "ZipFiles.asByteSource(" + file + ", " + entry + ")"; + } + } + + /** + * Returns a {@link FluentIterable} of all the entries in the given zip file. + */ + // unmodifiable Iterator can be safely cast + // to Iterator + @SuppressWarnings("unchecked") + static FluentIterable entries(final ZipFile file) { + checkNotNull(file); + return new FluentIterable() { + @Override + public Iterator iterator() { + return (Iterator) Iterators.forEnumeration(file.entries()); + } + }; + } + + /** + * Unzips the zip file specified by the path and creates the directory structure inside + * the target directory. Refuses to unzip files that refer to a parent directory, for security + * reasons. + * + * @param zipFile the source zip-file to unzip + * @param targetDirectory the directory to unzip to. If the zip-file contains + * any subdirectories, they will be created within our target directory. + * @throws IOException the unzipping failed, e.g. because the output was not writable, the {@code + * zipFile} was not readable, or contains an illegal entry (contains "..", pointing outside + * the target directory) + * @throws IllegalArgumentException the target directory is not a valid directory (e.g. does not + * exist, or is a file instead of a directory) + */ + static void unzipFile( + File zipFile, + File targetDirectory) throws IOException { + checkNotNull(zipFile); + checkNotNull(targetDirectory); + checkArgument( + targetDirectory.isDirectory(), + "%s is not a valid directory", + targetDirectory.getAbsolutePath()); + final ZipFile zipFileObj = new ZipFile(zipFile); + try { + for (ZipEntry entry : entries(zipFileObj)) { + checkName(entry.getName()); + File targetFile = new File(targetDirectory, entry.getName()); + if (entry.isDirectory()) { + if (!targetFile.isDirectory() && !targetFile.mkdirs()) { + throw new IOException( + "Failed to create directory: " + targetFile.getAbsolutePath()); + } + } else { + File parentFile = targetFile.getParentFile(); + if (!parentFile.isDirectory()) { + if (!parentFile.mkdirs()) { + throw new IOException( + "Failed to create directory: " + + parentFile.getAbsolutePath()); + } + } + // Write the file to the destination. + asByteSource(zipFileObj, entry).copyTo(Files.asByteSink(targetFile)); + } + } + } finally { + zipFileObj.close(); + } + } + + /** + * Checks that the given entry name is legal for unzipping: if it contains + * ".." as a name element, it could cause the entry to be unzipped outside + * the directory we're unzipping to. + * + * @throws IOException if the name is illegal + */ + private static void checkName(String name) throws IOException { + // First just check whether the entry name string contains "..". + // This should weed out the the vast majority of entries, which will not + // contain "..". + if (name.contains("..")) { + // If the string does contain "..", break it down into its actual name + // elements to ensure it actually contains ".." as a name, not just a + // name like "foo..bar" or even "foo..", which should be fine. + File file = new File(name); + while (file != null) { + if (file.getName().equals("..")) { + throw new IOException("Cannot unzip file containing an entry with " + + "\"..\" in the name: " + name); + } + file = file.getParentFile(); + } + } + } + + /** + * Zips an entire directory specified by the path. + * + * @param sourceDirectory the directory to read from. This directory and all + * subdirectories will be added to the zip-file. The path within the zip + * file is relative to the directory given as parameter, not absolute. + * @param zipFile the zip-file to write to. + * @throws IOException the zipping failed, e.g. because the input was not + * readable. + */ + static void zipDirectory( + File sourceDirectory, + File zipFile) throws IOException { + checkNotNull(sourceDirectory); + checkNotNull(zipFile); + checkArgument( + sourceDirectory.isDirectory(), + "%s is not a valid directory", + sourceDirectory.getAbsolutePath()); + checkArgument( + !zipFile.exists(), + "%s does already exist, files are not being overwritten", + zipFile.getAbsolutePath()); + Closer closer = Closer.create(); + try { + OutputStream outputStream = closer.register(new BufferedOutputStream( + new FileOutputStream(zipFile))); + zipDirectory(sourceDirectory, outputStream); + } catch (Throwable t) { + throw closer.rethrow(t); + } finally { + closer.close(); + } + } + + /** + * Zips an entire directory specified by the path. + * + * @param sourceDirectory the directory to read from. This directory and all + * subdirectories will be added to the zip-file. The path within the zip + * file is relative to the directory given as parameter, not absolute. + * @param outputStream the stream to write the zip-file to. This method does not close + * outputStream. + * @throws IOException the zipping failed, e.g. because the input was not + * readable. + */ + static void zipDirectory( + File sourceDirectory, + OutputStream outputStream) throws IOException { + checkNotNull(sourceDirectory); + checkNotNull(outputStream); + checkArgument( + sourceDirectory.isDirectory(), + "%s is not a valid directory", + sourceDirectory.getAbsolutePath()); + ZipOutputStream zos = new ZipOutputStream(outputStream); + for (File file : sourceDirectory.listFiles()) { + zipDirectoryInternal(file, "", zos); + } + zos.finish(); + } + + /** + * Private helper function for zipping files. This one goes recursively + * through the input directory and all of its subdirectories and adds the + * single zip entries. + * + * @param inputFile the file or directory to be added to the zip file + * @param directoryName the string-representation of the parent directory + * name. Might be an empty name, or a name containing multiple directory + * names separated by "/". The directory name must be a valid name + * according to the file system limitations. The directory name should be + * empty or should end in "/". + * @param zos the zipstream to write to + * @throws IOException the zipping failed, e.g. because the output was not + * writeable. + */ + private static void zipDirectoryInternal( + File inputFile, + String directoryName, + ZipOutputStream zos) throws IOException { + String entryName = directoryName + inputFile.getName(); + if (inputFile.isDirectory()) { + entryName += "/"; + + // We are hitting a sub-directory. Recursively add children to zip in deterministic, + // sorted order. + File[] childFiles = inputFile.listFiles(); + if (childFiles.length > 0) { + Arrays.sort(childFiles); + // loop through the directory content, and zip the files + for (File file : childFiles) { + zipDirectoryInternal(file, entryName, zos); + } + + // Since this directory has children, exit now without creating a zipentry specific to + // this directory. The entry for a non-entry directory is incompatible with certain + // implementations of unzip. + return; + } + } + + // Put the zip-entry for this file or empty directory into the zipoutputstream. + ZipEntry entry = new ZipEntry(entryName); + entry.setTime(inputFile.lastModified()); + zos.putNextEntry(entry); + + // Copy file contents into zipoutput stream. + if (inputFile.isFile()) { + Files.asByteSource(inputFile).copyTo(zos); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Counter.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Counter.java new file mode 100644 index 000000000000..2c1985c0535b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Counter.java @@ -0,0 +1,1103 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.AND; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.OR; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.util.concurrent.AtomicDouble; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import javax.annotation.Nullable; + +/** + * A Counter enables the aggregation of a stream of values over time. The + * cumulative aggregate value is updated as new values are added, or it can be + * reset to a new value. Multiple kinds of aggregation are supported depending + * on the type of the counter. + * + *

    Counters compare using value equality of their name, kind, and + * cumulative value. Equal counters should have equal toString()s. + * + * @param the type of values aggregated by this counter + */ +public abstract class Counter { + /** + * Possible kinds of counter aggregation. + */ + public static enum AggregationKind { + + /** + * Computes the sum of all added values. + * Applicable to {@link Integer}, {@link Long}, and {@link Double} values. + */ + SUM, + + /** + * Computes the maximum value of all added values. + * Applicable to {@link Integer}, {@link Long}, and {@link Double} values. + */ + MAX, + + /** + * Computes the minimum value of all added values. + * Applicable to {@link Integer}, {@link Long}, and {@link Double} values. + */ + MIN, + + /** + * Computes the arithmetic mean of all added values. Applicable to + * {@link Integer}, {@link Long}, and {@link Double} values. + */ + MEAN, + + /** + * Computes boolean AND over all added values. + * Applicable only to {@link Boolean} values. + */ + AND, + + /** + * Computes boolean OR over all added values. Applicable only to + * {@link Boolean} values. + */ + OR + // TODO: consider adding VECTOR_SUM, HISTOGRAM, KV_SET, PRODUCT, TOP. + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Integer}, values + * according to the desired aggregation kind. The supported aggregation kinds + * are {@link AggregationKind#SUM}, {@link AggregationKind#MIN}, + * {@link AggregationKind#MAX}, and {@link AggregationKind#MEAN}. + * This is a convenience wrapper over a + * {@link Counter} implementation that aggregates {@link Long} values. This is + * useful when the application handles (boxed) {@link Integer} values that + * are not readily convertible to the (boxed) {@link Long} values otherwise + * expected by the {@link Counter} implementation aggregating {@link Long} + * values. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter ints(String name, AggregationKind kind) { + return new IntegerCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Long} values + * according to the desired aggregation kind. The supported aggregation kinds + * are {@link AggregationKind#SUM}, {@link AggregationKind#MIN}, + * {@link AggregationKind#MAX}, and {@link AggregationKind#MEAN}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter longs(String name, AggregationKind kind) { + return new LongCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Double} values + * according to the desired aggregation kind. The supported aggregation kinds + * are {@link AggregationKind#SUM}, {@link AggregationKind#MIN}, + * {@link AggregationKind#MAX}, and {@link AggregationKind#MEAN}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter doubles(String name, AggregationKind kind) { + return new DoubleCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link Boolean} values + * according to the desired aggregation kind. The only supported aggregation + * kinds are {@link AggregationKind#AND} and {@link AggregationKind#OR}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + public static Counter booleans(String name, AggregationKind kind) { + return new BooleanCounter(name, kind); + } + + /** + * Constructs a new {@link Counter} that aggregates {@link String} values + * according to the desired aggregation kind. The only supported aggregation + * kind is {@link AggregationKind#MIN} and {@link AggregationKind#MAX}. + * + * @param name the name of the new counter + * @param kind the new counter's aggregation kind + * @return the newly constructed Counter + * @throws IllegalArgumentException if the aggregation kind is not supported + */ + @SuppressWarnings("unused") + private static Counter strings(String name, AggregationKind kind) { + return new StringCounter(name, kind); + } + + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Adds a new value to the aggregation stream. Returns this (to allow method + * chaining). + */ + public abstract Counter addValue(T value); + + /** + * Resets the aggregation stream to this new value. This aggregator must not + * be a MEAN aggregator. Returns this (to allow method chaining). + */ + public abstract Counter resetToValue(T value); + + /** + * Resets the aggregation stream to this new value. Returns this (to allow + * method chaining). The value of elementCount must be non-negative, and this + * aggregator must be a MEAN aggregator. + */ + public abstract Counter resetMeanToValue(long elementCount, T value); + + /** + * Resets the counter's delta value to have no values accumulated and returns + * the value of the delta prior to the reset. + * + * @return the aggregate delta at the time this method is called + */ + public abstract T getAndResetDelta(); + + /** + * Resets the counter's delta value to have no values accumulated and returns + * the value of the delta prior to the reset, for a MEAN counter. + * + * @return the mean delta t the time this method is called + */ + public abstract CounterMean getAndResetMeanDelta(); + + /** + * Returns the counter's name. + */ + public String getName() { + return name; + } + + /** + * Returns the counter's aggregation kind. + */ + public AggregationKind getKind() { + return kind; + } + + /** + * Returns the counter's type. + */ + public Class getType() { + return new TypeDescriptor(getClass()) {}.getRawType(); + } + + /** + * Returns the aggregated value, or the sum for MEAN aggregation, either + * total or, if delta, since the last update extraction or resetDelta. + */ + public abstract T getAggregate(); + + /** + * The mean value of a {@code Counter}, represented as an aggregate value and + * a count. + * + * @param the type of the aggregate + */ + public static interface CounterMean { + /** + * Gets the aggregate value of this {@code CounterMean}. + */ + T getAggregate(); + + /** + * Gets the count of this {@code CounterMean}. + */ + long getCount(); + } + + /** + * Returns the mean in the form of a CounterMean, or null if this is not a + * MEAN counter. + */ + @Nullable + public abstract CounterMean getMean(); + + /** + * Returns a string representation of the Counter. Useful for debugging logs. + * Example return value: "ElementCount:SUM(15)". + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getName()); + sb.append(":"); + sb.append(getKind()); + sb.append("("); + switch (kind) { + case SUM: + case MAX: + case MIN: + case AND: + case OR: + sb.append(getAggregate()); + break; + case MEAN: + sb.append(getMean()); + break; + default: + throw illegalArgumentException(); + } + sb.append(")"); + + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof Counter) { + Counter that = (Counter) o; + if (this.name.equals(that.name) && this.kind == that.kind + && this.getClass().equals(that.getClass())) { + if (kind == MEAN) { + CounterMean thisMean = this.getMean(); + CounterMean thatMean = that.getMean(); + return thisMean == thatMean + || (Objects.equals(thisMean.getAggregate(), thatMean.getAggregate()) + && thisMean.getCount() == thatMean.getCount()); + } else { + return Objects.equals(this.getAggregate(), that.getAggregate()); + } + } + } + return false; + } + + @Override + public int hashCode() { + if (kind == MEAN) { + CounterMean mean = getMean(); + return Objects.hash(getClass(), name, kind, mean.getAggregate(), mean.getCount()); + } else { + return Objects.hash(getClass(), name, kind, getAggregate()); + } + } + + /** + * Returns whether this Counter is compatible with that Counter. If + * so, they can be merged into a single Counter. + */ + public boolean isCompatibleWith(Counter that) { + return this.name.equals(that.name) + && this.kind == that.kind + && this.getClass().equals(that.getClass()); + } + + /** + * Merges this counter with the provided counter, returning this counter with the combined value + * of both counters. This may reset the delta of this counter. + * + * @throws IllegalArgumentException if the provided Counter is not compatible with this Counter + */ + public abstract Counter merge(Counter that); + + ////////////////////////////////////////////////////////////////////////////// + + /** The name of this counter. */ + protected final String name; + + /** The kind of aggregation function to apply to this counter. */ + protected final AggregationKind kind; + + protected Counter(String name, AggregationKind kind) { + this.name = name; + this.kind = kind; + } + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Implements a {@link Counter} for {@link Long} values. + */ + private static class LongCounter extends Counter { + private final AtomicLong aggregate; + private final AtomicLong deltaAggregate; + private final AtomicReference mean; + private final AtomicReference deltaMean; + + /** Initializes a new {@link Counter} for {@link Long} values. */ + private LongCounter(String name, AggregationKind kind) { + super(name, kind); + switch (kind) { + case MEAN: + mean = new AtomicReference<>(); + deltaMean = new AtomicReference<>(); + getAndResetMeanDelta(); + mean.set(deltaMean.get()); + aggregate = deltaAggregate = null; + break; + case SUM: + case MAX: + case MIN: + aggregate = new AtomicLong(); + deltaAggregate = new AtomicLong(); + getAndResetDelta(); + aggregate.set(deltaAggregate.get()); + mean = deltaMean = null; + break; + default: + throw illegalArgumentException(); + } + } + + @Override + public LongCounter addValue(Long value) { + switch (kind) { + case SUM: + aggregate.addAndGet(value); + deltaAggregate.addAndGet(value); + break; + case MEAN: + addToMeanAndSet(value, mean); + addToMeanAndSet(value, deltaMean); + break; + case MAX: + maxAndSet(value, aggregate); + maxAndSet(value, deltaAggregate); + break; + case MIN: + minAndSet(value, aggregate); + minAndSet(value, deltaAggregate); + break; + default: + throw illegalArgumentException(); + } + return this; + } + + private void minAndSet(Long value, AtomicLong target) { + long current; + long update; + do { + current = target.get(); + update = Math.min(value, current); + } while (update < current && !target.compareAndSet(current, update)); + } + + private void maxAndSet(Long value, AtomicLong target) { + long current; + long update; + do { + current = target.get(); + update = Math.max(value, current); + } while (update > current && !target.compareAndSet(current, update)); + } + + private void addToMeanAndSet(Long value, AtomicReference target) { + LongCounterMean current; + LongCounterMean update; + do { + current = target.get(); + update = new LongCounterMean(current.getAggregate() + value, current.getCount() + 1L); + } while (!target.compareAndSet(current, update)); + } + + @Override + public Long getAggregate() { + if (kind != MEAN) { + return aggregate.get(); + } else { + return getMean().getAggregate(); + } + } + + @Override + public Long getAndResetDelta() { + switch (kind) { + case SUM: + return deltaAggregate.getAndSet(0L); + case MAX: + return deltaAggregate.getAndSet(Long.MIN_VALUE); + case MIN: + return deltaAggregate.getAndSet(Long.MAX_VALUE); + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter resetToValue(Long value) { + if (kind == MEAN) { + throw illegalArgumentException(); + } + aggregate.set(value); + deltaAggregate.set(value); + return this; + } + + @Override + public Counter resetMeanToValue(long elementCount, Long value) { + if (kind != MEAN) { + throw illegalArgumentException(); + } + if (elementCount < 0) { + throw new IllegalArgumentException("elementCount must be non-negative"); + } + LongCounterMean counterMean = new LongCounterMean(value, elementCount); + mean.set(counterMean); + deltaMean.set(counterMean); + return this; + } + + @Override + public CounterMean getAndResetMeanDelta() { + if (kind != MEAN) { + throw illegalArgumentException(); + } + return deltaMean.getAndSet(new LongCounterMean(0L, 0L)); + } + + @Override + @Nullable + public CounterMean getMean() { + if (kind != MEAN) { + throw illegalArgumentException(); + } + return mean.get(); + } + + @Override + public Counter merge(Counter that) { + checkArgument(this.isCompatibleWith(that), "Counters %s and %s are incompatible", this, that); + switch (kind) { + case SUM: + case MIN: + case MAX: + return addValue(that.getAggregate()); + case MEAN: + CounterMean thisCounterMean = this.getMean(); + CounterMean thatCounterMean = that.getMean(); + return resetMeanToValue( + thisCounterMean.getCount() + thatCounterMean.getCount(), + thisCounterMean.getAggregate() + thatCounterMean.getAggregate()); + default: + throw illegalArgumentException(); + } + } + + private static class LongCounterMean implements CounterMean { + private final long aggregate; + private final long count; + + public LongCounterMean(long aggregate, long count) { + this.aggregate = aggregate; + this.count = count; + } + + @Override + public Long getAggregate() { + return aggregate; + } + + @Override + public long getCount() { + return count; + } + + @Override + public String toString() { + return aggregate + "/" + count; + } + } + } + + /** + * Implements a {@link Counter} for {@link Double} values. + */ + private static class DoubleCounter extends Counter { + AtomicDouble aggregate; + AtomicDouble deltaAggregate; + AtomicReference mean; + AtomicReference deltaMean; + + /** Initializes a new {@link Counter} for {@link Double} values. */ + private DoubleCounter(String name, AggregationKind kind) { + super(name, kind); + switch (kind) { + case MEAN: + aggregate = deltaAggregate = null; + mean = new AtomicReference<>(); + deltaMean = new AtomicReference<>(); + getAndResetMeanDelta(); + mean.set(deltaMean.get()); + break; + case SUM: + case MAX: + case MIN: + mean = deltaMean = null; + aggregate = new AtomicDouble(); + deltaAggregate = new AtomicDouble(); + getAndResetDelta(); + aggregate.set(deltaAggregate.get()); + break; + default: + throw illegalArgumentException(); + } + } + + @Override + public DoubleCounter addValue(Double value) { + switch (kind) { + case SUM: + aggregate.addAndGet(value); + deltaAggregate.addAndGet(value); + break; + case MEAN: + addToMeanAndSet(value, mean); + addToMeanAndSet(value, deltaMean); + break; + case MAX: + maxAndSet(value, aggregate); + maxAndSet(value, deltaAggregate); + break; + case MIN: + minAndSet(value, aggregate); + minAndSet(value, deltaAggregate); + break; + default: + throw illegalArgumentException(); + } + return this; + } + + private void addToMeanAndSet(Double value, AtomicReference target) { + DoubleCounterMean current; + DoubleCounterMean update; + do { + current = target.get(); + update = new DoubleCounterMean(current.getAggregate() + value, current.getCount() + 1); + } while (!target.compareAndSet(current, update)); + } + + private void maxAndSet(Double value, AtomicDouble target) { + double current; + double update; + do { + current = target.get(); + update = Math.max(current, value); + } while (update > current && !target.compareAndSet(current, update)); + } + + private void minAndSet(Double value, AtomicDouble target) { + double current; + double update; + do { + current = target.get(); + update = Math.min(current, value); + } while (update < current && !target.compareAndSet(current, update)); + } + + @Override + public Double getAndResetDelta() { + switch (kind) { + case SUM: + return deltaAggregate.getAndSet(0.0); + case MAX: + return deltaAggregate.getAndSet(Double.NEGATIVE_INFINITY); + case MIN: + return deltaAggregate.getAndSet(Double.POSITIVE_INFINITY); + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter resetToValue(Double value) { + if (kind == MEAN) { + throw illegalArgumentException(); + } + aggregate.set(value); + deltaAggregate.set(value); + return this; + } + + @Override + public Counter resetMeanToValue(long elementCount, Double value) { + if (kind != MEAN) { + throw illegalArgumentException(); + } + if (elementCount < 0) { + throw new IllegalArgumentException("elementCount must be non-negative"); + } + DoubleCounterMean counterMean = new DoubleCounterMean(value, elementCount); + mean.set(counterMean); + deltaMean.set(counterMean); + return this; + } + + @Override + public CounterMean getAndResetMeanDelta() { + if (kind != MEAN) { + throw illegalArgumentException(); + } + return deltaMean.getAndSet(new DoubleCounterMean(0.0, 0L)); + } + + @Override + public Double getAggregate() { + if (kind != MEAN) { + return aggregate.get(); + } else { + return getMean().getAggregate(); + } + } + + @Override + @Nullable + public CounterMean getMean() { + if (kind != MEAN) { + throw illegalArgumentException(); + } + return mean.get(); + } + + @Override + public Counter merge(Counter that) { + checkArgument(this.isCompatibleWith(that), "Counters %s and %s are incompatible", this, that); + switch (kind) { + case SUM: + case MIN: + case MAX: + return addValue(that.getAggregate()); + case MEAN: + CounterMean thisCounterMean = this.getMean(); + CounterMean thatCounterMean = that.getMean(); + return resetMeanToValue( + thisCounterMean.getCount() + thatCounterMean.getCount(), + thisCounterMean.getAggregate() + thatCounterMean.getAggregate()); + default: + throw illegalArgumentException(); + } + } + + private static class DoubleCounterMean implements CounterMean { + private final double aggregate; + private final long count; + + public DoubleCounterMean(double aggregate, long count) { + this.aggregate = aggregate; + this.count = count; + } + + @Override + public Double getAggregate() { + return aggregate; + } + + @Override + public long getCount() { + return count; + } + + @Override + public String toString() { + return aggregate + "/" + count; + } + } + } + + /** + * Implements a {@link Counter} for {@link Boolean} values. + */ + private static class BooleanCounter extends Counter { + private final AtomicBoolean aggregate; + private final AtomicBoolean deltaAggregate; + + /** Initializes a new {@link Counter} for {@link Boolean} values. */ + private BooleanCounter(String name, AggregationKind kind) { + super(name, kind); + aggregate = new AtomicBoolean(); + deltaAggregate = new AtomicBoolean(); + getAndResetDelta(); + aggregate.set(deltaAggregate.get()); + } + + @Override + public BooleanCounter addValue(Boolean value) { + if (kind.equals(AND) && !value) { + aggregate.set(value); + deltaAggregate.set(value); + } else if (kind.equals(OR) && value) { + aggregate.set(value); + deltaAggregate.set(value); + } + return this; + } + + @Override + public Boolean getAndResetDelta() { + switch (kind) { + case AND: + return deltaAggregate.getAndSet(true); + case OR: + return deltaAggregate.getAndSet(false); + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter resetToValue(Boolean value) { + aggregate.set(value); + deltaAggregate.set(value); + return this; + } + + @Override + public Counter resetMeanToValue(long elementCount, Boolean value) { + throw illegalArgumentException(); + } + + @Override + public CounterMean getAndResetMeanDelta() { + throw illegalArgumentException(); + } + + @Override + public Boolean getAggregate() { + return aggregate.get(); + } + + @Override + @Nullable + public CounterMean getMean() { + throw illegalArgumentException(); + } + + @Override + public Counter merge(Counter that) { + checkArgument(this.isCompatibleWith(that), "Counters %s and %s are incompatible", this, that); + return addValue(that.getAggregate()); + } + } + + /** + * Implements a {@link Counter} for {@link String} values. + */ + private static class StringCounter extends Counter { + /** Initializes a new {@link Counter} for {@link String} values. */ + private StringCounter(String name, AggregationKind kind) { + super(name, kind); + // TODO: Support MIN, MAX of Strings. + throw illegalArgumentException(); + } + + @Override + public StringCounter addValue(String value) { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter resetToValue(String value) { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter resetMeanToValue(long elementCount, String value) { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + public String getAndResetDelta() { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + public CounterMean getAndResetMeanDelta() { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + public String getAggregate() { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + @Nullable + public CounterMean getMean() { + switch (kind) { + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter merge(Counter that) { + checkArgument(this.isCompatibleWith(that), "Counters %s and %s are incompatible", this, that); + switch (kind) { + default: + throw illegalArgumentException(); + } + } + } + + /** + * Implements a {@link Counter} for {@link Integer} values. + */ + private static class IntegerCounter extends Counter { + private final AtomicInteger aggregate; + private final AtomicInteger deltaAggregate; + private final AtomicReference mean; + private final AtomicReference deltaMean; + + /** Initializes a new {@link Counter} for {@link Integer} values. */ + private IntegerCounter(String name, AggregationKind kind) { + super(name, kind); + switch (kind) { + case MEAN: + aggregate = deltaAggregate = null; + mean = new AtomicReference<>(); + deltaMean = new AtomicReference<>(); + getAndResetMeanDelta(); + mean.set(deltaMean.get()); + break; + case SUM: + case MAX: + case MIN: + mean = deltaMean = null; + aggregate = new AtomicInteger(); + deltaAggregate = new AtomicInteger(); + getAndResetDelta(); + aggregate.set(deltaAggregate.get()); + break; + default: + throw illegalArgumentException(); + } + } + + @Override + public IntegerCounter addValue(Integer value) { + switch (kind) { + case SUM: + aggregate.getAndAdd(value); + deltaAggregate.getAndAdd(value); + break; + case MEAN: + addToMeanAndSet(value, mean); + addToMeanAndSet(value, deltaMean); + break; + case MAX: + maxAndSet(value, aggregate); + maxAndSet(value, deltaAggregate); + break; + case MIN: + minAndSet(value, aggregate); + minAndSet(value, deltaAggregate); + break; + default: + throw illegalArgumentException(); + } + return this; + } + + private void addToMeanAndSet(int value, AtomicReference target) { + IntegerCounterMean current; + IntegerCounterMean update; + do { + current = target.get(); + update = new IntegerCounterMean(current.getAggregate() + value, current.getCount() + 1); + } while (!target.compareAndSet(current, update)); + } + + private void maxAndSet(int value, AtomicInteger target) { + int current; + int update; + do { + current = target.get(); + update = Math.max(value, current); + } while (update > current && !target.compareAndSet(current, update)); + } + + private void minAndSet(int value, AtomicInteger target) { + int current; + int update; + do { + current = target.get(); + update = Math.min(value, current); + } while (update < current && !target.compareAndSet(current, update)); + } + + @Override + public Integer getAndResetDelta() { + switch (kind) { + case SUM: + return deltaAggregate.getAndSet(0); + case MAX: + return deltaAggregate.getAndSet(Integer.MIN_VALUE); + case MIN: + return deltaAggregate.getAndSet(Integer.MAX_VALUE); + default: + throw illegalArgumentException(); + } + } + + @Override + public Counter resetToValue(Integer value) { + if (kind == MEAN) { + throw illegalArgumentException(); + } + aggregate.set(value); + deltaAggregate.set(value); + return this; + } + + @Override + public Counter resetMeanToValue(long elementCount, Integer value) { + if (kind != MEAN) { + throw illegalArgumentException(); + } + if (elementCount < 0) { + throw new IllegalArgumentException("elementCount must be non-negative"); + } + IntegerCounterMean counterMean = new IntegerCounterMean(value, elementCount); + mean.set(counterMean); + deltaMean.set(counterMean); + return this; + } + + @Override + public CounterMean getAndResetMeanDelta() { + if (kind != MEAN) { + throw illegalArgumentException(); + } + return deltaMean.getAndSet(new IntegerCounterMean(0, 0L)); + } + + @Override + public Integer getAggregate() { + if (kind != MEAN) { + return aggregate.get(); + } else { + return getMean().getAggregate(); + } + } + + @Override + @Nullable + public CounterMean getMean() { + if (kind != MEAN) { + throw illegalArgumentException(); + } + return mean.get(); + } + + @Override + public Counter merge(Counter that) { + checkArgument(this.isCompatibleWith(that), "Counters %s and %s are incompatible", this, that); + switch (kind) { + case SUM: + case MIN: + case MAX: + return addValue(that.getAggregate()); + case MEAN: + CounterMean thisCounterMean = this.getMean(); + CounterMean thatCounterMean = that.getMean(); + return resetMeanToValue( + thisCounterMean.getCount() + thatCounterMean.getCount(), + thisCounterMean.getAggregate() + thatCounterMean.getAggregate()); + default: + throw illegalArgumentException(); + } + } + + private static class IntegerCounterMean implements CounterMean { + private final int aggregate; + private final long count; + + public IntegerCounterMean(int aggregate, long count) { + this.aggregate = aggregate; + this.count = count; + } + + @Override + public Integer getAggregate() { + return aggregate; + } + + @Override + public long getCount() { + return count; + } + + @Override + public String toString() { + return aggregate + "/" + count; + } + } + } + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Constructs an {@link IllegalArgumentException} explaining that this + * {@link Counter}'s aggregation kind is not supported by its value type. + */ + protected IllegalArgumentException illegalArgumentException() { + return new IllegalArgumentException("Cannot compute " + kind + + " aggregation over " + getType().getSimpleName() + " values."); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterProvider.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterProvider.java new file mode 100644 index 000000000000..ba53f80d208a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterProvider.java @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common; + +/** + * A counter provider can provide {@link Counter} instances. + * + * @param the input type of the counter. + */ +public interface CounterProvider { + Counter getCounter(String name); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterSet.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterSet.java new file mode 100644 index 000000000000..9e9638ff33eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/CounterSet.java @@ -0,0 +1,177 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.util.AbstractSet; +import java.util.HashMap; +import java.util.Iterator; + +/** + * A CounterSet maintains a set of {@link Counter}s. + * + *

    Thread-safe. + */ +public class CounterSet extends AbstractSet> { + + /** Registered counters. */ + private final HashMap> counters = new HashMap<>(); + + private final AddCounterMutator addCounterMutator = new AddCounterMutator(); + + /** + * Constructs a CounterSet containing the given Counters. + */ + public CounterSet(Counter... counters) { + for (Counter counter : counters) { + addNewCounter(counter); + } + } + + /** + * Returns an object that supports adding additional counters into + * this CounterSet. + */ + public AddCounterMutator getAddCounterMutator() { + return addCounterMutator; + } + + /** + * Adds a new counter, throwing an exception if a counter of the + * same name already exists. + */ + public void addNewCounter(Counter counter) { + if (!addCounter(counter)) { + throw new IllegalArgumentException( + "Counter " + counter + " duplicates an existing counter in " + this); + } + } + + /** + * Adds the given Counter to this CounterSet. + * + *

    If a counter with the same name already exists, it will be + * reused, as long as it is compatible. + * + * @return the Counter that was reused, or added + * @throws IllegalArgumentException if a counter with the same + * name but an incompatible kind had already been added + */ + public synchronized Counter addOrReuseCounter(Counter counter) { + Counter oldCounter = counters.get(counter.getName()); + if (oldCounter == null) { + // A new counter. + counters.put(counter.getName(), counter); + return counter; + } + if (counter.isCompatibleWith(oldCounter)) { + // Return the counter to reuse. + @SuppressWarnings("unchecked") + Counter compatibleCounter = (Counter) oldCounter; + return compatibleCounter; + } + throw new IllegalArgumentException( + "Counter " + counter + " duplicates incompatible counter " + + oldCounter + " in " + this); + } + + /** + * Adds a counter. Returns {@code true} if the counter was added to the set + * and false if the given counter was {@code null} or it already existed in + * the set. + * + * @param counter to register + */ + public boolean addCounter(Counter counter) { + return add(counter); + } + + /** + * Returns the Counter with the given name in this CounterSet; + * returns null if no such Counter exists. + */ + public synchronized Counter getExistingCounter(String name) { + return counters.get(name); + } + + @Override + public synchronized Iterator> iterator() { + return counters.values().iterator(); + } + + @Override + public synchronized int size() { + return counters.size(); + } + + @Override + public synchronized boolean add(Counter e) { + if (null == e) { + return false; + } + if (counters.containsKey(e.getName())) { + return false; + } + counters.put(e.getName(), e); + return true; + } + + public synchronized void merge(CounterSet that) { + for (Counter theirCounter : that) { + Counter myCounter = counters.get(theirCounter.getName()); + if (myCounter != null) { + mergeCounters(myCounter, theirCounter); + } else { + addCounter(theirCounter); + } + } + } + + private void mergeCounters(Counter mine, Counter theirCounter) { + checkArgument( + mine.isCompatibleWith(theirCounter), + "Can't merge CounterSets containing incompatible counters with the same name: " + + "%s (existing) and %s (merged)", + mine, + theirCounter); + @SuppressWarnings("unchecked") + Counter theirs = (Counter) theirCounter; + mine.merge(theirs); + } + + /** + * A nested class that supports adding additional counters into the + * enclosing CounterSet. This is useful as a mutator, hiding other + * public methods of the CounterSet. + */ + public class AddCounterMutator { + /** + * Adds the given Counter into the enclosing CounterSet. + * + *

    If a counter with the same name already exists, it will be + * reused, as long as it has the same type. + * + * @return the Counter that was reused, or added + * @throws IllegalArgumentException if a counter with the same + * name but an incompatible kind had already been added + */ + public Counter addCounter(Counter counter) { + return addOrReuseCounter(counter); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservable.java new file mode 100644 index 000000000000..fee673774570 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservable.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.common; + +/** + * An interface for things that allow observing the size in bytes of + * encoded values of type {@code T}. + * + * @param the type of the values being observed + */ +public interface ElementByteSizeObservable { + /** + * Returns whether {@link #registerByteSizeObserver} is cheap enough + * to call for every element, that is, if this + * {@code ElementByteSizeObservable} can calculate the byte size of + * the element to be coded in roughly constant time (or lazily). + */ + public boolean isRegisterByteSizeObserverCheap(T value); + + /** + * Notifies the {@code ElementByteSizeObserver} about the byte size + * of the encoded value using this {@code ElementByteSizeObservable}. + */ + public void registerByteSizeObserver(T value, + ElementByteSizeObserver observer) + throws Exception; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterable.java new file mode 100644 index 000000000000..591d2be28abc --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterable.java @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.ArrayList; +import java.util.List; +import java.util.Observer; + +/** + * An abstract class used for iterables that notify observers about size in + * bytes of their elements, as they are being iterated over. + * + * @param the type of elements returned by this iterable + * @param type type of iterator returned by this iterable + */ +public abstract class ElementByteSizeObservableIterable< + V, InputT extends ElementByteSizeObservableIterator> + implements Iterable { + private List observers = new ArrayList<>(); + + /** + * Derived classes override this method to return an iterator for this + * iterable. + */ + protected abstract InputT createIterator(); + + /** + * Sets the observer, which will observe the iterator returned in + * the next call to iterator() method. Future calls to iterator() + * won't be observed, unless an observer is set again. + */ + public void addObserver(Observer observer) { + observers.add(observer); + } + + /** + * Returns a new iterator for this iterable. If an observer was set in + * a previous call to setObserver(), it will observe the iterator returned. + */ + @Override + public InputT iterator() { + InputT iterator = createIterator(); + for (Observer observer : observers) { + iterator.addObserver(observer); + } + observers.clear(); + return iterator; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterator.java new file mode 100644 index 000000000000..c0949003017a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObservableIterator.java @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.Iterator; +import java.util.Observable; + +/** + * An abstract class used for iterators that notify observers about size in + * bytes of their elements, as they are being iterated over. The subclasses + * need to implement the standard Iterator interface and call method + * notifyValueReturned() for each element read and/or iterated over. + * + * @param value type + */ +public abstract class ElementByteSizeObservableIterator + extends Observable implements Iterator { + protected final void notifyValueReturned(long byteSize) { + setChanged(); + notifyObservers(byteSize); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObserver.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObserver.java new file mode 100644 index 000000000000..6c764d99bb15 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ElementByteSizeObserver.java @@ -0,0 +1,92 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.Observable; +import java.util.Observer; + +/** + * An observer that gets notified when additional bytes are read + * and/or used. It adds all bytes into a local counter. When the + * observer gets advanced via the next() call, it adds the total byte + * count to the specified counter, and prepares for the next element. + */ +public class ElementByteSizeObserver implements Observer { + private final Counter counter; + private boolean isLazy = false; + private long totalSize = 0; + private double scalingFactor = 1.0; + + public ElementByteSizeObserver(Counter counter) { + this.counter = counter; + } + + /** + * Sets byte counting for the current element as lazy. That is, the + * observer will get notified of the element's byte count only as + * element's pieces are being processed or iterated over. + */ + public void setLazy() { + isLazy = true; + } + + /** + * Returns whether byte counting for the current element is lazy, that is, + * whether the observer gets notified of the element's byte count only as + * element's pieces are being processed or iterated over. + */ + public boolean getIsLazy() { + return isLazy; + } + + /** + * Updates the observer with a context specified, but without an instance of + * the Observable. + */ + public void update(Object obj) { + update(null, obj); + } + + /** + * Sets a multiplier to use on observed sizes. + */ + public void setScalingFactor(double scalingFactor) { + this.scalingFactor = scalingFactor; + } + + @Override + public void update(Observable obs, Object obj) { + if (obj instanceof Long) { + totalSize += scalingFactor * (Long) obj; + } else if (obj instanceof Integer) { + totalSize += scalingFactor * (Integer) obj; + } else { + throw new AssertionError("unexpected parameter object"); + } + } + + /** + * Advances the observer to the next element. Adds the current total byte + * size to the counter, and prepares the observer for the next element. + */ + public void advance() { + counter.addValue(totalSize); + + totalSize = 0; + isLazy = false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/PeekingReiterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/PeekingReiterator.java new file mode 100644 index 000000000000..094874791825 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/PeekingReiterator.java @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import java.util.NoSuchElementException; + +/** + * A {@link Reiterator} that supports one-element lookahead during iteration. + * + * @param the type of elements returned by this iterator + */ +public final class PeekingReiterator implements Reiterator { + private T nextElement; + private boolean nextElementComputed; + private final Reiterator iterator; + + public PeekingReiterator(Reiterator iterator) { + this.iterator = checkNotNull(iterator); + } + + PeekingReiterator(PeekingReiterator it) { + this.iterator = checkNotNull(checkNotNull(it).iterator.copy()); + this.nextElement = it.nextElement; + this.nextElementComputed = it.nextElementComputed; + } + + @Override + public boolean hasNext() { + computeNext(); + return nextElementComputed; + } + + @Override + public T next() { + T result = peek(); + nextElementComputed = false; + return result; + } + + /** + * {@inheritDoc} + * + *

    If {@link #peek} is called, {@code remove} is disallowed until + * {@link #next} has been subsequently called. + */ + @Override + public void remove() { + checkState(!nextElementComputed, + "After peek(), remove() is disallowed until next() is called"); + iterator.remove(); + } + + @Override + public PeekingReiterator copy() { + return new PeekingReiterator<>(this); + } + + /** + * Returns the element that would be returned by {@link #next}, without + * actually consuming the element. + * @throws NoSuchElementException if there is no next element + */ + public T peek() { + computeNext(); + if (!nextElementComputed) { + throw new NoSuchElementException(); + } + return nextElement; + } + + private void computeNext() { + if (nextElementComputed) { + return; + } + if (!iterator.hasNext()) { + return; + } + nextElement = iterator.next(); + nextElementComputed = true; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ReflectHelpers.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ReflectHelpers.java new file mode 100644 index 000000000000..f87242f3ca2f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/ReflectHelpers.java @@ -0,0 +1,209 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.common; + +import static java.util.Arrays.asList; + +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Queues; + +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.lang.reflect.WildcardType; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Queue; + +import javax.annotation.Nullable; + +/** + * Utilities for working with with {@link Class Classes} and {@link Method Methods}. + */ +public class ReflectHelpers { + + private static final Joiner COMMA_SEPARATOR = Joiner.on(", "); + + /** A {@link Function} that turns a method into a simple method signature. */ + public static final Function METHOD_FORMATTER = new Function() { + @Override + public String apply(Method input) { + String parameterTypes = FluentIterable.from(asList(input.getParameterTypes())) + .transform(CLASS_SIMPLE_NAME) + .join(COMMA_SEPARATOR); + return String.format("%s(%s)", + input.getName(), + parameterTypes); + } + }; + + /** A {@link Function} that turns a method into the declaring class + method signature. */ + public static final Function CLASS_AND_METHOD_FORMATTER = + new Function() { + @Override + public String apply(Method input) { + return String.format("%s#%s", + CLASS_NAME.apply(input.getDeclaringClass()), + METHOD_FORMATTER.apply(input)); + } + }; + + /** A {@link Function} with returns the classes name. */ + public static final Function, String> CLASS_NAME = + new Function, String>() { + @Override + public String apply(Class input) { + return input.getName(); + } + }; + + /** A {@link Function} with returns the classes name. */ + public static final Function, String> CLASS_SIMPLE_NAME = + new Function, String>() { + @Override + public String apply(Class input) { + return input.getSimpleName(); + } + }; + + /** A {@link Function} that formats types. */ + public static final Function TYPE_SIMPLE_DESCRIPTION = + new Function() { + @Override + @Nullable + public String apply(@Nullable Type input) { + StringBuilder builder = new StringBuilder(); + format(builder, input); + return builder.toString(); + } + + private void format(StringBuilder builder, Type t) { + if (t instanceof Class) { + formatClass(builder, (Class) t); + } else if (t instanceof TypeVariable) { + formatTypeVariable(builder, (TypeVariable) t); + } else if (t instanceof WildcardType) { + formatWildcardType(builder, (WildcardType) t); + } else if (t instanceof ParameterizedType) { + formatParameterizedType(builder, (ParameterizedType) t); + } else if (t instanceof GenericArrayType) { + formatGenericArrayType(builder, (GenericArrayType) t); + } else { + builder.append(t.toString()); + } + } + + private void formatClass(StringBuilder builder, Class clazz) { + builder.append(clazz.getSimpleName()); + } + + private void formatTypeVariable(StringBuilder builder, TypeVariable t) { + builder.append(t.getName()); + } + + private void formatWildcardType(StringBuilder builder, WildcardType t) { + builder.append("?"); + for (Type lowerBound : t.getLowerBounds()) { + builder.append(" super "); + format(builder, lowerBound); + } + for (Type upperBound : t.getUpperBounds()) { + if (!Object.class.equals(upperBound)) { + builder.append(" extends "); + format(builder, upperBound); + } + } + } + + private void formatParameterizedType(StringBuilder builder, ParameterizedType t) { + format(builder, t.getRawType()); + builder.append('<'); + COMMA_SEPARATOR.appendTo(builder, + FluentIterable.from(asList(t.getActualTypeArguments())) + .transform(TYPE_SIMPLE_DESCRIPTION)); + builder.append('>'); + } + + private void formatGenericArrayType(StringBuilder builder, GenericArrayType t) { + format(builder, t.getGenericComponentType()); + builder.append("[]"); + } + }; + + /** + * Returns all interfaces of the given clazz. + * @param clazz + * @return + */ + public static FluentIterable> getClosureOfInterfaces(Class clazz) { + Preconditions.checkNotNull(clazz); + Queue> interfacesToProcess = Queues.newArrayDeque(); + Collections.addAll(interfacesToProcess, clazz.getInterfaces()); + + LinkedHashSet> interfaces = new LinkedHashSet<>(); + while (!interfacesToProcess.isEmpty()) { + Class current = interfacesToProcess.remove(); + if (interfaces.add(current)) { + Collections.addAll(interfacesToProcess, current.getInterfaces()); + } + } + return FluentIterable.from(interfaces); + } + + /** + * Returns all the methods visible from the provided interfaces. + * + * @param interfaces The interfaces to use when searching for all their methods. + * @return An iterable of {@link Method}s which interfaces expose. + */ + public static Iterable getClosureOfMethodsOnInterfaces( + Iterable> interfaces) { + return FluentIterable.from(interfaces).transformAndConcat( + new Function, Iterable>() { + @Override + public Iterable apply(Class input) { + return getClosureOfMethodsOnInterface(input); + } + }); + } + + /** + * Returns all the methods visible from {@code iface}. + * + * @param iface The interface to use when searching for all its methods. + * @return An iterable of {@link Method}s which {@code iface} exposes. + */ + public static Iterable getClosureOfMethodsOnInterface(Class iface) { + Preconditions.checkNotNull(iface); + Preconditions.checkArgument(iface.isInterface()); + ImmutableSet.Builder builder = ImmutableSet.builder(); + Queue> interfacesToProcess = Queues.newArrayDeque(); + interfacesToProcess.add(iface); + while (!interfacesToProcess.isEmpty()) { + Class current = interfacesToProcess.remove(); + builder.add(current.getMethods()); + interfacesToProcess.addAll(Arrays.asList(current.getInterfaces())); + } + return builder.build(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterable.java new file mode 100644 index 000000000000..01c5775a3f88 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterable.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +/** + * An {@link Iterable} that returns {@link Reiterator} iterators. + * + * @param the type of elements returned by the iterator + */ +public interface Reiterable extends Iterable { + @Override + public Reiterator iterator(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterator.java new file mode 100644 index 000000000000..dd8036da02a0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/Reiterator.java @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import java.util.Iterator; + +/** + * An {@link Iterator} with the ability to copy its iteration state. + * + * @param the type of elements returned by this iterator + */ +public interface Reiterator extends Iterator { + /** + * Returns a copy of the current {@link Reiterator}. The copy's iteration + * state is logically independent of the current iterator; each may be + * advanced without affecting the other. + * + *

    The returned {@code Reiterator} is not guaranteed to return + * referentially identical iteration results as the original + * {@link Reiterator}, although {@link Object#equals} will typically return + * true for the corresponding elements of each if the original source is + * logically immutable. + */ + public Reiterator copy(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/package-info.java new file mode 100644 index 000000000000..7fb16c58fde0 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities shared by multiple PipelineRunner implementations. **/ +package com.google.cloud.dataflow.sdk.util.common; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSampler.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSampler.java new file mode 100644 index 000000000000..00d3b3b904ca --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/StateSampler.java @@ -0,0 +1,365 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common.worker; + +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * A StateSampler object may be used to obtain an approximate + * breakdown of the time spent by an execution context in various + * states, as a fraction of the total time. The sampling is taken at + * regular intervals, with adjustment for scheduling delay. + */ +@ThreadSafe +public class StateSampler implements AutoCloseable { + + /** Different kinds of states. */ + public enum StateKind { + /** IO, user code, etc. */ + USER, + /** Reading/writing from/to shuffle service, etc. */ + FRAMEWORK + } + + public static final long DEFAULT_SAMPLING_PERIOD_MS = 200; + + private final String prefix; + private final CounterSet.AddCounterMutator counterSetMutator; + + /** Array of counters indexed by their state. */ + private ArrayList> countersByState = new ArrayList<>(); + + /** Map of state name to state. */ + private Map statesByName = new HashMap<>(); + + /** Map of state id to kind. */ + private Map kindsByState = new HashMap<>(); + + /** The current state. */ + private volatile int currentState; + + /** Special value of {@code currentState} that means we do not sample. */ + public static final int DO_NOT_SAMPLE = -1; + + /** + * A counter that increments with each state transition. May be used + * to detect a context being stuck in a state for some amount of + * time. + */ + private volatile long stateTransitionCount; + + /** + * The timestamp (in nanoseconds) corresponding to the last time the + * state was sampled (and recorded). + */ + private long stateTimestampNs = 0; + + /** Using a fixed number of timers for all StateSampler objects. */ + private static final int NUM_EXECUTOR_THREADS = 16; + + private static final ScheduledExecutorService executorService = + Executors.newScheduledThreadPool(NUM_EXECUTOR_THREADS, + new ThreadFactoryBuilder().setDaemon(true).build()); + + private Random rand = new Random(); + + private List callbacks = new ArrayList<>(); + + private ScheduledFuture invocationTriggerFuture = null; + + private ScheduledFuture invocationFuture = null; + + /** + * Constructs a new {@link StateSampler} that can be used to obtain + * an approximate breakdown of the time spent by an execution + * context in various states, as a fraction of the total time. + * + * @param prefix the prefix of the counter names for the states + * @param counterSetMutator the {@link CounterSet.AddCounterMutator} + * used to create a counter for each distinct state + * @param samplingPeriodMs the sampling period in milliseconds + */ + public StateSampler(String prefix, + CounterSet.AddCounterMutator counterSetMutator, + final long samplingPeriodMs) { + this.prefix = prefix; + this.counterSetMutator = counterSetMutator; + currentState = DO_NOT_SAMPLE; + scheduleSampling(samplingPeriodMs); + } + + /** + * Constructs a new {@link StateSampler} that can be used to obtain + * an approximate breakdown of the time spent by an execution + * context in various states, as a fraction of the total time. + * + * @param prefix the prefix of the counter names for the states + * @param counterSetMutator the {@link CounterSet.AddCounterMutator} + * used to create a counter for each distinct state + */ + public StateSampler(String prefix, + CounterSet.AddCounterMutator counterSetMutator) { + this(prefix, counterSetMutator, DEFAULT_SAMPLING_PERIOD_MS); + } + + /** + * Called by the constructor to schedule sampling at the given period. + * + *

    Should not be overridden by sub-classes unless they want to change + * or disable the automatic sampling of state. + */ + protected void scheduleSampling(final long samplingPeriodMs) { + // Here "stratified sampling" is used, which makes sure that there's 1 uniformly chosen sampled + // point in every bucket of samplingPeriodMs, to prevent pathological behavior in case some + // states happen to occur at a similar period. + // The current implementation uses a fixed-rate timer with a period samplingPeriodMs as a + // trampoline to a one-shot random timer which fires with a random delay within + // samplingPeriodMs. + stateTimestampNs = System.nanoTime(); + invocationTriggerFuture = + executorService.scheduleAtFixedRate( + new Runnable() { + @Override + public void run() { + long delay = rand.nextInt((int) samplingPeriodMs); + synchronized (StateSampler.this) { + if (invocationFuture != null) { + invocationFuture.cancel(false); + } + invocationFuture = + executorService.schedule( + new Runnable() { + @Override + public void run() { + StateSampler.this.run(); + } + }, + delay, + TimeUnit.MILLISECONDS); + } + } + }, + 0, + samplingPeriodMs, + TimeUnit.MILLISECONDS); + } + + public synchronized void run() { + long startTimestampNs = System.nanoTime(); + int state = currentState; + if (state != DO_NOT_SAMPLE) { + StateKind kind = null; + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(startTimestampNs - stateTimestampNs); + kind = kindsByState.get(state); + countersByState.get(state).addValue(elapsedMs); + // Invoke all callbacks. + for (SamplingCallback c : callbacks) { + c.run(state, kind, elapsedMs); + } + } + stateTimestampNs = startTimestampNs; + } + + @Override + public synchronized void close() { + currentState = DO_NOT_SAMPLE; + if (invocationTriggerFuture != null) { + invocationTriggerFuture.cancel(false); + } + if (invocationFuture != null) { + invocationFuture.cancel(false); + } + } + + /** + * Returns the state associated with a name; creating a new state if + * necessary. Using states instead of state names during state + * transitions is done for efficiency. + * + * @name the name for the state + * @kind kind of the state, see {#code StateKind} + * @return the state associated with the state name + */ + public int stateForName(String name, StateKind kind) { + if (name.isEmpty()) { + return DO_NOT_SAMPLE; + } + + synchronized (this) { + Integer state = statesByName.get(name); + if (state == null) { + String counterName = prefix + name + "-msecs"; + Counter counter = counterSetMutator.addCounter( + Counter.longs(counterName, Counter.AggregationKind.SUM)); + state = countersByState.size(); + statesByName.put(name, state); + countersByState.add(counter); + kindsByState.put(state, kind); + } + StateKind originalKind = kindsByState.get(state); + if (originalKind != kind) { + throw new IllegalArgumentException( + "for state named " + name + + ", requested kind " + kind + " different from the original kind " + originalKind); + } + return state; + } + } + + /** + * An internal class for representing StateSampler information + * typically used for debugging. + */ + public static class StateSamplerInfo { + public final String state; + public final Long transitionCount; + public final Long stateDurationMillis; + + public StateSamplerInfo(String state, Long transitionCount, + Long stateDurationMillis) { + this.state = state; + this.transitionCount = transitionCount; + this.stateDurationMillis = stateDurationMillis; + } + } + + /** + * Returns information about the current state of this state sampler + * into a {@link StateSamplerInfo} object, or null if sampling is + * not turned on. + * + * @return information about this state sampler or null if sampling is off + */ + public synchronized StateSamplerInfo getInfo() { + return currentState == DO_NOT_SAMPLE ? null + : new StateSamplerInfo(countersByState.get(currentState).getName(), + stateTransitionCount, null); + } + + /** + * Returns the current state of this state sampler. + */ + public int getCurrentState() { + return currentState; + } + + /** + * Sets the current thread state. + * + * @param state the new state to transition to + * @return the previous state + */ + public int setState(int state) { + // Updates to stateTransitionCount are always done by the same + // thread, making the non-atomic volatile update below safe. The + // count is updated first to avoid incorrectly attributing + // stuckness occuring in an old state to the new state. + long previousStateTransitionCount = this.stateTransitionCount; + this.stateTransitionCount = previousStateTransitionCount + 1; + int previousState = currentState; + currentState = state; + return previousState; + } + + /** + * Sets the current thread state. + * + * @param name the name of the new state to transition to + * @param kind kind of the new state + * @return the previous state + */ + public int setState(String name, StateKind kind) { + return setState(stateForName(name, kind)); + } + + /** + * Returns an AutoCloseable {@link ScopedState} that will perform a + * state transition to the given state, and will automatically reset + * the state to the prior state upon closing. + * + * @param state the new state to transition to + * @return a {@link ScopedState} that automatically resets the state + * to the prior state + */ + public ScopedState scopedState(int state) { + return new ScopedState(this, setState(state)); + } + + /** + * Add a callback to the sampler. + * The callbacks will be executed sequentially upon {@link StateSampler#run}. + */ + public synchronized void addSamplingCallback(SamplingCallback callback) { + callbacks.add(callback); + } + + /** Get the counter prefix associated with this sampler. */ + public String getPrefix() { + return prefix; + } + + /** + * A nested class that is used to account for states and state + * transitions based on lexical scopes. + * + *

    Thread-safe. + */ + public class ScopedState implements AutoCloseable { + private StateSampler sampler; + private int previousState; + + private ScopedState(StateSampler sampler, int previousState) { + this.sampler = sampler; + this.previousState = previousState; + } + + @Override + public void close() { + sampler.setState(previousState); + } + } + + /** + * Callbacks which supposed to be called sequentially upon {@link StateSampler#run}. + * They should be registered via {@link #addSamplingCallback}. + */ + public static interface SamplingCallback { + /** + * The entrance method of the callback, it is called in {@link StateSampler#run}, + * once per sample. This method should be thread safe. + * + * @param state The state of the StateSampler at the time of sample. + * @param kind The kind associated with the state, see {@link StateKind}. + * @param elapsedMs Milliseconds since last sample. + */ + public void run(int state, StateKind kind, long elapsedMs); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/package-info.java new file mode 100644 index 000000000000..c3da9ed8b82a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/common/worker/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities used to implement the harness that runs user code. **/ +package com.google.cloud.dataflow.sdk.util.common.worker; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPath.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPath.java new file mode 100644 index 000000000000..f72ba4c2bc18 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPath.java @@ -0,0 +1,619 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsfs; + +import com.google.api.services.storage.model.StorageObject; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.FileSystem; +import java.nio.file.LinkOption; +import java.nio.file.Path; +import java.nio.file.WatchEvent; +import java.nio.file.WatchKey; +import java.nio.file.WatchService; +import java.util.Iterator; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Implements the Java NIO {@link Path} API for Google Cloud Storage paths. + * + *

    GcsPath uses a slash ('/') as a directory separator. Below is + * a summary of how slashes are treated: + *

      + *
    • A GCS bucket may not contain a slash. An object may contain zero or + * more slashes. + *
    • A trailing slash always indicates a directory, which is compliant + * with POSIX.1-2008. + *
    • Slashes separate components of a path. Empty components are allowed, + * these are represented as repeated slashes. An empty component always + * refers to a directory, and always ends in a slash. + *
    • {@link #getParent()}} always returns a path ending in a slash, as the + * parent of a GcsPath is always a directory. + *
    • Use {@link #resolve(String)} to append elements to a GcsPath -- this + * applies the rules consistently and is highly recommended over any + * custom string concatenation. + *
    + * + *

    GcsPath treats all GCS objects and buckets as belonging to the same + * filesystem, so the root of a GcsPath is the GcsPath bucket="", object="". + * + *

    Relative paths are not associated with any bucket. This matches common + * treatment of Path in which relative paths can be constructed from one + * filesystem and appended to another filesystem. + * + * @see Java Tutorials: Path Operations + */ +public class GcsPath implements Path { + + public static final String SCHEME = "gs"; + + /** + * Creates a GcsPath from a URI. + * + *

    The URI must be in the form {@code gs://[bucket]/[path]}, and may not + * contain a port, user info, a query, or a fragment. + */ + public static GcsPath fromUri(URI uri) { + Preconditions.checkArgument(uri.getScheme().equalsIgnoreCase(SCHEME), + "URI: %s is not a GCS URI", uri); + Preconditions.checkArgument(uri.getPort() == -1, + "GCS URI may not specify port: %s (%i)", uri, uri.getPort()); + Preconditions.checkArgument( + Strings.isNullOrEmpty(uri.getUserInfo()), + "GCS URI may not specify userInfo: %s (%s)", uri, uri.getUserInfo()); + Preconditions.checkArgument( + Strings.isNullOrEmpty(uri.getQuery()), + "GCS URI may not specify query: %s (%s)", uri, uri.getQuery()); + Preconditions.checkArgument( + Strings.isNullOrEmpty(uri.getFragment()), + "GCS URI may not specify fragment: %s (%s)", uri, uri.getFragment()); + + return fromUri(uri.toString()); + } + + /** + * Pattern that is used to parse a GCS URL. + * + *

    This is used to separate the components. Verification is handled + * separately. + */ + public static final Pattern GCS_URI = + Pattern.compile("(?[^:]+)://(?[^/]+)(/(?.*))?"); + + /** + * Creates a GcsPath from a URI in string form. + * + *

    This does not use URI parsing, which means it may accept patterns that + * the URI parser would not accept. + */ + public static GcsPath fromUri(String uri) { + Matcher m = GCS_URI.matcher(uri); + Preconditions.checkArgument(m.matches(), "Invalid GCS URI: %s", uri); + + Preconditions.checkArgument(m.group("SCHEME").equalsIgnoreCase(SCHEME), + "URI: %s is not a GCS URI", uri); + return new GcsPath(null, m.group("BUCKET"), m.group("OBJECT")); + } + + /** + * Pattern that is used to parse a GCS resource name. + */ + private static final Pattern GCS_RESOURCE_NAME = + Pattern.compile("storage.googleapis.com/(?[^/]+)(/(?.*))?"); + + /** + * Creates a GcsPath from a OnePlatform resource name in string form. + */ + public static GcsPath fromResourceName(String name) { + Matcher m = GCS_RESOURCE_NAME.matcher(name); + Preconditions.checkArgument(m.matches(), "Invalid GCS resource name: %s", name); + + return new GcsPath(null, m.group("BUCKET"), m.group("OBJECT")); + } + + /** + * Creates a GcsPath from a {@linkplain StorageObject}. + */ + public static GcsPath fromObject(StorageObject object) { + return new GcsPath(null, object.getBucket(), object.getName()); + } + + /** + * Creates a GcsPath from bucket and object components. + * + *

    A GcsPath without a bucket name is treated as a relative path, which + * is a path component with no linkage to the root element. This is similar + * to a Unix path that does not begin with the root marker (a slash). + * GCS has different naming constraints and APIs for working with buckets and + * objects, so these two concepts are kept separate to avoid accidental + * attempts to treat objects as buckets, or vice versa, as much as possible. + * + *

    A GcsPath without an object name is a bucket reference. + * A bucket is always a directory, which could be used to lookup or add + * files to a bucket, but could not be opened as a file. + * + *

    A GcsPath containing neither bucket or object names is treated as + * the root of the GCS filesystem. A listing on the root element would return + * the buckets available to the user. + * + *

    If {@code null} is passed as either parameter, it is converted to an + * empty string internally for consistency. There is no distinction between + * an empty string and a {@code null}, as neither are allowed by GCS. + * + * @param bucket a GCS bucket name, or none ({@code null} or an empty string) + * if the object is not associated with a bucket + * (e.g. relative paths or the root node). + * @param object a GCS object path, or none ({@code null} or an empty string) + * for no object. + */ + public static GcsPath fromComponents(@Nullable String bucket, + @Nullable String object) { + return new GcsPath(null, bucket, object); + } + + @Nullable + private FileSystem fs; + @Nonnull + private final String bucket; + @Nonnull + private final String object; + + /** + * Constructs a GcsPath. + * + * @param fs the associated FileSystem, if any + * @param bucket the associated bucket, or none ({@code null} or an empty + * string) for a relative path component + * @param object the object, which is a fully-qualified object name if bucket + * was also provided, or none ({@code null} or an empty string) + * for no object + * @throws java.lang.IllegalArgumentException if the bucket of object names + * are invalid. + */ + public GcsPath(@Nullable FileSystem fs, + @Nullable String bucket, + @Nullable String object) { + if (bucket == null) { + bucket = ""; + } + Preconditions.checkArgument(!bucket.contains("/"), + "GCS bucket may not contain a slash"); + Preconditions + .checkArgument(bucket.isEmpty() + || bucket.matches("[a-z0-9][-_a-z0-9.]+[a-z0-9]"), + "GCS bucket names must contain only lowercase letters, numbers, " + + "dashes (-), underscores (_), and dots (.). Bucket names " + + "must start and end with a number or letter. " + + "See https://developers.google.com/storage/docs/bucketnaming " + + "for more details. Bucket name: " + bucket); + + if (object == null) { + object = ""; + } + Preconditions.checkArgument( + object.indexOf('\n') < 0 && object.indexOf('\r') < 0, + "GCS object names must not contain Carriage Return or " + + "Line Feed characters."); + + this.fs = fs; + this.bucket = bucket; + this.object = object; + } + + /** + * Returns the bucket name associated with this GCS path, or an empty string + * if this is a relative path component. + */ + public String getBucket() { + return bucket; + } + + /** + * Returns the object name associated with this GCS path, or an empty string + * if no object is specified. + */ + public String getObject() { + return object; + } + + public void setFileSystem(FileSystem fs) { + this.fs = fs; + } + + @Override + public FileSystem getFileSystem() { + return fs; + } + + // Absolute paths are those that have a bucket and the root path. + @Override + public boolean isAbsolute() { + return !bucket.isEmpty() || object.isEmpty(); + } + + @Override + public GcsPath getRoot() { + return new GcsPath(fs, "", ""); + } + + @Override + public GcsPath getFileName() { + throw new UnsupportedOperationException(); + } + + /** + * Returns the parent path, or {@code null} if this path does not + * have a parent. + * + *

    Returns a path that ends in '/', as the parent path always refers to + * a directory. + */ + @Override + public GcsPath getParent() { + if (bucket.isEmpty() && object.isEmpty()) { + // The root path has no parent, by definition. + return null; + } + + if (object.isEmpty()) { + // A GCS bucket. All buckets come from a common root. + return getRoot(); + } + + // Skip last character, in case it is a trailing slash. + int i = object.lastIndexOf('/', object.length() - 2); + if (i <= 0) { + if (bucket.isEmpty()) { + // Relative paths are not attached to the root node. + return null; + } + return new GcsPath(fs, bucket, ""); + } + + // Retain trailing slash. + return new GcsPath(fs, bucket, object.substring(0, i + 1)); + } + + @Override + public int getNameCount() { + int count = bucket.isEmpty() ? 0 : 1; + if (object.isEmpty()) { + return count; + } + + // Add another for each separator found. + int index = -1; + while ((index = object.indexOf('/', index + 1)) != -1) { + count++; + } + + return object.endsWith("/") ? count : count + 1; + } + + @Override + public GcsPath getName(int count) { + Preconditions.checkArgument(count >= 0); + + Iterator iterator = iterator(); + for (int i = 0; i < count; ++i) { + Preconditions.checkArgument(iterator.hasNext()); + iterator.next(); + } + + Preconditions.checkArgument(iterator.hasNext()); + return (GcsPath) iterator.next(); + } + + @Override + public GcsPath subpath(int beginIndex, int endIndex) { + Preconditions.checkArgument(beginIndex >= 0); + Preconditions.checkArgument(endIndex > beginIndex); + + Iterator iterator = iterator(); + for (int i = 0; i < beginIndex; ++i) { + Preconditions.checkArgument(iterator.hasNext()); + iterator.next(); + } + + GcsPath path = null; + while (beginIndex < endIndex) { + Preconditions.checkArgument(iterator.hasNext()); + if (path == null) { + path = (GcsPath) iterator.next(); + } else { + path = path.resolve(iterator.next()); + } + ++beginIndex; + } + + return path; + } + + @Override + public boolean startsWith(Path other) { + if (other instanceof GcsPath) { + GcsPath gcsPath = (GcsPath) other; + return startsWith(gcsPath.bucketAndObject()); + } else { + return startsWith(other.toString()); + } + } + + @Override + public boolean startsWith(String prefix) { + return bucketAndObject().startsWith(prefix); + } + + @Override + public boolean endsWith(Path other) { + if (other instanceof GcsPath) { + GcsPath gcsPath = (GcsPath) other; + return endsWith(gcsPath.bucketAndObject()); + } else { + return endsWith(other.toString()); + } + } + + @Override + public boolean endsWith(String suffix) { + return bucketAndObject().endsWith(suffix); + } + + // TODO: support "." and ".." path components? + @Override + public GcsPath normalize() { + return this; + } + + @Override + public GcsPath resolve(Path other) { + if (other instanceof GcsPath) { + GcsPath path = (GcsPath) other; + if (path.isAbsolute()) { + return path; + } else { + return resolve(path.getObject()); + } + } else { + return resolve(other.toString()); + } + } + + @Override + public GcsPath resolve(String other) { + if (bucket.isEmpty() && object.isEmpty()) { + // Resolve on a root path is equivalent to looking up a bucket and object. + other = SCHEME + "://" + other; + } + + if (other.startsWith(SCHEME + "://")) { + GcsPath path = GcsPath.fromUri(other); + path.setFileSystem(getFileSystem()); + return path; + } + + if (other.isEmpty()) { + // An empty component MUST refer to a directory. + other = "/"; + } + + if (object.isEmpty()) { + return new GcsPath(fs, bucket, other); + } else if (object.endsWith("/")) { + return new GcsPath(fs, bucket, object + other); + } else { + return new GcsPath(fs, bucket, object + "/" + other); + } + } + + @Override + public Path resolveSibling(Path other) { + throw new UnsupportedOperationException(); + } + + @Override + public Path resolveSibling(String other) { + throw new UnsupportedOperationException(); + } + + @Override + public Path relativize(Path other) { + throw new UnsupportedOperationException(); + } + + @Override + public GcsPath toAbsolutePath() { + return this; + } + + @Override + public GcsPath toRealPath(LinkOption... options) throws IOException { + return this; + } + + @Override + public File toFile() { + throw new UnsupportedOperationException(); + } + + @Override + public WatchKey register(WatchService watcher, WatchEvent.Kind[] events, + WatchEvent.Modifier... modifiers) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public WatchKey register(WatchService watcher, WatchEvent.Kind... events) + throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() { + return new NameIterator(fs, !bucket.isEmpty(), bucketAndObject()); + } + + private static class NameIterator implements Iterator { + private final FileSystem fs; + private boolean fullPath; + private String name; + + NameIterator(FileSystem fs, boolean fullPath, String name) { + this.fs = fs; + this.fullPath = fullPath; + this.name = name; + } + + @Override + public boolean hasNext() { + return !Strings.isNullOrEmpty(name); + } + + @Override + public GcsPath next() { + int i = name.indexOf('/'); + String component; + if (i >= 0) { + component = name.substring(0, i); + name = name.substring(i + 1); + } else { + component = name; + name = null; + } + if (fullPath) { + fullPath = false; + return new GcsPath(fs, component, ""); + } else { + // Relative paths have no bucket. + return new GcsPath(fs, "", component); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + + @Override + public int compareTo(Path other) { + if (!(other instanceof GcsPath)) { + throw new ClassCastException(); + } + + GcsPath path = (GcsPath) other; + int b = bucket.compareTo(path.bucket); + if (b != 0) { + return b; + } + + // Compare a component at a time, so that the separator char doesn't + // get compared against component contents. Eg, "a/b" < "a-1/b". + Iterator left = iterator(); + Iterator right = path.iterator(); + + while (left.hasNext() && right.hasNext()) { + String leftStr = left.next().toString(); + String rightStr = right.next().toString(); + int c = leftStr.compareTo(rightStr); + if (c != 0) { + return c; + } + } + + if (!left.hasNext() && !right.hasNext()) { + return 0; + } else { + return left.hasNext() ? 1 : -1; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + GcsPath paths = (GcsPath) o; + return bucket.equals(paths.bucket) && object.equals(paths.object); + } + + @Override + public int hashCode() { + int result = bucket.hashCode(); + result = 31 * result + object.hashCode(); + return result; + } + + @Override + public String toString() { + if (!isAbsolute()) { + return object; + } + StringBuilder sb = new StringBuilder(); + sb.append(SCHEME) + .append("://"); + if (!bucket.isEmpty()) { + sb.append(bucket) + .append('/'); + } + sb.append(object); + return sb.toString(); + } + + // TODO: Consider using resource names for all GCS paths used by the SDK. + public String toResourceName() { + StringBuilder sb = new StringBuilder(); + sb.append("storage.googleapis.com/"); + if (!bucket.isEmpty()) { + sb.append(bucket).append('/'); + } + sb.append(object); + return sb.toString(); + } + + @Override + public URI toUri() { + try { + return new URI(SCHEME, "//" + bucketAndObject(), null); + } catch (URISyntaxException e) { + throw new RuntimeException("Unable to create URI for GCS path " + this); + } + } + + private String bucketAndObject() { + if (bucket.isEmpty()) { + return object; + } else { + return bucket + "/" + object; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/package-info.java new file mode 100644 index 000000000000..2f57938e2007 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/gcsfs/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities used to interact with Google Cloud Storage. **/ +package com.google.cloud.dataflow.sdk.util.gcsfs; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/package-info.java new file mode 100644 index 000000000000..c92adab9b6eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Defines utilities used by the Dataflow SDK. **/ +package com.google.cloud.dataflow.sdk.util; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/AccumulatorCombiningState.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/AccumulatorCombiningState.java new file mode 100644 index 000000000000..0d78b13bad64 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/AccumulatorCombiningState.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; + +/** + * State for a single value that is managed by a {@link CombineFn}. This is an internal extension + * to {@link CombiningState} that includes the {@code AccumT} type. + * + * @param the type of values added to the state + * @param the type of accumulator + * @param the type of value extracted from the state + */ +public interface AccumulatorCombiningState + extends CombiningState { + + /** + * Read the merged accumulator for this combining value. It is implied that reading the + * state involes reading the accumulator, so {@link #readLater} is sufficient to prefetch for + * this. + */ + AccumT getAccum(); + + /** + * Add an accumulator to this combining value. Depending on implementation this may immediately + * merge it with the previous accumulator, or may buffer this accumulator for a future merge. + */ + void addAccum(AccumT accum); + + /** + * Merge the given accumulators according to the underlying combiner. + */ + AccumT mergeAccumulators(Iterable accumulators); + + @Override + AccumulatorCombiningState readLater(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/BagState.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/BagState.java new file mode 100644 index 000000000000..363e48079798 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/BagState.java @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +/** + * State containing a bag values. Items can be added to the bag and the contents read out. + * + * @param The type of elements in the bag. + */ +public interface BagState extends CombiningState> { + @Override + BagState readLater(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CombiningState.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CombiningState.java new file mode 100644 index 000000000000..673bebbfb42a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CombiningState.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; + +/** + * State that combines multiple {@code InputT} values using a {@link CombineFn} to produce a single + * {@code OutputT} value. + * + * @param the type of values added to the state + * @param the type of value extracted from the state + */ +public interface CombiningState extends ReadableState, State { + /** + * Add a value to the buffer. + */ + void add(InputT value); + + /** + * Return true if this state is empty. + */ + ReadableState isEmpty(); + + @Override + CombiningState readLater(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java new file mode 100644 index 000000000000..3683b74d9adf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java @@ -0,0 +1,454 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.CombineFnUtil; +import com.google.cloud.dataflow.sdk.util.state.InMemoryStateInternals.InMemoryState; +import com.google.cloud.dataflow.sdk.util.state.StateTag.StateBinder; +import com.google.common.base.Optional; +import com.google.common.collect.Iterables; + +import org.joda.time.Instant; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * {@link StateInternals} built on top of an underlying {@link StateTable} that contains instances + * of {@link InMemoryState}. Whenever state that exists in the underlying {@link StateTable} is + * accessed, an independent copy will be created within this table. + */ +public class CopyOnAccessInMemoryStateInternals implements StateInternals { + private final K key; + private final CopyOnAccessInMemoryStateTable table; + + /** + * Creates a new {@link CopyOnAccessInMemoryStateInternals} with the underlying (possibly null) + * StateInternals. + */ + public static CopyOnAccessInMemoryStateInternals withUnderlying( + K key, @Nullable CopyOnAccessInMemoryStateInternals underlying) { + return new CopyOnAccessInMemoryStateInternals(key, underlying); + } + + private CopyOnAccessInMemoryStateInternals( + K key, CopyOnAccessInMemoryStateInternals underlying) { + this.key = key; + table = + new CopyOnAccessInMemoryStateTable(key, underlying == null ? null : underlying.table); + } + + /** + * Ensures this {@link CopyOnAccessInMemoryStateInternals} is complete. Other copies of state for + * the same Step and Key may be discarded after invoking this method. + * + *

    For each {@link StateNamespace}, for each {@link StateTag address} in that namespace that + * has not been bound in this {@link CopyOnAccessInMemoryStateInternals}, put a reference to that + * state within this {@link StateInternals}. + * + *

    Additionally, stores the {@link WatermarkHoldState} with the earliest time bound in the + * state table after the commit is completed, enabling calls to + * {@link #getEarliestWatermarkHold()}. + * + * @return this table + */ + public CopyOnAccessInMemoryStateInternals commit() { + table.commit(); + return this; + } + + /** + * Gets the earliest Watermark Hold present in this table. + * + *

    Must be called after this state has been committed. Will throw an + * {@link IllegalStateException} if the state has not been committed. + */ + public Instant getEarliestWatermarkHold() { + // After commit, the watermark hold is always present, but may be + // BoundedWindow#TIMESTAMP_MAX_VALUE if there is no hold set. + checkState( + table.earliestWatermarkHold.isPresent(), + "Can't get the earliest watermark hold in a %s before it is committed", + getClass().getSimpleName()); + return table.earliestWatermarkHold.get(); + } + + @Override + public T state(StateNamespace namespace, StateTag address) { + return state(namespace, address, StateContexts.nullContext()); + } + + @Override + public T state( + StateNamespace namespace, StateTag address, StateContext c) { + return table.get(namespace, address, c); + } + + @Override + public K getKey() { + return key; + } + + public boolean isEmpty() { + return Iterables.isEmpty(table.values()); + } + + /** + * A {@link StateTable} that, when a value is retrieved with + * {@link StateTable#get(StateNamespace, StateTag)}, first attempts to obtain a copy of existing + * {@link State} from an underlying {@link StateTable}. + */ + private static class CopyOnAccessInMemoryStateTable extends StateTable { + private final K key; + private Optional> underlying; + + /** + * The StateBinderFactory currently in use by this {@link CopyOnAccessInMemoryStateTable}. + * + *

    There are three {@link StateBinderFactory} implementations used by the {@link + * CopyOnAccessInMemoryStateTable}. + *

      + *
    • The default {@link StateBinderFactory} is a {@link CopyOnBindBinderFactory}, allowing + * the table to copy any existing {@link State} values to this {@link StateTable} from the + * underlying table when accessed, at which point mutations will not be visible to the + * underlying table - effectively a "Copy by Value" binder.
    • + *
    • During the execution of the {@link #commit()} method, this is a + * {@link ReadThroughBinderFactory}, which copies the references to the existing + * {@link State} objects to this {@link StateTable}.
    • + *
    • After the execution of the {@link #commit()} method, this is an + * instance of {@link InMemoryStateBinderFactory}, which constructs new instances of state + * when a {@link StateTag} is bound.
    • + *
    + */ + private StateBinderFactory binderFactory; + + /** + * The earliest watermark hold in this table. + */ + private Optional earliestWatermarkHold; + + public CopyOnAccessInMemoryStateTable(K key, StateTable underlying) { + this.key = key; + this.underlying = Optional.fromNullable(underlying); + binderFactory = new CopyOnBindBinderFactory<>(key, this.underlying); + earliestWatermarkHold = Optional.absent(); + } + + /** + * Copies all values in the underlying table to this table, then discards the underlying table. + * + *

    If there is an underlying table, this replaces the existing + * {@link CopyOnBindBinderFactory} with a {@link ReadThroughBinderFactory}, then reads all of + * the values in the existing table, binding the state values to this table. The old StateTable + * should be discarded after the call to {@link #commit()}. + * + *

    After copying all of the existing values, replace the binder factory with an instance of + * {@link InMemoryStateBinderFactory} to construct new values, since all existing values + * are bound in this {@link StateTable table} and this table represents the canonical state. + */ + private void commit() { + Instant earliestHold = getEarliestWatermarkHold(); + if (underlying.isPresent()) { + ReadThroughBinderFactory readThroughBinder = + new ReadThroughBinderFactory<>(underlying.get()); + binderFactory = readThroughBinder; + Instant earliestUnderlyingHold = readThroughBinder.readThroughAndGetEarliestHold(this); + if (earliestUnderlyingHold.isBefore(earliestHold)) { + earliestHold = earliestUnderlyingHold; + } + } + earliestWatermarkHold = Optional.of(earliestHold); + clearEmpty(); + binderFactory = new InMemoryStateBinderFactory<>(key); + underlying = Optional.absent(); + } + + /** + * Get the earliest watermark hold in this table. Ignores the contents of any underlying table. + */ + private Instant getEarliestWatermarkHold() { + Instant earliest = BoundedWindow.TIMESTAMP_MAX_VALUE; + for (State existingState : this.values()) { + if (existingState instanceof WatermarkHoldState) { + Instant hold = ((WatermarkHoldState) existingState).read(); + if (hold != null && hold.isBefore(earliest)) { + earliest = hold; + } + } + } + return earliest; + } + + /** + * Clear all empty {@link StateNamespace StateNamespaces} from this table. If all states are + * empty, clear the entire table. + * + *

    Because {@link InMemoryState} is not removed from the {@link StateTable} after it is + * cleared, in case contents are modified after being cleared, the table must be explicitly + * checked to ensure that it contains state and removed if not (otherwise we may never use + * the table again). + */ + private void clearEmpty() { + Collection emptyNamespaces = new HashSet<>(this.getNamespacesInUse()); + for (StateNamespace namespace : this.getNamespacesInUse()) { + for (State existingState : this.getTagsInUse(namespace).values()) { + if (!((InMemoryState) existingState).isCleared()) { + emptyNamespaces.remove(namespace); + break; + } + } + } + for (StateNamespace empty : emptyNamespaces) { + this.clearNamespace(empty); + } + } + + @Override + protected StateBinder binderForNamespace(final StateNamespace namespace, StateContext c) { + return binderFactory.forNamespace(namespace, c); + } + + private static interface StateBinderFactory { + StateBinder forNamespace(StateNamespace namespace, StateContext c); + } + + /** + * {@link StateBinderFactory} that creates a copy of any existing state when the state is bound. + */ + private static class CopyOnBindBinderFactory implements StateBinderFactory { + private final K key; + private final Optional> underlying; + + public CopyOnBindBinderFactory(K key, Optional> underlying) { + this.key = key; + this.underlying = underlying; + } + + private boolean containedInUnderlying(StateNamespace namespace, StateTag tag) { + return underlying.isPresent() && underlying.get().isNamespaceInUse(namespace) + && underlying.get().getTagsInUse(namespace).containsKey(tag); + } + + @Override + public StateBinder forNamespace(final StateNamespace namespace, final StateContext c) { + return new StateBinder() { + @Override + public WatermarkHoldState bindWatermark( + StateTag> address, + OutputTimeFn outputTimeFn) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState> existingState = + (InMemoryStateInternals.InMemoryState>) + underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryStateInternals.InMemoryWatermarkHold<>( + outputTimeFn); + } + } + + @Override + public ValueState bindValue( + StateTag> address, Coder coder) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState> existingState = + (InMemoryStateInternals.InMemoryState>) + underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryStateInternals.InMemoryValue<>(); + } + } + + @Override + public AccumulatorCombiningState + bindCombiningValue( + StateTag> address, + Coder accumCoder, CombineFn combineFn) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState> + existingState = ( + InMemoryStateInternals + .InMemoryState>) underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryStateInternals.InMemoryCombiningValue<>( + key, combineFn.asKeyedFn()); + } + } + + @Override + public BagState bindBag( + StateTag> address, Coder elemCoder) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState> existingState = + (InMemoryStateInternals.InMemoryState>) + underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryStateInternals.InMemoryBag<>(); + } + } + + @Override + public AccumulatorCombiningState + bindKeyedCombiningValue( + StateTag> address, + Coder accumCoder, + KeyedCombineFn combineFn) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState> + existingState = ( + InMemoryStateInternals + .InMemoryState>) underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryStateInternals.InMemoryCombiningValue<>(key, combineFn); + } + } + + @Override + public AccumulatorCombiningState + bindKeyedCombiningValueWithContext( + StateTag> address, + Coder accumCoder, + KeyedCombineFnWithContext combineFn) { + return bindKeyedCombiningValue( + address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + } + }; + } + } + + /** + * {@link StateBinderFactory} that reads directly from the underlying table. Used during calls + * to {@link CopyOnAccessInMemoryStateTable#commit()} to read all values from + * the underlying table. + */ + private static class ReadThroughBinderFactory implements StateBinderFactory { + private final StateTable underlying; + + public ReadThroughBinderFactory(StateTable underlying) { + this.underlying = underlying; + } + + public Instant readThroughAndGetEarliestHold(StateTable readTo) { + Instant earliestHold = BoundedWindow.TIMESTAMP_MAX_VALUE; + for (StateNamespace namespace : underlying.getNamespacesInUse()) { + for (Map.Entry, ? extends State> existingState : + underlying.getTagsInUse(namespace).entrySet()) { + if (!((InMemoryState) existingState.getValue()).isCleared()) { + // Only read through non-cleared values to ensure that completed windows are + // eventually discarded, and remember the earliest watermark hold from among those + // values. + State state = + readTo.get(namespace, existingState.getKey(), StateContexts.nullContext()); + if (state instanceof WatermarkHoldState) { + Instant hold = ((WatermarkHoldState) state).read(); + if (hold != null && hold.isBefore(earliestHold)) { + earliestHold = hold; + } + } + } + } + } + return earliestHold; + } + + @Override + public StateBinder forNamespace(final StateNamespace namespace, final StateContext c) { + return new StateBinder() { + @Override + public WatermarkHoldState bindWatermark( + StateTag> address, + OutputTimeFn outputTimeFn) { + return underlying.get(namespace, address, c); + } + + @Override + public ValueState bindValue( + StateTag> address, Coder coder) { + return underlying.get(namespace, address, c); + } + + @Override + public AccumulatorCombiningState + bindCombiningValue( + StateTag> address, + Coder accumCoder, CombineFn combineFn) { + return underlying.get(namespace, address, c); + } + + @Override + public BagState bindBag( + StateTag> address, Coder elemCoder) { + return underlying.get(namespace, address, c); + } + + @Override + public AccumulatorCombiningState + bindKeyedCombiningValue( + StateTag> address, + Coder accumCoder, + KeyedCombineFn combineFn) { + return underlying.get(namespace, address, c); + } + + @Override + public AccumulatorCombiningState + bindKeyedCombiningValueWithContext( + StateTag> address, + Coder accumCoder, + KeyedCombineFnWithContext combineFn) { + return bindKeyedCombiningValue( + address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + } + }; + } + } + + private static class InMemoryStateBinderFactory implements StateBinderFactory { + private final K key; + + public InMemoryStateBinderFactory(K key) { + this.key = key; + } + + @Override + public StateBinder forNamespace(StateNamespace namespace, StateContext c) { + return new InMemoryStateInternals.InMemoryStateBinder<>(key, c); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java new file mode 100644 index 000000000000..840480126022 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java @@ -0,0 +1,414 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.CombineFnUtil; +import com.google.cloud.dataflow.sdk.util.state.StateTag.StateBinder; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import javax.annotation.Nullable; + +/** + * In-memory implementation of {@link StateInternals}. Used in {@code BatchModeExecutionContext} + * and for running tests that need state. + */ +@Experimental(Kind.STATE) +public class InMemoryStateInternals implements StateInternals { + + public static InMemoryStateInternals forKey(K key) { + return new InMemoryStateInternals<>(key); + } + + private final K key; + + protected InMemoryStateInternals(K key) { + this.key = key; + } + + @Override + public K getKey() { + return key; + } + + interface InMemoryState> { + boolean isCleared(); + T copy(); + } + + protected final StateTable inMemoryState = new StateTable() { + @Override + protected StateBinder binderForNamespace(StateNamespace namespace, StateContext c) { + return new InMemoryStateBinder(key, c); + } + }; + + public void clear() { + inMemoryState.clear(); + } + + /** + * Return true if the given state is empty. This is used by the test framework to make sure + * that the state has been properly cleaned up. + */ + protected boolean isEmptyForTesting(State state) { + return ((InMemoryState) state).isCleared(); + } + + @Override + public T state(StateNamespace namespace, StateTag address) { + return inMemoryState.get(namespace, address, StateContexts.nullContext()); + } + + @Override + public T state( + StateNamespace namespace, StateTag address, final StateContext c) { + return inMemoryState.get(namespace, address, c); + } + + /** + * A {@link StateBinder} that returns In Memory {@link State} objects. + */ + static class InMemoryStateBinder implements StateBinder { + private final K key; + private final StateContext c; + + InMemoryStateBinder(K key, StateContext c) { + this.key = key; + this.c = c; + } + + @Override + public ValueState bindValue( + StateTag> address, Coder coder) { + return new InMemoryValue(); + } + + @Override + public BagState bindBag( + final StateTag> address, Coder elemCoder) { + return new InMemoryBag(); + } + + @Override + public AccumulatorCombiningState + bindCombiningValue( + StateTag> address, + Coder accumCoder, + final CombineFn combineFn) { + return new InMemoryCombiningValue(key, combineFn.asKeyedFn()); + } + + @Override + public WatermarkHoldState bindWatermark( + StateTag> address, + OutputTimeFn outputTimeFn) { + return new InMemoryWatermarkHold(outputTimeFn); + } + + @Override + public AccumulatorCombiningState + bindKeyedCombiningValue( + StateTag> address, + Coder accumCoder, + KeyedCombineFn combineFn) { + return new InMemoryCombiningValue(key, combineFn); + } + + @Override + public AccumulatorCombiningState + bindKeyedCombiningValueWithContext( + StateTag> address, + Coder accumCoder, + KeyedCombineFnWithContext combineFn) { + return bindKeyedCombiningValue(address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + } + } + + static final class InMemoryValue implements ValueState, InMemoryState> { + private boolean isCleared = true; + private T value = null; + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this Value. + value = null; + isCleared = true; + } + + @Override + public InMemoryValue readLater() { + return this; + } + + @Override + public T read() { + return value; + } + + @Override + public void write(T input) { + isCleared = false; + this.value = input; + } + + @Override + public InMemoryValue copy() { + InMemoryValue that = new InMemoryValue<>(); + if (!this.isCleared) { + that.isCleared = this.isCleared; + that.value = this.value; + } + return that; + } + + @Override + public boolean isCleared() { + return isCleared; + } + } + + static final class InMemoryWatermarkHold + implements WatermarkHoldState, InMemoryState> { + + private final OutputTimeFn outputTimeFn; + + @Nullable + private Instant combinedHold = null; + + public InMemoryWatermarkHold(OutputTimeFn outputTimeFn) { + this.outputTimeFn = outputTimeFn; + } + + @Override + public InMemoryWatermarkHold readLater() { + return this; + } + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this WatermarkBagInternal. + combinedHold = null; + } + + @Override + public Instant read() { + return combinedHold; + } + + @Override + public void add(Instant outputTime) { + combinedHold = combinedHold == null ? outputTime + : outputTimeFn.combine(combinedHold, outputTime); + } + + @Override + public boolean isCleared() { + return combinedHold == null; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public ReadableState readLater() { + return this; + } + @Override + public Boolean read() { + return combinedHold == null; + } + }; + } + + @Override + public OutputTimeFn getOutputTimeFn() { + return outputTimeFn; + } + + @Override + public String toString() { + return Objects.toString(combinedHold); + } + + @Override + public InMemoryWatermarkHold copy() { + InMemoryWatermarkHold that = + new InMemoryWatermarkHold<>(outputTimeFn); + that.combinedHold = this.combinedHold; + return that; + } + } + + static final class InMemoryCombiningValue + implements AccumulatorCombiningState, + InMemoryState> { + private final K key; + private boolean isCleared = true; + private final KeyedCombineFn combineFn; + private AccumT accum; + + InMemoryCombiningValue( + K key, KeyedCombineFn combineFn) { + this.key = key; + this.combineFn = combineFn; + accum = combineFn.createAccumulator(key); + } + + @Override + public InMemoryCombiningValue readLater() { + return this; + } + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this CombiningValue. + accum = combineFn.createAccumulator(key); + isCleared = true; + } + + @Override + public OutputT read() { + return combineFn.extractOutput(key, accum); + } + + @Override + public void add(InputT input) { + isCleared = false; + accum = combineFn.addInput(key, accum, input); + } + + @Override + public AccumT getAccum() { + return accum; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public ReadableState readLater() { + return this; + } + @Override + public Boolean read() { + return isCleared; + } + }; + } + + @Override + public void addAccum(AccumT accum) { + isCleared = false; + this.accum = combineFn.mergeAccumulators(key, Arrays.asList(this.accum, accum)); + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(key, accumulators); + } + + @Override + public boolean isCleared() { + return isCleared; + } + + @Override + public InMemoryCombiningValue copy() { + InMemoryCombiningValue that = + new InMemoryCombiningValue<>(key, combineFn); + if (!this.isCleared) { + that.isCleared = this.isCleared; + that.addAccum(accum); + } + return that; + } + } + + static final class InMemoryBag implements BagState, InMemoryState> { + private List contents = new ArrayList<>(); + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this Bag. + // The result of get/read below must be stable for the lifetime of the bundle within which it + // was generated. In batch and direct runners the bundle lifetime can be + // greater than the window lifetime, in which case this method can be called while + // the result is still in use. We protect against this by hot-swapping instead of + // clearing the contents. + contents = new ArrayList<>(); + } + + @Override + public InMemoryBag readLater() { + return this; + } + + @Override + public Iterable read() { + return contents; + } + + @Override + public void add(T input) { + contents.add(input); + } + + @Override + public boolean isCleared() { + return contents.isEmpty(); + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public ReadableState readLater() { + return this; + } + + @Override + public Boolean read() { + return contents.isEmpty(); + } + }; + } + + @Override + public InMemoryBag copy() { + InMemoryBag that = new InMemoryBag<>(); + that.contents.addAll(this.contents); + return that; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/MergingStateAccessor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/MergingStateAccessor.java new file mode 100644 index 000000000000..40211d739bea --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/MergingStateAccessor.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; + +import java.util.Map; + +/** + * Interface for accessing persistent state while windows are merging. + * + *

    For internal use only. + */ +@Experimental(Kind.STATE) +public interface MergingStateAccessor + extends StateAccessor { + /** + * Analogous to {@link #access}, but returned as a map from each window which is + * about to be merged to the corresponding state. Only includes windows which + * are known to have state. + */ + Map accessInEachMergingWindow( + StateTag address); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/ReadableState.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/ReadableState.java new file mode 100644 index 000000000000..8f690a33de4f --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/ReadableState.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; + +/** + * A {@code StateContents} is produced by the read methods on all {@link State} objects. + * Calling {@link #read} returns the associated value. + * + *

    This class is similar to {@link java.util.concurrent.Future}, but each invocation of + * {@link #read} need not return the same value. + * + *

    Getting the {@code StateContents} from a read method indicates the desire to eventually + * read a value. Depending on the runner this may or may not immediately start the read. + * + * @param The type of value returned by {@link #read}. + */ +@Experimental(Kind.STATE) +public interface ReadableState { + /** + * Read the current value, blocking until it is available. + * + *

    If there will be many calls to {@link #read} for different state in short succession, + * you should first call {@link #readLater} for all of them so the reads can potentially be + * batched (depending on the underlying {@link StateInternals} implementation}. + */ + T read(); + + /** + * Indicate that the value will be read later. + * + *

    This allows a {@link StateInternals} implementation to start an asynchronous prefetch or + * to include this state in the next batch of reads. + * + * @return this for convenient chaining + */ + ReadableState readLater(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/State.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/State.java new file mode 100644 index 000000000000..0cef786ad5cf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/State.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +/** + * Base interface for all state locations. + * + *

    Specific types of state add appropriate accessors for reading and writing values, see + * {@link ValueState}, {@link BagState}, and {@link CombiningState}. + */ +public interface State { + + /** + * Clear out the state location. + */ + void clear(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateAccessor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateAccessor.java new file mode 100644 index 000000000000..6cfbecff7097 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateAccessor.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; + +/** + * Interface for accessing a {@link StateTag} in the current context. + * + *

    For internal use only. + */ +@Experimental(Kind.STATE) +public interface StateAccessor { + /** + * Access the storage for the given {@code address} in the current window. + * + *

    Never accounts for merged windows. When windows are merged, any state accessed via + * this method must be eagerly combined and written into the result window. + */ + StateT access(StateTag address); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java new file mode 100644 index 000000000000..96387d85084a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +/** + * Information accessible the state API. + */ +public interface StateContext { + /** + * Returns the {@code PipelineOptions} specified with the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner}. + */ + public abstract PipelineOptions getPipelineOptions(); + + /** + * Returns the value of the side input for the corresponding state window. + */ + public abstract T sideInput(PCollectionView view); + + /** + * Returns the window corresponding to the state. + */ + public abstract W window(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java new file mode 100644 index 000000000000..e301d438cdf3 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import javax.annotation.Nullable; + +/** + * Factory that produces {@link StateContext} based on different inputs. + */ +public class StateContexts { + private static final StateContext NULL_CONTEXT = + new StateContext() { + @Override + public PipelineOptions getPipelineOptions() { + throw new IllegalArgumentException("cannot call getPipelineOptions() in a null context"); + } + + @Override + public T sideInput(PCollectionView view) { + throw new IllegalArgumentException("cannot call sideInput() in a null context"); + } + + @Override + public BoundedWindow window() { + throw new IllegalArgumentException("cannot call window() in a null context"); + }}; + + /** + * Returns a fake {@link StateContext}. + */ + @SuppressWarnings("unchecked") + public static StateContext nullContext() { + return (StateContext) NULL_CONTEXT; + } + + /** + * Returns a {@link StateContext} that only contains the state window. + */ + public static StateContext windowOnly(final W window) { + return new StateContext() { + @Override + public PipelineOptions getPipelineOptions() { + throw new IllegalArgumentException( + "cannot call getPipelineOptions() in a window only context"); + } + @Override + public T sideInput(PCollectionView view) { + throw new IllegalArgumentException("cannot call sideInput() in a window only context"); + } + @Override + public W window() { + return window; + } + }; + } + + /** + * Returns a {@link StateContext} from {@code PipelineOptions}, {@link WindowingInternals}, + * and the state window. + */ + public static StateContext createFromComponents( + @Nullable final PipelineOptions options, + final WindowingInternals windowingInternals, + final W window) { + @SuppressWarnings("unchecked") + StateContext typedNullContext = (StateContext) NULL_CONTEXT; + if (options == null) { + return typedNullContext; + } else { + return new StateContext() { + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public T sideInput(PCollectionView view) { + return windowingInternals.sideInput(view, window); + } + + @Override + public W window() { + return window; + } + }; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java new file mode 100644 index 000000000000..b31afb469802 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; + +/** + * {@code StateInternals} describes the functionality a runner needs to provide for the + * State API to be supported. + * + *

    The SDK will only use this after elements have been partitioned by key. For instance, after a + * {@link GroupByKey} operation. The runner implementation must ensure that any writes using + * {@link StateInternals} are implicitly scoped to the key being processed and the specific step + * accessing state. + * + *

    The runner implementation must also ensure that any writes to the associated state objects + * are persisted together with the completion status of the processing that produced these + * writes. + * + *

    This is a low-level API intended for use by the Dataflow SDK. It should not be + * used directly, and is highly likely to change. + */ +@Experimental(Kind.STATE) +public interface StateInternals { + + /** The key for this {@link StateInternals}. */ + K getKey(); + + /** + * Return the state associated with {@code address} in the specified {@code namespace}. + */ + T state(StateNamespace namespace, StateTag address); + + /** + * Return the state associated with {@code address} in the specified {@code namespace} + * with the {@link StateContext}. + */ + T state( + StateNamespace namespace, StateTag address, StateContext c); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateMerging.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateMerging.java new file mode 100644 index 000000000000..0b33ea9b8e61 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateMerging.java @@ -0,0 +1,254 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.common.base.Preconditions; + +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** + * Helpers for merging state. + */ +public class StateMerging { + /** + * Clear all state in {@code address} in all windows under merge (even result windows) + * in {@code context}. + */ + public static void clear( + MergingStateAccessor context, StateTag address) { + for (StateT state : context.accessInEachMergingWindow(address).values()) { + state.clear(); + } + } + + /** + * Prefetch all bag state in {@code address} across all windows under merge in + * {@code context}, except for the bag state in the final state address window which we can + * blindly append to. + */ + public static void prefetchBags( + MergingStateAccessor context, StateTag> address) { + Map> map = context.accessInEachMergingWindow(address); + if (map.isEmpty()) { + // Nothing to prefetch. + return; + } + BagState result = context.access(address); + // Prefetch everything except what's already in result. + for (BagState source : map.values()) { + if (!source.equals(result)) { + source.readLater(); + } + } + } + + /** + * Merge all bag state in {@code address} across all windows under merge. + */ + public static void mergeBags( + MergingStateAccessor context, StateTag> address) { + mergeBags(context.accessInEachMergingWindow(address).values(), context.access(address)); + } + + /** + * Merge all bag state in {@code sources} (which may include {@code result}) into {@code result}. + */ + public static void mergeBags( + Collection> sources, BagState result) { + if (sources.isEmpty()) { + // Nothing to merge. + return; + } + // Prefetch everything except what's already in result. + List>> futures = new ArrayList<>(sources.size()); + for (BagState source : sources) { + if (!source.equals(result)) { + source.readLater(); + futures.add(source); + } + } + if (futures.isEmpty()) { + // Result already holds all the values. + return; + } + // Transfer from sources to result. + for (ReadableState> future : futures) { + for (T element : future.read()) { + result.add(element); + } + } + // Clear sources except for result. + for (BagState source : sources) { + if (!source.equals(result)) { + source.clear(); + } + } + } + + /** + * Prefetch all combining value state for {@code address} across all merging windows in {@code + * context}. + */ + public static , W extends BoundedWindow> void + prefetchCombiningValues(MergingStateAccessor context, + StateTag address) { + for (StateT state : context.accessInEachMergingWindow(address).values()) { + state.readLater(); + } + } + + /** + * Merge all value state in {@code address} across all merging windows in {@code context}. + */ + public static void mergeCombiningValues( + MergingStateAccessor context, + StateTag> address) { + mergeCombiningValues( + context.accessInEachMergingWindow(address).values(), context.access(address)); + } + + /** + * Merge all value state from {@code sources} (which may include {@code result}) into + * {@code result}. + */ + public static void mergeCombiningValues( + Collection> sources, + AccumulatorCombiningState result) { + if (sources.isEmpty()) { + // Nothing to merge. + return; + } + if (sources.size() == 1 && sources.contains(result)) { + // Result already holds combined value. + return; + } + // Prefetch. + List> futures = new ArrayList<>(sources.size()); + for (AccumulatorCombiningState source : sources) { + source.readLater(); + } + // Read. + List accumulators = new ArrayList<>(futures.size()); + for (AccumulatorCombiningState source : sources) { + accumulators.add(source.getAccum()); + } + // Merge (possibly update and return one of the existing accumulators). + AccumT merged = result.mergeAccumulators(accumulators); + // Clear sources. + for (AccumulatorCombiningState source : sources) { + source.clear(); + } + // Update result. + result.addAccum(merged); + } + + /** + * Prefetch all watermark state for {@code address} across all merging windows in + * {@code context}. + */ + public static void prefetchWatermarks( + MergingStateAccessor context, + StateTag> address) { + Map> map = context.accessInEachMergingWindow(address); + WatermarkHoldState result = context.access(address); + if (map.isEmpty()) { + // Nothing to prefetch. + return; + } + if (map.size() == 1 && map.values().contains(result) + && result.getOutputTimeFn().dependsOnlyOnEarliestInputTimestamp()) { + // Nothing to change. + return; + } + if (result.getOutputTimeFn().dependsOnlyOnWindow()) { + // No need to read existing holds. + return; + } + // Prefetch. + for (WatermarkHoldState source : map.values()) { + source.readLater(); + } + } + + /** + * Merge all watermark state in {@code address} across all merging windows in {@code context}, + * where the final merge result window is {@code mergeResult}. + */ + public static void mergeWatermarks( + MergingStateAccessor context, + StateTag> address, + W mergeResult) { + mergeWatermarks( + context.accessInEachMergingWindow(address).values(), context.access(address), mergeResult); + } + + /** + * Merge all watermark state in {@code sources} (which must include {@code result} if non-empty) + * into {@code result}, where the final merge result window is {@code mergeResult}. + */ + public static void mergeWatermarks( + Collection> sources, WatermarkHoldState result, + W resultWindow) { + if (sources.isEmpty()) { + // Nothing to merge. + return; + } + if (sources.size() == 1 && sources.contains(result) + && result.getOutputTimeFn().dependsOnlyOnEarliestInputTimestamp()) { + // Nothing to merge. + return; + } + if (result.getOutputTimeFn().dependsOnlyOnWindow()) { + // Clear sources. + for (WatermarkHoldState source : sources) { + source.clear(); + } + // Update directly from window-derived hold. + Instant hold = result.getOutputTimeFn().assignOutputTime( + BoundedWindow.TIMESTAMP_MIN_VALUE, resultWindow); + Preconditions.checkState(hold.isAfter(BoundedWindow.TIMESTAMP_MIN_VALUE)); + result.add(hold); + } else { + // Prefetch. + List> futures = new ArrayList<>(sources.size()); + for (WatermarkHoldState source : sources) { + futures.add(source); + } + // Read. + List outputTimesToMerge = new ArrayList<>(sources.size()); + for (ReadableState future : futures) { + Instant sourceOutputTime = future.read(); + if (sourceOutputTime != null) { + outputTimesToMerge.add(sourceOutputTime); + } + } + // Clear sources. + for (WatermarkHoldState source : sources) { + source.clear(); + } + if (!outputTimesToMerge.isEmpty()) { + // Merge and update. + result.add(result.getOutputTimeFn().merge(resultWindow, outputTimesToMerge)); + } + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespace.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespace.java new file mode 100644 index 000000000000..f972e312f9ee --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespace.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import java.io.IOException; + +/** + * A namespace used for scoping state stored with {@link StateInternals}. + * + *

    Instances of {@code StateNamespace} are guaranteed to have a {@link #hashCode} and + * {@link #equals} that uniquely identify the namespace. + */ +public interface StateNamespace { + + /** + * Return a {@link String} representation of the key. It is guaranteed that this + * {@code String} will uniquely identify the key. + * + *

    This will encode the actual namespace as a {@code String}. It is + * preferable to use the {@code StateNamespace} object when possible. + * + *

    The string produced by the standard implementations will not contain a '+' character. This + * enables adding a '+' between the actual namespace and other information, if needed, to separate + * the two. + */ + String stringKey(); + + /** + * Append the string representation of this key to the {@link Appendable}. + */ + void appendTo(Appendable sb) throws IOException; + + /** + * Return an {@code Object} to use as a key in a cache. + * + *

    Different namespaces may use the same key in order to be treated as a unit in the cache. + * The {@code Object}'s {@code hashCode} and {@code equals} methods will be used to determine + * equality. + */ + Object getCacheKey(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespaceForTest.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespaceForTest.java new file mode 100644 index 000000000000..09b86d67e9bf --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespaceForTest.java @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import java.io.IOException; +import java.util.Objects; + +/** + * A simple {@link StateNamespace} used for testing. + */ +public class StateNamespaceForTest implements StateNamespace { + private String key; + + public StateNamespaceForTest(String key) { + this.key = key; + } + + @Override + public String stringKey() { + return key; + } + + @Override + public Object getCacheKey() { + return key; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof StateNamespaceForTest)) { + return false; + } + + return Objects.equals(this.key, ((StateNamespaceForTest) obj).key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + + @Override + public void appendTo(Appendable sb) throws IOException { + sb.append(key); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespaces.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespaces.java new file mode 100644 index 000000000000..8fee9959b944 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateNamespaces.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.base.Splitter; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * Factory methods for creating the {@link StateNamespace StateNamespaces}. + */ +public class StateNamespaces { + + private enum Namespace { + GLOBAL, + WINDOW, + WINDOW_AND_TRIGGER; + } + + public static StateNamespace global() { + return new GlobalNamespace(); + } + + public static StateNamespace window(Coder windowCoder, W window) { + return new WindowNamespace<>(windowCoder, window); + } + + public static + StateNamespace windowAndTrigger(Coder windowCoder, W window, int triggerIdx) { + return new WindowAndTriggerNamespace<>(windowCoder, window, triggerIdx); + } + + private StateNamespaces() {} + + /** + * {@link StateNamespace} that is global to the current key being processed. + */ + public static class GlobalNamespace implements StateNamespace { + + private static final String GLOBAL_STRING = "/"; + + @Override + public String stringKey() { + return GLOBAL_STRING; + } + + @Override + public Object getCacheKey() { + return GLOBAL_STRING; + } + + @Override + public boolean equals(Object obj) { + return obj == this || obj instanceof GlobalNamespace; + } + + @Override + public int hashCode() { + return Objects.hash(Namespace.GLOBAL); + } + + @Override + public String toString() { + return "Global"; + } + + @Override + public void appendTo(Appendable sb) throws IOException { + sb.append(GLOBAL_STRING); + } + } + + /** + * {@link StateNamespace} that is scoped to a specific window. + */ + public static class WindowNamespace implements StateNamespace { + + private static final String WINDOW_FORMAT = "/%s/"; + + private Coder windowCoder; + private W window; + + private WindowNamespace(Coder windowCoder, W window) { + this.windowCoder = windowCoder; + this.window = window; + } + + public W getWindow() { + return window; + } + + @Override + public String stringKey() { + try { + return String.format(WINDOW_FORMAT, CoderUtils.encodeToBase64(windowCoder, window)); + } catch (CoderException e) { + throw new RuntimeException("Unable to generate string key from window " + window, e); + } + } + + @Override + public void appendTo(Appendable sb) throws IOException { + sb.append('/').append(CoderUtils.encodeToBase64(windowCoder, window)).append('/'); + } + + /** + * State in the same window will all be evicted together. + */ + @Override + public Object getCacheKey() { + return window; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof WindowNamespace)) { + return false; + } + + WindowNamespace that = (WindowNamespace) obj; + return Objects.equals(this.window, that.window); + } + + @Override + public int hashCode() { + return Objects.hash(Namespace.WINDOW, window); + } + + @Override + public String toString() { + return "Window(" + window + ")"; + } + } + + /** + * {@link StateNamespace} that is scoped to a particular window and trigger index. + */ + public static class WindowAndTriggerNamespace + implements StateNamespace { + + private static final String WINDOW_AND_TRIGGER_FORMAT = "/%s/%s/"; + + private static final int TRIGGER_RADIX = 36; + private Coder windowCoder; + private W window; + private int triggerIndex; + + private WindowAndTriggerNamespace(Coder windowCoder, W window, int triggerIndex) { + this.windowCoder = windowCoder; + this.window = window; + this.triggerIndex = triggerIndex; + } + + public W getWindow() { + return window; + } + + public int getTriggerIndex() { + return triggerIndex; + } + + @Override + public String stringKey() { + try { + return String.format(WINDOW_AND_TRIGGER_FORMAT, + CoderUtils.encodeToBase64(windowCoder, window), + // Use base 36 so that can address 36 triggers in a single byte and still be human + // readable. + Integer.toString(triggerIndex, TRIGGER_RADIX).toUpperCase()); + } catch (CoderException e) { + throw new RuntimeException("Unable to generate string key from window " + window, e); + } + } + + @Override + public void appendTo(Appendable sb) throws IOException { + sb.append('/').append(CoderUtils.encodeToBase64(windowCoder, window)); + sb.append('/').append(Integer.toString(triggerIndex, TRIGGER_RADIX).toUpperCase()); + sb.append('/'); + } + + /** + * State in the same window will all be evicted together. + */ + @Override + public Object getCacheKey() { + return window; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof WindowAndTriggerNamespace)) { + return false; + } + + WindowAndTriggerNamespace that = (WindowAndTriggerNamespace) obj; + return this.triggerIndex == that.triggerIndex + && Objects.equals(this.window, that.window); + } + + @Override + public int hashCode() { + return Objects.hash(Namespace.WINDOW_AND_TRIGGER, window, triggerIndex); + } + + @Override + public String toString() { + return "WindowAndTrigger(" + window + "," + triggerIndex + ")"; + } + } + + private static final Splitter SLASH_SPLITTER = Splitter.on('/'); + + /** + * Convert a {@code stringKey} produced using {@link StateNamespace#stringKey} + * on one of the namespaces produced by this class into the original + * {@link StateNamespace}. + */ + public static StateNamespace fromString( + String stringKey, Coder windowCoder) { + if (!stringKey.startsWith("/") || !stringKey.endsWith("/")) { + throw new RuntimeException("Invalid namespace string: '" + stringKey + "'"); + } + + if (GlobalNamespace.GLOBAL_STRING.equals(stringKey)) { + return global(); + } + + List parts = SLASH_SPLITTER.splitToList(stringKey); + if (parts.size() != 3 && parts.size() != 4) { + throw new RuntimeException("Invalid namespace string: '" + stringKey + "'"); + } + // Ends should be empty (we start and end with /) + if (!parts.get(0).isEmpty() || !parts.get(parts.size() - 1).isEmpty()) { + throw new RuntimeException("Invalid namespace string: '" + stringKey + "'"); + } + + try { + W window = CoderUtils.decodeFromBase64(windowCoder, parts.get(1)); + if (parts.size() > 3) { + int index = Integer.parseInt(parts.get(2), WindowAndTriggerNamespace.TRIGGER_RADIX); + return windowAndTrigger(windowCoder, window, index); + } else { + return window(windowCoder, window); + } + } catch (Exception e) { + throw new RuntimeException("Invalid namespace string: '" + stringKey + "'", e); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java new file mode 100644 index 000000000000..edd1dae279e9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.util.state.StateTag.StateBinder; +import com.google.common.base.Supplier; +import com.google.common.collect.Table; +import com.google.common.collect.Tables; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Table mapping {@code StateNamespace} and {@code StateTag} to a {@code State} instance. + */ +public abstract class StateTable { + + private final Table, State> stateTable = + Tables.newCustomTable(new HashMap, State>>(), + new Supplier, State>>() { + @Override + public Map, State> get() { + return new HashMap<>(); + } + }); + + /** + * Gets the {@link State} in the specified {@link StateNamespace} with the specified {@link + * StateTag}, binding it using the {@link #binderForNamespace} if it is not + * already present in this {@link StateTable}. + */ + public StateT get( + StateNamespace namespace, StateTag tag, StateContext c) { + State storage = stateTable.get(namespace, tag); + if (storage != null) { + @SuppressWarnings("unchecked") + StateT typedStorage = (StateT) storage; + return typedStorage; + } + + StateT typedStorage = tag.bind(binderForNamespace(namespace, c)); + stateTable.put(namespace, tag, typedStorage); + return typedStorage; + } + + public void clearNamespace(StateNamespace namespace) { + stateTable.rowKeySet().remove(namespace); + } + + public void clear() { + stateTable.clear(); + } + + public Iterable values() { + return stateTable.values(); + } + + public boolean isNamespaceInUse(StateNamespace namespace) { + return stateTable.containsRow(namespace); + } + + public Map, State> getTagsInUse(StateNamespace namespace) { + return stateTable.row(namespace); + } + + public Set getNamespacesInUse() { + return stateTable.rowKeySet(); + } + + /** + * Provide the {@code StateBinder} to use for creating {@code Storage} instances + * in the specified {@code namespace}. + */ + protected abstract StateBinder binderForNamespace(StateNamespace namespace, StateContext c); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java new file mode 100644 index 000000000000..c87bdb788c36 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; + +import java.io.IOException; +import java.io.Serializable; + +/** + * An address for persistent state. This includes a unique identifier for the location, the + * information necessary to encode the value, and details about the intended access pattern. + * + *

    State can be thought of as a sparse table, with each {@code StateTag} defining a column + * that has cells of type {@code StateT}. + * + *

    Currently, this can only be used in a step immediately following a {@link GroupByKey}. + * + * @param The type of key that must be used with the state tag. Contravariant: methods should + * accept values of type {@code KeyedStateTag}. + * @param The type of state being tagged. + */ +@Experimental(Kind.STATE) +public interface StateTag extends Serializable { + + /** + * Visitor for binding a {@link StateTag} and to the associated {@link State}. + * + * @param the type of key this binder embodies. + */ + public interface StateBinder { + ValueState bindValue(StateTag> address, Coder coder); + + BagState bindBag(StateTag> address, Coder elemCoder); + + AccumulatorCombiningState + bindCombiningValue( + StateTag> address, + Coder accumCoder, CombineFn combineFn); + + AccumulatorCombiningState + bindKeyedCombiningValue( + StateTag> address, + Coder accumCoder, KeyedCombineFn combineFn); + + AccumulatorCombiningState + bindKeyedCombiningValueWithContext( + StateTag> address, + Coder accumCoder, + KeyedCombineFnWithContext combineFn); + + /** + * Bind to a watermark {@link StateTag}. + * + *

    This accepts the {@link OutputTimeFn} that dictates how watermark hold timestamps + * added to the returned {@link WatermarkHoldState} are to be combined. + */ + WatermarkHoldState bindWatermark( + StateTag> address, + OutputTimeFn outputTimeFn); + } + + /** Append the UTF-8 encoding of this tag to the given {@link Appendable}. */ + void appendTo(Appendable sb) throws IOException; + + /** + * Returns the user-provided name of this state cell. + */ + String getId(); + + /** + * Use the {@code binder} to create an instance of {@code StateT} appropriate for this address. + */ + StateT bind(StateBinder binder); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java new file mode 100644 index 000000000000..0cbaa5236922 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java @@ -0,0 +1,569 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.common.base.MoreObjects; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Objects; + +/** + * Static utility methods for creating {@link StateTag} instances. + */ +@Experimental(Kind.STATE) +public class StateTags { + + private static final CoderRegistry STANDARD_REGISTRY = new CoderRegistry(); + + static { + STANDARD_REGISTRY.registerStandardCoders(); + } + + private enum StateKind { + SYSTEM('s'), + USER('u'); + + private char prefix; + + StateKind(char prefix) { + this.prefix = prefix; + } + } + + private StateTags() { } + + private interface SystemStateTag { + StateTag asKind(StateKind kind); + } + + /** + * Create a simple state tag for values of type {@code T}. + */ + public static StateTag> value(String id, Coder valueCoder) { + return new ValueStateTag<>(new StructuredId(id), valueCoder); + } + + /** + * Create a state tag for values that use a {@link CombineFn} to automatically merge + * multiple {@code InputT}s into a single {@code OutputT}. + */ + public static + StateTag> + combiningValue( + String id, Coder accumCoder, CombineFn combineFn) { + return combiningValueInternal(id, accumCoder, combineFn); + } + + /** + * Create a state tag for values that use a {@link KeyedCombineFn} to automatically merge + * multiple {@code InputT}s into a single {@code OutputT}. The key provided to the + * {@link KeyedCombineFn} comes from the keyed {@link StateAccessor}. + */ + public static StateTag> + keyedCombiningValue(String id, Coder accumCoder, + KeyedCombineFn combineFn) { + return keyedCombiningValueInternal(id, accumCoder, combineFn); + } + + /** + * Create a state tag for values that use a {@link KeyedCombineFnWithContext} to automatically + * merge multiple {@code InputT}s into a single {@code OutputT}. The key provided to the + * {@link KeyedCombineFn} comes from the keyed {@link StateAccessor}, the context provided comes + * from the {@link StateContext}. + */ + public static + StateTag> + keyedCombiningValueWithContext( + String id, + Coder accumCoder, + KeyedCombineFnWithContext combineFn) { + return new KeyedCombiningValueWithContextStateTag( + new StructuredId(id), + accumCoder, + combineFn); + } + + /** + * Create a state tag for values that use a {@link CombineFn} to automatically merge + * multiple {@code InputT}s into a single {@code OutputT}. + * + *

    This determines the {@code Coder} from the given {@code Coder}, and + * should only be used to initialize static values. + */ + public static + StateTag> + combiningValueFromInputInternal( + String id, Coder inputCoder, CombineFn combineFn) { + try { + Coder accumCoder = combineFn.getAccumulatorCoder(STANDARD_REGISTRY, inputCoder); + return combiningValueInternal(id, accumCoder, combineFn); + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException( + "Unable to determine accumulator coder for " + combineFn.getClass().getSimpleName() + + " from " + inputCoder, e); + } + } + + private static StateTag> + combiningValueInternal( + String id, Coder accumCoder, CombineFn combineFn) { + return + new CombiningValueStateTag( + new StructuredId(id), accumCoder, combineFn); + } + + private static + StateTag> keyedCombiningValueInternal( + String id, + Coder accumCoder, + KeyedCombineFn combineFn) { + return new KeyedCombiningValueStateTag( + new StructuredId(id), accumCoder, combineFn); + } + + /** + * Create a state tag that is optimized for adding values frequently, and + * occasionally retrieving all the values that have been added. + */ + public static StateTag> bag(String id, Coder elemCoder) { + return new BagStateTag(new StructuredId(id), elemCoder); + } + + /** + * Create a state tag for holding the watermark. + */ + public static StateTag> + watermarkStateInternal(String id, OutputTimeFn outputTimeFn) { + return new WatermarkStateTagInternal(new StructuredId(id), outputTimeFn); + } + + /** + * Convert an arbitrary {@link StateTag} to a system-internal tag that is guaranteed not to + * collide with any user tags. + */ + public static StateTag makeSystemTagInternal( + StateTag tag) { + if (!(tag instanceof SystemStateTag)) { + throw new IllegalArgumentException("Expected subclass of StateTagBase, got " + tag); + } + // Checked above + @SuppressWarnings("unchecked") + SystemStateTag typedTag = (SystemStateTag) tag; + return typedTag.asKind(StateKind.SYSTEM); + } + + public static StateTag> + convertToBagTagInternal( + StateTag> combiningTag) { + if (!(combiningTag instanceof KeyedCombiningValueStateTag)) { + throw new IllegalArgumentException("Unexpected StateTag " + combiningTag); + } + // Checked above; conversion to a bag tag depends on the provided tag being one of those + // created via the factory methods in this class. + @SuppressWarnings("unchecked") + KeyedCombiningValueStateTag typedTag = + (KeyedCombiningValueStateTag) combiningTag; + return typedTag.asBagTag(); + } + + private static class StructuredId implements Serializable { + private final StateKind kind; + private final String rawId; + + private StructuredId(String rawId) { + this(StateKind.USER, rawId); + } + + private StructuredId(StateKind kind, String rawId) { + this.kind = kind; + this.rawId = rawId; + } + + public StructuredId asKind(StateKind kind) { + return new StructuredId(kind, rawId); + } + + public void appendTo(Appendable sb) throws IOException { + sb.append(kind.prefix).append(rawId); + } + + public String getRawId() { + return rawId; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("id", rawId) + .add("kind", kind) + .toString(); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof StructuredId)) { + return false; + } + + StructuredId that = (StructuredId) obj; + return Objects.equals(this.kind, that.kind) + && Objects.equals(this.rawId, that.rawId); + } + + @Override + public int hashCode() { + return Objects.hash(kind, rawId); + } + } + + /** + * A base class that just manages the structured ids. + */ + private abstract static class StateTagBase + implements StateTag, SystemStateTag { + + protected final StructuredId id; + + protected StateTagBase(StructuredId id) { + this.id = id; + } + + @Override + public String getId() { + return id.getRawId(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("id", id) + .toString(); + } + + @Override + public void appendTo(Appendable sb) throws IOException { + id.appendTo(sb); + } + + @Override + public abstract StateTag asKind(StateKind kind); + } + + /** + * A value state cell for values of type {@code T}. + * + * @param the type of value being stored + */ + private static class ValueStateTag extends StateTagBase> + implements StateTag> { + + private final Coder coder; + + private ValueStateTag(StructuredId id, Coder coder) { + super(id); + this.coder = coder; + } + + @Override + public ValueState bind(StateBinder visitor) { + return visitor.bindValue(this, coder); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof ValueStateTag)) { + return false; + } + + ValueStateTag that = (ValueStateTag) obj; + return Objects.equals(this.id, that.id) + && Objects.equals(this.coder, that.coder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), id, coder); + } + + @Override + public StateTag> asKind(StateKind kind) { + return new ValueStateTag(id.asKind(kind), coder); + } + } + + /** + * A state cell for values that are combined according to a {@link CombineFn}. + * + * @param the type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + private static class CombiningValueStateTag + extends KeyedCombiningValueStateTag + implements StateTag>, + SystemStateTag> { + + private final Coder accumCoder; + private final CombineFn combineFn; + + private CombiningValueStateTag( + StructuredId id, + Coder accumCoder, CombineFn combineFn) { + super(id, accumCoder, combineFn.asKeyedFn()); + this.combineFn = combineFn; + this.accumCoder = accumCoder; + } + + @Override + public StateTag> + asKind(StateKind kind) { + return new CombiningValueStateTag( + id.asKind(kind), accumCoder, combineFn); + } + } + + /** + * A state cell for values that are combined according to a {@link KeyedCombineFnWithContext}. + * + * @param the type of keys + * @param the type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + private static class KeyedCombiningValueWithContextStateTag + extends StateTagBase> + implements SystemStateTag> { + + private final Coder accumCoder; + private final KeyedCombineFnWithContext combineFn; + + protected KeyedCombiningValueWithContextStateTag( + StructuredId id, + Coder accumCoder, + KeyedCombineFnWithContext combineFn) { + super(id); + this.combineFn = combineFn; + this.accumCoder = accumCoder; + } + + @Override + public AccumulatorCombiningState bind( + StateBinder visitor) { + return visitor.bindKeyedCombiningValueWithContext(this, accumCoder, combineFn); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof KeyedCombiningValueWithContextStateTag)) { + return false; + } + + KeyedCombiningValueWithContextStateTag that = + (KeyedCombiningValueWithContextStateTag) obj; + return Objects.equals(this.id, that.id) + && Objects.equals(this.accumCoder, that.accumCoder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), id, accumCoder); + } + + @Override + public StateTag> asKind( + StateKind kind) { + return new KeyedCombiningValueWithContextStateTag<>( + id.asKind(kind), accumCoder, combineFn); + } + } + + /** + * A state cell for values that are combined according to a {@link KeyedCombineFn}. + * + * @param the type of keys + * @param the type of input values + * @param type of mutable accumulator values + * @param type of output values + */ + private static class KeyedCombiningValueStateTag + extends StateTagBase> + implements SystemStateTag> { + + private final Coder accumCoder; + private final KeyedCombineFn keyedCombineFn; + + protected KeyedCombiningValueStateTag( + StructuredId id, + Coder accumCoder, KeyedCombineFn keyedCombineFn) { + super(id); + this.keyedCombineFn = keyedCombineFn; + this.accumCoder = accumCoder; + } + + @Override + public AccumulatorCombiningState bind( + StateBinder visitor) { + return visitor.bindKeyedCombiningValue(this, accumCoder, keyedCombineFn); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof CombiningValueStateTag)) { + return false; + } + + KeyedCombiningValueStateTag that = (KeyedCombiningValueStateTag) obj; + return Objects.equals(this.id, that.id) + && Objects.equals(this.accumCoder, that.accumCoder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), id, accumCoder); + } + + @Override + public StateTag> asKind( + StateKind kind) { + return new KeyedCombiningValueStateTag<>(id.asKind(kind), accumCoder, keyedCombineFn); + } + + private StateTag> asBagTag() { + return new BagStateTag(id, accumCoder); + } + } + + /** + * A state cell optimized for bag-like access patterns (frequent additions, occasional reads + * of all the values). + * + * @param the type of value in the bag + */ + private static class BagStateTag extends StateTagBase> + implements StateTag>{ + + private final Coder elemCoder; + + private BagStateTag(StructuredId id, Coder elemCoder) { + super(id); + this.elemCoder = elemCoder; + } + + @Override + public BagState bind(StateBinder visitor) { + return visitor.bindBag(this, elemCoder); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof BagStateTag)) { + return false; + } + + BagStateTag that = (BagStateTag) obj; + return Objects.equals(this.id, that.id) + && Objects.equals(this.elemCoder, that.elemCoder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), id, elemCoder); + } + + @Override + public StateTag> asKind(StateKind kind) { + return new BagStateTag<>(id.asKind(kind), elemCoder); + } + } + + private static class WatermarkStateTagInternal + extends StateTagBase> { + + /** + * When multiple output times are added to hold the watermark, this determines how they are + * combined, and also the behavior when merging windows. Does not contribute to equality/hash + * since we have at most one watermark hold tag per computation. + */ + private final OutputTimeFn outputTimeFn; + + private WatermarkStateTagInternal(StructuredId id, OutputTimeFn outputTimeFn) { + super(id); + this.outputTimeFn = outputTimeFn; + } + + @Override + public WatermarkHoldState bind(StateBinder visitor) { + return visitor.bindWatermark(this, outputTimeFn); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof WatermarkStateTagInternal)) { + return false; + } + + WatermarkStateTagInternal that = (WatermarkStateTagInternal) obj; + return Objects.equals(this.id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), id); + } + + @Override + public StateTag> asKind(StateKind kind) { + return new WatermarkStateTagInternal(id.asKind(kind), outputTimeFn); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/ValueState.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/ValueState.java new file mode 100644 index 000000000000..19c12bb164a9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/ValueState.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; + +/** + * State holding a single value. + * + * @param The type of values being stored. + */ +@Experimental(Kind.STATE) +public interface ValueState extends ReadableState, State { + /** + * Set the value of the buffer. + */ + void write(T input); + + @Override + ValueState readLater(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/WatermarkHoldState.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/WatermarkHoldState.java new file mode 100644 index 000000000000..8a1adc95585b --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/WatermarkHoldState.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; + +import org.joda.time.Instant; + +/** + * A {@link State} accepting and aggregating output timestamps, which determines + * the time to which the output watermark must be held. + * + *

    For internal use only. This API may change at any time. + */ +@Experimental(Kind.STATE) +public interface WatermarkHoldState + extends CombiningState { + /** + * Return the {@link OutputTimeFn} which will be used to determine a watermark hold time given + * an element timestamp, and to combine watermarks from windows which are about to be merged. + */ + OutputTimeFn getOutputTimeFn(); + + @Override + WatermarkHoldState readLater(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/KV.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/KV.java new file mode 100644 index 000000000000..23cee07cfe05 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/KV.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.SerializableComparator; +import com.google.common.base.MoreObjects; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Objects; + +/** + * An immutable key/value pair. + * + *

    Various {@link PTransform PTransforms} like {@link GroupByKey} and {@link Combine#perKey} + * operate on {@link PCollection PCollections} of {@link KV KVs}. + * + * @param the type of the key + * @param the type of the value + */ +public class KV implements Serializable { + /** Returns a {@link KV} with the given key and value. */ + public static KV of(K key, V value) { + return new KV<>(key, value); + } + + /** Returns the key of this {@link KV}. */ + public K getKey() { + return key; + } + + /** Returns the value of this {@link KV}. */ + public V getValue() { + return value; + } + + + ///////////////////////////////////////////////////////////////////////////// + + final K key; + final V value; + + private KV(K key, V value) { + this.key = key; + this.value = value; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof KV)) { + return false; + } + KV otherKv = (KV) other; + // Arrays are very common as values and keys, so deepEquals is mandatory + return Objects.deepEquals(this.key, otherKv.key) + && Objects.deepEquals(this.value, otherKv.value); + } + + /** + * A {@link Comparator} that orders {@link KV KVs} by the natural ordering of their keys. + * + *

    A {@code null} key is less than any non-{@code null} key. + */ + public static class OrderByKey, V> implements + SerializableComparator> { + @Override + public int compare(KV a, KV b) { + if (a.key == null) { + return b.key == null ? 0 : -1; + } else if (b.key == null) { + return 1; + } else { + return a.key.compareTo(b.key); + } + } + } + + /** + * A {@link Comparator} that orders {@link KV KVs} by the natural ordering of their values. + * + *

    A {@code null} value is less than any non-{@code null} value. + */ + public static class OrderByValue> + implements SerializableComparator> { + @Override + public int compare(KV a, KV b) { + if (a.value == null) { + return b.value == null ? 0 : -1; + } else if (b.value == null) { + return 1; + } else { + return a.value.compareTo(b.value); + } + } + } + + @Override + public int hashCode() { + // Objects.deepEquals requires Arrays.deepHashCode for correctness + return Arrays.deepHashCode(new Object[]{key, value}); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .addValue(key) + .addValue(value) + .toString(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PBegin.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PBegin.java new file mode 100644 index 000000000000..23ac3aed32d8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PBegin.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO.Read; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; +import java.util.Collections; + +/** + * {@link PBegin} is the "input" to a root {@link PTransform}, such as {@link Read Read} or + * {@link Create}. + * + *

    Typically created by calling {@link Pipeline#begin} on a Pipeline. + */ +public class PBegin implements PInput { + /** + * Returns a {@link PBegin} in the given {@link Pipeline}. + */ + public static PBegin in(Pipeline pipeline) { + return new PBegin(pipeline); + } + + /** + * Like {@link #apply(String, PTransform)} but defaulting to the name + * of the {@link PTransform}. + */ + public OutputT apply( + PTransform t) { + return Pipeline.applyTransform(this, t); + } + + /** + * Applies the given {@link PTransform} to this input {@link PBegin}, + * using {@code name} to identify this specific application of the transform. + * This name is used in various places, including the monitoring UI, logging, + * and to stably identify this application node in the job graph. + */ + public OutputT apply( + String name, PTransform t) { + return Pipeline.applyTransform(name, this, t); + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Collection expand() { + // A PBegin contains no PValues. + return Collections.emptyList(); + } + + @Override + public void finishSpecifying() { + // Nothing more to be done. + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Constructs a {@link PBegin} in the given {@link Pipeline}. + */ + protected PBegin(Pipeline pipeline) { + this.pipeline = pipeline; + } + + private final Pipeline pipeline; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollection.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollection.java new file mode 100644 index 000000000000..6fffddfeb960 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollection.java @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.PubsubIO; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; + +/** + * A {@link PCollection PCollection<T>} is an immutable collection of values of type + * {@code T}. A {@link PCollection} can contain either a bounded or unbounded + * number of elements. Bounded and unbounded {@link PCollection PCollections} are produced + * as the output of {@link PTransform PTransforms} + * (including root PTransforms like {@link Read} and {@link Create}), and can + * be passed as the inputs of other PTransforms. + * + *

    Some root transforms produce bounded {@code PCollections} and others + * produce unbounded ones. For example, {@link TextIO.Read} reads a static set + * of files, so it produces a bounded {@link PCollection}. + * {@link PubsubIO.Read}, on the other hand, receives a potentially infinite stream + * of Pubsub messages, so it produces an unbounded {@link PCollection}. + * + *

    Each element in a {@link PCollection} may have an associated implicit + * timestamp. Readers assign timestamps to elements when they create + * {@link PCollection PCollections}, and other {@link PTransform PTransforms} propagate these + * timestamps from their input to their output. For example, {@link PubsubIO.Read} + * assigns pubsub message timestamps to elements, and {@link TextIO.Read} assigns + * the default value {@link BoundedWindow#TIMESTAMP_MIN_VALUE} to elements. User code can + * explicitly assign timestamps to elements with + * {@link com.google.cloud.dataflow.sdk.transforms.DoFn.Context#outputWithTimestamp}. + * + *

    Additionally, a {@link PCollection} has an associated + * {@link WindowFn} and each element is assigned to a set of windows. + * By default, the windowing function is {@link GlobalWindows} + * and all elements are assigned into a single default window. + * This default can be overridden with the {@link Window} + * {@link PTransform}. + * + *

    See the individual {@link PTransform} subclasses for specific information + * on how they propagate timestamps and windowing. + * + * @param the type of the elements of this {@link PCollection} + */ +public class PCollection extends TypedPValue { + + /** + * The enumeration of cases for whether a {@link PCollection} is bounded. + */ + public enum IsBounded { + /** + * Indicates that a {@link PCollection} contains bounded data elements, such as + * {@link PCollection PCollections} from {@link TextIO}, {@link BigQueryIO}, + * {@link Create} e.t.c. + */ + BOUNDED, + /** + * Indicates that a {@link PCollection} contains unbounded data elements, such as + * {@link PCollection PCollections} from {@link PubsubIO}. + */ + UNBOUNDED; + + /** + * Returns the composed IsBounded property. + * + *

    The composed property is {@link #BOUNDED} only if all components are {@link #BOUNDED}. + * Otherwise, it is {@link #UNBOUNDED}. + */ + public IsBounded and(IsBounded that) { + if (this == BOUNDED && that == BOUNDED) { + return BOUNDED; + } else { + return UNBOUNDED; + } + } + } + + /** + * Returns the name of this {@link PCollection}. + * + *

    By default, the name of a {@link PCollection} is based on the name of the + * {@link PTransform} that produces it. It can be specified explicitly by + * calling {@link #setName}. + * + * @throws IllegalStateException if the name hasn't been set yet + */ + @Override + public String getName() { + return super.getName(); + } + + /** + * Sets the name of this {@link PCollection}. Returns {@code this}. + * + * @throws IllegalStateException if this {@link PCollection} has already been + * finalized and may no longer be set. + * Once {@link #apply} has been called, this will be the case. + */ + @Override + public PCollection setName(String name) { + super.setName(name); + return this; + } + + /** + * Returns the {@link Coder} used by this {@link PCollection} to encode and decode + * the values stored in it. + * + * @throws IllegalStateException if the {@link Coder} hasn't been set, and + * couldn't be inferred. + */ + @Override + public Coder getCoder() { + return super.getCoder(); + } + + /** + * Sets the {@link Coder} used by this {@link PCollection} to encode and decode the + * values stored in it. Returns {@code this}. + * + * @throws IllegalStateException if this {@link PCollection} has already + * been finalized and may no longer be set. + * Once {@link #apply} has been called, this will be the case. + */ + @Override + public PCollection setCoder(Coder coder) { + super.setCoder(coder); + return this; + } + + /** + * Like {@link IsBounded#apply(String, PTransform)} but defaulting to the name + * of the {@link PTransform}. + * + * @return the output of the applied {@link PTransform} + */ + public OutputT apply(PTransform, OutputT> t) { + return Pipeline.applyTransform(this, t); + } + + /** + * Applies the given {@link PTransform} to this input {@link PCollection}, + * using {@code name} to identify this specific application of the transform. + * This name is used in various places, including the monitoring UI, logging, + * and to stably identify this application node in the job graph. + * + * @return the output of the applied {@link PTransform} + */ + public OutputT apply( + String name, PTransform, OutputT> t) { + return Pipeline.applyTransform(name, this, t); + } + + /** + * Returns the {@link WindowingStrategy} of this {@link PCollection}. + */ + public WindowingStrategy getWindowingStrategy() { + return windowingStrategy; + } + + public IsBounded isBounded() { + return isBounded; + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + /** + * {@link WindowingStrategy} that will be used for merging windows and triggering output in this + * {@link PCollection} and subsequence {@link PCollection PCollections} produced from this one. + * + *

    By default, no merging is performed. + */ + private WindowingStrategy windowingStrategy; + + private IsBounded isBounded; + + private PCollection(Pipeline p) { + super(p); + } + + /** + * Sets the {@link TypeDescriptor TypeDescriptor<T>} for this + * {@link PCollection PCollection<T>}. This may allow the enclosing + * {@link PCollectionTuple}, {@link PCollectionList}, or {@code PTransform>}, + * etc., to provide more detailed reflective information. + */ + @Override + public PCollection setTypeDescriptorInternal(TypeDescriptor typeDescriptor) { + super.setTypeDescriptorInternal(typeDescriptor); + return this; + } + + /** + * Sets the {@link WindowingStrategy} of this {@link PCollection}. + * + *

    For use by primitive transformations only. + */ + public PCollection setWindowingStrategyInternal(WindowingStrategy windowingStrategy) { + this.windowingStrategy = windowingStrategy; + return this; + } + + /** + * Sets the {@link PCollection.IsBounded} of this {@link PCollection}. + * + *

    For use by internal transformations only. + */ + public PCollection setIsBoundedInternal(IsBounded isBounded) { + this.isBounded = isBounded; + return this; + } + + /** + * Creates and returns a new {@link PCollection} for a primitive output. + * + *

    For use by primitive transformations only. + */ + public static PCollection createPrimitiveOutputInternal( + Pipeline pipeline, + WindowingStrategy windowingStrategy, + IsBounded isBounded) { + return new PCollection(pipeline) + .setWindowingStrategyInternal(windowingStrategy) + .setIsBoundedInternal(isBounded); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionList.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionList.java new file mode 100644 index 000000000000..b99af020bfc8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionList.java @@ -0,0 +1,238 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.Partition; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + * A {@link PCollectionList PCollectionList<T>} is an immutable list of homogeneously + * typed {@link PCollection PCollection<T>s}. A {@link PCollectionList} is used, for + * instance, as the input to + * {@link Flatten} or the output of {@link Partition}. + * + *

    PCollectionLists can be created and accessed like follows: + *

     {@code
    + * PCollection pc1 = ...;
    + * PCollection pc2 = ...;
    + * PCollection pc3 = ...;
    + *
    + * // Create a PCollectionList with three PCollections:
    + * PCollectionList pcs = PCollectionList.of(pc1).and(pc2).and(pc3);
    + *
    + * // Create an empty PCollectionList:
    + * Pipeline p = ...;
    + * PCollectionList pcs2 = PCollectionList.empty(p);
    + *
    + * // Get PCollections out of a PCollectionList, by index (origin 0):
    + * PCollection pcX = pcs.get(1);
    + * PCollection pcY = pcs.get(0);
    + * PCollection pcZ = pcs.get(2);
    + *
    + * // Get a list of all PCollections in a PCollectionList:
    + * List> allPcs = pcs.getAll();
    + * } 
    + * + * @param the type of the elements of all the {@link PCollection PCollections} in this list + */ +public class PCollectionList implements PInput, POutput { + /** + * Returns an empty {@link PCollectionList} that is part of the given {@link Pipeline}. + * + *

    Longer {@link PCollectionList PCollectionLists} can be created by calling + * {@link #and} on the result. + */ + public static PCollectionList empty(Pipeline pipeline) { + return new PCollectionList<>(pipeline); + } + + /** + * Returns a singleton {@link PCollectionList} containing the given {@link PCollection}. + * + *

    Longer {@link PCollectionList PCollectionLists} can be created by calling + * {@link #and} on the result. + */ + public static PCollectionList of(PCollection pc) { + return new PCollectionList(pc.getPipeline()).and(pc); + } + + /** + * Returns a {@link PCollectionList} containing the given {@link PCollection PCollections}, + * in order. + * + *

    The argument list cannot be empty. + * + *

    All the {@link PCollection PCollections} in the resulting {@link PCollectionList} must be + * part of the same {@link Pipeline}. + * + *

    Longer PCollectionLists can be created by calling + * {@link #and} on the result. + */ + public static PCollectionList of(Iterable> pcs) { + Iterator> pcsIter = pcs.iterator(); + if (!pcsIter.hasNext()) { + throw new IllegalArgumentException( + "must either have a non-empty list of PCollections, " + + "or must first call empty(Pipeline)"); + } + return new PCollectionList(pcsIter.next().getPipeline()).and(pcs); + } + + /** + * Returns a new {@link PCollectionList} that has all the {@link PCollection PCollections} of + * this {@link PCollectionList} plus the given {@link PCollection} appended to the end. + * + *

    All the {@link PCollection PCollections} in the resulting {@link PCollectionList} must be + * part of the same {@link Pipeline}. + */ + public PCollectionList and(PCollection pc) { + if (pc.getPipeline() != pipeline) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + return new PCollectionList<>(pipeline, + new ImmutableList.Builder>() + .addAll(pcollections) + .add(pc) + .build()); + } + + /** + * Returns a new {@link PCollectionList} that has all the {@link PCollection PCollections} of + * this {@link PCollectionList} plus the given {@link PCollection PCollections} appended to the + * end, in order. + * + *

    All the {@link PCollections} in the resulting {@link PCollectionList} must be + * part of the same {@link Pipeline}. + */ + public PCollectionList and(Iterable> pcs) { + List> copy = new ArrayList<>(pcollections); + for (PCollection pc : pcs) { + if (pc.getPipeline() != pipeline) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + copy.add(pc); + } + return new PCollectionList<>(pipeline, copy); + } + + /** + * Returns the number of {@link PCollection PCollections} in this {@link PCollectionList}. + */ + public int size() { + return pcollections.size(); + } + + /** + * Returns the {@link PCollection} at the given index (origin zero). + * + * @throws IndexOutOfBoundsException if the index is out of the range + * {@code [0..size()-1]}. + */ + public PCollection get(int index) { + return pcollections.get(index); + } + + /** + * Returns an immutable List of all the {@link PCollection PCollections} in this + * {@link PCollectionList}. + */ + public List> getAll() { + return pcollections; + } + + /** + * Like {@link #apply(String, PTransform)} but defaulting to the name + * of the {@code PTransform}. + */ + public OutputT apply( + PTransform, OutputT> t) { + return Pipeline.applyTransform(this, t); + } + + /** + * Applies the given {@link PTransform} to this input {@link PCollectionList}, + * using {@code name} to identify this specific application of the transform. + * This name is used in various places, including the monitoring UI, logging, + * and to stably identify this application node in the job graph. + * + * @return the output of the applied {@link PTransform} + */ + public OutputT apply( + String name, PTransform, OutputT> t) { + return Pipeline.applyTransform(name, this, t); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + final Pipeline pipeline; + final List> pcollections; + + PCollectionList(Pipeline pipeline) { + this(pipeline, new ArrayList>()); + } + + PCollectionList(Pipeline pipeline, List> pcollections) { + this.pipeline = pipeline; + this.pcollections = Collections.unmodifiableList(pcollections); + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Collection expand() { + return pcollections; + } + + @Override + public void recordAsOutput(AppliedPTransform transform) { + int i = 0; + for (PCollection pc : pcollections) { + pc.recordAsOutput(transform, "out" + i); + i++; + } + } + + @Override + public void finishSpecifying() { + for (PCollection pc : pcollections) { + pc.finishSpecifying(); + } + } + + @Override + public void finishSpecifyingOutput() { + for (PCollection pc : pcollections) { + pc.finishSpecifyingOutput(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionTuple.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionTuple.java new file mode 100644 index 000000000000..58550e4182c7 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionTuple.java @@ -0,0 +1,264 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.common.collect.ImmutableMap; + +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A {@link PCollectionTuple} is an immutable tuple of + * heterogeneously-typed {@link PCollection PCollections}, "keyed" by + * {@link TupleTag TupleTags}. A {@link PCollectionTuple} can be used as the input or + * output of a + * {@link PTransform} taking + * or producing multiple PCollection inputs or outputs that can be of + * different types, for instance a + * {@link ParDo} with side + * outputs. + * + *

    A {@link PCollectionTuple} can be created and accessed like follows: + *

     {@code
    + * PCollection pc1 = ...;
    + * PCollection pc2 = ...;
    + * PCollection> pc3 = ...;
    + *
    + * // Create TupleTags for each of the PCollections to put in the
    + * // PCollectionTuple (the type of the TupleTag enables tracking the
    + * // static type of each of the PCollections in the PCollectionTuple):
    + * TupleTag tag1 = new TupleTag<>();
    + * TupleTag tag2 = new TupleTag<>();
    + * TupleTag> tag3 = new TupleTag<>();
    + *
    + * // Create a PCollectionTuple with three PCollections:
    + * PCollectionTuple pcs =
    + *     PCollectionTuple.of(tag1, pc1)
    + *                     .and(tag2, pc2)
    + *                     .and(tag3, pc3);
    + *
    + * // Create an empty PCollectionTuple:
    + * Pipeline p = ...;
    + * PCollectionTuple pcs2 = PCollectionTuple.empty(p);
    + *
    + * // Get PCollections out of a PCollectionTuple, using the same tags
    + * // that were used to put them in:
    + * PCollection pcX = pcs.get(tag2);
    + * PCollection pcY = pcs.get(tag1);
    + * PCollection> pcZ = pcs.get(tag3);
    + *
    + * // Get a map of all PCollections in a PCollectionTuple:
    + * Map, PCollection> allPcs = pcs.getAll();
    + * } 
    + */ +public class PCollectionTuple implements PInput, POutput { + /** + * Returns an empty {@link PCollectionTuple} that is part of the given {@link Pipeline}. + * + *

    A {@link PCollectionTuple} containing additional elements can be created by calling + * {@link #and} on the result. + */ + public static PCollectionTuple empty(Pipeline pipeline) { + return new PCollectionTuple(pipeline); + } + + /** + * Returns a singleton {@link PCollectionTuple} containing the given + * {@link PCollection} keyed by the given {@link TupleTag}. + * + *

    A {@link PCollectionTuple} containing additional elements can be created by calling + * {@link #and} on the result. + */ + public static PCollectionTuple of(TupleTag tag, PCollection pc) { + return empty(pc.getPipeline()).and(tag, pc); + } + + /** + * Returns a new {@link PCollectionTuple} that has each {@link PCollection} and + * {@link TupleTag} of this {@link PCollectionTuple} plus the given {@link PCollection} + * associated with the given {@link TupleTag}. + * + *

    The given {@link TupleTag} should not already be mapped to a + * {@link PCollection} in this {@link PCollectionTuple}. + * + *

    Each {@link PCollection} in the resulting {@link PCollectionTuple} must be + * part of the same {@link Pipeline}. + */ + public PCollectionTuple and(TupleTag tag, PCollection pc) { + if (pc.getPipeline() != pipeline) { + throw new IllegalArgumentException( + "PCollections come from different Pipelines"); + } + + return new PCollectionTuple(pipeline, + new ImmutableMap.Builder, PCollection>() + .putAll(pcollectionMap) + .put(tag, pc) + .build()); + } + + /** + * Returns whether this {@link PCollectionTuple} contains a {@link PCollection} with + * the given tag. + */ + public boolean has(TupleTag tag) { + return pcollectionMap.containsKey(tag); + } + + /** + * Returns the {@link PCollection} associated with the given {@link TupleTag} + * in this {@link PCollectionTuple}. Throws {@link IllegalArgumentException} if there is no + * such {@link PCollection}, i.e., {@code !has(tag)}. + */ + public PCollection get(TupleTag tag) { + @SuppressWarnings("unchecked") + PCollection pcollection = (PCollection) pcollectionMap.get(tag); + if (pcollection == null) { + throw new IllegalArgumentException( + "TupleTag not found in this PCollectionTuple tuple"); + } + return pcollection; + } + + /** + * Returns an immutable Map from {@link TupleTag} to corresponding + * {@link PCollection}, for all the members of this {@link PCollectionTuple}. + */ + public Map, PCollection> getAll() { + return pcollectionMap; + } + + /** + * Like {@link #apply(String, PTransform)} but defaulting to the name + * of the {@link PTransform}. + * + * @return the output of the applied {@link PTransform} + */ + public OutputT apply( + PTransform t) { + return Pipeline.applyTransform(this, t); + } + + /** + * Applies the given {@link PTransform} to this input {@link PCollectionTuple}, + * using {@code name} to identify this specific application of the transform. + * This name is used in various places, including the monitoring UI, logging, + * and to stably identify this application node in the job graph. + * + * @return the output of the applied {@link PTransform} + */ + public OutputT apply( + String name, PTransform t) { + return Pipeline.applyTransform(name, this, t); + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + Pipeline pipeline; + final Map, PCollection> pcollectionMap; + + PCollectionTuple(Pipeline pipeline) { + this(pipeline, new LinkedHashMap, PCollection>()); + } + + PCollectionTuple(Pipeline pipeline, + Map, PCollection> pcollectionMap) { + this.pipeline = pipeline; + this.pcollectionMap = Collections.unmodifiableMap(pcollectionMap); + } + + /** + * Returns a {@link PCollectionTuple} with each of the given tags mapping to a new + * output {@link PCollection}. + * + *

    For use by primitive transformations only. + */ + public static PCollectionTuple ofPrimitiveOutputsInternal( + Pipeline pipeline, + TupleTagList outputTags, + WindowingStrategy windowingStrategy, + IsBounded isBounded) { + Map, PCollection> pcollectionMap = new LinkedHashMap<>(); + for (TupleTag outputTag : outputTags.tupleTags) { + if (pcollectionMap.containsKey(outputTag)) { + throw new IllegalArgumentException( + "TupleTag already present in this tuple"); + } + + // In fact, `token` and `outputCollection` should have + // types TypeDescriptor and PCollection for some + // unknown T. It is safe to create `outputCollection` + // with type PCollection because it has the same + // erasure as the correct type. When a transform adds + // elements to `outputCollection` they will be of type T. + @SuppressWarnings("unchecked") + TypeDescriptor token = (TypeDescriptor) outputTag.getTypeDescriptor(); + PCollection outputCollection = PCollection + .createPrimitiveOutputInternal(pipeline, windowingStrategy, isBounded) + .setTypeDescriptorInternal(token); + + pcollectionMap.put(outputTag, outputCollection); + } + return new PCollectionTuple(pipeline, pcollectionMap); + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Collection expand() { + return pcollectionMap.values(); + } + + @Override + public void recordAsOutput(AppliedPTransform transform) { + int i = 0; + for (Map.Entry, PCollection> entry + : pcollectionMap.entrySet()) { + TupleTag tag = entry.getKey(); + PCollection pc = entry.getValue(); + pc.recordAsOutput(transform, tag.getOutName(i)); + i++; + } + } + + @Override + public void finishSpecifying() { + for (PCollection pc : pcollectionMap.values()) { + pc.finishSpecifying(); + } + } + + @Override + public void finishSpecifyingOutput() { + for (PCollection pc : pcollectionMap.values()) { + pc.finishSpecifyingOutput(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionView.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionView.java new file mode 100644 index 000000000000..515e21ba6df9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PCollectionView.java @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; + +import java.io.Serializable; + +/** + * A {@link PCollectionView PCollectionView<T>} is an immutable view of a {@link PCollection} + * as a value of type {@code T} that can be accessed + * as a side input to a {@link ParDo} transform. + * + *

    A {@link PCollectionView} should always be the output of a + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}. It is the joint responsibility of + * this transform and each {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner} to implement + * the view in a runner-specific manner. + * + *

    The most common case is using the {@link View} transforms to prepare a {@link PCollection} + * for use as a side input to {@link ParDo}. See {@link View#asSingleton()}, + * {@link View#asIterable()}, and {@link View#asMap()} for more detail on specific views + * available in the SDK. + * + * @param the type of the value(s) accessible via this {@link PCollectionView} + */ +public interface PCollectionView extends PValue, Serializable { + /** + * A unique identifier, for internal use. + */ + public TupleTag>> getTagInternal(); + + /** + * For internal use only. + */ + public T fromIterableInternal(Iterable> contents); + + /** + * For internal use only. + */ + public WindowingStrategy getWindowingStrategyInternal(); + + /** + * For internal use only. + */ + public Coder>> getCoderInternal(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PDone.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PDone.java new file mode 100644 index 000000000000..39a00616bf71 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PDone.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; +import java.util.Collections; + +/** + * {@link PDone} is the output of a {@link PTransform} that has a trivial result, + * such as a {@link Write}. + */ +public class PDone extends POutputValueBase { + + /** + * Creates a {@link PDone} in the given {@link Pipeline}. + */ + public static PDone in(Pipeline pipeline) { + return new PDone(pipeline); + } + + @Override + public Collection expand() { + // A PDone contains no PValues. + return Collections.emptyList(); + } + + private PDone(Pipeline pipeline) { + super(pipeline); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PInput.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PInput.java new file mode 100644 index 000000000000..89b097a65318 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PInput.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; + +import java.util.Collection; + +/** + * The interface for things that might be input to a + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}. + */ +public interface PInput { + /** + * Returns the owning {@link Pipeline} of this {@link PInput}. + */ + public Pipeline getPipeline(); + + /** + * Expands this {@link PInput} into a list of its component output + * {@link PValue PValues}. + * + *

      + *
    • A {@link PValue} expands to itself.
    • + *
    • A tuple or list of {@link PValue PValues} (such as + * {@link PCollectionTuple} or {@link PCollectionList}) + * expands to its component {@code PValue PValues}.
    • + *
    + * + *

    Not intended to be invoked directly by user code. + */ + public Collection expand(); + + /** + *

    After building, finalizes this {@code PInput} to make it ready for + * being used as an input to a {@link com.google.cloud.dataflow.sdk.transforms.PTransform}. + * + *

    Automatically invoked whenever {@code apply()} is invoked on + * this {@code PInput}, so users do not normally call this explicitly. + */ + public void finishSpecifying(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutput.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutput.java new file mode 100644 index 000000000000..f99bc0b09dda --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutput.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; + +/** + * The interface for things that might be output from a {@link PTransform}. + */ +public interface POutput { + + /** + * Returns the owning {@link Pipeline} of this {@link POutput}. + */ + public Pipeline getPipeline(); + + /** + * Expands this {@link POutput} into a list of its component output + * {@link PValue PValues}. + * + *

      + *
    • A {@link PValue} expands to itself.
    • + *
    • A tuple or list of {@link PValue PValues} (such as + * {@link PCollectionTuple} or {@link PCollectionList}) + * expands to its component {@code PValue PValues}.
    • + *
    + * + *

    Not intended to be invoked directly by user code. + */ + public Collection expand(); + + /** + * Records that this {@code POutput} is an output of the given + * {@code PTransform}. + * + *

    For a compound {@code POutput}, it is advised to call + * this method on each component {@code POutput}. + * + *

    This is not intended to be invoked by user code, but + * is automatically invoked as part of applying the + * producing {@link PTransform}. + */ + public void recordAsOutput(AppliedPTransform transform); + + /** + * As part of applying the producing {@link PTransform}, finalizes this + * output to make it ready for being used as an input and for running. + * + *

    This includes ensuring that all {@link PCollection PCollections} + * have {@link Coder Coders} specified or defaulted. + * + *

    Automatically invoked whenever this {@link POutput} is used + * as a {@link PInput} to another {@link PTransform}, or if never + * used as a {@link PInput}, when {@link Pipeline#run} + * is called, so users do not normally call this explicitly. + */ + public void finishSpecifyingOutput(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutputValueBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutputValueBase.java new file mode 100644 index 000000000000..69e04c343642 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/POutputValueBase.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * A {@link POutputValueBase} is the abstract base class of + * {@code PTransform} outputs. + * + *

    A {@link PValueBase} that adds tracking of its producing + * {@link AppliedPTransform}. + * + *

    For internal use. + */ +public abstract class POutputValueBase implements POutput { + + private final Pipeline pipeline; + + protected POutputValueBase(Pipeline pipeline) { + this.pipeline = pipeline; + } + + /** + * No-arg constructor for Java serialization only. + * The resulting {@link POutputValueBase} is unlikely to be + * valid. + */ + protected POutputValueBase() { + pipeline = null; + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + /** + * Returns the {@link AppliedPTransform} that this {@link POutputValueBase} + * is an output of. + * + *

    For internal use only. + */ + public AppliedPTransform getProducingTransformInternal() { + return producingTransform; + } + + /** + * Records that this {@link POutputValueBase} is an output with the + * given name of the given {@link AppliedPTransform}. + * + *

    To be invoked only by {@link POutput#recordAsOutput} + * implementations. Not to be invoked directly by user code. + */ + @Override + public void recordAsOutput(AppliedPTransform transform) { + if (producingTransform != null) { + // Already used this POutput as a PTransform output. This can + // happen if the POutput is an output of a transform within a + // composite transform, and is also the result of the composite. + // We want to record the "immediate" atomic transform producing + // this output, and ignore all later composite transforms that + // also produce this output. + // + // Pipeline.applyInternal() uses !hasProducingTransform() to + // avoid calling this operation redundantly, but + // hasProducingTransform() doesn't apply to POutputValueBases + // that aren't PValues or composites of PValues, e.g., PDone. + return; + } + producingTransform = transform; + } + + /** + * Default behavior for {@link #finishSpecifyingOutput()} is + * to do nothing. Override if your {@link PValue} requires + * finalization. + */ + @Override + public void finishSpecifyingOutput() { } + + /** + * The {@link PTransform} that produces this {@link POutputValueBase}. + */ + private AppliedPTransform producingTransform; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValue.java new file mode 100644 index 000000000000..eb95a23f50f4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValue.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * The interface for values that can be input to and output from {@link PTransform PTransforms}. + */ +public interface PValue extends POutput, PInput { + + /** + * Returns the name of this {@link PValue}. + */ + public String getName(); + + /** + * Returns the {@link AppliedPTransform} that this {@link PValue} is an output of. + * + *

    For internal use only. + */ + public AppliedPTransform getProducingTransformInternal(); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValueBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValueBase.java new file mode 100644 index 000000000000..7e57204f3306 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/PValueBase.java @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.util.StringUtils; + +import java.util.Collection; +import java.util.Collections; + +/** + * A {@link PValueBase} is an abstract base class that provides + * sensible default implementations for methods of {@link PValue}. + * In particular, this includes functionality for getting/setting: + * + *

      + *
    • The {@link Pipeline} that the {@link PValue} is part of.
    • + *
    • Whether the {@link PValue} has bee finalized (as an input + * or an output), after which its properties can no longer be changed.
    • + *
    + * + *

    For internal use. + */ +public abstract class PValueBase extends POutputValueBase implements PValue { + /** + * Returns the name of this {@link PValueBase}. + * + *

    By default, the name of a {@link PValueBase} is based on the + * name of the {@link PTransform} that produces it. It can be + * specified explicitly by calling {@link #setName}. + * + * @throws IllegalStateException if the name hasn't been set yet + */ + @Override + public String getName() { + if (name == null) { + throw new IllegalStateException("name not set"); + } + return name; + } + + /** + * Sets the name of this {@link PValueBase}. Returns {@code this}. + * + * @throws IllegalStateException if this {@link PValueBase} has + * already been finalized and may no longer be set. + */ + public PValueBase setName(String name) { + if (finishedSpecifying) { + throw new IllegalStateException( + "cannot change the name of " + this + " once it's been used"); + } + this.name = name; + return this; + } + + ///////////////////////////////////////////////////////////////////////////// + + protected PValueBase(Pipeline pipeline) { + super(pipeline); + } + + /** + * No-arg constructor for Java serialization only. + * The resulting {@link PValueBase} is unlikely to be + * valid. + */ + protected PValueBase() { + super(); + } + + /** + * The name of this {@link PValueBase}, or null if not yet set. + */ + private String name; + + /** + * Whether this {@link PValueBase} has been finalized, and its core + * properties, e.g., name, can no longer be changed. + */ + private boolean finishedSpecifying = false; + + @Override + public void recordAsOutput(AppliedPTransform transform) { + recordAsOutput(transform, "out"); + } + + /** + * Records that this {@link POutputValueBase} is an output with the + * given name of the given {@link AppliedPTransform} in the given + * {@link Pipeline}. + * + *

    To be invoked only by {@link POutput#recordAsOutput} + * implementations. Not to be invoked directly by user code. + */ + protected void recordAsOutput(AppliedPTransform transform, + String outName) { + super.recordAsOutput(transform); + if (name == null) { + name = transform.getFullName() + "." + outName; + } + } + + /** + * Returns whether this {@link PValueBase} has been finalized, and + * its core properties, e.g., name, can no longer be changed. + * + *

    For internal use only. + */ + public boolean isFinishedSpecifyingInternal() { + return finishedSpecifying; + } + + @Override + public Collection expand() { + return Collections.singletonList(this); + } + + @Override + public void finishSpecifying() { + finishSpecifyingOutput(); + finishedSpecifying = true; + } + + @Override + public String toString() { + return (name == null ? "" : getName()) + + " [" + getKindString() + "]"; + } + + /** + * Returns a {@link String} capturing the kind of this + * {@link PValueBase}. + * + *

    By default, uses the base name of the current class as its kind string. + */ + protected String getKindString() { + return StringUtils.approximateSimpleName(getClass()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TimestampedValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TimestampedValue.java new file mode 100644 index 000000000000..1085d44b135c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TimestampedValue.java @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.util.PropertyNames; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * An immutable pair of a value and a timestamp. + * + *

    The timestamp of a value determines many properties, such as its assignment to + * windows and whether the value is late (with respect to the watermark of a {@link PCollection}). + * + * @param the type of the value + */ +public class TimestampedValue { + + /** + * Returns a new {@code TimestampedValue} with the given value and timestamp. + */ + public static TimestampedValue of(V value, Instant timestamp) { + return new TimestampedValue<>(value, timestamp); + } + + public V getValue() { + return value; + } + + public Instant getTimestamp() { + return timestamp; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof TimestampedValue)) { + return false; + } + TimestampedValue that = (TimestampedValue) other; + return Objects.equals(value, that.value) && Objects.equals(timestamp, that.timestamp); + } + + @Override + public int hashCode() { + return Objects.hash(value, timestamp); + } + + @Override + public String toString() { + return "TimestampedValue(" + value + ", " + timestamp + ")"; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@link Coder} for {@link TimestampedValue}. + */ + public static class TimestampedValueCoder + extends StandardCoder> { + + private final Coder valueCoder; + + public static TimestampedValueCoder of(Coder valueCoder) { + return new TimestampedValueCoder<>(valueCoder); + } + + @JsonCreator + public static TimestampedValueCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List components) { + checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + @SuppressWarnings("unchecked") + TimestampedValueCoder(Coder valueCoder) { + this.valueCoder = checkNotNull(valueCoder); + } + + @Override + public void encode(TimestampedValue windowedElem, + OutputStream outStream, + Context context) + throws IOException { + valueCoder.encode(windowedElem.getValue(), outStream, context.nested()); + InstantCoder.of().encode( + windowedElem.getTimestamp(), outStream, context); + } + + @Override + public TimestampedValue decode(InputStream inStream, Context context) + throws IOException { + T value = valueCoder.decode(inStream, context.nested()); + Instant timestamp = InstantCoder.of().decode(inStream, context); + return TimestampedValue.of(value, timestamp); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "TimestampedValueCoder requires a deterministic valueCoder", + valueCoder); + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(valueCoder); + } + + public static List getInstanceComponents(TimestampedValue exampleValue) { + return Arrays.asList(exampleValue.getValue()); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private final V value; + private final Instant timestamp; + + protected TimestampedValue(V value, Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTag.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTag.java new file mode 100644 index 000000000000..74949211325c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTag.java @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.Random; + +/** + * A {@link TupleTag} is a typed tag to use as the key of a + * heterogeneously typed tuple, like {@link PCollectionTuple}. + * Its generic type parameter allows tracking + * the static type of things stored in tuples. + * + *

    To aid in assigning default {@link Coder Coders} for results of + * side outputs of {@link ParDo}, an output + * {@link TupleTag} should be instantiated with an extra {@code {}} so + * it is an instance of an anonymous subclass without generic type + * parameters. Input {@link TupleTag TupleTags} require no such extra + * instantiation (although it doesn't hurt). For example: + * + *

     {@code
    + * TupleTag inputTag = new TupleTag<>();
    + * TupleTag outputTag = new TupleTag(){};
    + * } 
    + * + * @param the type of the elements or values of the tagged thing, + * e.g., a {@code PCollection}. + */ +public class TupleTag implements Serializable { + /** + * Constructs a new {@code TupleTag}, with a fresh unique id. + * + *

    This is the normal way {@code TupleTag}s are constructed. + */ + public TupleTag() { + this(genId(), true); + } + + /** + * Constructs a new {@code TupleTag} with the given id. + * + *

    It is up to the user to ensure that two {@code TupleTag}s + * with the same id actually mean the same tag and carry the same + * generic type parameter. Violating this invariant can lead to + * hard-to-diagnose runtime type errors. Consequently, this + * operation should be used very sparingly, such as when the + * producer and consumer of {@code TupleTag}s are written in + * separate modules and can only coordinate via ids rather than + * shared {@code TupleTag} instances. Most of the time, + * {@link #TupleTag()} should be preferred. + */ + public TupleTag(String id) { + this(id, false); + } + + /** + * Returns the id of this {@code TupleTag}. + * + *

    Two {@code TupleTag}s with the same id are considered equal. + * + *

    {@code TupleTag}s are not ordered, i.e., the class does not implement + * Comparable interface. TupleTags implement equals and hashCode, making them + * suitable for use as keys in HashMap and HashSet. + */ + public String getId() { + return id; + } + + /** + * If this {@code TupleTag} is tagging output {@code outputIndex} of + * a {@code PTransform}, returns the name that should be used by + * default for the output. + */ + public String getOutName(int outIndex) { + if (generated) { + return "out" + outIndex; + } else { + return id; + } + } + + /** + * Returns a {@code TypeDescriptor} capturing what is known statically + * about the type of this {@code TupleTag} instance's most-derived + * class. + * + *

    This is useful for a {@code TupleTag} constructed as an + * instance of an anonymous subclass with a trailing {@code {}}, + * e.g., {@code new TupleTag(){}}. + */ + public TypeDescriptor getTypeDescriptor() { + return new TypeDescriptor(getClass()) {}; + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + static final Random RANDOM = new Random(0); + private static final Multiset staticInits = HashMultiset.create(); + + final String id; + final boolean generated; + + /** Generates and returns a fresh unique id for a TupleTag's id. */ + static synchronized String genId() { + // It is a common pattern to store tags that are shared between the main + // program and workers in static variables, but such references are not + // serialized as part of the *Fns state. Fortunately, most such tags + // are constructed in static class initializers, e.g. + // + // static final TupleTag MY_TAG = new TupleTag<>(); + // + // and class initialization order is well defined by the JVM spec, so in + // this case we can assign deterministic ids. + StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); + for (StackTraceElement frame : stackTrace) { + if (frame.getMethodName().equals("")) { + int counter = staticInits.add(frame.getClassName(), 1); + return frame.getClassName() + "#" + counter; + } + } + // Otherwise, assume it'll be serialized and choose a random value to reduce + // the chance of collision. + String nonce = Long.toHexString(RANDOM.nextLong()); + // [Thread.getStackTrace, TupleTag.getId, TupleTag., caller, ...] + String caller = stackTrace.length >= 4 + ? stackTrace[3].getClassName() + "." + stackTrace[3].getMethodName() + + ":" + stackTrace[3].getLineNumber() + : "unknown"; + return caller + "#" + nonce; + } + + @JsonCreator + @SuppressWarnings("unused") + private static TupleTag fromJson( + @JsonProperty(PropertyNames.VALUE) String id, + @JsonProperty(PropertyNames.IS_GENERATED) boolean generated) { + return new TupleTag<>(id, generated); + } + + private TupleTag(String id, boolean generated) { + this.id = id; + this.generated = generated; + } + + public CloudObject asCloudObject() { + CloudObject result = CloudObject.forClass(getClass()); + addString(result, PropertyNames.VALUE, id); + addBoolean(result, PropertyNames.IS_GENERATED, generated); + return result; + } + + @Override + public boolean equals(Object that) { + if (that instanceof TupleTag) { + return this.id.equals(((TupleTag) that).id); + } else { + return false; + } + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public String toString() { + return "Tag<" + id + ">"; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTagList.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTagList.java new file mode 100644 index 000000000000..f019fc26e4cb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TupleTagList.java @@ -0,0 +1,148 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.collect.ImmutableList; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A {@link TupleTagList} is an immutable list of heterogeneously + * typed {@link TupleTag TupleTags}. A {@link TupleTagList} is used, for instance, to + * specify the tags of the side outputs of a + * {@link ParDo}. + * + *

    A {@link TupleTagList} can be created and accessed like follows: + *

     {@code
    + * TupleTag tag1 = ...;
    + * TupleTag tag2 = ...;
    + * TupleTag> tag3 = ...;
    + *
    + * // Create a TupleTagList with three TupleTags:
    + * TupleTagList tags = TupleTagList.of(tag1).and(tag2).and(tag3);
    + *
    + * // Create an empty TupleTagList:
    + * Pipeline p = ...;
    + * TupleTagList tags2 = TupleTagList.empty(p);
    + *
    + * // Get TupleTags out of a TupleTagList, by index (origin 0):
    + * TupleTag tagX = tags.get(1);
    + * TupleTag tagY = tags.get(0);
    + * TupleTag tagZ = tags.get(2);
    + *
    + * // Get a list of all TupleTags in a TupleTagList:
    + * List> allTags = tags.getAll();
    + * } 
    + */ +public class TupleTagList implements Serializable { + /** + * Returns an empty {@link TupleTagList}. + * + *

    Longer {@link TupleTagList TupleTagLists} can be created by calling + * {@link #and} on the result. + */ + public static TupleTagList empty() { + return new TupleTagList(); + } + + /** + * Returns a singleton {@link TupleTagList} containing the given {@link TupleTag}. + * + *

    Longer {@link TupleTagList TupleTagLists} can be created by calling + * {@link #and} on the result. + */ + public static TupleTagList of(TupleTag tag) { + return empty().and(tag); + } + + /** + * Returns a {@link TupleTagList} containing the given {@link TupleTag TupleTags}, in order. + * + *

    Longer {@link TupleTagList TupleTagLists} can be created by calling + * {@link #and} on the result. + */ + public static TupleTagList of(List> tags) { + return empty().and(tags); + } + + /** + * Returns a new {@link TupleTagList} that has all the {@link TupleTag TupleTags} of + * this {@link TupleTagList} plus the given {@link TupleTag} appended to the end. + */ + public TupleTagList and(TupleTag tag) { + return new TupleTagList( + new ImmutableList.Builder>() + .addAll(tupleTags) + .add(tag) + .build()); + } + + /** + * Returns a new {@link TupleTagList} that has all the {@link TupleTag TupleTags} of + * this {@link TupleTagList} plus the given {@link TupleTag TupleTags} appended to the end, + * in order. + */ + public TupleTagList and(List> tags) { + return new TupleTagList( + new ImmutableList.Builder>() + .addAll(tupleTags) + .addAll(tags) + .build()); + } + + /** + * Returns the number of TupleTags in this TupleTagList. + */ + public int size() { + return tupleTags.size(); + } + + /** + * Returns the {@link TupleTag} at the given index (origin zero). + * + * @throws IndexOutOfBoundsException if the index is out of the range + * {@code [0..size()-1]}. + */ + public TupleTag get(int index) { + return tupleTags.get(index); + } + + /** + * Returns an immutable List of all the {@link TupleTag TupleTags} in this {@link TupleTagList}. + */ + public List> getAll() { + return tupleTags; + } + + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + final List> tupleTags; + + TupleTagList() { + this(new ArrayList>()); + } + + TupleTagList(List> tupleTags) { + this.tupleTags = Collections.unmodifiableList(tupleTags); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypeDescriptor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypeDescriptor.java new file mode 100644 index 000000000000..559d67ce05c1 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypeDescriptor.java @@ -0,0 +1,351 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.common.collect.Lists; +import com.google.common.reflect.Invokable; +import com.google.common.reflect.Parameter; +import com.google.common.reflect.TypeToken; + +import java.io.Serializable; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * A description of a Java type, including actual generic parameters where possible. + * + *

    To prevent losing actual type arguments due to erasure, create an anonymous subclass + * with concrete types: + *

    + * {@code
    + * TypeDecriptor> = new TypeDescriptor>() {};
    + * }
    + * 
    + * + *

    If the above were not an anonymous subclass, the type {@code List} + * would be erased and unavailable at run time. + * + * @param the type represented by this {@link TypeDescriptor} + */ +public abstract class TypeDescriptor implements Serializable { + + // This class is just a wrapper for TypeToken + private final TypeToken token; + + /** + * Creates a {@link TypeDescriptor} wrapping the provided token. + * This constructor is private so Guava types do not leak. + */ + private TypeDescriptor(TypeToken token) { + this.token = token; + } + + /** + * Creates a {@link TypeDescriptor} representing + * the type parameter {@code T}. To use this constructor + * properly, the type parameter must be a concrete type, for example + * {@code new TypeDescriptor>(){}}. + */ + protected TypeDescriptor() { + token = new TypeToken(getClass()) {}; + } + + /** + * Creates a {@link TypeDescriptor} representing the type parameter {@code T}, which should + * resolve to a concrete type in the context of the class {@code clazz}. + * + *

    Unlike {@link TypeDescriptor#TypeDescriptor(Class)} this will also use context's of the + * enclosing instances while attempting to resolve the type. This means that the types of any + * classes instantiated in the concrete instance should be resolvable. + */ + protected TypeDescriptor(Object instance) { + TypeToken unresolvedToken = new TypeToken(getClass()) {}; + + // While we haven't fully resolved the parameters, refine it using the captured + // enclosing instance of the object. + unresolvedToken = TypeToken.of(instance.getClass()).resolveType(unresolvedToken.getType()); + + if (hasUnresolvedParameters(unresolvedToken.getType())) { + for (Field field : instance.getClass().getDeclaredFields()) { + Object fieldInstance = getEnclosingInstance(field, instance); + if (fieldInstance != null) { + unresolvedToken = + TypeToken.of(fieldInstance.getClass()).resolveType(unresolvedToken.getType()); + if (!hasUnresolvedParameters(unresolvedToken.getType())) { + break; + } + } + } + } + + // Once we've either fully resolved the parameters or exhausted enclosing instances, we have + // the best approximation to the token we can get. + @SuppressWarnings("unchecked") + TypeToken typedToken = (TypeToken) unresolvedToken; + token = typedToken; + } + + private boolean hasUnresolvedParameters(Type type) { + if (type instanceof TypeVariable) { + return true; + } else if (type instanceof ParameterizedType) { + ParameterizedType param = (ParameterizedType) type; + for (Type arg : param.getActualTypeArguments()) { + if (hasUnresolvedParameters(arg)) { + return true; + } + } + } + return false; + } + + /** + * Returns the enclosing instance if the field is synthetic and it is able to access it, or + * {@literal null} if not. + */ + @Nullable + private Object getEnclosingInstance(Field field, Object instance) { + if (!field.isSynthetic()) { + return null; + } + + boolean accessible = field.isAccessible(); + try { + field.setAccessible(true); + return field.get(instance); + } catch (IllegalArgumentException | IllegalAccessException e) { + // If we fail to get the enclosing instance field, do nothing. In the worst case, we won't + // refine the type based on information in this enclosing class -- that is consistent with + // previous behavior and is still a correct answer that can be fixed by returning the correct + // type descriptor. + return null; + } finally { + field.setAccessible(accessible); + } + } + + /** + * Creates a {@link TypeDescriptor} representing the type parameter + * {@code T}, which should resolve to a concrete type in the context + * of the class {@code clazz}. + */ + @SuppressWarnings("unchecked") + protected TypeDescriptor(Class clazz) { + TypeToken unresolvedToken = new TypeToken(getClass()) {}; + token = (TypeToken) TypeToken.of(clazz).resolveType(unresolvedToken.getType()); + } + + /** + * Returns a {@link TypeDescriptor} representing the given type. + */ + public static TypeDescriptor of(Class type) { + return new SimpleTypeDescriptor<>(TypeToken.of(type)); + } + + /** + * Returns a {@link TypeDescriptor} representing the given type. + */ + @SuppressWarnings("unchecked") + public static TypeDescriptor of(Type type) { + return new SimpleTypeDescriptor<>((TypeToken) TypeToken.of(type)); + } + + /** + * Returns the {@link Type} represented by this {@link TypeDescriptor}. + */ + public Type getType() { + return token.getType(); + } + + /** + * Returns the {@link Class} underlying the {@link Type} represented by + * this {@link TypeDescriptor}. + */ + public Class getRawType() { + return token.getRawType(); + } + + /** + * Returns the component type if this type is an array type, + * otherwise returns {@code null}. + */ + public TypeDescriptor getComponentType() { + return new SimpleTypeDescriptor<>(token.getComponentType()); + } + + /** + * Returns the generic form of a supertype. + */ + public final TypeDescriptor getSupertype(Class superclass) { + return new SimpleTypeDescriptor<>(token.getSupertype(superclass)); + } + + /** + * Returns true if this type is known to be an array type. + */ + public final boolean isArray() { + return token.isArray(); + } + + /** + * Returns a {@link TypeVariable} for the named type parameter. Throws + * {@link IllegalArgumentException} if a type variable by the requested type parameter is not + * found. + * + *

    For example, {@code new TypeDescriptor(){}.getTypeParameter("T")} returns a + * {@code TypeVariable} representing the formal type parameter {@code T}. + * + *

    Do not mistake the type parameters (formal type argument list) with the actual + * type arguments. For example, if a class {@code Foo} extends {@code List}, it + * does not make sense to ask for a type parameter, because {@code Foo} does not have any. + */ + public final TypeVariable> getTypeParameter(String paramName) { + // Cannot convert TypeVariable>[] to TypeVariable>[] + // due to how they are used here, so the result of getTypeParameters() cannot be used + // without upcast. + Class rawType = getRawType(); + for (TypeVariable param : rawType.getTypeParameters()) { + if (param.getName().equals(paramName)) { + @SuppressWarnings("unchecked") + TypeVariable> typedParam = (TypeVariable>) param; + return typedParam; + } + } + throw new IllegalArgumentException( + "No type parameter named " + paramName + " found on " + getRawType()); + } + + /** + * Returns true if this type is assignable from the given type. + */ + public final boolean isSupertypeOf(TypeDescriptor source) { + return token.isSupertypeOf(source.token); + } + + /** + * Return true if this type is a subtype of the given type. + */ + public final boolean isSubtypeOf(TypeDescriptor parent) { + return token.isSubtypeOf(parent.token); + } + + /** + * Returns a list of argument types for the given method, which must + * be a part of the class. + */ + public List> getArgumentTypes(Method method) { + Invokable typedMethod = token.method(method); + + List> argTypes = Lists.newArrayList(); + for (Parameter parameter : typedMethod.getParameters()) { + argTypes.add(new SimpleTypeDescriptor<>(parameter.getType())); + } + return argTypes; + } + + /** + * Returns a {@link TypeDescriptor} representing the given + * type, with type variables resolved according to the specialization + * in this type. + * + *

    For example, consider the following class: + *

    +   * {@code
    +   * class MyList implements List { ... }
    +   * }
    +   * 
    + * + *

    The {@link TypeDescriptor} returned by + *

    +   * {@code
    +   * TypeDescriptor.of(MyList.class)
    +   *     .resolveType(Mylist.class.getMethod("get", int.class).getGenericReturnType)
    +   * }
    +   * 
    + * will represent the type {@code String}. + */ + public TypeDescriptor resolveType(Type type) { + return new SimpleTypeDescriptor<>(token.resolveType(type)); + } + + /** + * Returns a set of {@link TypeDescriptor}s, one for each + * interface implemented by this class. + */ + @SuppressWarnings("rawtypes") + public Iterable getInterfaces() { + List interfaces = Lists.newArrayList(); + for (TypeToken interfaceToken : token.getTypes().interfaces()) { + interfaces.add(new SimpleTypeDescriptor<>(interfaceToken)); + } + return interfaces; + } + + /** + * Returns a set of {@link TypeDescriptor}s, one for each + * superclass (including this class). + */ + @SuppressWarnings("rawtypes") + public Iterable getClasses() { + List classes = Lists.newArrayList(); + for (TypeToken classToken : token.getTypes().classes()) { + classes.add(new SimpleTypeDescriptor<>(classToken)); + } + return classes; + } + + @Override + public String toString() { + return token.toString(); + } + + /** + * Two type descriptor are equal if and only if they + * represent the same type. + */ + @Override + public boolean equals(Object other) { + if (!(other instanceof TypeDescriptor)) { + return false; + } else { + @SuppressWarnings("unchecked") + TypeDescriptor descriptor = (TypeDescriptor) other; + return token.equals(descriptor.token); + } + } + + @Override + public int hashCode() { + return token.hashCode(); + } + + /** + * A non-abstract {@link TypeDescriptor} for construction directly from an existing + * {@link TypeToken}. + */ + private static final class SimpleTypeDescriptor extends TypeDescriptor { + SimpleTypeDescriptor(TypeToken typeToken) { + super(typeToken); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypedPValue.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypedPValue.java new file mode 100644 index 000000000000..29fd639409ec --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/TypedPValue.java @@ -0,0 +1,197 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException.ReasonCode; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +/** + * A {@link TypedPValue TypedPValue<T>} is the abstract base class of things that + * store some number of values of type {@code T}. + * + *

    Because we know the type {@code T}, this is the layer of the inheritance hierarchy where + * we store a coder for objects of type {@code T}. + * + * @param the type of the values stored in this {@link TypedPValue} + */ +public abstract class TypedPValue extends PValueBase implements PValue { + + /** + * Returns the {@link Coder} used by this {@link TypedPValue} to encode and decode + * the values stored in it. + * + * @throws IllegalStateException if the {@link Coder} hasn't been set, and + * couldn't be inferred. + */ + public Coder getCoder() { + if (coder == null) { + coder = inferCoderOrFail(); + } + return coder; + } + + /** + * Sets the {@link Coder} used by this {@link TypedPValue} to encode and decode the + * values stored in it. Returns {@code this}. + * + * @throws IllegalStateException if this {@link TypedPValue} has already + * been finalized and is no longer settable, e.g., by having + * {@code apply()} called on it + */ + public TypedPValue setCoder(Coder coder) { + if (isFinishedSpecifyingInternal()) { + throw new IllegalStateException( + "cannot change the Coder of " + this + " once it's been used"); + } + if (coder == null) { + throw new IllegalArgumentException( + "Cannot setCoder(null)"); + } + this.coder = coder; + return this; + } + + /** + * After building, finalizes this {@link PValue} to make it ready for + * running. Automatically invoked whenever the {@link PValue} is "used" + * (e.g., when apply() is called on it) and when the Pipeline is + * run (useful if this is a {@link PValue} with no consumers). + */ + @Override + public void finishSpecifying() { + if (isFinishedSpecifyingInternal()) { + return; + } + super.finishSpecifying(); + // Ensure that this TypedPValue has a coder by inferring the coder if none exists; If not, + // this will throw an exception. + getCoder(); + } + + ///////////////////////////////////////////////////////////////////////////// + // Internal details below here. + + /** + * The {@link Coder} used by this {@link TypedPValue} to encode and decode the + * values stored in it, or null if not specified nor inferred yet. + */ + private Coder coder; + + protected TypedPValue(Pipeline p) { + super(p); + } + + private TypeDescriptor typeDescriptor; + + /** + * Returns a {@link TypeDescriptor TypeDescriptor<T>} with some reflective information + * about {@code T}, if possible. May return {@code null} if no information + * is available. Subclasses may override this to enable better + * {@code Coder} inference. + */ + public TypeDescriptor getTypeDescriptor() { + return typeDescriptor; + } + + /** + * Sets the {@link TypeDescriptor TypeDescriptor<T>} associated with this class. Better + * reflective type information will lead to better {@link Coder} + * inference. + */ + public TypedPValue setTypeDescriptorInternal(TypeDescriptor typeDescriptor) { + this.typeDescriptor = typeDescriptor; + return this; + } + + /** + * If the coder is not explicitly set, this sets the coder for + * this {@link TypedPValue} to the best coder that can be inferred + * based upon the known {@link TypeDescriptor}. By default, this is null, + * but can and should be improved by subclasses. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + private Coder inferCoderOrFail() { + // First option for a coder: use the Coder set on this PValue. + if (coder != null) { + return coder; + } + + AppliedPTransform application = getProducingTransformInternal(); + + // Second option for a coder: Look in the coder registry. + CoderRegistry registry = getPipeline().getCoderRegistry(); + TypeDescriptor token = getTypeDescriptor(); + CannotProvideCoderException inferFromTokenException = null; + if (token != null) { + try { + return registry.getDefaultCoder(token); + } catch (CannotProvideCoderException exc) { + inferFromTokenException = exc; + // Attempt to detect when the token came from a TupleTag used for a ParDo side output, + // and provide a better error message if so. Unfortunately, this information is not + // directly available from the TypeDescriptor, so infer based on the type of the PTransform + // and the error message itself. + if (application.getTransform() instanceof ParDo.BoundMulti + && exc.getReason() == ReasonCode.TYPE_ERASURE) { + inferFromTokenException = new CannotProvideCoderException(exc.getMessage() + + " If this error occurs for a side output of the producing ParDo, verify that the " + + "TupleTag for this output is constructed with proper type information (see " + + "TupleTag Javadoc) or explicitly set the Coder to use if this is not possible."); + } + } + } + + // Third option for a coder: use the default Coder from the producing PTransform. + CannotProvideCoderException inputCoderException; + try { + return ((PTransform) application.getTransform()).getDefaultOutputCoder( + application.getInput(), this); + } catch (CannotProvideCoderException exc) { + inputCoderException = exc; + } + + // Build up the error message and list of causes. + StringBuilder messageBuilder = new StringBuilder() + .append("Unable to return a default Coder for ").append(this) + .append(". Correct one of the following root causes:"); + + // No exception, but give the user a message about .setCoder() has not been called. + messageBuilder.append("\n No Coder has been manually specified; ") + .append(" you may do so using .setCoder()."); + + if (inferFromTokenException != null) { + messageBuilder + .append("\n Inferring a Coder from the CoderRegistry failed: ") + .append(inferFromTokenException.getMessage()); + } + + if (inputCoderException != null) { + messageBuilder + .append("\n Using the default output Coder from the producing PTransform failed: ") + .append(inputCoderException.getMessage()); + } + + // Build and throw the exception. + throw new IllegalStateException(messageBuilder.toString()); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/package-info.java new file mode 100644 index 000000000000..b8ca756f0ab4 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/values/package-info.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines {@link com.google.cloud.dataflow.sdk.values.PCollection} and other classes for + * representing data in a {@link com.google.cloud.dataflow.sdk.Pipeline}. + * + *

    In particular, see these collection abstractions: + * + *

      + *
    • {@link com.google.cloud.dataflow.sdk.values.PCollection} - an immutable collection of + * values of type {@code T} and the main representation for data in Dataflow.
    • + *
    • {@link com.google.cloud.dataflow.sdk.values.PCollectionView} - an immutable view of a + * {@link com.google.cloud.dataflow.sdk.values.PCollection} that can be accessed as a + * side input of a {@link com.google.cloud.dataflow.sdk.transforms.ParDo} + * {@link com.google.cloud.dataflow.sdk.transforms.PTransform}.
    • + *
    • {@link com.google.cloud.dataflow.sdk.values.PCollectionTuple} - a heterogeneous tuple of + * {@link com.google.cloud.dataflow.sdk.values.PCollection PCollections} + * used in cases where a {@link com.google.cloud.dataflow.sdk.transforms.PTransform} takes + * or returns multiple + * {@link com.google.cloud.dataflow.sdk.values.PCollection PCollections}.
    • + *
    • {@link com.google.cloud.dataflow.sdk.values.PCollectionList} - a homogeneous list of + * {@link com.google.cloud.dataflow.sdk.values.PCollection PCollections} used, for example, + * as input to {@link com.google.cloud.dataflow.sdk.transforms.Flatten}.
    • + *
    + * + *

    And these classes for individual values play particular roles in Dataflow: + * + *

      + *
    • {@link com.google.cloud.dataflow.sdk.values.KV} - a key/value pair that is used by + * keyed transforms, most notably {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey}. + *
    • + *
    • {@link com.google.cloud.dataflow.sdk.values.TimestampedValue} - a timestamp/value pair + * that is used for windowing and handling out-of-order data in streaming execution.
    • + *
    + * + *

    For further details, see the documentation for each class in this package. + */ +package com.google.cloud.dataflow.sdk.values; diff --git a/sdk/src/main/proto/README.md b/sdk/src/main/proto/README.md new file mode 100644 index 000000000000..fa4e925c982c --- /dev/null +++ b/sdk/src/main/proto/README.md @@ -0,0 +1,27 @@ +## Protocol Buffers in Google Cloud Dataflow + +This directory contains the Protocol Buffer messages used in Google Cloud +Dataflow. + +They aren't, however, used during the Maven build process, and are included here +for completeness only. Instead, the following artifact on Maven Central contains +the binary version of the generated code from these Protocol Buffers: + + + com.google.cloud.dataflow + google-cloud-dataflow-java-proto-library-all + LATEST + + +Please follow this process for testing changes: + +* Make changes to the Protocol Buffer messages in this directory. +* Use `protoc` to generate the new code, and compile it into a new Java library. +* Install that Java library into your local Maven repository. +* Update SDK's `pom.xml` to pick up the newly installed library, instead of +downloading it from Maven Central. + +Once the changes are ready for submission, please separate them into two +commits. The first commit should update the Protocol Buffer messages only. After +that, we need to update the generated artifact on Maven Central. Finally, +changes that make use of the Protocol Buffer changes may be committed. diff --git a/sdk/src/main/proto/proto2_coder_test_messages.proto b/sdk/src/main/proto/proto2_coder_test_messages.proto new file mode 100644 index 000000000000..eb3c3dfa9b90 --- /dev/null +++ b/sdk/src/main/proto/proto2_coder_test_messages.proto @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/* + * Protocol Buffer messages used for testing Proto2Coder implementation. + */ + +syntax = "proto2"; + +package proto2_coder_test_messages; + +option java_package = "com.google.cloud.dataflow.sdk.coders"; + +message MessageA { + optional string field1 = 1; + repeated MessageB field2 = 2; +} + +message MessageB { + optional bool field1 = 1; +} + +message MessageC { + extensions 100 to 105; +} + +extend MessageC { + optional MessageA field1 = 101; + optional MessageB field2 = 102; +} + +message MessageWithMap { + map field1 = 1; +} + +message ReferencesMessageWithMap { + repeated MessageWithMap field1 = 1; +} diff --git a/sdk/src/main/proto/windmill.proto b/sdk/src/main/proto/windmill.proto new file mode 100644 index 000000000000..d9d9706cd0f6 --- /dev/null +++ b/sdk/src/main/proto/windmill.proto @@ -0,0 +1,327 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/* + * Protocol Buffers describing the interface between streaming Dataflow workers + * and the Windmill servers. + */ + +syntax = "proto2"; + +package windmill; + +option java_package = "com.google.cloud.dataflow.sdk.runners.worker.windmill"; +option java_outer_classname = "Windmill"; + +//////////////////////////////////////////////////////////////////////////////// +// API Data types + +message Message { + required int64 timestamp = 1 [default=-0x8000000000000000]; + required bytes data = 2; + optional bytes metadata = 3; +} + +message Timer { + required bytes tag = 1; + optional int64 timestamp = 2 [default=-0x8000000000000000]; + enum Type { + WATERMARK = 0; + REALTIME = 1; + DEPENDENT_REALTIME = 2; + } + optional Type type = 3 [default = WATERMARK]; + optional string state_family = 4; +} + +message InputMessageBundle { + required string source_computation_id = 1; + repeated Message messages = 2; +} + +message KeyedMessageBundle { + required bytes key = 1; + repeated Message messages = 2; + repeated bytes messages_ids = 3; +} + +message OutputMessageBundle { + optional string destination_computation_id = 1; + optional string destination_stream_id = 3; + repeated KeyedMessageBundle bundles = 2; +} + +message PubSubMessageBundle { + required string topic = 1; + repeated Message messages = 2; + optional string timestamp_label = 3; + optional string id_label = 4; +} + +message TimerBundle { + repeated Timer timers = 1; +} + +message Value { + required int64 timestamp = 1 [default=-0x8000000000000000]; + required bytes data = 2; +} + +message TagValue { + required bytes tag = 1; + optional Value value = 2; + optional string state_family = 3; +} + +message TagList { + required bytes tag = 1; + // In request: All items till this timestamp (inclusive) are deleted before + // adding the new ones listed below. + optional int64 end_timestamp = 2 [default=-0x8000000000000000]; + repeated Value values = 3; + optional string state_family = 4; + + // In request: A previously returned continuation_token from an + // earlier request. Indicates we wish to fetch the next page of + // values. + // In response: Copied from request. + optional bytes request_token = 7; + // In response only: Set when there are values after those returned + // above, but they were suppressed to respect the fetch_max_bytes + // limit. Subsequent requests should copy this to request_token to + // retrieve the next page of values. + optional bytes continuation_token = 5; + // For a TagList fetch request, attempt to limit the size of each fetched tag + // list to this byte limit. + optional int64 fetch_max_bytes = 6 [default = 0x7fffffffffffffff]; +} + +message GlobalDataId { + required string tag = 1; + required bytes version = 2; +} + +message GlobalData { + required GlobalDataId data_id = 1; + optional bool is_ready = 2; + optional bytes data = 3; + optional string state_family = 4; +} + +message SourceState { + optional bytes state = 1; + repeated fixed64 finalize_ids = 2; +} + +message WatermarkHold { + required bytes tag = 1; + repeated int64 timestamps = 2 [packed=true]; + optional bool reset = 3; + optional string state_family = 4; +} + +message WorkItem { + required bytes key = 1; + required fixed64 work_token = 2; + + optional fixed64 cache_token = 7; + + repeated InputMessageBundle message_bundles = 3; + optional TimerBundle timers = 4; + repeated GlobalDataId global_data_id_notifications = 5; + optional SourceState source_state = 6; + optional int64 output_data_watermark = 8 [default=-0x8000000000000000]; +} + +message ComputationWorkItems { + required string computation_id = 1; + repeated WorkItem work = 2; + optional int64 input_data_watermark = 3 [default=-0x8000000000000000]; + optional int64 dependent_realtime_input_watermark = 4 + [default = -0x8000000000000000]; +} + +//////////////////////////////////////////////////////////////////////////////// +// API calls + +// GetWork + +message GetWorkRequest { + required fixed64 client_id = 1; + optional string worker_id = 4; + optional string job_id = 5; + optional int64 max_items = 2 [default = 0xffffffff]; + optional int64 max_bytes = 3 [default = 0x7fffffffffffffff]; + // reserved field number = 6 +} + +message GetWorkResponse { + repeated ComputationWorkItems work = 1; +} + +// GetData + +message KeyedGetDataRequest { + required bytes key = 1; + required fixed64 work_token = 2; + repeated TagValue values_to_fetch = 3; + repeated TagList lists_to_fetch = 4; + repeated WatermarkHold watermark_holds_to_fetch = 5; +} + +message ComputationGetDataRequest { + required string computation_id = 1; + repeated KeyedGetDataRequest requests = 2; +} + +message GetDataRequest { + optional string job_id = 4; + repeated ComputationGetDataRequest requests = 1; + repeated GlobalDataRequest global_data_fetch_requests = 3; + + // DEPRECATED + repeated GlobalDataId global_data_to_fetch = 2; +} + +message KeyedGetDataResponse { + required bytes key = 1; + // The response for this key is not populated due to the fetch failing. + optional bool failed = 2; + repeated TagValue values = 3; + repeated TagList lists = 4; + repeated WatermarkHold watermark_holds = 5; +} + +message ComputationGetDataResponse { + required string computation_id = 1; + repeated KeyedGetDataResponse data = 2; +} + +message GetDataResponse { + repeated ComputationGetDataResponse data = 1; + repeated GlobalData global_data = 2; +} + +// CommitWork + +message Counter { + optional string name = 1; + enum Kind { + SUM = 0; + MAX = 1; + MIN = 2; + MEAN = 3; + }; + optional Kind kind = 2; + + // For SUM, MAX, MIN, AND, OR, MEAN at most one of the following should be + // set. For MEAN it is the sum + optional double double_scalar = 3; + optional int64 int_scalar = 4; + + // Only set for MEAN. Count of elements contributing to the sum. + optional int64 mean_count = 6; + + // True if this metric is reported as the total cumulative aggregate + // value accumulated since the worker started working on this WorkItem. + // By default this is false, indicating that this metric is reported + // as a delta that is not associated with any WorkItem. + optional bool cumulative = 7; +} + +message GlobalDataRequest { + required GlobalDataId data_id = 1; + optional int64 existence_watermark_deadline = 2 [default=0x7FFFFFFFFFFFFFFF]; + optional string state_family = 3; +} + +// next id: 15 +message WorkItemCommitRequest { + required bytes key = 1; + required fixed64 work_token = 2; + repeated OutputMessageBundle output_messages = 3; + repeated PubSubMessageBundle pubsub_messages = 7; + repeated Timer output_timers = 4; + repeated TagValue value_updates = 5; + repeated TagList list_updates = 6; + repeated Counter counter_updates = 8; + repeated GlobalDataRequest global_data_requests = 11; + repeated GlobalData global_data_updates = 10; + optional SourceState source_state_updates = 12; + optional int64 source_watermark = 13 [default=-0x8000000000000000]; + repeated WatermarkHold watermark_holds = 14; + + // DEPRECATED + repeated GlobalDataId global_data_id_requests = 9; +} + +message ComputationCommitWorkRequest { + required string computation_id = 1; + repeated WorkItemCommitRequest requests = 2; +} + +message CommitWorkRequest { + optional string job_id = 2; + repeated ComputationCommitWorkRequest requests = 1; +} + +message CommitWorkResponse {} + +// Configuration + +message GetConfigRequest { + optional string job_id = 2; + repeated string computations = 1; +} + +message GetConfigResponse { + repeated string cloud_works = 1; + + message NameMapEntry { + optional string user_name = 1; + optional string system_name = 2; + } + + // Map of user names to system names + repeated NameMapEntry name_map = 2; + + message SystemNameToComputationIdMapEntry { + optional string system_name = 1; + optional string computation_id = 2; + } + repeated SystemNameToComputationIdMapEntry + system_name_to_computation_id_map = 3; +} + +// Reporting + +message Exception { + repeated string stack_frames = 1; + optional Exception cause = 2; +} + +message ReportStatsRequest { + optional string job_id = 6; + optional string computation_id = 1; + optional bytes key = 2; + optional fixed64 work_token = 3; + repeated Exception exceptions = 4; + repeated Counter counter_updates = 5; +} + +message ReportStatsResponse { + optional bool failed = 1; +} diff --git a/sdk/src/main/proto/windmill_service.proto b/sdk/src/main/proto/windmill_service.proto new file mode 100644 index 000000000000..bd25fe5efc60 --- /dev/null +++ b/sdk/src/main/proto/windmill_service.proto @@ -0,0 +1,27 @@ +syntax = "proto2"; + +import "windmill.proto"; + +package google.dataflow.windmillservice.v1alpha1; + +// The Cloud Windmill Service API used by GCE to acquire and process streaming +// Dataflow work. +service CloudWindmillServiceV1Alpha1 { + // Gets streaming Dataflow work. + rpc GetWork(.windmill.GetWorkRequest) returns(.windmill.GetWorkResponse); + + // Gets data from Windmill. + rpc GetData(.windmill.GetDataRequest) returns(.windmill.GetDataResponse); + + // Commits previously acquired work. + rpc CommitWork(.windmill.CommitWorkRequest) + returns(.windmill.CommitWorkResponse); + + // Gets dependant configuration from windmill. + rpc GetConfig(.windmill.GetConfigRequest) + returns(.windmill.GetConfigResponse); + + // Reports stats to Windmill. + rpc ReportStats(.windmill.ReportStatsRequest) + returns(.windmill.ReportStatsResponse); +} \ No newline at end of file diff --git a/sdk/src/main/resources/com/google/cloud/dataflow/sdk/sdk.properties b/sdk/src/main/resources/com/google/cloud/dataflow/sdk/sdk.properties new file mode 100644 index 000000000000..5b0a720b215d --- /dev/null +++ b/sdk/src/main/resources/com/google/cloud/dataflow/sdk/sdk.properties @@ -0,0 +1,5 @@ +# SDK source version. +version=${pom.version} + +build.date=${timestamp} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/DataflowMatchers.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/DataflowMatchers.java new file mode 100644 index 000000000000..ad21072dc4c9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/DataflowMatchers.java @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import com.google.protobuf.ByteString; + +import org.hamcrest.Description; +import org.hamcrest.TypeSafeMatcher; + +import java.io.Serializable; + +/** + * Matchers that are useful when writing Dataflow tests. + */ +public class DataflowMatchers { + /** + * Matcher for {@link ByteString} that prints the strings in UTF8. + */ + public static class ByteStringMatcher extends TypeSafeMatcher + implements Serializable { + private ByteString expected; + private ByteStringMatcher(ByteString expected) { + this.expected = expected; + } + + public static ByteStringMatcher byteStringEq(ByteString expected) { + return new ByteStringMatcher(expected); + } + + @Override + public void describeTo(Description description) { + description + .appendText("ByteString(") + .appendText(expected.toStringUtf8()) + .appendText(")"); + } + + @Override + public void describeMismatchSafely(ByteString actual, Description description) { + description + .appendText("was ByteString(") + .appendText(actual.toStringUtf8()) + .appendText(")"); + } + + @Override + protected boolean matchesSafely(ByteString actual) { + return actual.equals(expected); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/PipelineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/PipelineTest.java new file mode 100644 index 000000000000..e311252021b6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/PipelineTest.java @@ -0,0 +1,296 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions.CheckEnabled; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for Pipeline. + */ +@RunWith(JUnit4.class) +public class PipelineTest { + + @Rule public ExpectedLogs logged = ExpectedLogs.none(Pipeline.class); + @Rule public ExpectedException thrown = ExpectedException.none(); + + static class PipelineWrapper extends Pipeline { + protected PipelineWrapper(PipelineRunner runner) { + super(runner, PipelineOptionsFactory.create()); + } + } + + // Mock class that throws a user code exception during the call to + // Pipeline.run(). + static class TestPipelineRunnerThrowingUserException + extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + Throwable t = new IllegalStateException("user code exception"); + throw UserCodeException.wrap(t); + } + } + + // Mock class that throws an SDK or API client code exception during + // the call to Pipeline.run(). + static class TestPipelineRunnerThrowingSDKException + extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + throw new IllegalStateException("SDK exception"); + } + } + + @Test + public void testPipelineUserExceptionHandling() { + Pipeline p = new PipelineWrapper( + new TestPipelineRunnerThrowingUserException()); + + // Check pipeline runner correctly catches user errors. + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + thrown.expectMessage("user code exception"); + p.run(); + } + + @Test + public void testPipelineSDKExceptionHandling() { + Pipeline p = new PipelineWrapper(new TestPipelineRunnerThrowingSDKException()); + + // Check pipeline runner correctly catches SDK errors. + try { + p.run(); + fail("Should have thrown an exception."); + } catch (RuntimeException exn) { + // Make sure the exception isn't a UserCodeException. + Assert.assertThat(exn, not(instanceOf(UserCodeException.class))); + // Assert that the message is correct. + Assert.assertThat(exn.getMessage(), containsString("SDK exception")); + // RuntimeException should be IllegalStateException. + Assert.assertThat(exn, instanceOf(IllegalStateException.class)); + } + } + + @Test + @Category(com.google.cloud.dataflow.sdk.testing.RunnableOnService.class) + public void testMultipleApply() { + PTransform, PCollection> myTransform = + addSuffix("+"); + + Pipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(ImmutableList.of("a", "b"))); + + PCollection left = input.apply("Left1", myTransform).apply("Left2", myTransform); + PCollection right = input.apply("Right", myTransform); + + PCollection both = PCollectionList.of(left).and(right) + .apply(Flatten.pCollections()); + + DataflowAssert.that(both).containsInAnyOrder("a++", "b++", "a+", "b+"); + + p.run(); + } + + private static PTransform, PCollection> addSuffix( + final String suffix) { + return ParDo.of(new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) { + c.output(c.element() + suffix); + } + }); + } + + @Test + public void testToString() { + PipelineOptions options = PipelineOptionsFactory.as(PipelineOptions.class); + options.setRunner(DirectPipelineRunner.class); + Pipeline pipeline = Pipeline.create(options); + assertEquals("Pipeline#" + pipeline.hashCode(), pipeline.toString()); + } + + @Test + public void testStableUniqueNameOff() { + Pipeline p = TestPipeline.create(); + p.getOptions().setStableUniqueNames(CheckEnabled.OFF); + + p.apply(Create.of(5, 6, 7)); + p.apply(Create.of(5, 6, 7)); + + logged.verifyNotLogged("does not have a stable unique name."); + } + + @Test + public void testStableUniqueNameWarning() { + Pipeline p = TestPipeline.create(); + p.getOptions().setStableUniqueNames(CheckEnabled.WARNING); + + p.apply(Create.of(5, 6, 7)); + p.apply(Create.of(5, 6, 7)); + + logged.verifyWarn("does not have a stable unique name."); + } + + @Test + public void testStableUniqueNameError() { + Pipeline p = TestPipeline.create(); + p.getOptions().setStableUniqueNames(CheckEnabled.ERROR); + + p.apply(Create.of(5, 6, 7)); + + thrown.expectMessage("does not have a stable unique name."); + p.apply(Create.of(5, 6, 7)); + } + + /** + * Tests that Pipeline supports a pass-through identity function. + */ + @Test + @Category(RunnableOnService.class) + public void testIdentityTransform() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4)) + .apply("IdentityTransform", new IdentityTransform>()); + + DataflowAssert.that(output).containsInAnyOrder(1, 2, 3, 4); + pipeline.run(); + } + + private static class IdentityTransform + extends PTransform { + @Override + public T apply(T input) { + return input; + } + } + + /** + * Tests that Pipeline supports pulling an element out of a tuple as a transform. + */ + @Test + @Category(RunnableOnService.class) + public void testTupleProjectionTransform() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection input = pipeline + .apply(Create.of(1, 2, 3, 4)); + + TupleTag tag = new TupleTag(); + PCollectionTuple tuple = PCollectionTuple.of(tag, input); + + PCollection output = tuple + .apply("ProjectTag", new TupleProjectionTransform(tag)); + + DataflowAssert.that(output).containsInAnyOrder(1, 2, 3, 4); + pipeline.run(); + } + + private static class TupleProjectionTransform + extends PTransform> { + private TupleTag tag; + + public TupleProjectionTransform(TupleTag tag) { + this.tag = tag; + } + + @Override + public PCollection apply(PCollectionTuple input) { + return input.get(tag); + } + } + + /** + * Tests that Pipeline supports putting an element into a tuple as a transform. + */ + @Test + @Category(RunnableOnService.class) + public void testTupleInjectionTransform() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection input = pipeline + .apply(Create.of(1, 2, 3, 4)); + + TupleTag tag = new TupleTag(); + + PCollectionTuple output = input + .apply("ProjectTag", new TupleInjectionTransform(tag)); + + DataflowAssert.that(output.get(tag)).containsInAnyOrder(1, 2, 3, 4); + pipeline.run(); + } + + private static class TupleInjectionTransform + extends PTransform, PCollectionTuple> { + private TupleTag tag; + + public TupleInjectionTransform(TupleTag tag) { + this.tag = tag; + } + + @Override + public PCollectionTuple apply(PCollection input) { + return PCollectionTuple.of(tag, input); + } + } + + /** + * Tests that an empty pipeline runs. + */ + @Test + public void testEmptyPipeline() throws Exception { + Pipeline pipeline = TestPipeline.create(); + pipeline.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/TestUtils.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/TestUtils.java new file mode 100644 index 000000000000..257ecbbd4e33 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/TestUtils.java @@ -0,0 +1,213 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Utilities for tests. + */ +public class TestUtils { + // Do not instantiate. + private TestUtils() {} + + public static final String[] NO_LINES_ARRAY = new String[] { }; + + public static final List NO_LINES = Arrays.asList(NO_LINES_ARRAY); + + public static final String[] LINES_ARRAY = new String[] { + "To be, or not to be: that is the question: ", + "Whether 'tis nobler in the mind to suffer ", + "The slings and arrows of outrageous fortune, ", + "Or to take arms against a sea of troubles, ", + "And by opposing end them? To die: to sleep; ", + "No more; and by a sleep to say we end ", + "The heart-ache and the thousand natural shocks ", + "That flesh is heir to, 'tis a consummation ", + "Devoutly to be wish'd. To die, to sleep; ", + "To sleep: perchance to dream: ay, there's the rub; ", + "For in that sleep of death what dreams may come ", + "When we have shuffled off this mortal coil, ", + "Must give us pause: there's the respect ", + "That makes calamity of so long life; ", + "For who would bear the whips and scorns of time, ", + "The oppressor's wrong, the proud man's contumely, ", + "The pangs of despised love, the law's delay, ", + "The insolence of office and the spurns ", + "That patient merit of the unworthy takes, ", + "When he himself might his quietus make ", + "With a bare bodkin? who would fardels bear, ", + "To grunt and sweat under a weary life, ", + "But that the dread of something after death, ", + "The undiscover'd country from whose bourn ", + "No traveller returns, puzzles the will ", + "And makes us rather bear those ills we have ", + "Than fly to others that we know not of? ", + "Thus conscience does make cowards of us all; ", + "And thus the native hue of resolution ", + "Is sicklied o'er with the pale cast of thought, ", + "And enterprises of great pith and moment ", + "With this regard their currents turn awry, ", + "And lose the name of action.--Soft you now! ", + "The fair Ophelia! Nymph, in thy orisons ", + "Be all my sins remember'd." }; + + public static final List LINES = Arrays.asList(LINES_ARRAY); + + public static final String[] LINES2_ARRAY = new String[] { + "hi", "there", "bob!" }; + + public static final List LINES2 = Arrays.asList(LINES2_ARRAY); + + public static final Integer[] NO_INTS_ARRAY = new Integer[] { }; + + public static final List NO_INTS = Arrays.asList(NO_INTS_ARRAY); + + public static final Integer[] INTS_ARRAY = new Integer[] { + 3, 42, Integer.MAX_VALUE, 0, -1, Integer.MIN_VALUE, 666 }; + + public static final List INTS = Arrays.asList(INTS_ARRAY); + + /** + * Matcher for KVs. + */ + public static class KvMatcher + extends TypeSafeMatcher> { + final Matcher keyMatcher; + final Matcher valueMatcher; + + public static KvMatcher isKv(Matcher keyMatcher, + Matcher valueMatcher) { + return new KvMatcher<>(keyMatcher, valueMatcher); + } + + public KvMatcher(Matcher keyMatcher, + Matcher valueMatcher) { + this.keyMatcher = keyMatcher; + this.valueMatcher = valueMatcher; + } + + @Override + public boolean matchesSafely(KV kv) { + return keyMatcher.matches(kv.getKey()) + && valueMatcher.matches(kv.getValue()); + } + + @Override + public void describeTo(Description description) { + description + .appendText("a KV(").appendValue(keyMatcher) + .appendText(", ").appendValue(valueMatcher) + .appendText(")"); + } + } + + //////////////////////////////////////////////////////////////////////////// + // Utilities for testing CombineFns, ensuring they give correct results + // across various permutations and shardings of the input. + + public static void checkCombineFn( + CombineFn fn, List input, final OutputT expected) { + checkCombineFn(fn, input, CoreMatchers.is(expected)); + } + + public static void checkCombineFn( + CombineFn fn, List input, Matcher matcher) { + checkCombineFnInternal(fn, input, matcher); + Collections.shuffle(input); + checkCombineFnInternal(fn, input, matcher); + } + + private static void checkCombineFnInternal( + CombineFn fn, List input, Matcher matcher) { + int size = input.size(); + checkCombineFnShards(fn, Collections.singletonList(input), matcher); + checkCombineFnShards(fn, shardEvenly(input, 2), matcher); + if (size > 4) { + checkCombineFnShards(fn, shardEvenly(input, size / 2), matcher); + checkCombineFnShards( + fn, shardEvenly(input, (int) (size / Math.sqrt(size))), matcher); + } + checkCombineFnShards(fn, shardExponentially(input, 1.4), matcher); + checkCombineFnShards(fn, shardExponentially(input, 2), matcher); + checkCombineFnShards(fn, shardExponentially(input, Math.E), matcher); + } + + public static void checkCombineFnShards( + CombineFn fn, + List> shards, + Matcher matcher) { + checkCombineFnShardsInternal(fn, shards, matcher); + Collections.shuffle(shards); + checkCombineFnShardsInternal(fn, shards, matcher); + } + + private static void checkCombineFnShardsInternal( + CombineFn fn, + Iterable> shards, + Matcher matcher) { + List accumulators = new ArrayList<>(); + int maybeCompact = 0; + for (Iterable shard : shards) { + AccumT accumulator = fn.createAccumulator(); + for (InputT elem : shard) { + accumulator = fn.addInput(accumulator, elem); + } + if (maybeCompact++ % 2 == 0) { + accumulator = fn.compact(accumulator); + } + accumulators.add(accumulator); + } + AccumT merged = fn.mergeAccumulators(accumulators); + assertThat(fn.extractOutput(merged), matcher); + } + + private static List> shardEvenly(List input, int numShards) { + List> shards = new ArrayList<>(numShards); + for (int i = 0; i < numShards; i++) { + shards.add(input.subList(i * input.size() / numShards, + (i + 1) * input.size() / numShards)); + } + return shards; + } + + private static List> shardExponentially( + List input, double base) { + assert base > 1.0; + List> shards = new ArrayList<>(); + int end = input.size(); + while (end > 0) { + int start = (int) (end / base); + shards.add(input.subList(start, end)); + end = start; + } + return shards; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/WindowMatchers.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/WindowMatchers.java new file mode 100644 index 000000000000..9d7cfc869458 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/WindowMatchers.java @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.hamcrest.TypeSafeMatcher; +import org.joda.time.Instant; + +import java.util.Collection; +import java.util.Objects; + +/** + * Matchers that are useful for working with Windowing, Timestamps, etc. + */ +public class WindowMatchers { + + public static Matcher> isWindowedValue( + Matcher valueMatcher, Matcher timestampMatcher, + Matcher> windowsMatcher) { + return new WindowedValueMatcher<>(valueMatcher, timestampMatcher, windowsMatcher); + } + + public static Matcher> isWindowedValue( + Matcher valueMatcher, Matcher timestampMatcher) { + return new WindowedValueMatcher<>(valueMatcher, timestampMatcher, Matchers.anything()); + } + + public static Matcher> isSingleWindowedValue( + T value, long timestamp, long windowStart, long windowEnd) { + return WindowMatchers.isSingleWindowedValue( + Matchers.equalTo(value), timestamp, windowStart, windowEnd); + } + + public static Matcher> isSingleWindowedValue( + Matcher valueMatcher, long timestamp, long windowStart, long windowEnd) { + IntervalWindow intervalWindow = + new IntervalWindow(new Instant(windowStart), new Instant(windowEnd)); + return WindowMatchers.isSingleWindowedValue( + valueMatcher, + Matchers.describedAs("%0", Matchers.equalTo(new Instant(timestamp)), timestamp), + Matchers.equalTo(intervalWindow)); + } + + public static Matcher> isSingleWindowedValue( + Matcher valueMatcher, Matcher timestampMatcher, + Matcher windowMatcher) { + return new WindowedValueMatcher( + valueMatcher, timestampMatcher, Matchers.contains(windowMatcher)); + } + + public static Matcher intervalWindow(long start, long end) { + return Matchers.equalTo(new IntervalWindow(new Instant(start), new Instant(end))); + } + + public static Matcher> valueWithPaneInfo(final PaneInfo paneInfo) { + return new TypeSafeMatcher>() { + @Override + public void describeTo(Description description) { + description + .appendText("WindowedValue(paneInfo = ").appendValue(paneInfo).appendText(")"); + } + + @Override + protected boolean matchesSafely(WindowedValue item) { + return Objects.equals(item.getPane(), paneInfo); + } + + @Override + protected void describeMismatchSafely( + WindowedValue item, Description mismatchDescription) { + mismatchDescription.appendValue(item.getPane()); + } + }; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @SafeVarargs + public static final Matcher> ofWindows( + Matcher... windows) { + return (Matcher) Matchers.containsInAnyOrder(windows); + } + + private WindowMatchers() {} + + private static class WindowedValueMatcher extends TypeSafeMatcher> { + + private Matcher valueMatcher; + private Matcher timestampMatcher; + private Matcher> windowsMatcher; + + private WindowedValueMatcher( + Matcher valueMatcher, + Matcher timestampMatcher, + Matcher> windowsMatcher) { + this.valueMatcher = valueMatcher; + this.timestampMatcher = timestampMatcher; + this.windowsMatcher = windowsMatcher; + } + + @Override + public void describeTo(Description description) { + description + .appendText("a WindowedValue(").appendValue(valueMatcher) + .appendText(", ").appendValue(timestampMatcher) + .appendText(", ").appendValue(windowsMatcher) + .appendText(")"); + } + + @Override + protected boolean matchesSafely(WindowedValue windowedValue) { + return valueMatcher.matches(windowedValue.getValue()) + && timestampMatcher.matches(windowedValue.getTimestamp()) + && windowsMatcher.matches(windowedValue.getWindows()); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/AvroCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/AvroCoderTest.java new file mode 100644 index 000000000000..db6e9449abb4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/AvroCoderTest.java @@ -0,0 +1,754 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.apache.avro.AvroTypeException; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.reflect.AvroName; +import org.apache.avro.reflect.AvroSchema; +import org.apache.avro.reflect.Nullable; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.Stringable; +import org.apache.avro.reflect.Union; +import org.apache.avro.specific.SpecificData; +import org.apache.avro.util.Utf8; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.hamcrest.TypeSafeMatcher; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; + +/** Tests for {@link AvroCoder}. */ +@RunWith(JUnit4.class) +public class AvroCoderTest { + + @DefaultCoder(AvroCoder.class) + private static class Pojo { + public String text; + public int count; + + // Empty constructor required for Avro decoding. + @SuppressWarnings("unused") + public Pojo() { + } + + public Pojo(String text, int count) { + this.text = text; + this.count = count; + } + + // auto-generated + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Pojo pojo = (Pojo) o; + + if (count != pojo.count) { + return false; + } + if (text != null + ? !text.equals(pojo.text) + : pojo.text != null) { + return false; + } + + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public String toString() { + return "Pojo{" + + "text='" + text + '\'' + + ", count=" + count + + '}'; + } + } + + private static class GetTextFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().text); + } + } + + @Test + public void testAvroCoderEncoding() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + CloudObject encoding = coder.asCloudObject(); + + Assert.assertThat(encoding.keySet(), + Matchers.containsInAnyOrder("@type", "type", "schema", "encoding_id")); + } + + @Test + public void testPojoEncoding() throws Exception { + Pojo value = new Pojo("Hello", 42); + AvroCoder coder = AvroCoder.of(Pojo.class); + + CoderProperties.coderDecodeEncodeEqual(coder, value); + } + + @Test + public void testPojoEncodingId() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + CoderProperties.coderHasEncodingId(coder, Pojo.class.getName()); + } + + @Test + public void testGenericRecordEncoding() throws Exception { + String schemaString = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"User\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},\n" + + " {\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + Schema schema = (new Schema.Parser()).parse(schemaString); + + GenericRecord before = new GenericData.Record(schema); + before.put("name", "Bob"); + before.put("favorite_number", 256); + // Leave favorite_color null + + AvroCoder coder = AvroCoder.of(GenericRecord.class, schema); + + CoderProperties.coderDecodeEncodeEqual(coder, before); + Assert.assertEquals(schema, coder.getSchema()); + } + + @Test + public void testEncodingNotBuffered() throws Exception { + // This test ensures that the coder doesn't read ahead and buffer data. + // Reading ahead causes a problem if the stream consists of records of different + // types. + Pojo before = new Pojo("Hello", 42); + + AvroCoder coder = AvroCoder.of(Pojo.class); + SerializableCoder intCoder = SerializableCoder.of(Integer.class); + + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + + Context context = Context.NESTED; + coder.encode(before, outStream, context); + intCoder.encode(10, outStream, context); + + ByteArrayInputStream inStream = new ByteArrayInputStream(outStream.toByteArray()); + + Pojo after = coder.decode(inStream, context); + Assert.assertEquals(before, after); + + Integer intAfter = intCoder.decode(inStream, context); + Assert.assertEquals(new Integer(10), intAfter); + } + + @Test + public void testDefaultCoder() throws Exception { + Pipeline p = TestPipeline.create(); + + // Use MyRecord as input and output types without explicitly specifying + // a coder (this uses the default coders, which may not be AvroCoder). + PCollection output = + p.apply(Create.of(new Pojo("hello", 1), new Pojo("world", 2))) + .apply(ParDo.of(new GetTextFn())); + + DataflowAssert.that(output) + .containsInAnyOrder("hello", "world"); + p.run(); + } + + @Test + public void testAvroCoderIsSerializable() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + + // Check that the coder is serializable using the regular JSON approach. + SerializableUtils.ensureSerializable(coder); + } + + private final void assertDeterministic(AvroCoder coder) { + try { + coder.verifyDeterministic(); + } catch (NonDeterministicException e) { + fail("Expected " + coder + " to be deterministic, but got:\n" + e); + } + } + + private final void assertNonDeterministic(AvroCoder coder, + Matcher reason1) { + try { + coder.verifyDeterministic(); + fail("Expected " + coder + " to be non-deterministic."); + } catch (NonDeterministicException e) { + assertThat(e.getReasons(), Matchers.iterableWithSize(1)); + assertThat(e.getReasons(), Matchers.contains(reason1)); + } + } + + @Test + public void testDeterministicInteger() { + assertDeterministic(AvroCoder.of(Integer.class)); + } + + @Test + public void testDeterministicInt() { + assertDeterministic(AvroCoder.of(int.class)); + } + + private static class SimpleDeterministicClass { + @SuppressWarnings("unused") + private Integer intField; + @SuppressWarnings("unused") + private char charField; + @SuppressWarnings("unused") + private Integer[] intArray; + @SuppressWarnings("unused") + private Utf8 utf8field; + } + + @Test + public void testDeterministicSimple() { + assertDeterministic(AvroCoder.of(SimpleDeterministicClass.class)); + } + + private static class UnorderedMapClass { + @SuppressWarnings("unused") + private Map mapField; + } + + private Matcher reason(final String prefix, final String messagePart) { + return new TypeSafeMatcher(String.class) { + @Override + public void describeTo(Description description) { + description.appendText(String.format("Reason starting with '%s:' containing '%s'", + prefix, messagePart)); + } + + @Override + protected boolean matchesSafely(String item) { + return item.startsWith(prefix + ":") && item.contains(messagePart); + } + }; + } + + private Matcher reasonClass(Class clazz, String message) { + return reason(clazz.getName(), message); + } + + private Matcher reasonField( + Class clazz, String field, String message) { + return reason(clazz.getName() + "#" + field, message); + } + + @Test + public void testDeterministicUnorderedMap() { + assertNonDeterministic(AvroCoder.of(UnorderedMapClass.class), + reasonField(UnorderedMapClass.class, "mapField", + "java.util.Map " + + "may not be deterministically ordered")); + } + + private static class NonDeterministicArray { + @SuppressWarnings("unused") + private UnorderedMapClass[] arrayField; + } + @Test + public void testDeterministicNonDeterministicArray() { + assertNonDeterministic(AvroCoder.of(NonDeterministicArray.class), + reasonField(UnorderedMapClass.class, "mapField", + "java.util.Map" + + " may not be deterministically ordered")); + } + + private static class SubclassOfUnorderedMapClass extends UnorderedMapClass {} + + + @Test + public void testDeterministicNonDeterministicChild() { + // Super class has non deterministic fields. + assertNonDeterministic(AvroCoder.of(SubclassOfUnorderedMapClass.class), + reasonField(UnorderedMapClass.class, "mapField", + "may not be deterministically ordered")); + } + + private static class SubclassHidingParent extends UnorderedMapClass { + @SuppressWarnings("unused") + @AvroName("mapField2") // AvroName is not enough + private int mapField; + } + + @Test + public void testAvroProhibitsShadowing() { + // This test verifies that Avro won't serialize a class with two fields of + // the same name. This is important for our error reporting, and also how + // we lookup a field. + try { + ReflectData.get().getSchema(SubclassHidingParent.class); + fail("Expected AvroTypeException"); + } catch (AvroTypeException e) { + assertThat(e.getMessage(), containsString("mapField")); + assertThat(e.getMessage(), containsString("two fields named")); + } + } + + private static class FieldWithAvroName { + @AvroName("name") + @SuppressWarnings("unused") + private int someField; + } + + @Test + public void testDeterministicWithAvroName() { + assertDeterministic(AvroCoder.of(FieldWithAvroName.class)); + } + + @Test + public void testDeterminismSortedMap() { + assertDeterministic(AvroCoder.of(StringSortedMapField.class)); + } + + private static class StringSortedMapField { + @SuppressWarnings("unused") + SortedMap sortedMapField; + } + + @Test + public void testDeterminismTreeMapValue() { + // The value is non-deterministic, so we should fail. + assertNonDeterministic(AvroCoder.of(TreeMapNonDetValue.class), + reasonField(UnorderedMapClass.class, "mapField", + "java.util.Map " + + "may not be deterministically ordered")); + } + + private static class TreeMapNonDetValue { + @SuppressWarnings("unused") + TreeMap nonDeterministicField; + } + + @Test + public void testDeterminismUnorderedMap() { + // LinkedHashMap is not deterministically ordered, so we should fail. + assertNonDeterministic(AvroCoder.of(LinkedHashMapField.class), + reasonField(LinkedHashMapField.class, "nonDeterministicMap", + "java.util.LinkedHashMap " + + "may not be deterministically ordered")); + } + + private static class LinkedHashMapField { + @SuppressWarnings("unused") + LinkedHashMap nonDeterministicMap; + } + + @Test + public void testDeterminismCollection() { + assertNonDeterministic(AvroCoder.of(StringCollection.class), + reasonField(StringCollection.class, "stringCollection", + "java.util.Collection may not be deterministically ordered")); + } + + private static class StringCollection { + @SuppressWarnings("unused") + Collection stringCollection; + } + + @Test + public void testDeterminismList() { + assertDeterministic(AvroCoder.of(StringList.class)); + assertDeterministic(AvroCoder.of(StringArrayList.class)); + } + + private static class StringList { + @SuppressWarnings("unused") + List stringCollection; + } + + private static class StringArrayList { + @SuppressWarnings("unused") + ArrayList stringCollection; + } + + @Test + public void testDeterminismSet() { + assertDeterministic(AvroCoder.of(StringSortedSet.class)); + assertDeterministic(AvroCoder.of(StringTreeSet.class)); + assertNonDeterministic(AvroCoder.of(StringHashSet.class), + reasonField(StringHashSet.class, "stringCollection", + "java.util.HashSet may not be deterministically ordered")); + } + + private static class StringSortedSet{ + @SuppressWarnings("unused") + SortedSet stringCollection; + } + + private static class StringTreeSet { + @SuppressWarnings("unused") + TreeSet stringCollection; + } + + private static class StringHashSet { + @SuppressWarnings("unused") + HashSet stringCollection; + } + + @Test + public void testDeterminismCollectionValue() { + assertNonDeterministic(AvroCoder.of(OrderedSetOfNonDetValues.class), + reasonField(UnorderedMapClass.class, "mapField", + "may not be deterministically ordered")); + assertNonDeterministic(AvroCoder.of(ListOfNonDetValues.class), + reasonField(UnorderedMapClass.class, "mapField", + "may not be deterministically ordered")); + } + + private static class OrderedSetOfNonDetValues { + @SuppressWarnings("unused") + SortedSet set; + } + + private static class ListOfNonDetValues { + @SuppressWarnings("unused") + List set; + } + + @Test + public void testDeterminismUnion() { + assertDeterministic(AvroCoder.of(DeterministicUnionBase.class)); + assertNonDeterministic(AvroCoder.of(NonDeterministicUnionBase.class), + reasonField(UnionCase3.class, "mapField", "may not be deterministically ordered")); + } + + @Test + public void testDeterminismStringable() { + assertDeterministic(AvroCoder.of(String.class)); + assertNonDeterministic(AvroCoder.of(StringableClass.class), + reasonClass(StringableClass.class, "may not have deterministic #toString()")); + } + + @Stringable + private static class StringableClass { + } + + @Test + public void testDeterminismCyclicClass() { + assertNonDeterministic(AvroCoder.of(Cyclic.class), + reasonField(Cyclic.class, "cyclicField", "appears recursively")); + assertNonDeterministic(AvroCoder.of(CyclicField.class), + reasonField(Cyclic.class, "cyclicField", + Cyclic.class.getName() + " appears recursively")); + assertNonDeterministic(AvroCoder.of(IndirectCycle1.class), + reasonField(IndirectCycle2.class, "field2", + IndirectCycle1.class.getName() + " appears recursively")); + } + + private static class Cyclic { + @SuppressWarnings("unused") + int intField; + @SuppressWarnings("unused") + Cyclic cyclicField; + } + + private static class CyclicField { + @SuppressWarnings("unused") + Cyclic cyclicField2; + } + + private static class IndirectCycle1 { + @SuppressWarnings("unused") + IndirectCycle2 field1; + } + + private static class IndirectCycle2 { + @SuppressWarnings("unused") + IndirectCycle1 field2; + } + + @Test + public void testDeterminismHasGenericRecord() { + assertDeterministic(AvroCoder.of(HasGenericRecord.class)); + } + + private static class HasGenericRecord { + @AvroSchema("{\"name\": \"bar\", \"type\": \"record\", \"fields\": [" + + "{\"name\": \"foo\", \"type\": \"int\"}]}") + GenericRecord genericRecord; + } + + @Test + public void testDeterminismHasCustomSchema() { + assertNonDeterministic(AvroCoder.of(HasCustomSchema.class), + reasonField(HasCustomSchema.class, "withCustomSchema", + "Custom schemas are only supported for subtypes of IndexedRecord.")); + } + + private static class HasCustomSchema { + @AvroSchema("{\"name\": \"bar\", \"type\": \"record\", \"fields\": [" + + "{\"name\": \"foo\", \"type\": \"int\"}]}") + int withCustomSchema; + } + + @Test + public void testAvroCoderTreeMapDeterminism() + throws Exception, NonDeterministicException { + TreeMapField size1 = new TreeMapField(); + TreeMapField size2 = new TreeMapField(); + + // Different order for entries + size1.field.put("hello", "world"); + size1.field.put("another", "entry"); + + size2.field.put("another", "entry"); + size2.field.put("hello", "world"); + + AvroCoder coder = AvroCoder.of(TreeMapField.class); + coder.verifyDeterministic(); + + ByteArrayOutputStream outStream1 = new ByteArrayOutputStream(); + ByteArrayOutputStream outStream2 = new ByteArrayOutputStream(); + + Context context = Context.NESTED; + coder.encode(size1, outStream1, context); + coder.encode(size2, outStream2, context); + + assertTrue(Arrays.equals( + outStream1.toByteArray(), outStream2.toByteArray())); + } + + private static class TreeMapField { + private TreeMap field = new TreeMap<>(); + } + + @Union({ UnionCase1.class, UnionCase2.class }) + private abstract static class DeterministicUnionBase {} + + @Union({ UnionCase1.class, UnionCase2.class, UnionCase3.class }) + private abstract static class NonDeterministicUnionBase {} + private static class UnionCase1 extends DeterministicUnionBase {} + private static class UnionCase2 extends DeterministicUnionBase { + @SuppressWarnings("unused") + String field; + } + + private static class UnionCase3 extends NonDeterministicUnionBase { + @SuppressWarnings("unused") + private Map mapField; + } + + @Test + public void testAvroCoderSimpleSchemaDeterminism() { + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .endRecord())); + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("int").type().intType().noDefault() + .endRecord())); + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("string").type().stringType().noDefault() + .endRecord())); + + assertNonDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("map").type().map().values().stringType().noDefault() + .endRecord()), + reason("someRecord.map", "HashMap to represent MAPs")); + + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("array").type().array().items().stringType().noDefault() + .endRecord())); + + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("enum").type().enumeration("anEnum").symbols("s1", "s2").enumDefault("s1") + .endRecord())); + + assertDeterministic(AvroCoder.of(SchemaBuilder.unionOf() + .intType().and() + .record("someRecord").fields().nullableString("someField", "").endRecord() + .endUnion())); + } + + @Test + public void testAvroCoderStrings() { + // Custom Strings in Records + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("string").prop(SpecificData.CLASS_PROP, "java.lang.String") + .type().stringType().noDefault() + .endRecord())); + assertNonDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields() + .name("string").prop(SpecificData.CLASS_PROP, "unknownString") + .type().stringType().noDefault() + .endRecord()), + reason("someRecord.string", "unknownString is not known to be deterministic")); + + // Custom Strings in Unions + assertNonDeterministic(AvroCoder.of(SchemaBuilder.unionOf() + .intType().and() + .record("someRecord").fields() + .name("someField").prop(SpecificData.CLASS_PROP, "unknownString") + .type().stringType().noDefault().endRecord() + .endUnion()), + reason("someRecord.someField", "unknownString is not known to be deterministic")); + } + + @Test + public void testAvroCoderNestedRecords() { + // Nested Record + assertDeterministic(AvroCoder.of(SchemaBuilder.record("nestedRecord").fields() + .name("subRecord").type().record("subRecord").fields() + .name("innerField").type().stringType().noDefault() + .endRecord().noDefault() + .endRecord())); + } + + @Test + public void testAvroCoderCyclicRecords() { + // Recursive record + assertNonDeterministic(AvroCoder.of(SchemaBuilder.record("cyclicRecord").fields() + .name("cycle").type("cyclicRecord").noDefault() + .endRecord()), + reason("cyclicRecord.cycle", "cyclicRecord appears recursively")); + } + + private static class NullableField { + @SuppressWarnings("unused") + @Nullable private String nullable; + } + + @Test + public void testNullableField() { + assertDeterministic(AvroCoder.of(NullableField.class)); + } + + private static class NullableNonDeterministicField { + @SuppressWarnings("unused") + @Nullable private NonDeterministicArray nullableNonDetArray; + } + + private static class NullableCyclic { + @SuppressWarnings("unused") + @Nullable private NullableCyclic nullableNullableCyclicField; + } + + private static class NullableCyclicField { + @SuppressWarnings("unused") + @Nullable private Cyclic nullableCyclicField; + } + + @Test + public void testNullableNonDeterministicField() { + assertNonDeterministic(AvroCoder.of(NullableCyclic.class), + reasonField(NullableCyclic.class, "nullableNullableCyclicField", + NullableCyclic.class.getName() + " appears recursively")); + assertNonDeterministic(AvroCoder.of(NullableCyclicField.class), + reasonField(Cyclic.class, "cyclicField", + Cyclic.class.getName() + " appears recursively")); + assertNonDeterministic(AvroCoder.of(NullableNonDeterministicField.class), + reasonField(UnorderedMapClass.class, "mapField", + " may not be deterministically ordered")); + } + + /** + * Tests that a parameterized class can have an automatically generated schema if the generic + * field is annotated with a union tag. + */ + @Test + public void testGenericClassWithUnionAnnotation() throws Exception { + // Cast is safe as long as the same coder is used for encoding and decoding. + @SuppressWarnings({"unchecked", "rawtypes"}) + AvroCoder> coder = + (AvroCoder) AvroCoder.of(GenericWithAnnotation.class); + + assertThat(coder.getSchema().getField("onlySomeTypesAllowed").schema().getType(), + equalTo(Schema.Type.UNION)); + + CoderProperties.coderDecodeEncodeEqual(coder, new GenericWithAnnotation<>("hello")); + } + + private static class GenericWithAnnotation { + @AvroSchema("[\"string\", \"int\"]") + private T onlySomeTypesAllowed; + + public GenericWithAnnotation(T value) { + onlySomeTypesAllowed = value; + } + + // For deserialization only + @SuppressWarnings("unused") + protected GenericWithAnnotation() { } + + @Override + public boolean equals(Object other) { + return other instanceof GenericWithAnnotation + && onlySomeTypesAllowed.equals(((GenericWithAnnotation) other).onlySomeTypesAllowed); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), onlySomeTypesAllowed); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoderTest.java new file mode 100644 index 000000000000..d96c20805df7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/BigEndianIntegerCoderTest.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link BigEndianIntegerCoder}. + */ +@RunWith(JUnit4.class) +public class BigEndianIntegerCoderTest { + + private static final Coder TEST_CODER = BigEndianIntegerCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + -11, -3, -1, 0, 1, 5, 13, 29, + Integer.MAX_VALUE, + Integer.MIN_VALUE); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Integer value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // This should never change. The definition of big endian encoding is fixed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "____9Q", + "_____Q", + "_____w", + "AAAAAA", + "AAAAAQ", + "AAAABQ", + "AAAADQ", + "AAAAHQ", + "f____w", + "gAAAAA"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Integer"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoderTest.java new file mode 100644 index 000000000000..ea486c18230f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/BigEndianLongCoderTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link BigEndianLongCoder}. + */ +@RunWith(JUnit4.class) +public class BigEndianLongCoderTest { + + private static final Coder TEST_CODER = BigEndianLongCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + -11L, -3L, -1L, 0L, 1L, 5L, 13L, 29L, + Integer.MAX_VALUE + 131L, + Integer.MIN_VALUE - 29L, + Long.MAX_VALUE, + Long.MIN_VALUE); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Long value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // This should never change. The definition of big endian is fixed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "__________U", + "__________0", + "__________8", + "AAAAAAAAAAA", + "AAAAAAAAAAE", + "AAAAAAAAAAU", + "AAAAAAAAAA0", + "AAAAAAAAAB0", + "AAAAAIAAAII", + "_____3___-M", + "f_________8", + "gAAAAAAAAAA"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Long"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoderTest.java new file mode 100644 index 000000000000..989bc7f80b14 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteArrayCoderTest.java @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.common.CounterTestUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for {@link ByteArrayCoder}. + */ +@RunWith(JUnit4.class) +public class ByteArrayCoderTest { + + private static final ByteArrayCoder TEST_CODER = ByteArrayCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + new byte[]{0xa, 0xb, 0xc}, + new byte[]{0xd, 0x3}, + new byte[]{0xd, 0xe}, + new byte[]{}); + + @Test + public void testDecodeEncodeEquals() throws Exception { + for (byte[] value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testRegisterByteSizeObserver() throws Exception { + CounterTestUtils.testByteCount(ByteArrayCoder.of(), Coder.Context.OUTER, + new byte[][]{{ 0xa, 0xb, 0xc }}); + + CounterTestUtils.testByteCount(ByteArrayCoder.of(), Coder.Context.NESTED, + new byte[][]{{ 0xa, 0xb, 0xc }, {}, {}, { 0xd, 0xe }, {}}); + } + + @Test + public void testStructuralValueConsistentWithEquals() throws Exception { + // We know that byte array coders are NOT compatible with equals + // (aka injective w.r.t. Object.equals) + for (byte[] value1 : TEST_VALUES) { + for (byte[] value2 : TEST_VALUES) { + CoderProperties.structuralValueConsistentWithEquals(TEST_CODER, value1, value2); + } + } + } + + @Test + public void testEncodeThenMutate() throws Exception { + byte[] input = { 0x7, 0x3, 0xA, 0xf }; + byte[] encoded = CoderUtils.encodeToByteArray(TEST_CODER, input); + input[1] = 0x9; + byte[] decoded = CoderUtils.decodeFromByteArray(TEST_CODER, encoded); + + // now that I have mutated the input, the output should NOT match + assertThat(input, not(equalTo(decoded))); + } + + @Test + public void testEncodeAndOwn() throws Exception { + for (byte[] value : TEST_VALUES) { + byte[] encodedSlow = CoderUtils.encodeToByteArray(TEST_CODER, value); + byte[] encodedFast = encodeToByteArrayAndOwn(TEST_CODER, value); + assertThat(encodedSlow, equalTo(encodedFast)); + } + } + + private static byte[] encodeToByteArrayAndOwn(ByteArrayCoder coder, byte[] value) + throws IOException { + return encodeToByteArrayAndOwn(coder, value, Coder.Context.OUTER); + } + + private static byte[] encodeToByteArrayAndOwn( + ByteArrayCoder coder, byte[] value, Coder.Context context) throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + coder.encodeAndOwn(value, os, context); + return os.toByteArray(); + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "CgsM", + "DQM", + "DQ4", + ""); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null byte[]"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteCoderTest.java new file mode 100644 index 000000000000..6cb852e23613 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteCoderTest.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link ByteCoder}. + */ +@RunWith(JUnit4.class) +public class ByteCoderTest { + + private static final Coder TEST_CODER = ByteCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + (byte) 1, + (byte) 4, + (byte) 6, + (byte) 50, + (byte) 124, + Byte.MAX_VALUE, + Byte.MIN_VALUE); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Byte value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // This should never change. The format is fixed by Java. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AQ", + "BA", + "Bg", + "Mg", + "fA", + "fw", + "gA"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Byte"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteStringCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteStringCoderTest.java new file mode 100644 index 000000000000..debae71fddf7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ByteStringCoderTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link ByteStringCoder}. + */ +@RunWith(JUnit4.class) +public class ByteStringCoderTest { + + private static final ByteStringCoder TEST_CODER = ByteStringCoder.of(); + + private static final List TEST_STRING_VALUES = Arrays.asList( + "", "a", "13", "hello", + "a longer string with spaces and all that", + "a string with a \n newline", + "???????????????"); + private static final ImmutableList TEST_VALUES; + static { + ImmutableList.Builder builder = ImmutableList.builder(); + for (String s : TEST_STRING_VALUES) { + builder.add(ByteString.copyFrom(s.getBytes())); + } + TEST_VALUES = builder.build(); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "", + "YQ", + "MTM", + "aGVsbG8", + "YSBsb25nZXIgc3RyaW5nIHdpdGggc3BhY2VzIGFuZCBhbGwgdGhhdA", + "YSBzdHJpbmcgd2l0aCBhIAogbmV3bGluZQ", + "Pz8_Pz8_Pz8_Pz8_Pz8_"); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testDecodeEncodeEqualInAllContexts() throws Exception { + for (ByteString value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Test + public void testCoderDeterministic() throws Throwable { + TEST_CODER.verifyDeterministic(); + } + + @Test + public void testConsistentWithEquals() { + assertTrue(TEST_CODER.consistentWithEquals()); + } + + @Test + public void testEncodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null ByteString"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } + + @Test + public void testNestedCoding() throws Throwable { + Coder> listCoder = ListCoder.of(TEST_CODER); + CoderProperties.coderDecodeEncodeContentsEqual(listCoder, TEST_VALUES); + CoderProperties.coderDecodeEncodeContentsInSameOrder(listCoder, TEST_VALUES); + } + + @Test + public void testEncodedElementByteSizeInAllContexts() throws Throwable { + for (Context context : CoderProperties.ALL_CONTEXTS) { + for (ByteString value : TEST_VALUES) { + byte[] encoded = CoderUtils.encodeToByteArray(TEST_CODER, value, context); + assertEquals(encoded.length, TEST_CODER.getEncodedElementByteSize(value, context)); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderFactoriesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderFactoriesTest.java new file mode 100644 index 000000000000..8d702bf259cb --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderFactoriesTest.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; + +/** + * Tests for {@link CoderFactories}. + */ +@RunWith(JUnit4.class) +public class CoderFactoriesTest { + + /** + * Ensures that a few of our standard atomic coder classes + * can each be built into a factory that works as expected. + * It is presumed that testing a few, not all, suffices to + * exercise CoderFactoryFromStaticMethods. + */ + @Test + public void testAtomicCoderClassFactories() { + checkAtomicCoderFactory(StringUtf8Coder.class, StringUtf8Coder.of()); + checkAtomicCoderFactory(DoubleCoder.class, DoubleCoder.of()); + checkAtomicCoderFactory(ByteArrayCoder.class, ByteArrayCoder.of()); + } + + /** + * Checks that {#link CoderFactories.fromStaticMethods} successfully + * builds a working {@link CoderFactory} from {@link KvCoder KvCoder.class}. + */ + @Test + public void testKvCoderFactory() { + CoderFactory kvCoderFactory = CoderFactories.fromStaticMethods(KvCoder.class); + assertEquals( + KvCoder.of(DoubleCoder.of(), DoubleCoder.of()), + kvCoderFactory.create(Arrays.asList(DoubleCoder.of(), DoubleCoder.of()))); + } + + /** + * Checks that {#link CoderFactories.fromStaticMethods} successfully + * builds a working {@link CoderFactory} from {@link ListCoder ListCoder.class}. + */ + @Test + public void testListCoderFactory() { + CoderFactory listCoderFactory = CoderFactories.fromStaticMethods(ListCoder.class); + + assertEquals( + ListCoder.of(DoubleCoder.of()), + listCoderFactory.create(Arrays.asList(DoubleCoder.of()))); + } + + /** + * Checks that {#link CoderFactories.fromStaticMethods} successfully + * builds a working {@link CoderFactory} from {@link IterableCoder IterableCoder.class}. + */ + @Test + public void testIterableCoderFactory() { + CoderFactory iterableCoderFactory = CoderFactories.fromStaticMethods(IterableCoder.class); + + assertEquals( + IterableCoder.of(DoubleCoder.of()), + iterableCoderFactory.create(Arrays.asList(DoubleCoder.of()))); + } + + /////////////////////////////////////////////////////////////////////// + + /** + * Checks that an atomic coder class can be converted into + * a factory that then yields a coder equal to the example + * provided. + */ + private void checkAtomicCoderFactory( + Class> coderClazz, + Coder expectedCoder) { + CoderFactory factory = CoderFactories.fromStaticMethods(coderClazz); + @SuppressWarnings("unchecked") + Coder actualCoder = (Coder) factory.create(Collections.>emptyList()); + assertEquals(expectedCoder, actualCoder); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderProvidersTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderProvidersTest.java new file mode 100644 index 000000000000..1c0a89ed1b76 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderProvidersTest.java @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Map; + +/** + * Tests for {@link CoderFactories}. + */ +@RunWith(JUnit4.class) +public class CoderProvidersTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testAvroThenSerializableStringMap() throws Exception { + CoderProvider provider = CoderProviders.firstOf(AvroCoder.PROVIDER, SerializableCoder.PROVIDER); + Coder> coder = + provider.getCoder(new TypeDescriptor>(){}); + assertThat(coder, instanceOf(AvroCoder.class)); + } + + @Test + public void testThrowingThenSerializable() throws Exception { + CoderProvider provider = + CoderProviders.firstOf(new ThrowingCoderProvider(), SerializableCoder.PROVIDER); + Coder coder = provider.getCoder(new TypeDescriptor(){}); + assertThat(coder, instanceOf(SerializableCoder.class)); + } + + @Test + public void testNullThrows() throws Exception { + CoderProvider provider = CoderProviders.firstOf(new ThrowingCoderProvider()); + thrown.expect(CannotProvideCoderException.class); + thrown.expectMessage("ThrowingCoderProvider"); + provider.getCoder(new TypeDescriptor(){}); + } + + private static class ThrowingCoderProvider implements CoderProvider { + @Override + public Coder getCoder(TypeDescriptor type) throws CannotProvideCoderException { + throw new CannotProvideCoderException("ThrowingCoderProvider cannot ever provide a Coder"); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderRegistryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderRegistryTest.java new file mode 100644 index 000000000000..2f350b288dc4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderRegistryTest.java @@ -0,0 +1,521 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry.IncompatibleCoderException; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageA; +import com.google.cloud.dataflow.sdk.coders.protobuf.ProtoCoder; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Duration; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Tests for CoderRegistry. + */ +@RunWith(JUnit4.class) +public class CoderRegistryTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + public static CoderRegistry getStandardRegistry() { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + return registry; + } + + private static class SerializableClass implements Serializable { + } + + private static class NotSerializableClass { } + + @Test + public void testSerializableFallbackCoderProvider() throws Exception { + CoderRegistry registry = getStandardRegistry(); + registry.setFallbackCoderProvider(SerializableCoder.PROVIDER); + Coder serializableCoder = registry.getDefaultCoder(SerializableClass.class); + + assertEquals(serializableCoder, SerializableCoder.of(SerializableClass.class)); + } + + @Test + public void testProtoCoderFallbackCoderProvider() throws Exception { + CoderRegistry registry = getStandardRegistry(); + + // MessageA is a Protocol Buffers test message with syntax 2 + assertEquals(registry.getDefaultCoder(MessageA.class), ProtoCoder.of(MessageA.class)); + + // Duration is a Protocol Buffers default type with syntax 3 + assertEquals(registry.getDefaultCoder(Duration.class), ProtoCoder.of(Duration.class)); + } + + @Test + public void testAvroFallbackCoderProvider() throws Exception { + CoderRegistry registry = getStandardRegistry(); + registry.setFallbackCoderProvider(AvroCoder.PROVIDER); + Coder avroCoder = registry.getDefaultCoder(NotSerializableClass.class); + + assertEquals(avroCoder, AvroCoder.of(NotSerializableClass.class)); + } + + @Test + public void testRegisterInstantiatedCoder() throws Exception { + CoderRegistry registry = new CoderRegistry(); + registry.registerCoder(MyValue.class, MyValueCoder.of()); + assertEquals(registry.getDefaultCoder(MyValue.class), MyValueCoder.of()); + } + + @SuppressWarnings("rawtypes") // this class exists to fail a test because of its rawtypes + private class MyListCoder extends DeterministicStandardCoder { + @Override + public void encode(List value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public List decode(InputStream inStream, Context context) + throws CoderException, IOException { + return Collections.emptyList(); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + } + + @Test + public void testRegisterInstantiatedCoderInvalidRawtype() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("may not be used with unspecialized generic classes"); + CoderRegistry registry = new CoderRegistry(); + registry.registerCoder(List.class, new MyListCoder()); + } + + @Test + public void testSimpleDefaultCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + assertEquals(StringUtf8Coder.of(), registry.getDefaultCoder(String.class)); + } + + @Test + public void testSimpleUnknownDefaultCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + thrown.expect(CannotProvideCoderException.class); + thrown.expectMessage(allOf( + containsString(UnknownType.class.getCanonicalName()), + containsString("No CoderFactory has been registered"), + containsString("does not have a @DefaultCoder annotation"), + containsString("does not implement Serializable"))); + registry.getDefaultCoder(UnknownType.class); + } + + @Test + public void testParameterizedDefaultListCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + TypeDescriptor> listToken = new TypeDescriptor>() {}; + assertEquals(ListCoder.of(VarIntCoder.of()), + registry.getDefaultCoder(listToken)); + + registry.registerCoder(MyValue.class, MyValueCoder.class); + TypeDescriptor>> kvToken = + new TypeDescriptor>>() {}; + assertEquals(KvCoder.of(StringUtf8Coder.of(), + ListCoder.of(MyValueCoder.of())), + registry.getDefaultCoder(kvToken)); + + } + + @Test + public void testParameterizedDefaultMapCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + TypeDescriptor> mapToken = new TypeDescriptor>() {}; + assertEquals(MapCoder.of(VarIntCoder.of(), StringUtf8Coder.of()), + registry.getDefaultCoder(mapToken)); + } + + @Test + public void testParameterizedDefaultNestedMapCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + TypeDescriptor>> mapToken = + new TypeDescriptor>>() {}; + assertEquals( + MapCoder.of(VarIntCoder.of(), MapCoder.of(StringUtf8Coder.of(), DoubleCoder.of())), + registry.getDefaultCoder(mapToken)); + } + + @Test + public void testParameterizedDefaultSetCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + TypeDescriptor> setToken = new TypeDescriptor>() {}; + assertEquals(SetCoder.of(VarIntCoder.of()), registry.getDefaultCoder(setToken)); + } + + @Test + public void testParameterizedDefaultNestedSetCoder() throws Exception { + CoderRegistry registry = getStandardRegistry(); + TypeDescriptor>> setToken = new TypeDescriptor>>() {}; + assertEquals(SetCoder.of(SetCoder.of(VarIntCoder.of())), registry.getDefaultCoder(setToken)); + } + + @Test + public void testParameterizedDefaultCoderUnknown() throws Exception { + CoderRegistry registry = getStandardRegistry(); + TypeDescriptor> listUnknownToken = new TypeDescriptor>() {}; + + thrown.expect(CannotProvideCoderException.class); + thrown.expectMessage(String.format( + "Cannot provide coder for parameterized type %s: Unable to provide a default Coder for %s", + listUnknownToken, + UnknownType.class.getCanonicalName())); + + registry.getDefaultCoder(listUnknownToken); + } + + @Test + public void testTypeParameterInferenceForward() throws Exception { + CoderRegistry registry = getStandardRegistry(); + MyGenericClass> instance = + new MyGenericClass>() {}; + + Coder bazCoder = registry.getDefaultCoder( + instance.getClass(), + MyGenericClass.class, + Collections.>singletonMap( + TypeDescriptor.of(MyGenericClass.class).getTypeParameter("FooT"), MyValueCoder.of()), + TypeDescriptor.of(MyGenericClass.class).getTypeParameter("BazT")); + + assertEquals(ListCoder.of(MyValueCoder.of()), bazCoder); + } + + @Test + public void testTypeParameterInferenceBackward() throws Exception { + CoderRegistry registry = getStandardRegistry(); + MyGenericClass> instance = + new MyGenericClass>() {}; + + Coder fooCoder = registry.getDefaultCoder( + instance.getClass(), + MyGenericClass.class, + Collections.>singletonMap( + TypeDescriptor.of(MyGenericClass.class).getTypeParameter("BazT"), + ListCoder.of(MyValueCoder.of())), + TypeDescriptor.of(MyGenericClass.class).getTypeParameter("FooT")); + + assertEquals(MyValueCoder.of(), fooCoder); + } + + @Test + public void testGetDefaultCoderFromIntegerValue() throws Exception { + CoderRegistry registry = getStandardRegistry(); + Integer i = 13; + Coder coder = registry.getDefaultCoder(i); + assertEquals(VarIntCoder.of(), coder); + } + + @Test + public void testGetDefaultCoderFromNullValue() throws Exception { + CoderRegistry registry = getStandardRegistry(); + assertEquals(VoidCoder.of(), registry.getDefaultCoder((Void) null)); + } + + @Test + public void testGetDefaultCoderFromKvValue() throws Exception { + CoderRegistry registry = getStandardRegistry(); + KV kv = KV.of(13, "hello"); + Coder> coder = registry.getDefaultCoder(kv); + assertEquals(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()), + coder); + } + + @Test + public void testGetDefaultCoderFromKvNullValue() throws Exception { + CoderRegistry registry = getStandardRegistry(); + KV kv = KV.of((Void) null, (Void) null); + assertEquals(KvCoder.of(VoidCoder.of(), VoidCoder.of()), + registry.getDefaultCoder(kv)); + } + + @Test + public void testGetDefaultCoderFromNestedKvValue() throws Exception { + CoderRegistry registry = getStandardRegistry(); + KV>> kv = KV.of(13, KV.of(17L, KV.of("hello", "goodbye"))); + Coder>>> coder = registry.getDefaultCoder(kv); + assertEquals( + KvCoder.of(VarIntCoder.of(), + KvCoder.of(VarLongCoder.of(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))), + coder); + } + + @Test + public void testTypeCompatibility() throws Exception { + CoderRegistry.verifyCompatible(BigEndianIntegerCoder.of(), Integer.class); + CoderRegistry.verifyCompatible( + ListCoder.of(BigEndianIntegerCoder.of()), + new TypeDescriptor>() {}.getType()); + } + + @Test + public void testIntVersusStringIncompatibility() throws Exception { + thrown.expect(IncompatibleCoderException.class); + thrown.expectMessage("not assignable"); + CoderRegistry.verifyCompatible(BigEndianIntegerCoder.of(), String.class); + } + + private static class TooManyComponentCoders extends ListCoder { + public TooManyComponentCoders(Coder actualComponentCoder) { + super(actualComponentCoder); + } + + @Override + public List> getCoderArguments() { + return ImmutableList.>builder() + .addAll(super.getCoderArguments()) + .add(BigEndianLongCoder.of()) + .build(); + } + } + + @Test + public void testTooManyCoderArguments() throws Exception { + thrown.expect(IncompatibleCoderException.class); + thrown.expectMessage("type parameters"); + thrown.expectMessage("less than the number of coder arguments"); + CoderRegistry.verifyCompatible( + new TooManyComponentCoders<>(BigEndianIntegerCoder.of()), List.class); + } + + @Test + public void testComponentIncompatibility() throws Exception { + thrown.expect(IncompatibleCoderException.class); + thrown.expectMessage("component coder is incompatible"); + CoderRegistry.verifyCompatible( + ListCoder.of(BigEndianIntegerCoder.of()), + new TypeDescriptor>() {}.getType()); + } + + @Test + public void testDefaultCoderAnnotationGenericRawtype() throws Exception { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + assertEquals( + registry.getDefaultCoder(MySerializableGeneric.class), + SerializableCoder.of(MySerializableGeneric.class)); + } + + @Test + public void testDefaultCoderAnnotationGeneric() throws Exception { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + assertEquals( + registry.getDefaultCoder(new TypeDescriptor>() {}), + SerializableCoder.of(MySerializableGeneric.class)); + } + + private static class PTransformOutputingMySerializableGeneric + extends PTransform, PCollection>>> { + + private class OutputDoFn extends DoFn>> { + @Override + public void processElement(ProcessContext c) { } + } + + @Override + public PCollection>> + apply(PCollection input) { + return input.apply(ParDo.of(new OutputDoFn())); + } + } + + /** + * Tests that the error message for a type variable includes a mention of where the + * type variable was declared. + */ + @Test + public void testTypeVariableErrorMessage() throws Exception { + CoderRegistry registry = new CoderRegistry(); + + thrown.expect(CannotProvideCoderException.class); + thrown.expectMessage(allOf( + containsString("TestGenericT"), + containsString("erasure"), + containsString("com.google.cloud.dataflow.sdk.coders.CoderRegistryTest$TestGenericClass"))); + registry.getDefaultCoder(TypeDescriptor.of( + TestGenericClass.class.getTypeParameters()[0])); + } + + private static class TestGenericClass { } + + /** + * In-context test that assures the functionality tested in + * {@link #testDefaultCoderAnnotationGeneric} is invoked in the right ways. + */ + @Test + public void testSpecializedButIgnoredGenericInPipeline() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of("hello", "goodbye")) + .apply(new PTransformOutputingMySerializableGeneric()); + + pipeline.run(); + } + + private static class GenericOutputMySerializedGeneric + extends PTransform< + PCollection, + PCollection>>> { + + private class OutputDoFn extends DoFn>> { + @Override + public void processElement(ProcessContext c) { } + } + + @Override + public PCollection>> + apply(PCollection input) { + return input.apply(ParDo.of(new OutputDoFn())); + } + } + + @Test + public void testIgnoredGenericInPipeline() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of("hello", "goodbye")) + .apply(new GenericOutputMySerializedGeneric()); + + pipeline.run(); + } + + private static class MyGenericClass { } + + private static class MyValue { } + + private static class MyValueCoder implements Coder { + + private static final MyValueCoder INSTANCE = new MyValueCoder(); + + public static MyValueCoder of() { + return INSTANCE; + } + + @SuppressWarnings("unused") + public static List getInstanceComponents( + @SuppressWarnings("unused") MyValue exampleValue) { + return Arrays.asList(); + } + + @Override + public void encode(MyValue value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public MyValue decode(InputStream inStream, Context context) + throws CoderException, IOException { + return new MyValue(); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public CloudObject asCloudObject() { + return null; + } + + @Override + public void verifyDeterministic() { } + + @Override + public boolean consistentWithEquals() { + return true; + } + + @Override + public Object structuralValue(MyValue value) { + return value; + } + + @Override + public boolean isRegisterByteSizeObserverCheap(MyValue value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + MyValue value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(0L); + } + + @Override + public String getEncodingId() { + return getClass().getName(); + } + + @Override + public Collection getAllowedEncodings() { + return Collections.singletonList(getEncodingId()); + } + } + + private static class UnknownType { } + + @DefaultCoder(SerializableCoder.class) + private static class MySerializableGeneric implements Serializable { + @SuppressWarnings("unused") + private T foo; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderTest.java new file mode 100644 index 000000000000..c5d275cf63f1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CoderTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; + +/** Tests for constructs defined within {@link Coder}. */ +@RunWith(JUnit4.class) +public class CoderTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testContextEqualsAndHashCode() { + assertEquals(Context.NESTED, new Context(false)); + assertEquals(Context.OUTER, new Context(true)); + assertNotEquals(Context.NESTED, Context.OUTER); + + assertEquals(Context.NESTED.hashCode(), new Context(false).hashCode()); + assertEquals(Context.OUTER.hashCode(), new Context(true).hashCode()); + // Even though this isn't strictly required by the hashCode contract, + // we still want this to be true. + assertNotEquals(Context.NESTED.hashCode(), Context.OUTER.hashCode()); + } + + @Test + public void testContextToString() { + assertEquals("Context{NESTED}", Context.NESTED.toString()); + assertEquals("Context{OUTER}", Context.OUTER.toString()); + } + + @Test + public void testNonDeterministicExcpetionRequiresReason() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Reasons must not be empty"); + new NonDeterministicException(VoidCoder.of(), Collections.emptyList()); + } + + @Test + public void testNonDeterministicException() { + NonDeterministicException rootCause = + new NonDeterministicException(VoidCoder.of(), "Root Cause"); + NonDeterministicException exception = + new NonDeterministicException(StringUtf8Coder.of(), "Problem", rootCause); + assertEquals(rootCause, exception.getCause()); + assertThat(exception.getReasons(), contains("Problem")); + assertThat(exception.toString(), containsString("Problem")); + assertThat(exception.toString(), containsString("is not deterministic")); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CollectionCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CollectionCoderTest.java new file mode 100644 index 000000000000..ae1d167df15e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CollectionCoderTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.TreeSet; + +/** + * Test case for {@link CollectionCoder}. + */ +@RunWith(JUnit4.class) +public class CollectionCoderTest { + + private static final Coder> TEST_CODER = CollectionCoder.of(VarIntCoder.of()); + + private static final List> TEST_VALUES = Arrays.>asList( + Collections.emptyList(), + Collections.emptySet(), + Collections.singletonList(13), + Arrays.asList(1, 2, 3, 4), + new LinkedList<>(Arrays.asList(7, 6, 5)), + new TreeSet<>(Arrays.asList(31, -5, 83))); + + @Test + public void testDecodeEncodeContentsEqual() throws Exception { + for (Collection value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeContentsEqual(TEST_CODER, value); + } + } + + // If this becomes nonempty, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAAA", + "AAAAAA", + "AAAAAQ0", + "AAAABAECAwQ", + "AAAAAwcGBQ", + "AAAAA_v___8PH1M"); + + @Test + public void testWireFormat() throws Exception { + CoderProperties.coderDecodesBase64ContentsEqual(TEST_CODER, TEST_ENCODINGS, TEST_VALUES); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Collection"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CustomCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CustomCoderTest.java new file mode 100644 index 000000000000..21f750362e0c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/CustomCoderTest.java @@ -0,0 +1,135 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** Unit tests for {@link CustomCoder}. */ +@RunWith(JUnit4.class) +public class CustomCoderTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private static class MyCustomCoder extends CustomCoder> { + private final String key; + + public MyCustomCoder(String key) { + this.key = key; + } + + @Override + public void encode(KV kv, OutputStream out, Context context) + throws IOException { + new DataOutputStream(out).writeLong(kv.getValue()); + } + + @Override + public KV decode(InputStream inStream, Context context) + throws IOException { + return KV.of(key, new DataInputStream(inStream).readLong()); + } + + @Override + public boolean equals(Object other) { + return other instanceof MyCustomCoder + && key.equals(((MyCustomCoder) other).key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + } + + @Test + public void testEncodeDecode() throws Exception { + MyCustomCoder coder = new MyCustomCoder("key"); + CoderProperties.coderDecodeEncodeEqual(coder, KV.of("key", 3L)); + + byte[] encoded2 = CoderUtils.encodeToByteArray(coder, KV.of("ignored", 3L)); + Assert.assertEquals( + KV.of("key", 3L), CoderUtils.decodeFromByteArray(coder, encoded2)); + } + + @Test + public void testEncodable() throws Exception { + SerializableUtils.ensureSerializable(new MyCustomCoder("key")); + } + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(new MyCustomCoder("foo"), + MyCustomCoder.class.getCanonicalName()); + } + + @Test + public void testAnonymousEncodingIdError() throws Exception { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("Anonymous CustomCoder subclass"); + thrown.expectMessage("must override getEncodingId()"); + new CustomCoder() { + + @Override + public void encode(Integer kv, OutputStream out, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public Integer decode(InputStream inStream, Context context) { + throw new UnsupportedOperationException(); + } + }.getEncodingId(); + } + + @Test + public void testAnonymousEncodingIdOk() throws Exception { + new CustomCoder() { + + @Override + public void encode(Integer kv, OutputStream out, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public Integer decode(InputStream inStream, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public String getEncodingId() { + return "A user must specify this. It can contain any character, including these: !$#%$@."; + } + }.getEncodingId(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DefaultCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DefaultCoderTest.java new file mode 100644 index 000000000000..498b64db572c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DefaultCoderTest.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.common.base.Preconditions.checkArgument; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.List; + +/** + * Tests of Coder defaults. + */ +@RunWith(JUnit4.class) +public class DefaultCoderTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + public CoderRegistry registry = new CoderRegistry(); + + @Before + public void registerStandardCoders() { + registry.registerStandardCoders(); + } + + @DefaultCoder(AvroCoder.class) + private static class AvroRecord { + } + + private static class SerializableBase implements Serializable { + } + + @DefaultCoder(SerializableCoder.class) + private static class SerializableRecord extends SerializableBase { + } + + @DefaultCoder(CustomSerializableCoder.class) + private static class CustomRecord extends SerializableBase { + } + + @DefaultCoder(OldCustomSerializableCoder.class) + private static class OldCustomRecord extends SerializableBase { + } + + private static class Unknown { + } + + private static class CustomSerializableCoder extends SerializableCoder { + // Extending SerializableCoder isn't trivial, but it can be done. + @SuppressWarnings("unchecked") + public static SerializableCoder of(TypeDescriptor recordType) { + checkArgument(recordType.isSupertypeOf(new TypeDescriptor() {})); + return (SerializableCoder) new CustomSerializableCoder(); + } + + protected CustomSerializableCoder() { + super(CustomRecord.class); + } + } + + private static class OldCustomSerializableCoder extends SerializableCoder { + // Extending SerializableCoder isn't trivial, but it can be done. + @Deprecated // old form using a Class + @SuppressWarnings("unchecked") + public static SerializableCoder of(Class recordType) { + checkArgument(OldCustomRecord.class.isAssignableFrom(recordType)); + return (SerializableCoder) new OldCustomSerializableCoder(); + } + + protected OldCustomSerializableCoder() { + super(OldCustomRecord.class); + } + } + + @Test + public void testDefaultCoderClasses() throws Exception { + assertThat(registry.getDefaultCoder(AvroRecord.class), instanceOf(AvroCoder.class)); + assertThat(registry.getDefaultCoder(SerializableBase.class), + instanceOf(SerializableCoder.class)); + assertThat(registry.getDefaultCoder(SerializableRecord.class), + instanceOf(SerializableCoder.class)); + assertThat(registry.getDefaultCoder(CustomRecord.class), + instanceOf(CustomSerializableCoder.class)); + assertThat(registry.getDefaultCoder(OldCustomRecord.class), + instanceOf(OldCustomSerializableCoder.class)); + } + + @Test + public void testDefaultCoderInCollection() throws Exception { + assertThat(registry.getDefaultCoder(new TypeDescriptor>(){}), + instanceOf(ListCoder.class)); + assertThat(registry.getDefaultCoder(new TypeDescriptor>(){}), + equalTo((Coder>) + ListCoder.of(SerializableCoder.of(SerializableRecord.class)))); + } + + @Test + public void testUnknown() throws Exception { + thrown.expect(CannotProvideCoderException.class); + registry.getDefaultCoder(Unknown.class); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DelegateCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DelegateCoderTest.java new file mode 100644 index 000000000000..3397818c90a5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DelegateCoderTest.java @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Unit tests for {@link DelegateCoder}. */ +@RunWith(JUnit4.class) +public class DelegateCoderTest implements Serializable { + + private static final List> TEST_VALUES = Arrays.>asList( + Collections.emptySet(), + Collections.singleton(13), + new HashSet<>(Arrays.asList(31, -5, 83))); + + private static final DelegateCoder, List> TEST_CODER = DelegateCoder.of( + ListCoder.of(VarIntCoder.of()), + new DelegateCoder.CodingFunction, List>() { + @Override + public List apply(Set input) { + return Lists.newArrayList(input); + } + }, + new DelegateCoder.CodingFunction, Set>() { + @Override + public Set apply(List input) { + return Sets.newHashSet(input); + } + }); + + @Test + public void testDeterministic() throws Exception { + for (Set value : TEST_VALUES) { + CoderProperties.coderDeterministic( + TEST_CODER, value, Sets.newHashSet(value)); + } + } + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Set value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testSerializable() throws Exception { + CoderProperties.coderSerializable(TEST_CODER); + } + + private static final String TEST_ENCODING_ID = "test-encoding-id"; + private static final String TEST_ALLOWED_ENCODING = "test-allowed-encoding"; + + private static class TestAllowedEncodingsCoder extends StandardCoder { + + @Override + public void encode(Integer value, OutputStream outstream, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public Integer decode(InputStream instream, Context context) { + throw new UnsupportedOperationException(); + } + + @Override + public void verifyDeterministic() { + throw new UnsupportedOperationException(); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + + @Override + public String getEncodingId() { + return TEST_ENCODING_ID; + } + + @Override + public Collection getAllowedEncodings() { + return Collections.singletonList(TEST_ALLOWED_ENCODING); + } + } + + @Test + public void testEncodingId() throws Exception { + Coder underlyingCoder = new TestAllowedEncodingsCoder(); + + Coder trivialDelegateCoder = DelegateCoder.of( + underlyingCoder, + new DelegateCoder.CodingFunction() { + @Override + public Integer apply(Integer input) { + return input; + } + }, + new DelegateCoder.CodingFunction() { + @Override + public Integer apply(Integer input) { + return input; + } + }); + CoderProperties.coderHasEncodingId( + trivialDelegateCoder, TestAllowedEncodingsCoder.class.getName() + ":" + TEST_ENCODING_ID); + CoderProperties.coderAllowsEncoding( + trivialDelegateCoder, + TestAllowedEncodingsCoder.class.getName() + ":" + TEST_ALLOWED_ENCODING); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DoubleCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DoubleCoderTest.java new file mode 100644 index 000000000000..8791eb4d46e9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DoubleCoderTest.java @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link DoubleCoder}. + */ +@RunWith(JUnit4.class) +public class DoubleCoderTest { + + private static final Coder TEST_CODER = DoubleCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + 0.0, -0.5, 0.5, 0.3, -0.3, 1.0, -43.89568740, 3.14159, + Double.MAX_VALUE, + Double.MIN_VALUE, + Double.POSITIVE_INFINITY, + Double.NEGATIVE_INFINITY, + Double.NaN); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Double value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // This should never change. The format is fixed by Java. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAAAAAAAA", + "v-AAAAAAAAA", + "P-AAAAAAAAA", + "P9MzMzMzMzM", + "v9MzMzMzMzM", + "P_AAAAAAAAA", + "wEXypeJ9ODo", + "QAkh-fAbhm4", + "f-________8", + "AAAAAAAAAAE", + "f_AAAAAAAAA", + "__AAAAAAAAA", + "f_gAAAAAAAA"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Double"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DurationCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DurationCoderTest.java new file mode 100644 index 000000000000..3d831b38c5b5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/DurationCoderTest.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.collect.Lists; + +import org.joda.time.Duration; +import org.joda.time.ReadableDuration; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link DurationCoder}. */ +@RunWith(JUnit4.class) +public class DurationCoderTest { + + private static final DurationCoder TEST_CODER = DurationCoder.of(); + private static final List TEST_MILLIS = + Lists.newArrayList(0L, 1L, -1L, -255L, 256L, Long.MIN_VALUE, Long.MAX_VALUE); + + private static final List TEST_VALUES; + + static { + TEST_VALUES = Lists.newArrayList(); + for (long millis : TEST_MILLIS) { + TEST_VALUES.add(Duration.millis(millis)); + } + } + + @Test + public void testBasicEncoding() throws Exception { + for (ReadableDuration value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AA", + "AQ", + "____________AQ", + "gf7_________AQ", + "gAI", + "gICAgICAgICAAQ", + "__________9_"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null ReadableDuration"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/EntityCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/EntityCoderTest.java new file mode 100644 index 000000000000..8ced1dc9d0a5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/EntityCoderTest.java @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static com.google.api.services.datastore.client.DatastoreHelper.makeKey; +import static com.google.api.services.datastore.client.DatastoreHelper.makeProperty; +import static com.google.api.services.datastore.client.DatastoreHelper.makeValue; + +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link EntityCoder}. + */ +@RunWith(JUnit4.class) +public class EntityCoderTest { + + private static final Coder TEST_CODER = EntityCoder.of(); + + // Presumably if anything works, everything works, + // as actual serialization is fully delegated to + // autogenerated code from a well-tested library. + private static final List TEST_VALUES = Arrays.asList( + Entity.newBuilder() + .setKey(makeKey("TestKind", "emptyEntity")) + .build(), + Entity.newBuilder() + .setKey(makeKey("TestKind", "testSimpleProperties")) + .addProperty(makeProperty("trueProperty", makeValue(true))) + .addProperty(makeProperty("falseProperty", makeValue(false))) + .addProperty(makeProperty("stringProperty", makeValue("hello"))) + .addProperty(makeProperty("integerProperty", makeValue(3))) + .addProperty(makeProperty("doubleProperty", makeValue(-1.583257))) + .build(), + Entity.newBuilder() + .setKey(makeKey("TestKind", "testNestedEntity")) + .addProperty(makeProperty("entityProperty", + makeValue(Entity.newBuilder() + .addProperty(makeProperty("stringProperty", makeValue("goodbye")))))) + .build()); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Entity value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAGwoZEhcKCFRlc3RLaW5kGgtlbXB0eUVudGl0eQ", + "AAAAnQoiEiAKCFRlc3RLaW5kGhR0ZXN0U2ltcGxlUHJvcGVydGllcxISCgx0cnVlUHJvcGVydHkiAggBEhMKDWZhbHNl" + + "UHJvcGVydHkiAggAEhoKDnN0cmluZ1Byb3BlcnR5IgiKAQVoZWxsbxIVCg9pbnRlZ2VyUHJvcGVydHkiAhADEh" + + "sKDmRvdWJsZVByb3BlcnR5IgkZ8ZvCSgVV-b8", + "AAAAVAoeEhwKCFRlc3RLaW5kGhB0ZXN0TmVzdGVkRW50aXR5EjIKDmVudGl0eVByb3BlcnR5IiAyHhIcCg5zdHJpbmdQ" + + "cm9wZXJ0eSIKigEHZ29vZGJ5ZQ"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Entity"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/InstantCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/InstantCoderTest.java new file mode 100644 index 000000000000..454500bf2418 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/InstantCoderTest.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.collect.Lists; +import com.google.common.primitives.UnsignedBytes; + +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** Unit tests for {@link InstantCoder}. */ +@RunWith(JUnit4.class) +public class InstantCoderTest { + + private static final InstantCoder TEST_CODER = InstantCoder.of(); + + private static final List TEST_TIMESTAMPS = + Arrays.asList(0L, 1L, -1L, -255L, 256L, Long.MIN_VALUE, Long.MAX_VALUE); + + private static final List TEST_VALUES; + + static { + TEST_VALUES = Lists.newArrayList(); + for (long timestamp : TEST_TIMESTAMPS) { + TEST_VALUES.add(new Instant(timestamp)); + } + } + + @Test + public void testBasicEncoding() throws Exception { + for (Instant value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testOrderedEncoding() throws Exception { + List sortedTimestamps = new ArrayList<>(TEST_TIMESTAMPS); + Collections.sort(sortedTimestamps); + + List encodings = new ArrayList<>(sortedTimestamps.size()); + for (long timestamp : sortedTimestamps) { + encodings.add(CoderUtils.encodeToByteArray(TEST_CODER, new Instant(timestamp))); + } + + // Verify that the encodings were already sorted, since they were generated + // in the correct order. + List sortedEncodings = new ArrayList<>(encodings); + Collections.sort(sortedEncodings, UnsignedBytes.lexicographicalComparator()); + + Assert.assertEquals(encodings, sortedEncodings); + } + + // If this changes, it implies that the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "gAAAAAAAAAA", + "gAAAAAAAAAE", + "f_________8", + "f________wE", + "gAAAAAAAAQA", + "AAAAAAAAAAA", + "__________8"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Instant"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/IterableCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/IterableCoderTest.java new file mode 100644 index 000000000000..9afddc1b998b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/IterableCoderTest.java @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +/** Unit tests for {@link IterableCoder}. */ +@RunWith(JUnit4.class) +public class IterableCoderTest { + + private static final Coder> TEST_CODER = IterableCoder.of(VarIntCoder.of()); + + private static final List> TEST_VALUES = Arrays.>asList( + Collections.emptyList(), + Collections.singletonList(13), + Arrays.asList(1, 2, 3, 4), + new LinkedList(Arrays.asList(7, 6, 5))); + + @Test + public void testDecodeEncodeContentsInSameOrder() throws Exception { + for (Iterable value : TEST_VALUES) { + CoderProperties.>coderDecodeEncodeContentsInSameOrder( + TEST_CODER, value); + } + } + + @Test + public void testGetInstanceComponentsNonempty() { + Iterable iterable = Arrays.asList(2, 58, 99, 5); + List components = IterableCoder.getInstanceComponents(iterable); + assertEquals(1, components.size()); + assertEquals(2, components.get(0)); + } + + @Test + public void testGetInstanceComponentsEmpty() { + Iterable iterable = Arrays.asList(); + List components = IterableCoder.getInstanceComponents(iterable); + assertNull(components); + } + + @Test + public void testCoderSerializable() throws Exception { + CoderProperties.coderSerializable(TEST_CODER); + } + + // If this changes, it implies that the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAAA", + "AAAAAQ0", + "AAAABAECAwQ", + "AAAAAwcGBQ"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Iterable"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java new file mode 100644 index 000000000000..7e1effff5df1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import javax.xml.bind.annotation.XmlRootElement; + +/** Unit tests for {@link JAXBCoder}. */ +@RunWith(JUnit4.class) +public class JAXBCoderTest { + + @XmlRootElement + static class TestType { + private String testString = null; + private int testInt; + + public TestType() {} + + public TestType(String testString, int testInt) { + this.testString = testString; + this.testInt = testInt; + } + + public String getTestString() { + return testString; + } + + public void setTestString(String testString) { + this.testString = testString; + } + + public int getTestInt() { + return testInt; + } + + public void setTestInt(int testInt) { + this.testInt = testInt; + } + + @Override + public int hashCode() { + int hashCode = 1; + hashCode = 31 * hashCode + (testString == null ? 0 : testString.hashCode()); + hashCode = 31 * hashCode + testInt; + return hashCode; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof TestType)) { + return false; + } + + TestType other = (TestType) obj; + return (testString == null || testString.equals(other.testString)) + && (testInt == other.testInt); + } + } + + @Test + public void testEncodeDecode() throws Exception { + JAXBCoder coder = JAXBCoder.of(TestType.class); + + byte[] encoded = CoderUtils.encodeToByteArray(coder, new TestType("abc", 9999)); + Assert.assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); + } + + @Test + public void testEncodable() throws Exception { + CoderProperties.coderSerializable(JAXBCoder.of(TestType.class)); + } + + @Test + public void testEncodingId() throws Exception { + Coder coder = JAXBCoder.of(TestType.class); + CoderProperties.coderHasEncodingId( + coder, TestType.class.getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/KvCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/KvCoderTest.java new file mode 100644 index 000000000000..0fd4c1b4d8ed --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/KvCoderTest.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.ImmutableMap; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Test case for {@link KvCoder}. + */ +@RunWith(JUnit4.class) +public class KvCoderTest { + + private static final Map, Iterable> TEST_DATA = + new ImmutableMap.Builder, Iterable>() + .put(VarIntCoder.of(), + Arrays.asList(-1, 0, 1, 13, Integer.MAX_VALUE, Integer.MIN_VALUE)) + .put(BigEndianLongCoder.of(), + Arrays.asList(-1L, 0L, 1L, 13L, Long.MAX_VALUE, Long.MIN_VALUE)) + .put(StringUtf8Coder.of(), + Arrays.asList("", "hello", "goodbye", "1")) + .put(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()), + Arrays.asList(KV.of("", -1), KV.of("hello", 0), KV.of("goodbye", Integer.MAX_VALUE))) + .put(ListCoder.of(VarLongCoder.of()), + Arrays.asList( + Arrays.asList(1L, 2L, 3L), + Collections.emptyList())) + .build(); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Map.Entry, Iterable> entry : TEST_DATA.entrySet()) { + // The coder and corresponding values must be the same type. + // If someone messes this up in the above test data, the test + // will fail anyhow (unless the coder magically works on data + // it does not understand). + @SuppressWarnings("unchecked") + Coder coder = (Coder) entry.getKey(); + Iterable values = entry.getValue(); + for (Object value : values) { + CoderProperties.coderDecodeEncodeEqual(coder, value); + } + } + } + + // If this changes, it implies the binary format has changed! + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId( + KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), + EXPECTED_ENCODING_ID); + } + + /** + * Homogeneously typed test value for ease of use with the wire format test utility. + */ + private static final Coder> TEST_CODER = + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()); + + private static final List> TEST_VALUES = Arrays.asList( + KV.of("", -1), + KV.of("hello", 0), + KV.of("goodbye", Integer.MAX_VALUE)); + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AP____8P", + "BWhlbGxvAA", + "B2dvb2RieWX_____Bw"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null KV"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ListCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ListCoderTest.java new file mode 100644 index 000000000000..6993f323e9ce --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/ListCoderTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +/** Unit tests for {@link ListCoder}. */ +@RunWith(JUnit4.class) +public class ListCoderTest { + + private static final Coder> TEST_CODER = ListCoder.of(VarIntCoder.of()); + + private static final List> TEST_VALUES = Arrays.>asList( + Collections.emptyList(), + Collections.singletonList(43), + Arrays.asList(1, 2, 3, 4), + new LinkedList(Arrays.asList(7, 6, 5))); + + @Test + public void testDecodeEncodeContentsInSameOrder() throws Exception { + for (List value : TEST_VALUES) { + CoderProperties.>coderDecodeEncodeContentsInSameOrder( + TEST_CODER, value); + } + } + + @Test + public void testGetInstanceComponentsNonempty() throws Exception { + List list = Arrays.asList(21, 5, 3, 5); + List components = ListCoder.getInstanceComponents(list); + assertEquals(1, components.size()); + assertEquals(21, components.get(0)); + } + + @Test + public void testGetInstanceComponentsEmpty() throws Exception { + List list = Arrays.asList(); + List components = ListCoder.getInstanceComponents(list); + assertNull(components); + } + + @Test + public void testEmptyList() throws Exception { + List list = Collections.emptyList(); + Coder> coder = ListCoder.of(VarIntCoder.of()); + CoderProperties.>coderDecodeEncodeEqual(coder, list); + } + + @Test + public void testCoderSerializable() throws Exception { + CoderProperties.coderSerializable(TEST_CODER); + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAAA", + "AAAAASs", + "AAAABAECAwQ", + "AAAAAwcGBQ"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null List"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } + + @Test + public void testListWithNullsAndVarIntCoderThrowsException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Integer"); + + List list = Arrays.asList(1, 2, 3, null, 4); + Coder> coder = ListCoder.of(VarIntCoder.of()); + CoderProperties.>coderDecodeEncodeEqual(coder, list); + } + + @Test + public void testListWithNullsAndSerializableCoder() throws Exception { + List list = Arrays.asList(1, 2, 3, null, 4); + Coder> coder = ListCoder.of(SerializableCoder.of(Integer.class)); + CoderProperties.>coderDecodeEncodeEqual(coder, list); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/MapCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/MapCoderTest.java new file mode 100644 index 000000000000..c263d2722d76 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/MapCoderTest.java @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.collect.ImmutableMap; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** Unit tests for {@link MapCoder}. */ +@RunWith(JUnit4.class) +public class MapCoderTest { + + private static final Coder> TEST_CODER = + MapCoder.of(VarIntCoder.of(), StringUtf8Coder.of()); + + private static final List> TEST_VALUES = Arrays.>asList( + Collections.emptyMap(), + new TreeMap(new ImmutableMap.Builder() + .put(1, "hello").put(-1, "foo").build())); + + @Test + public void testDecodeEncodeContentsInSameOrder() throws Exception { + for (Map value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testGetInstanceComponentsNonempty() { + Map map = new HashMap<>(); + map.put(17, "foozle"); + List components = MapCoder.getInstanceComponents(map); + assertEquals(2, components.size()); + assertEquals(17, components.get(0)); + assertEquals("foozle", components.get(1)); + } + + @Test + public void testGetInstanceComponentsEmpty() { + Map map = new HashMap<>(); + List components = MapCoder.getInstanceComponents(map); + assertNull(components); + } + + // If this changes, it implies the binary format has changed! + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAAA", + "AAAAAv____8PA2ZvbwEFaGVsbG8"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Map"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/NullableCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/NullableCoderTest.java new file mode 100644 index 000000000000..644930e72121 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/NullableCoderTest.java @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link NullableCoder}. */ +@RunWith(JUnit4.class) +public class NullableCoderTest { + + private static final Coder TEST_CODER = NullableCoder.of(StringUtf8Coder.of()); + + private static final List TEST_VALUES = Arrays.asList( + "", "a", "13", "hello", + null, + "a longer string with spaces and all that", + "a string with a \n newline", + "スタリング"); + + @Test + public void testDecodeEncodeContentsInSameOrder() throws Exception { + for (String value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testCoderSerializable() throws Exception { + CoderProperties.coderSerializable(TEST_CODER); + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@code PrintBase64Encodings}. + * + * @see com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AQA", + "AQFh", + "AQIxMw", + "AQVoZWxsbw", + "AA", + "AShhIGxvbmdlciBzdHJpbmcgd2l0aCBzcGFjZXMgYW5kIGFsbCB0aGF0", + "ARlhIHN0cmluZyB3aXRoIGEgCiBuZXdsaW5l", + "AQ_jgrnjgr_jg6rjg7PjgrA"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Test + public void testEncodedSize() throws Exception { + NullableCoder coder = NullableCoder.of(DoubleCoder.of()); + assertEquals(1, coder.getEncodedElementByteSize(null, Coder.Context.OUTER)); + assertEquals(9, coder.getEncodedElementByteSize(5.0, Coder.Context.OUTER)); + } + + @Test + public void testObserverIsCheap() throws Exception { + NullableCoder coder = NullableCoder.of(DoubleCoder.of()); + assertTrue(coder.isRegisterByteSizeObserverCheap(null, Coder.Context.OUTER)); + assertTrue(coder.isRegisterByteSizeObserverCheap(5.0, Coder.Context.OUTER)); + } + + @Test + public void testObserverIsNotCheap() throws Exception { + NullableCoder> coder = NullableCoder.of(ListCoder.of(StringUtf8Coder.of())); + assertFalse(coder.isRegisterByteSizeObserverCheap(null, Coder.Context.OUTER)); + assertFalse(coder.isRegisterByteSizeObserverCheap( + ImmutableList.of("hi", "test"), Coder.Context.OUTER)); + } + + @Test + public void testStructuralValueConsistentWithEquals() throws Exception { + CoderProperties.structuralValueConsistentWithEquals(TEST_CODER, null, null); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testDecodingError() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage(equalTo("NullableCoder expects either a byte valued 0 (null) " + + "or 1 (present), got 5")); + + InputStream input = new ByteArrayInputStream(new byte[] {5}); + TEST_CODER.decode(input, Coder.Context.OUTER); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/PrintBase64Encodings.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/PrintBase64Encodings.java new file mode 100644 index 000000000000..0f08262aa2d6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/PrintBase64Encodings.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.List; + +/** + * A command-line utility for printing the base-64 encodings of test values, for generating exact + * wire format tests. + * + *

    For internal use only. + * + *

    Example invocation via maven: + * {@code + * mvn test-compile exec:java \ + * -Dexec.classpathScope=test \ + * -Dexec.mainClass=com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings + * -Dexec.args='com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoderTest.TEST_CODER \ + * com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoderTest.TEST_VALUES' + * } + */ +public class PrintBase64Encodings { + + /** + * Gets a field even if it is private, which the test data generally will be. + */ + private static Field getField(Class clazz, String fieldName) throws Exception { + for (Field field : clazz.getDeclaredFields()) { + if (field.getName().equals(fieldName)) { + if (!Modifier.isPublic(field.getModifiers())) { + field.setAccessible(true); + } + return field; + } + } + throw new NoSuchFieldException(clazz.getCanonicalName() + "." + fieldName); + } + + private static Object getFullyQualifiedValue(String fullyQualifiedName) throws Exception { + int lastDot = fullyQualifiedName.lastIndexOf("."); + String className = fullyQualifiedName.substring(0, lastDot); + String fieldName = fullyQualifiedName.substring(lastDot + 1); + + Class clazz = Class.forName(className); + Field field = getField(clazz, fieldName); + return field.get(null); + } + + public static void main(String[] argv) throws Exception { + @SuppressWarnings("unchecked") + Coder testCoder = (Coder) getFullyQualifiedValue(argv[0]); + @SuppressWarnings("unchecked") + List testValues = (List) getFullyQualifiedValue(argv[1]); + + List base64Encodings = Lists.newArrayList(); + for (Object value : testValues) { + base64Encodings.add(CoderUtils.encodeToBase64(testCoder, value)); + } + System.out.println(String.format("\"%s\"", Joiner.on("\",\n\"").join(base64Encodings))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/Proto2CoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/Proto2CoderTest.java new file mode 100644 index 000000000000..91ebc6543827 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/Proto2CoderTest.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageA; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageB; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageC; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for Proto2Coder. + */ +@SuppressWarnings("deprecation") // test of a deprecated coder. +@RunWith(JUnit4.class) +public class Proto2CoderTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testFactoryMethodAgreement() throws Exception { + assertEquals( + Proto2Coder.of(new TypeDescriptor() {}), + Proto2Coder.of(MessageA.class)); + + assertEquals( + Proto2Coder.of(new TypeDescriptor() {}), + Proto2Coder.coderProvider().getCoder(new TypeDescriptor() {})); + } + + @Test + public void testProviderCannotProvideCoder() throws Exception { + thrown.expect(CannotProvideCoderException.class); + Proto2Coder.coderProvider().getCoder(new TypeDescriptor() {}); + } + + @Test + public void testCoderEncodeDecodeEqual() throws Exception { + MessageA value = MessageA.newBuilder() + .setField1("hello") + .addField2(MessageB.newBuilder() + .setField1(true).build()) + .addField2(MessageB.newBuilder() + .setField1(false).build()) + .build(); + CoderProperties.coderDecodeEncodeEqual(Proto2Coder.of(MessageA.class), value); + } + + @Test + public void testCoderEncodeDecodeEqualNestedContext() throws Exception { + MessageA value1 = MessageA.newBuilder() + .setField1("hello") + .addField2(MessageB.newBuilder() + .setField1(true).build()) + .addField2(MessageB.newBuilder() + .setField1(false).build()) + .build(); + MessageA value2 = MessageA.newBuilder() + .setField1("world") + .addField2(MessageB.newBuilder() + .setField1(false).build()) + .addField2(MessageB.newBuilder() + .setField1(true).build()) + .build(); + CoderProperties.coderDecodeEncodeEqual( + ListCoder.of(Proto2Coder.of(MessageA.class)), + ImmutableList.of(value1, value2)); + } + + @Test + public void testCoderEncodeDecodeExtensionsEqual() throws Exception { + MessageC value = MessageC.newBuilder() + .setExtension(Proto2CoderTestMessages.field1, + MessageA.newBuilder() + .setField1("hello") + .addField2(MessageB.newBuilder() + .setField1(true) + .build()) + .build()) + .setExtension(Proto2CoderTestMessages.field2, + MessageB.newBuilder() + .setField1(false) + .build()) + .build(); + CoderProperties.coderDecodeEncodeEqual( + Proto2Coder.of(MessageC.class).withExtensionsFrom(Proto2CoderTestMessages.class), + value); + } + + @Test + public void testCoderSerialization() throws Exception { + Proto2Coder coder = Proto2Coder.of(MessageA.class); + CoderProperties.coderSerializable(coder); + } + + @Test + public void testCoderExtensionsSerialization() throws Exception { + Proto2Coder coder = Proto2Coder.of(MessageC.class) + .withExtensionsFrom(Proto2CoderTestMessages.class); + CoderProperties.coderSerializable(coder); + } + + @Test + public void testEncodingId() throws Exception { + Coder coderA = Proto2Coder.of(MessageA.class); + CoderProperties.coderHasEncodingId(coderA, MessageA.class.getName()); + + Proto2Coder coder = Proto2Coder.of(MessageC.class) + .withExtensionsFrom(Proto2CoderTestMessages.class); + CoderProperties.coderHasEncodingId(coder, MessageC.class.getName()); + } + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null MessageA"); + + CoderUtils.encodeToBase64(Proto2Coder.of(MessageA.class), null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SerializableCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SerializableCoderTest.java new file mode 100644 index 000000000000..b819967bdc7c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SerializableCoderTest.java @@ -0,0 +1,222 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.Serializer; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests SerializableCoder. + */ +@RunWith(JUnit4.class) +public class SerializableCoderTest implements Serializable { + + @DefaultCoder(SerializableCoder.class) + static class MyRecord implements Serializable { + private static final long serialVersionUID = 42L; + + public String value; + + public MyRecord(String value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + MyRecord myRecord = (MyRecord) o; + return value.equals(myRecord.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + } + + static class StringToRecord extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(new MyRecord(c.element())); + } + } + + static class RecordToString extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().value); + } + } + + static final List LINES = Arrays.asList( + "To be,", + "or not to be"); + + @Test + public void testSerializableCoder() throws Exception { + IterableCoder coder = IterableCoder + .of(SerializableCoder.of(MyRecord.class)); + + List records = new LinkedList<>(); + for (String l : LINES) { + records.add(new MyRecord(l)); + } + + byte[] encoded = CoderUtils.encodeToByteArray(coder, records); + Iterable decoded = CoderUtils.decodeFromByteArray(coder, encoded); + + assertEquals(records, decoded); + } + + @Test + public void testSerializableCoderConstruction() throws Exception { + SerializableCoder coder = SerializableCoder.of(MyRecord.class); + assertEquals(coder.getRecordType(), MyRecord.class); + + CloudObject encoding = coder.asCloudObject(); + Assert.assertThat(encoding.getClassName(), + Matchers.containsString(SerializableCoder.class.getSimpleName())); + + Coder decoded = Serializer.deserialize(encoding, Coder.class); + Assert.assertThat(decoded, Matchers.instanceOf(SerializableCoder.class)); + } + + @Test + public void testDefaultCoder() throws Exception { + Pipeline p = TestPipeline.create(); + + // Use MyRecord as input and output types without explicitly specifying + // a coder (this uses the default coders, which may not be + // SerializableCoder). + PCollection output = + p.apply(Create.of("Hello", "World")) + .apply(ParDo.of(new StringToRecord())) + .apply(ParDo.of(new RecordToString())); + + DataflowAssert.that(output) + .containsInAnyOrder("Hello", "World"); + } + + @Test + public void testLongStringEncoding() throws Exception { + StringUtf8Coder coder = StringUtf8Coder.of(); + + // Java's DataOutputStream.writeUTF fails at 64k, so test well beyond that. + char[] chars = new char[100 * 1024]; + Arrays.fill(chars, 'o'); + String source = new String(chars); + + // Verify OUTER encoding. + assertEquals(source, CoderUtils.decodeFromByteArray(coder, + CoderUtils.encodeToByteArray(coder, source))); + + // Second string uses a UTF8 character. Each codepoint is translated into + // 4 characters in UTF8. + int[] codePoints = new int[20 * 1024]; + Arrays.fill(codePoints, 0x1D50A); // "MATHEMATICAL_FRAKTUR_CAPITAL_G" + String source2 = new String(codePoints, 0, codePoints.length); + + // Verify OUTER encoding. + assertEquals(source2, CoderUtils.decodeFromByteArray(coder, + CoderUtils.encodeToByteArray(coder, source2))); + + + // Encode both strings into NESTED form. + byte[] nestedEncoding; + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + coder.encode(source, os, Coder.Context.NESTED); + coder.encode(source2, os, Coder.Context.NESTED); + nestedEncoding = os.toByteArray(); + } + + // Decode from NESTED form. + try (ByteArrayInputStream is = new ByteArrayInputStream(nestedEncoding)) { + assertEquals(source, coder.decode(is, Coder.Context.NESTED)); + assertEquals(source2, coder.decode(is, Coder.Context.NESTED)); + assertEquals(0, is.available()); + } + } + + @Test + public void testNullEncoding() throws Exception { + Coder coder = SerializableCoder.of(String.class); + byte[] encodedBytes = CoderUtils.encodeToByteArray(coder, null); + assertNull(CoderUtils.decodeFromByteArray(coder, encodedBytes)); + } + + @Test + public void testMixedWithNullsEncoding() throws Exception { + Coder coder = SerializableCoder.of(String.class); + byte[] encodedBytes; + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + coder.encode(null, os, Coder.Context.NESTED); + coder.encode("TestValue", os, Coder.Context.NESTED); + coder.encode(null, os, Coder.Context.NESTED); + coder.encode("TestValue2", os, Coder.Context.NESTED); + coder.encode(null, os, Coder.Context.NESTED); + encodedBytes = os.toByteArray(); + } + + try (ByteArrayInputStream is = new ByteArrayInputStream(encodedBytes)) { + assertNull(coder.decode(is, Coder.Context.NESTED)); + assertEquals("TestValue", coder.decode(is, Coder.Context.NESTED)); + assertNull(coder.decode(is, Coder.Context.NESTED)); + assertEquals("TestValue2", coder.decode(is, Coder.Context.NESTED)); + assertNull(coder.decode(is, Coder.Context.NESTED)); + assertEquals(0, is.available()); + } + } + + @Test + public void testPojoEncodingId() throws Exception { + Coder coder = SerializableCoder.of(MyRecord.class); + CoderProperties.coderHasEncodingId( + coder, + String.format("%s:%s", MyRecord.class.getName(), MyRecord.serialVersionUID)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SetCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SetCoderTest.java new file mode 100644 index 000000000000..42e560ca902a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/SetCoderTest.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +/** + * Test case for {@link SetCoder}. + */ +@RunWith(JUnit4.class) +public class SetCoderTest { + + private static final Coder> TEST_CODER = SetCoder.of(VarIntCoder.of()); + + private static final List> TEST_VALUES = Arrays.>asList( + Collections.emptySet(), + Collections.singleton(13), + new TreeSet<>(Arrays.asList(31, -5, 83))); + + @Test + public void testDecodeEncodeContentsEqual() throws Exception { + for (Set value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeContentsEqual(TEST_CODER, value); + } + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "AAAAAA", + "AAAAAQ0", + "AAAAA_v___8PH1M"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Set"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StandardCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StandardCoderTest.java new file mode 100644 index 000000000000..8a1437471294 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StandardCoderTest.java @@ -0,0 +1,176 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Test case for {@link StandardCoder}. + */ +@RunWith(JUnit4.class) +public class StandardCoderTest { + + /** + * A coder for nullable {@code Boolean} values that is consistent with equals. + */ + private static class NullBooleanCoder extends StandardCoder { + + private static final long serialVersionUID = 0L; + + @Override + public void encode(@Nullable Boolean value, OutputStream outStream, Context context) + throws CoderException, IOException { + if (value == null) { + outStream.write(2); + } else if (value) { + outStream.write(1); + } else { + outStream.write(0); + } + } + + @Override + @Nullable + public Boolean decode( + InputStream inStream, com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException { + int value = inStream.read(); + if (value == 0) { + return false; + } else if (value == 1) { + return true; + } else if (value == 2) { + return null; + } + throw new CoderException("Invalid value for nullable Boolean: " + value); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { } + + @Override + public boolean consistentWithEquals() { + return true; + } + } + + /** + * A boxed {@code int} with {@code equals()} that compares object identity. + */ + private static class ObjectIdentityBoolean { + private final boolean value; + public ObjectIdentityBoolean(boolean value) { + this.value = value; + } + public boolean getValue() { + return value; + } + } + + /** + * A coder for nullable boxed {@code Boolean} values that is not consistent with equals. + */ + private static class ObjectIdentityBooleanCoder extends StandardCoder { + + private static final long serialVersionUID = 0L; + + @Override + public void encode( + @Nullable ObjectIdentityBoolean value, OutputStream outStream, Context context) + throws CoderException, IOException { + if (value == null) { + outStream.write(2); + } else if (value.getValue()){ + outStream.write(1); + } else { + outStream.write(0); + } + } + + @Override + @Nullable + public ObjectIdentityBoolean decode( + InputStream inStream, com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException { + int value = inStream.read(); + if (value == 0) { + return new ObjectIdentityBoolean(false); + } else if (value == 1) { + return new ObjectIdentityBoolean(true); + } else if (value == 2) { + return null; + } + throw new CoderException("Invalid value for nullable Boolean: " + value); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { } + + @Override + public boolean consistentWithEquals() { + return false; + } + } + + /** + * Tests that {@link StandardCoder#structuralValue()} is correct whenever a subclass has a correct + * {@link Coder#consistentWithEquals()}. + */ + @Test + public void testStructuralValue() throws Exception { + List testBooleans = Arrays.asList(null, true, false); + List testInconsistentBooleans = + Arrays.asList(null, new ObjectIdentityBoolean(true), new ObjectIdentityBoolean(false)); + + Coder consistentCoder = new NullBooleanCoder(); + for (Boolean value1 : testBooleans) { + for (Boolean value2 : testBooleans) { + CoderProperties.structuralValueConsistentWithEquals(consistentCoder, value1, value2); + } + } + + Coder inconsistentCoder = new ObjectIdentityBooleanCoder(); + for (ObjectIdentityBoolean value1 : testInconsistentBooleans) { + for (ObjectIdentityBoolean value2 : testInconsistentBooleans) { + CoderProperties.structuralValueConsistentWithEquals(inconsistentCoder, value1, value2); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StringDelegateCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StringDelegateCoderTest.java new file mode 100644 index 000000000000..46b3997b20ca --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StringDelegateCoderTest.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link StringDelegateCoder}. */ +@RunWith(JUnit4.class) +public class StringDelegateCoderTest { + + // Test data + + private static final Coder uriCoder = StringDelegateCoder.of(URI.class); + + private static final List TEST_URI_STRINGS = Arrays.asList( + "http://www.example.com", + "gs://myproject/mybucket/some/gcs/path", + "/just/some/path", + "file:/path/with/no/authority", + "file:///path/with/empty/authority"); + + // Tests + + @Test + public void testDeterministic() throws Exception, NonDeterministicException { + uriCoder.verifyDeterministic(); + for (String uriString : TEST_URI_STRINGS) { + CoderProperties.coderDeterministic(uriCoder, new URI(uriString), new URI(uriString)); + } + } + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (String uriString : TEST_URI_STRINGS) { + CoderProperties.coderDecodeEncodeEqual(uriCoder, new URI(uriString)); + } + } + + @Test + public void testSerializable() throws Exception { + CoderProperties.coderSerializable(uriCoder); + } + + @Test + public void testEncodingId() throws Exception { + StringDelegateCoder coder = StringDelegateCoder.of(URI.class); + CoderProperties.coderHasEncodingId(coder, URI.class.getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StringUtf8CoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StringUtf8CoderTest.java new file mode 100644 index 000000000000..7f40fc0bfd6c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StringUtf8CoderTest.java @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link StringUtf8Coder}. + */ +@RunWith(JUnit4.class) +public class StringUtf8CoderTest { + + private static final Coder TEST_CODER = StringUtf8Coder.of(); + + private static final List TEST_VALUES = Arrays.asList( + "", "a", "13", "hello", + "a longer string with spaces and all that", + "a string with a \n newline", + "スタリング"); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (String value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "", + "YQ", + "MTM", + "aGVsbG8", + "YSBsb25nZXIgc3RyaW5nIHdpdGggc3BhY2VzIGFuZCBhbGwgdGhhdA", + "YSBzdHJpbmcgd2l0aCBhIAogbmV3bGluZQ", + "44K544K_44Oq44Oz44Kw"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null String"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StructuralByteArrayTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StructuralByteArrayTest.java new file mode 100644 index 000000000000..8f8cd8cb081f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/StructuralByteArrayTest.java @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.coders; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link StructuralByteArray}. + */ +@RunWith(JUnit4.class) +public final class StructuralByteArrayTest { + + @Test + public void testStructuralByteArray() throws Exception { + assertEquals( + new StructuralByteArray("test string".getBytes()), + new StructuralByteArray("test string".getBytes())); + assertFalse(new StructuralByteArray("test string".getBytes()).equals( + new StructuralByteArray("diff string".getBytes()))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoderTest.java new file mode 100644 index 000000000000..d37692882365 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/TableRowJsonCoderTest.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link TableRowJsonCoder}. + */ +@RunWith(JUnit4.class) +public class TableRowJsonCoderTest { + + private static class TableRowBuilder { + private TableRow row; + public TableRowBuilder() { + row = new TableRow(); + } + public TableRowBuilder set(String fieldName, Object value) { + row.set(fieldName, value); + return this; + } + public TableRow build() { + return row; + } + } + + private static final Coder TEST_CODER = TableRowJsonCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + new TableRowBuilder().build(), + new TableRowBuilder().set("a", "1").build(), + new TableRowBuilder().set("b", 3.14).build(), + new TableRowBuilder().set("a", "1").set("b", true).set("c", "hi").build()); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (TableRow value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // This identifier should only change if the JSON format of results from the BigQuery API changes. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "e30", + "eyJhIjoiMSJ9", + "eyJiIjozLjE0fQ", + "eyJhIjoiMSIsImIiOnRydWUsImMiOiJoaSJ9"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoderTest.java new file mode 100644 index 000000000000..5ccff309d604 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/TextualIntegerCoderTest.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link TextualIntegerCoder}. + */ +@RunWith(JUnit4.class) +public class TextualIntegerCoderTest { + + private static final Coder TEST_CODER = TextualIntegerCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + -11, -3, -1, 0, 1, 5, 13, 29, + Integer.MAX_VALUE, + Integer.MIN_VALUE); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Integer value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // This should never change. The textual representation of an integer is fixed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "LTEx", + "LTM", + "LTE", + "MA", + "MQ", + "NQ", + "MTM", + "Mjk", + "MjE0NzQ4MzY0Nw", + "LTIxNDc0ODM2NDg"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Integer"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/VarIntCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/VarIntCoderTest.java new file mode 100644 index 000000000000..cce328036179 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/VarIntCoderTest.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link VarIntCoder}. + */ +@RunWith(JUnit4.class) +public class VarIntCoderTest { + + private static final Coder TEST_CODER = VarIntCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + -11, -3, -1, 0, 1, 5, 13, 29, + Integer.MAX_VALUE, + Integer.MIN_VALUE); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Integer value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "9f___w8", + "_f___w8", + "_____w8", + "AA", + "AQ", + "BQ", + "DQ", + "HQ", + "_____wc", + "gICAgAg"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Integer"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/VarLongCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/VarLongCoderTest.java new file mode 100644 index 000000000000..a371af3dfe4d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/VarLongCoderTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders; + +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Test case for {@link VarLongCoder}. + */ +@RunWith(JUnit4.class) +public class VarLongCoderTest { + + private static final Coder TEST_CODER = VarLongCoder.of(); + + private static final List TEST_VALUES = Arrays.asList( + -11L, -3L, -1L, 0L, 1L, 5L, 13L, 29L, + Integer.MAX_VALUE + 131L, + Integer.MIN_VALUE - 29L, + Long.MAX_VALUE, + Long.MIN_VALUE); + + @Test + public void testDecodeEncodeEqual() throws Exception { + for (Long value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + // If this changes, it implies the binary format has changed. + private static final String EXPECTED_ENCODING_ID = ""; + + @Test + public void testEncodingId() throws Exception { + CoderProperties.coderHasEncodingId(TEST_CODER, EXPECTED_ENCODING_ID); + } + + /** + * Generated data to check that the wire format has not changed. To regenerate, see + * {@link com.google.cloud.dataflow.sdk.coders.PrintBase64Encodings}. + */ + private static final List TEST_ENCODINGS = Arrays.asList( + "9f__________AQ", + "_f__________AQ", + "____________AQ", + "AA", + "AQ", + "BQ", + "DQ", + "HQ", + "goGAgAg", + "4_____f_____AQ", + "__________9_", + "gICAgICAgICAAQ"); + + @Test + public void testWireFormatEncode() throws Exception { + CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null Long"); + + CoderUtils.encodeToBase64(TEST_CODER, null); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtoCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtoCoderTest.java new file mode 100644 index 000000000000..6f4e99d388c6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtoCoderTest.java @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders.protobuf; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageA; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageB; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageC; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageWithMap; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ProtoCoder}. + */ +@RunWith(JUnit4.class) +public class ProtoCoderTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testFactoryMethodAgreement() throws Exception { + assertEquals(ProtoCoder.of(new TypeDescriptor() {}), ProtoCoder.of(MessageA.class)); + + assertEquals( + ProtoCoder.of(new TypeDescriptor() {}), + ProtoCoder.coderProvider().getCoder(new TypeDescriptor() {})); + } + + @Test + public void testProviderCannotProvideCoder() throws Exception { + thrown.expect(CannotProvideCoderException.class); + thrown.expectMessage("java.lang.Integer is not a subclass of com.google.protobuf.Message"); + + ProtoCoder.coderProvider().getCoder(new TypeDescriptor() {}); + } + + @Test + public void testCoderEncodeDecodeEqual() throws Exception { + MessageA value = + MessageA.newBuilder() + .setField1("hello") + .addField2(MessageB.newBuilder().setField1(true).build()) + .addField2(MessageB.newBuilder().setField1(false).build()) + .build(); + CoderProperties.coderDecodeEncodeEqual(ProtoCoder.of(MessageA.class), value); + } + + @Test + public void testCoderEncodeDecodeEqualNestedContext() throws Exception { + MessageA value1 = + MessageA.newBuilder() + .setField1("hello") + .addField2(MessageB.newBuilder().setField1(true).build()) + .addField2(MessageB.newBuilder().setField1(false).build()) + .build(); + MessageA value2 = + MessageA.newBuilder() + .setField1("world") + .addField2(MessageB.newBuilder().setField1(false).build()) + .addField2(MessageB.newBuilder().setField1(true).build()) + .build(); + CoderProperties.coderDecodeEncodeEqual( + ListCoder.of(ProtoCoder.of(MessageA.class)), ImmutableList.of(value1, value2)); + } + + @Test + public void testCoderEncodeDecodeExtensionsEqual() throws Exception { + MessageC value = + MessageC.newBuilder() + .setExtension( + Proto2CoderTestMessages.field1, + MessageA.newBuilder() + .setField1("hello") + .addField2(MessageB.newBuilder().setField1(true).build()) + .build()) + .setExtension( + Proto2CoderTestMessages.field2, MessageB.newBuilder().setField1(false).build()) + .build(); + CoderProperties.coderDecodeEncodeEqual( + ProtoCoder.of(MessageC.class).withExtensionsFrom(Proto2CoderTestMessages.class), value); + } + + @Test + public void testCoderSerialization() throws Exception { + ProtoCoder coder = ProtoCoder.of(MessageA.class); + CoderProperties.coderSerializable(coder); + } + + @Test + public void testCoderExtensionsSerialization() throws Exception { + ProtoCoder coder = + ProtoCoder.of(MessageC.class).withExtensionsFrom(Proto2CoderTestMessages.class); + CoderProperties.coderSerializable(coder); + } + + @Test + public void testEncodingId() throws Exception { + Coder coderA = ProtoCoder.of(MessageA.class); + CoderProperties.coderHasEncodingId(coderA, MessageA.class.getName() + "[]"); + + ProtoCoder coder = + ProtoCoder.of(MessageC.class).withExtensionsFrom(Proto2CoderTestMessages.class); + CoderProperties.coderHasEncodingId( + coder, + String.format("%s[%s]", MessageC.class.getName(), Proto2CoderTestMessages.class.getName())); + } + + @Test + public void encodeNullThrowsCoderException() throws Exception { + thrown.expect(CoderException.class); + thrown.expectMessage("cannot encode a null MessageA"); + + CoderUtils.encodeToBase64(ProtoCoder.of(MessageA.class), null); + } + + @Test + public void testDeterministicCoder() throws NonDeterministicException { + Coder coder = ProtoCoder.of(MessageA.class); + coder.verifyDeterministic(); + } + + @Test + public void testNonDeterministicCoder() throws NonDeterministicException { + thrown.expect(NonDeterministicException.class); + thrown.expectMessage(MessageWithMap.class.getName() + " transitively includes Map field"); + + Coder coder = ProtoCoder.of(MessageWithMap.class); + coder.verifyDeterministic(); + } + + @Test + public void testNonDeterministicProperty() throws CoderException { + MessageWithMap.Builder msg1B = MessageWithMap.newBuilder(); + MessageWithMap.Builder msg2B = MessageWithMap.newBuilder(); + + // Built in reverse order but with equal contents. + for (int i = 0; i < 10; ++i) { + msg1B.getMutableField1().put("key" + i, MessageA.getDefaultInstance()); + msg2B.getMutableField1().put("key" + (9 - i), MessageA.getDefaultInstance()); + } + + // Assert the messages are equal. + MessageWithMap msg1 = msg1B.build(); + MessageWithMap msg2 = msg2B.build(); + assertEquals(msg2, msg1); + + // Assert the encoded messages are not equal. + Coder coder = ProtoCoder.of(MessageWithMap.class); + assertNotEquals(CoderUtils.encodeToBase64(coder, msg2), CoderUtils.encodeToBase64(coder, msg1)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtobufUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtobufUtilTest.java new file mode 100644 index 000000000000..f2192e62e640 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/coders/protobuf/ProtobufUtilTest.java @@ -0,0 +1,195 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.coders.protobuf; + +import static com.google.cloud.dataflow.sdk.coders.protobuf.ProtobufUtil.checkProto2Syntax; +import static com.google.cloud.dataflow.sdk.coders.protobuf.ProtobufUtil.getRecursiveDescriptorsForClass; +import static com.google.cloud.dataflow.sdk.coders.protobuf.ProtobufUtil.verifyDeterministic; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageA; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageB; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageC; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.MessageWithMap; +import com.google.cloud.dataflow.sdk.coders.Proto2CoderTestMessages.ReferencesMessageWithMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import com.google.protobuf.Any; +import com.google.protobuf.Descriptors.GenericDescriptor; +import com.google.protobuf.Duration; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.Message; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.HashSet; +import java.util.Set; + +/** + * Tests for {@link ProtobufUtil}. + */ +@RunWith(JUnit4.class) +public class ProtobufUtilTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private static final Set MESSAGE_A_ONLY = + ImmutableSet.of("proto2_coder_test_messages.MessageA"); + + private static final Set MESSAGE_B_ONLY = + ImmutableSet.of("proto2_coder_test_messages.MessageB"); + + private static final Set MESSAGE_C_ONLY = + ImmutableSet.of("proto2_coder_test_messages.MessageC"); + + // map fields are actually represented as a nested Message in generated Java code. + private static final Set WITH_MAP_ONLY = + ImmutableSet.of( + "proto2_coder_test_messages.MessageWithMap", + "proto2_coder_test_messages.MessageWithMap.Field1Entry"); + + private static final Set REFERS_MAP_ONLY = + ImmutableSet.of("proto2_coder_test_messages.ReferencesMessageWithMap"); + + // A references A and B. + private static final Set MESSAGE_A_ALL = Sets.union(MESSAGE_A_ONLY, MESSAGE_B_ONLY); + + // C, only with registered extensions, references A. + private static final Set MESSAGE_C_EXT = Sets.union(MESSAGE_C_ONLY, MESSAGE_A_ALL); + + // MessageWithMap references A. + private static final Set WITH_MAP_ALL = Sets.union(WITH_MAP_ONLY, MESSAGE_A_ALL); + + // ReferencesMessageWithMap references MessageWithMap. + private static final Set REFERS_MAP_ALL = Sets.union(REFERS_MAP_ONLY, WITH_MAP_ALL); + + @Test + public void testRecursiveDescriptorsMessageA() { + assertThat(getRecursiveDescriptorFullNames(MessageA.class), equalTo(MESSAGE_A_ALL)); + } + + @Test + public void testRecursiveDescriptorsMessageB() { + assertThat(getRecursiveDescriptorFullNames(MessageB.class), equalTo(MESSAGE_B_ONLY)); + } + + @Test + public void testRecursiveDescriptorsMessageC() { + assertThat(getRecursiveDescriptorFullNames(MessageC.class), equalTo(MESSAGE_C_ONLY)); + } + + @Test + public void testRecursiveDescriptorsMessageCWithExtensions() { + // With extensions, Message C has a reference to Message A and Message B. + ExtensionRegistry registry = ExtensionRegistry.newInstance(); + Proto2CoderTestMessages.registerAllExtensions(registry); + assertThat(getRecursiveDescriptorFullNames(MessageC.class, registry), equalTo(MESSAGE_C_EXT)); + } + + @Test + public void testRecursiveDescriptorsMessageWithMap() { + assertThat(getRecursiveDescriptorFullNames(MessageWithMap.class), equalTo(WITH_MAP_ALL)); + } + + @Test + public void testRecursiveDescriptorsReferencesMessageWithMap() { + assertThat( + getRecursiveDescriptorFullNames(ReferencesMessageWithMap.class), equalTo(REFERS_MAP_ALL)); + } + + @Test + public void testVerifyProto2() { + // Everything in Dataflow's Proto2TestMessages uses Proto2 syntax. + checkProto2Syntax(MessageA.class, ExtensionRegistry.getEmptyRegistry()); + checkProto2Syntax(MessageB.class, ExtensionRegistry.getEmptyRegistry()); + checkProto2Syntax(MessageC.class, ExtensionRegistry.getEmptyRegistry()); + checkProto2Syntax(MessageWithMap.class, ExtensionRegistry.getEmptyRegistry()); + checkProto2Syntax(ReferencesMessageWithMap.class, ExtensionRegistry.getEmptyRegistry()); + } + + @Test + public void testAnyIsNotProto2() { + // Any is a core Protocol Buffers type that uses proto3 syntax. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(Any.class.getCanonicalName()); + thrown.expectMessage("in file " + Any.getDescriptor().getFile().getName()); + + checkProto2Syntax(Any.class, ExtensionRegistry.getEmptyRegistry()); + } + + @Test + public void testDurationIsNotProto2() { + // Duration is a core Protocol Buffers type that uses proto3 syntax. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(Duration.class.getCanonicalName()); + thrown.expectMessage("in file " + Duration.getDescriptor().getFile().getName()); + + checkProto2Syntax(Duration.class, ExtensionRegistry.getEmptyRegistry()); + } + + @Test + public void testEntityIsDeterministic() throws NonDeterministicException { + // Cloud Datastore's Entities can be encoded deterministically. + verifyDeterministic(ProtoCoder.of(Entity.class)); + } + + @Test + public void testMessageWithMapIsNotDeterministic() throws NonDeterministicException { + String mapFieldName = MessageWithMap.getDescriptor().findFieldByNumber(1).getFullName(); + thrown.expect(NonDeterministicException.class); + thrown.expectMessage(MessageWithMap.class.getName()); + thrown.expectMessage("transitively includes Map field " + mapFieldName); + thrown.expectMessage("file " + MessageWithMap.getDescriptor().getFile().getName()); + + verifyDeterministic(ProtoCoder.of(MessageWithMap.class)); + } + + @Test + public void testMessageWithTransitiveMapIsNotDeterministic() throws NonDeterministicException { + String mapFieldName = MessageWithMap.getDescriptor().findFieldByNumber(1).getFullName(); + thrown.expect(NonDeterministicException.class); + thrown.expectMessage(ReferencesMessageWithMap.class.getName()); + thrown.expectMessage("transitively includes Map field " + mapFieldName); + thrown.expectMessage("file " + MessageWithMap.getDescriptor().getFile().getName()); + + verifyDeterministic(ProtoCoder.of(ReferencesMessageWithMap.class)); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + + /** Helper used to test the recursive class traversal and print good error messages. */ + private static Set getRecursiveDescriptorFullNames(Class clazz) { + return getRecursiveDescriptorFullNames(clazz, ExtensionRegistry.getEmptyRegistry()); + } + + /** Helper used to test the recursive class traversal and print good error messages. */ + private static Set getRecursiveDescriptorFullNames( + Class clazz, ExtensionRegistry registry) { + Set result = new HashSet<>(); + for (GenericDescriptor d : getRecursiveDescriptorsForClass(clazz, registry)) { + result.add(d.getFullName()); + } + return result; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOGeneratedClassTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOGeneratedClassTest.java new file mode 100644 index 000000000000..6a7679f5edec --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOGeneratedClassTest.java @@ -0,0 +1,374 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner.EvaluationResults; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.specific.SpecificDatumReader; +import org.apache.avro.specific.SpecificDatumWriter; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for AvroIO Read and Write transforms, using classes generated from {@code user.avsc}. + */ +@RunWith(JUnit4.class) +public class AvroIOGeneratedClassTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + private File avroFile; + + @Before + public void prepareAvroFileBeforeAnyTest() throws IOException { + avroFile = tmpFolder.newFile("file.avro"); + } + + private final String schemaString = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"AvroGeneratedUser\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},\n" + + " {\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + private final Schema.Parser parser = new Schema.Parser(); + private final Schema schema = parser.parse(schemaString); + + private AvroGeneratedUser[] generateAvroObjects() { + AvroGeneratedUser user1 = new AvroGeneratedUser(); + user1.setName("Bob"); + user1.setFavoriteNumber(256); + + AvroGeneratedUser user2 = new AvroGeneratedUser(); + user2.setName("Alice"); + user2.setFavoriteNumber(128); + + AvroGeneratedUser user3 = new AvroGeneratedUser(); + user3.setName("Ted"); + user3.setFavoriteColor("white"); + + return new AvroGeneratedUser[] { user1, user2, user3 }; + } + + private GenericRecord[] generateAvroGenericRecords() { + GenericRecord user1 = new GenericData.Record(schema); + user1.put("name", "Bob"); + user1.put("favorite_number", 256); + + GenericRecord user2 = new GenericData.Record(schema); + user2.put("name", "Alice"); + user2.put("favorite_number", 128); + + GenericRecord user3 = new GenericData.Record(schema); + user3.put("name", "Ted"); + user3.put("favorite_color", "white"); + + return new GenericRecord[] { user1, user2, user3 }; + } + + private void generateAvroFile(AvroGeneratedUser[] elements) throws IOException { + DatumWriter userDatumWriter = + new SpecificDatumWriter<>(AvroGeneratedUser.class); + try (DataFileWriter dataFileWriter = new DataFileWriter<>(userDatumWriter)) { + dataFileWriter.create(elements[0].getSchema(), avroFile); + for (AvroGeneratedUser user : elements) { + dataFileWriter.append(user); + } + } + } + + private List readAvroFile() throws IOException { + DatumReader userDatumReader = + new SpecificDatumReader<>(AvroGeneratedUser.class); + List users = new ArrayList<>(); + try (DataFileReader dataFileReader = + new DataFileReader<>(avroFile, userDatumReader)) { + while (dataFileReader.hasNext()) { + users.add(dataFileReader.next()); + } + } + return users; + } + + void runTestRead(AvroIO.Read.Bound read, String expectedName, T[] expectedOutput) + throws Exception { + generateAvroFile(generateAvroObjects()); + + DirectPipeline p = DirectPipeline.createForTest(); + PCollection output = p.apply(read); + EvaluationResults results = p.run(); + assertEquals(expectedName, output.getName()); + assertThat(results.getPCollection(output), + containsInAnyOrder(expectedOutput)); + } + + @Test + public void testReadFromGeneratedClass() throws Exception { + runTestRead( + AvroIO.Read.from(avroFile.getPath()).withSchema(AvroGeneratedUser.class), + "AvroIO.Read/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.withSchema(AvroGeneratedUser.class).from(avroFile.getPath()), + "AvroIO.Read/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.named("MyRead").from(avroFile.getPath()).withSchema(AvroGeneratedUser.class), + "MyRead/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.named("MyRead").withSchema(AvroGeneratedUser.class).from(avroFile.getPath()), + "MyRead/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.from(avroFile.getPath()).withSchema(AvroGeneratedUser.class).named("HerRead"), + "HerRead/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.from(avroFile.getPath()).named("HerRead").withSchema(AvroGeneratedUser.class), + "HerRead/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.withSchema(AvroGeneratedUser.class).named("HerRead").from(avroFile.getPath()), + "HerRead/Read.out", + generateAvroObjects()); + runTestRead( + AvroIO.Read.withSchema(AvroGeneratedUser.class).from(avroFile.getPath()).named("HerRead"), + "HerRead/Read.out", + generateAvroObjects()); + } + + @Test + public void testReadFromSchema() throws Exception { + runTestRead( + AvroIO.Read.from(avroFile.getPath()).withSchema(schema), + "AvroIO.Read/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.withSchema(schema).from(avroFile.getPath()), + "AvroIO.Read/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.named("MyRead").from(avroFile.getPath()).withSchema(schema), + "MyRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.named("MyRead").withSchema(schema).from(avroFile.getPath()), + "MyRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.from(avroFile.getPath()).withSchema(schema).named("HerRead"), + "HerRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.from(avroFile.getPath()).named("HerRead").withSchema(schema), + "HerRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.withSchema(schema).named("HerRead").from(avroFile.getPath()), + "HerRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.withSchema(schema).from(avroFile.getPath()).named("HerRead"), + "HerRead/Read.out", + generateAvroGenericRecords()); + } + + @Test + public void testReadFromSchemaString() throws Exception { + runTestRead( + AvroIO.Read.from(avroFile.getPath()).withSchema(schemaString), + "AvroIO.Read/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.withSchema(schemaString).from(avroFile.getPath()), + "AvroIO.Read/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.named("MyRead").from(avroFile.getPath()).withSchema(schemaString), + "MyRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.named("MyRead").withSchema(schemaString).from(avroFile.getPath()), + "MyRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.from(avroFile.getPath()).withSchema(schemaString).named("HerRead"), + "HerRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.from(avroFile.getPath()).named("HerRead").withSchema(schemaString), + "HerRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.withSchema(schemaString).named("HerRead").from(avroFile.getPath()), + "HerRead/Read.out", + generateAvroGenericRecords()); + runTestRead( + AvroIO.Read.withSchema(schemaString).from(avroFile.getPath()).named("HerRead"), + "HerRead/Read.out", + generateAvroGenericRecords()); + } + + void runTestWrite(AvroIO.Write.Bound write, String expectedName) + throws Exception { + AvroGeneratedUser[] users = generateAvroObjects(); + + DirectPipeline p = DirectPipeline.createForTest(); + @SuppressWarnings("unchecked") + PCollection input = p.apply(Create.of(Arrays.asList((T[]) users)) + .withCoder((Coder) AvroCoder.of(AvroGeneratedUser.class))); + input.apply(write.withoutSharding()); + p.run(); + assertEquals(expectedName, write.getName()); + + assertThat(readAvroFile(), containsInAnyOrder(users)); + } + + @Test + public void testWriteFromGeneratedClass() throws Exception { + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(AvroGeneratedUser.class), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.withSchema(AvroGeneratedUser.class) + .to(avroFile.getPath()), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.named("MyWrite") + .to(avroFile.getPath()) + .withSchema(AvroGeneratedUser.class), + "MyWrite"); + runTestWrite(AvroIO.Write.named("MyWrite") + .withSchema(AvroGeneratedUser.class) + .to(avroFile.getPath()), + "MyWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(AvroGeneratedUser.class) + .named("HerWrite"), + "HerWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .named("HerWrite") + .withSchema(AvroGeneratedUser.class), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(AvroGeneratedUser.class) + .named("HerWrite") + .to(avroFile.getPath()), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(AvroGeneratedUser.class) + .to(avroFile.getPath()) + .named("HerWrite"), + "HerWrite"); + } + + @Test + public void testWriteFromSchema() throws Exception { + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schema), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.withSchema(schema) + .to(avroFile.getPath()), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.named("MyWrite") + .to(avroFile.getPath()) + .withSchema(schema), + "MyWrite"); + runTestWrite(AvroIO.Write.named("MyWrite") + .withSchema(schema) + .to(avroFile.getPath()), + "MyWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schema) + .named("HerWrite"), + "HerWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .named("HerWrite") + .withSchema(schema), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schema) + .named("HerWrite") + .to(avroFile.getPath()), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schema) + .to(avroFile.getPath()) + .named("HerWrite"), + "HerWrite"); + } + + @Test + public void testWriteFromSchemaString() throws Exception { + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schemaString), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.withSchema(schemaString) + .to(avroFile.getPath()), + "AvroIO.Write"); + runTestWrite(AvroIO.Write.named("MyWrite") + .to(avroFile.getPath()) + .withSchema(schemaString), + "MyWrite"); + runTestWrite(AvroIO.Write.named("MyWrite") + .withSchema(schemaString) + .to(avroFile.getPath()), + "MyWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .withSchema(schemaString) + .named("HerWrite"), + "HerWrite"); + runTestWrite(AvroIO.Write.to(avroFile.getPath()) + .named("HerWrite") + .withSchema(schemaString), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schemaString) + .named("HerWrite") + .to(avroFile.getPath()), + "HerWrite"); + runTestWrite(AvroIO.Write.withSchema(schemaString) + .to(avroFile.getPath()) + .named("HerWrite"), + "HerWrite"); + } + + // TODO: for Write only, test withSuffix, withNumShards, + // withShardNameTemplate and withoutSharding. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java new file mode 100644 index 000000000000..2258a9136f24 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java @@ -0,0 +1,226 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; + +import org.apache.avro.file.DataFileReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.reflect.Nullable; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * Tests for AvroIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +public class AvroIOTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testReadWithoutValidationFlag() throws Exception { + AvroIO.Read.Bound read = AvroIO.Read.from("gs://bucket/foo*/baz"); + assertTrue(read.needsValidation()); + assertFalse(read.withoutValidation().needsValidation()); + } + + @Test + public void testWriteWithoutValidationFlag() throws Exception { + AvroIO.Write.Bound write = AvroIO.Write.to("gs://bucket/foo/baz"); + assertTrue(write.needsValidation()); + assertFalse(write.withoutValidation().needsValidation()); + } + + @Test + public void testAvroIOGetName() { + assertEquals("AvroIO.Read", AvroIO.Read.from("gs://bucket/foo*/baz").getName()); + assertEquals("AvroIO.Write", AvroIO.Write.to("gs://bucket/foo/baz").getName()); + assertEquals("ReadMyFile", + AvroIO.Read.named("ReadMyFile").from("gs://bucket/foo*/baz").getName()); + assertEquals("WriteMyFile", + AvroIO.Write.named("WriteMyFile").to("gs://bucket/foo/baz").getName()); + } + + @DefaultCoder(AvroCoder.class) + static class GenericClass { + int intField; + String stringField; + public GenericClass() {} + public GenericClass(int intValue, String stringValue) { + this.intField = intValue; + this.stringField = stringValue; + } + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("intField", intField) + .add("stringField", stringField) + .toString(); + } + @Override + public int hashCode() { + return Objects.hash(intField, stringField); + } + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof GenericClass)) { + return false; + } + GenericClass o = (GenericClass) other; + return Objects.equals(intField, o.intField) && Objects.equals(stringField, o.stringField); + } + } + + @Test + public void testAvroIOWriteAndReadASingleFile() throws Throwable { + DirectPipeline p = DirectPipeline.createForTest(); + List values = ImmutableList.of(new GenericClass(3, "hi"), + new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + p.apply(Create.of(values)) + .apply(AvroIO.Write.to(outputFile.getAbsolutePath()) + .withoutSharding() + .withSchema(GenericClass.class)); + p.run(); + + p = DirectPipeline.createForTest(); + PCollection input = p + .apply(AvroIO.Read.from(outputFile.getAbsolutePath()).withSchema(GenericClass.class)); + + DataflowAssert.that(input).containsInAnyOrder(values); + p.run(); + } + + @DefaultCoder(AvroCoder.class) + static class GenericClassV2 { + int intField; + String stringField; + @Nullable String nullableField; + public GenericClassV2() {} + public GenericClassV2(int intValue, String stringValue, String nullableValue) { + this.intField = intValue; + this.stringField = stringValue; + this.nullableField = nullableValue; + } + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("intField", intField) + .add("stringField", stringField) + .add("nullableField", nullableField) + .toString(); + } + @Override + public int hashCode() { + return Objects.hash(intField, stringField, nullableField); + } + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof GenericClassV2)) { + return false; + } + GenericClassV2 o = (GenericClassV2) other; + return Objects.equals(intField, o.intField) + && Objects.equals(stringField, o.stringField) + && Objects.equals(nullableField, o.nullableField); + } + } + + /** + * Tests that {@code AvroIO} can read an upgraded version of an old class, as long as the + * schema resolution process succeeds. This test covers the case when a new, {@code @Nullable} + * field has been added. + * + *

    For more information, see http://avro.apache.org/docs/1.7.7/spec.html#Schema+Resolution + */ + @Test + public void testAvroIOWriteAndReadSchemaUpgrade() throws Throwable { + DirectPipeline p = DirectPipeline.createForTest(); + List values = ImmutableList.of(new GenericClass(3, "hi"), + new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + p.apply(Create.of(values)) + .apply(AvroIO.Write.to(outputFile.getAbsolutePath()) + .withoutSharding() + .withSchema(GenericClass.class)); + p.run(); + + List expected = ImmutableList.of(new GenericClassV2(3, "hi", null), + new GenericClassV2(5, "bar", null)); + p = DirectPipeline.createForTest(); + PCollection input = p + .apply(AvroIO.Read.from(outputFile.getAbsolutePath()).withSchema(GenericClassV2.class)); + + DataflowAssert.that(input).containsInAnyOrder(expected); + p.run(); + } + + @SuppressWarnings("deprecation") // using AvroCoder#createDatumReader for tests. + @Test + public void testAvroSinkWrite() throws Exception { + String outputFilePrefix = new File(tmpFolder.getRoot(), "prefix").getAbsolutePath(); + String[] expectedElements = new String[] {"first", "second", "third"}; + + TestPipeline p = TestPipeline.create(); + p.apply(Create.of(expectedElements)) + .apply(AvroIO.Write.to(outputFilePrefix).withSchema(String.class)); + p.run(); + + // Validate that the data written matches the expected elements in the expected order + String expectedName = + IOChannelUtils.constructName( + outputFilePrefix, ShardNameTemplate.INDEX_OF_MAX, "" /* no suffix */, 0, 1); + File outputFile = new File(expectedName); + assertTrue("Expected output file " + expectedName, outputFile.exists()); + try (DataFileReader reader = + new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) { + List actualElements = new ArrayList<>(); + Iterators.addAll(actualElements, reader); + assertThat(actualElements, containsInAnyOrder(expectedElements)); + } + } + + // TODO: for Write only, test withSuffix, withNumShards, + // withShardNameTemplate and withoutSharding. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroSourceTest.java new file mode 100644 index 000000000000..0990294d209f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroSourceTest.java @@ -0,0 +1,692 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.AvroSource.AvroReader; +import com.google.cloud.dataflow.sdk.io.AvroSource.AvroReader.Seeker; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.SourceTestUtils; +import com.google.common.base.MoreObjects; + +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.reflect.AvroDefault; +import org.apache.avro.reflect.Nullable; +import org.apache.avro.reflect.ReflectData; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.PushbackInputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Random; + +/** + * Tests for AvroSource. + */ +@RunWith(JUnit4.class) +public class AvroSourceTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + private enum SyncBehavior { + SYNC_REGULAR, // Sync at regular, user defined intervals + SYNC_RANDOM, // Sync at random intervals + SYNC_DEFAULT; // Sync at default intervals (i.e., no manual syncing). + } + + private static final int DEFAULT_RECORD_COUNT = 1000; + + /** + * Generates an input Avro file containing the given records in the temporary directory and + * returns the full path of the file. + */ + private String generateTestFile(String filename, List elems, SyncBehavior syncBehavior, + int syncInterval, AvroCoder coder, String codec) throws IOException { + Random random = new Random(0); + File tmpFile = tmpFolder.newFile(filename); + String path = tmpFile.toString(); + + FileOutputStream os = new FileOutputStream(tmpFile); + DatumWriter datumWriter = coder.createDatumWriter(); + try (DataFileWriter writer = new DataFileWriter<>(datumWriter)) { + writer.setCodec(CodecFactory.fromString(codec)); + writer.create(coder.getSchema(), os); + + int recordIndex = 0; + int syncIndex = syncBehavior == SyncBehavior.SYNC_RANDOM ? random.nextInt(syncInterval) : 0; + + for (T elem : elems) { + writer.append(elem); + recordIndex++; + + switch (syncBehavior) { + case SYNC_REGULAR: + if (recordIndex == syncInterval) { + recordIndex = 0; + writer.sync(); + } + break; + case SYNC_RANDOM: + if (recordIndex == syncIndex) { + recordIndex = 0; + writer.sync(); + syncIndex = random.nextInt(syncInterval); + } + break; + case SYNC_DEFAULT: + default: + } + } + } + return path; + } + + @Test + public void testReadWithDifferentCodecs() throws Exception { + // Test reading files generated using all codecs. + String codecs[] = {DataFileConstants.NULL_CODEC, DataFileConstants.BZIP2_CODEC, + DataFileConstants.DEFLATE_CODEC, DataFileConstants.SNAPPY_CODEC, + DataFileConstants.XZ_CODEC}; + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + + for (String codec : codecs) { + String filename = generateTestFile( + codec, expected, SyncBehavior.SYNC_DEFAULT, 0, AvroCoder.of(Bird.class), codec); + AvroSource source = AvroSource.from(filename).withSchema(Bird.class); + List actual = SourceTestUtils.readFromSource(source, null); + assertThat(expected, containsInAnyOrder(actual.toArray())); + } + } + + @Test + public void testSplitAtFraction() throws Exception { + // A reduced dataset is enough here. + List expected = createFixedRecords(DEFAULT_RECORD_COUNT); + // Create an AvroSource where each block is 1/10th of the total set of records. + String filename = generateTestFile( + "tmp.avro", expected, SyncBehavior.SYNC_REGULAR, + DEFAULT_RECORD_COUNT / 10 /* max records per block */, + AvroCoder.of(FixedRecord.class), DataFileConstants.NULL_CODEC); + File file = new File(filename); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + List> splits = + source.splitIntoBundles(file.length() / 3, null); + for (BoundedSource subSource : splits) { + int items = SourceTestUtils.readFromSource(subSource, null).size(); + // Shouldn't split while unstarted. + SourceTestUtils.assertSplitAtFractionFails(subSource, 0, 0.0, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, 0, 0.7, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent(subSource, 1, 0.7, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent( + subSource, DEFAULT_RECORD_COUNT / 100, 0.7, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent( + subSource, DEFAULT_RECORD_COUNT / 10, 0.1, null); + SourceTestUtils.assertSplitAtFractionFails( + subSource, DEFAULT_RECORD_COUNT / 10 + 1, 0.1, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, DEFAULT_RECORD_COUNT / 3, 0.3, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, items, 0.9, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, items, 1.0, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent(subSource, items, 0.999, null); + } + } + + @Test + public void testGetProgressFromUnstartedReader() throws Exception { + List records = createFixedRecords(DEFAULT_RECORD_COUNT); + String filename = generateTestFile("tmp.avro", records, SyncBehavior.SYNC_DEFAULT, 1000, + AvroCoder.of(FixedRecord.class), DataFileConstants.NULL_CODEC); + File file = new File(filename); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BoundedSource.BoundedReader reader = source.createReader(null)) { + assertEquals(new Double(0.0), reader.getFractionConsumed()); + } + + List> splits = + source.splitIntoBundles(file.length() / 3, null); + for (BoundedSource subSource : splits) { + try (BoundedSource.BoundedReader reader = subSource.createReader(null)) { + assertEquals(new Double(0.0), reader.getFractionConsumed()); + } + } + } + + @Test + public void testGetCurrentFromUnstartedReader() throws Exception { + List records = createFixedRecords(DEFAULT_RECORD_COUNT); + String filename = generateTestFile("tmp.avro", records, SyncBehavior.SYNC_DEFAULT, 1000, + AvroCoder.of(FixedRecord.class), DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BlockBasedSource.BlockBasedReader reader = + (BlockBasedSource.BlockBasedReader) source.createReader(null)) { + assertEquals(null, reader.getCurrentBlock()); + + expectedException.expect(NoSuchElementException.class); + expectedException.expectMessage("No block has been successfully read from"); + reader.getCurrent(); + } + } + + @Test + public void testSplitAtFractionExhaustive() throws Exception { + // A small-sized input is sufficient, because the test verifies that splitting is non-vacuous. + List expected = createFixedRecords(20); + String filename = generateTestFile("tmp.avro", expected, SyncBehavior.SYNC_REGULAR, 5, + AvroCoder.of(FixedRecord.class), DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + SourceTestUtils.assertSplitAtFractionExhaustive(source, null); + } + + @Test + public void testSplitsWithSmallBlocks() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + // Test reading from an object file with many small random-sized blocks. + // The file itself doesn't have to be big; we can use a decreased record count. + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + String filename = generateTestFile("tmp.avro", expected, SyncBehavior.SYNC_RANDOM, + DEFAULT_RECORD_COUNT / 20 /* max records/block */, + AvroCoder.of(Bird.class), DataFileConstants.NULL_CODEC); + File file = new File(filename); + + // Small minimum bundle size + AvroSource source = + AvroSource.from(filename).withSchema(Bird.class).withMinBundleSize(100L); + + // Assert that the source produces the expected records + assertEquals(expected, SourceTestUtils.readFromSource(source, options)); + + List> splits; + int nonEmptySplits; + + // Split with the minimum bundle size + splits = source.splitIntoBundles(100L, options); + assertTrue(splits.size() > 2); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + nonEmptySplits = 0; + for (BoundedSource subSource : splits) { + if (SourceTestUtils.readFromSource(subSource, options).size() > 0) { + nonEmptySplits += 1; + } + } + assertTrue(nonEmptySplits > 2); + + // Split with larger bundle size + splits = source.splitIntoBundles(file.length() / 4, options); + assertTrue(splits.size() > 2); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + nonEmptySplits = 0; + for (BoundedSource subSource : splits) { + if (SourceTestUtils.readFromSource(subSource, options).size() > 0) { + nonEmptySplits += 1; + } + } + assertTrue(nonEmptySplits > 2); + + // Split with the file length + splits = source.splitIntoBundles(file.length(), options); + assertTrue(splits.size() == 1); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + } + + @Test + public void testMultipleFiles() throws Exception { + String baseName = "tmp-"; + List expected = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + List contents = createRandomRecords(DEFAULT_RECORD_COUNT / 10); + expected.addAll(contents); + generateTestFile(baseName + i, contents, SyncBehavior.SYNC_DEFAULT, 0, + AvroCoder.of(Bird.class), DataFileConstants.NULL_CODEC); + } + + AvroSource source = + AvroSource.from(new File(tmpFolder.getRoot().toString(), baseName + "*").toString()) + .withSchema(Bird.class); + List actual = SourceTestUtils.readFromSource(source, null); + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + + @Test + public void testCreationWithSchema() throws Exception { + List expected = createRandomRecords(100); + String filename = generateTestFile("tmp.avro", expected, SyncBehavior.SYNC_DEFAULT, 0, + AvroCoder.of(Bird.class), DataFileConstants.NULL_CODEC); + + // Create a source with a schema object + Schema schema = ReflectData.get().getSchema(Bird.class); + AvroSource source = AvroSource.from(filename).withSchema(schema); + List records = SourceTestUtils.readFromSource(source, null); + assertEqualsWithGeneric(expected, records); + + // Create a source with a JSON schema + String schemaString = ReflectData.get().getSchema(Bird.class).toString(); + source = AvroSource.from(filename).withSchema(schemaString); + records = SourceTestUtils.readFromSource(source, null); + assertEqualsWithGeneric(expected, records); + + // Create a source with no schema + source = AvroSource.from(filename); + records = SourceTestUtils.readFromSource(source, null); + assertEqualsWithGeneric(expected, records); + } + + @Test + public void testSchemaUpdate() throws Exception { + List birds = createRandomRecords(100); + String filename = generateTestFile("tmp.avro", birds, SyncBehavior.SYNC_DEFAULT, 0, + AvroCoder.of(Bird.class), DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FancyBird.class); + List actual = SourceTestUtils.readFromSource(source, null); + + List expected = new ArrayList<>(); + for (Bird bird : birds) { + expected.add(new FancyBird( + bird.number, bird.species, bird.quality, bird.quantity, null, "MAXIMUM OVERDRIVE")); + } + + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + + private void assertEqualsWithGeneric(List expected, List actual) { + assertEquals(expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i++) { + Bird fixed = expected.get(i); + GenericRecord generic = actual.get(i); + assertEquals(fixed.number, generic.get("number")); + assertEquals(fixed.quality, generic.get("quality").toString()); // From Avro util.Utf8 + assertEquals(fixed.quantity, generic.get("quantity")); + assertEquals(fixed.species, generic.get("species").toString()); + } + } + + /** + * Creates a haystack byte array of the give size with a needle that starts at the given position. + */ + private byte[] createHaystack(byte[] needle, int position, int size) { + byte[] haystack = new byte[size]; + for (int i = position, j = 0; i < size && j < needle.length; i++, j++) { + haystack[i] = needle[j]; + } + return haystack; + } + + /** + * Asserts that advancePastNextSyncMarker advances an input stream past a sync marker and + * correctly returns the number of bytes consumed from the stream. + * Creates a haystack of size bytes and places a 16-byte sync marker at the position specified. + */ + private void testAdvancePastNextSyncMarkerAt(int position, int size) throws IOException { + byte sentinel = (byte) 0xFF; + byte[] marker = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}; + byte[] haystack = createHaystack(marker, position, size); + PushbackInputStream stream = + new PushbackInputStream(new ByteArrayInputStream(haystack), marker.length); + if (position + marker.length < size) { + haystack[position + marker.length] = sentinel; + assertEquals(position + marker.length, AvroReader.advancePastNextSyncMarker(stream, marker)); + assertEquals(sentinel, (byte) stream.read()); + } else { + assertEquals(size, AvroReader.advancePastNextSyncMarker(stream, marker)); + assertEquals(-1, stream.read()); + } + } + + @Test + public void testAdvancePastNextSyncMarker() throws IOException { + // Test placing the sync marker at different locations at the start and in the middle of the + // buffer. + for (int i = 0; i <= 16; i++) { + testAdvancePastNextSyncMarkerAt(i, 1000); + testAdvancePastNextSyncMarkerAt(160 + i, 1000); + } + // Test placing the sync marker at the end of the buffer. + testAdvancePastNextSyncMarkerAt(983, 1000); + // Test placing the sync marker so that it begins at the end of the buffer. + testAdvancePastNextSyncMarkerAt(984, 1000); + testAdvancePastNextSyncMarkerAt(985, 1000); + testAdvancePastNextSyncMarkerAt(999, 1000); + // Test with no sync marker. + testAdvancePastNextSyncMarkerAt(1000, 1000); + } + + // Tests for Seeker. + @Test + public void testSeekerFind() { + byte[] marker = {0, 1, 2, 3}; + byte[] buffer; + Seeker s; + s = new Seeker(marker); + + buffer = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; + assertEquals(3, s.find(buffer, buffer.length)); + + buffer = new byte[] {0, 0, 0, 0, 0, 1, 2, 3}; + assertEquals(7, s.find(buffer, buffer.length)); + + buffer = new byte[] {0, 1, 2, 0, 0, 1, 2, 3}; + assertEquals(7, s.find(buffer, buffer.length)); + + buffer = new byte[] {0, 1, 2, 3}; + assertEquals(3, s.find(buffer, buffer.length)); + } + + @Test + public void testSeekerFindResume() { + byte[] marker = {0, 1, 2, 3}; + byte[] buffer; + Seeker s; + s = new Seeker(marker); + + buffer = new byte[] {0, 0, 0, 0, 0, 0, 0, 0}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {1, 2, 3, 0, 0, 0, 0, 0}; + assertEquals(2, s.find(buffer, buffer.length)); + + buffer = new byte[] {0, 0, 0, 0, 0, 0, 1, 2}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {3, 0, 1, 2, 3, 0, 1, 2}; + assertEquals(0, s.find(buffer, buffer.length)); + + buffer = new byte[] {0}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {1}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {2}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {3}; + assertEquals(0, s.find(buffer, buffer.length)); + } + + @Test + public void testSeekerUsesBufferLength() { + byte[] marker = {0, 0, 1}; + byte[] buffer; + Seeker s; + s = new Seeker(marker); + + buffer = new byte[] {0, 0, 0, 1}; + assertEquals(-1, s.find(buffer, 3)); + + s = new Seeker(marker); + buffer = new byte[] {0, 0}; + assertEquals(-1, s.find(buffer, 1)); + buffer = new byte[] {1, 0}; + assertEquals(-1, s.find(buffer, 1)); + + s = new Seeker(marker); + buffer = new byte[] {0, 2}; + assertEquals(-1, s.find(buffer, 1)); + buffer = new byte[] {0, 2}; + assertEquals(-1, s.find(buffer, 1)); + buffer = new byte[] {1, 2}; + assertEquals(0, s.find(buffer, 1)); + } + + + @Test + public void testSeekerFindPartial() { + byte[] marker = {0, 0, 1}; + byte[] buffer; + Seeker s; + s = new Seeker(marker); + + buffer = new byte[] {0, 0, 0, 1}; + assertEquals(3, s.find(buffer, buffer.length)); + + marker = new byte[] {1, 1, 1, 2}; + s = new Seeker(marker); + + buffer = new byte[] {1, 1, 1, 1, 1}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {1, 1, 2}; + assertEquals(2, s.find(buffer, buffer.length)); + + buffer = new byte[] {1, 1, 1, 1, 1}; + assertEquals(-1, s.find(buffer, buffer.length)); + buffer = new byte[] {2, 1, 1, 1, 2}; + assertEquals(0, s.find(buffer, buffer.length)); + } + + @Test + public void testSeekerFindAllLocations() { + byte[] marker = {1, 1, 2}; + byte[] allOnes = new byte[] {1, 1, 1, 1}; + byte[] findIn = new byte[] {1, 1, 1, 1}; + Seeker s = new Seeker(marker); + + for (int i = 0; i < findIn.length; i++) { + assertEquals(-1, s.find(allOnes, allOnes.length)); + findIn[i] = 2; + assertEquals(i, s.find(findIn, findIn.length)); + findIn[i] = 1; + } + } + + /** + * Class that will encode to a fixed size: 16 bytes. + * + *

    Each object has a 15-byte array. Avro encodes an object of this type as + * a byte array, so each encoded object will consist of 1 byte that encodes the + * length of the array, followed by 15 bytes. + */ + @DefaultCoder(AvroCoder.class) + public static class FixedRecord { + private byte[] value = new byte[15]; + + public FixedRecord() { + this(0); + } + + public FixedRecord(int i) { + value[0] = (byte) i; + value[1] = (byte) (i >> 8); + value[2] = (byte) (i >> 16); + value[3] = (byte) (i >> 24); + } + + public int asInt() { + return value[0] | (value[1] << 8) | (value[2] << 16) | (value[3] << 24); + } + + @Override + public boolean equals(Object o) { + if (o instanceof FixedRecord) { + FixedRecord other = (FixedRecord) o; + return this.asInt() == other.asInt(); + } + return false; + } + + @Override + public int hashCode() { + return toString().hashCode(); + } + + @Override + public String toString() { + return Integer.toString(this.asInt()); + } + } + + /** + * Create a list of count 16-byte records. + */ + private static List createFixedRecords(int count) { + List records = new ArrayList<>(); + for (int i = 0; i < count; i++) { + records.add(new FixedRecord(i)); + } + return records; + } + + /** + * Class used as the record type in tests. + */ + @DefaultCoder(AvroCoder.class) + static class Bird { + long number; + String species; + String quality; + long quantity; + + public Bird() {} + + public Bird(long number, String species, String quality, long quantity) { + this.number = number; + this.species = species; + this.quality = quality; + this.quantity = quantity; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(Bird.class) + .addValue(number) + .addValue(species) + .addValue(quantity) + .addValue(quality) + .toString(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof Bird) { + Bird other = (Bird) obj; + return Objects.equals(species, other.species) && Objects.equals(quality, other.quality) + && quantity == other.quantity && number == other.number; + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(number, species, quality, quantity); + } + } + + /** + * Class used as the record type in tests. + * + *

    Contains nullable fields and fields with default values. Can be read using a file written + * with the Bird schema. + */ + @DefaultCoder(AvroCoder.class) + public static class FancyBird { + long number; + String species; + String quality; + long quantity; + + @Nullable + String habitat; + + @AvroDefault("\"MAXIMUM OVERDRIVE\"") + String fancinessLevel; + + public FancyBird() {} + + public FancyBird(long number, String species, String quality, long quantity, String habitat, + String fancinessLevel) { + this.number = number; + this.species = species; + this.quality = quality; + this.quantity = quantity; + this.habitat = habitat; + this.fancinessLevel = fancinessLevel; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(FancyBird.class) + .addValue(number) + .addValue(species) + .addValue(quality) + .addValue(quantity) + .addValue(habitat) + .addValue(fancinessLevel) + .toString(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof FancyBird) { + FancyBird other = (FancyBird) obj; + return Objects.equals(species, other.species) && Objects.equals(quality, other.quality) + && quantity == other.quantity && number == other.number + && Objects.equals(fancinessLevel, other.fancinessLevel) + && Objects.equals(habitat, other.habitat); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(number, species, quality, quantity, habitat, fancinessLevel); + } + } + + /** + * Create a list of n random records. + */ + private static List createRandomRecords(long n) { + String[] qualities = { + "miserable", "forelorn", "fidgity", "squirrelly", "fanciful", "chipper", "lazy"}; + String[] species = {"pigeons", "owls", "gulls", "hawks", "robins", "jays"}; + Random random = new Random(0); + + List records = new ArrayList<>(); + for (long i = 0; i < n; i++) { + Bird bird = new Bird(); + bird.quality = qualities[random.nextInt(qualities.length)]; + bird.species = species[random.nextInt(species.length)]; + bird.number = i; + bird.quantity = random.nextLong(); + records.add(bird); + } + return records; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BigQueryIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BigQueryIOTest.java new file mode 100644 index 000000000000..a081de095c75 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BigQueryIOTest.java @@ -0,0 +1,445 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.api.client.util.Data; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.CreateDisposition; +import com.google.cloud.dataflow.sdk.io.BigQueryIO.Write.WriteDisposition; +import com.google.cloud.dataflow.sdk.options.BigQueryOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.util.CoderUtils; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for BigQueryIO. + */ +@RunWith(JUnit4.class) +public class BigQueryIOTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private void checkReadTableObject( + BigQueryIO.Read.Bound bound, String project, String dataset, String table) { + checkReadTableObjectWithValidate(bound, project, dataset, table, true); + } + + private void checkReadQueryObject( + BigQueryIO.Read.Bound bound, String query) { + checkReadQueryObjectWithValidate(bound, query, true); + } + + private void checkReadTableObjectWithValidate( + BigQueryIO.Read.Bound bound, String project, String dataset, String table, boolean validate) { + assertEquals(project, bound.table.getProjectId()); + assertEquals(dataset, bound.table.getDatasetId()); + assertEquals(table, bound.table.getTableId()); + assertNull(bound.query); + assertEquals(validate, bound.getValidate()); + } + + private void checkReadQueryObjectWithValidate( + BigQueryIO.Read.Bound bound, String query, boolean validate) { + assertNull(bound.table); + assertEquals(query, bound.query); + assertEquals(validate, bound.getValidate()); + } + + private void checkWriteObject( + BigQueryIO.Write.Bound bound, String project, String dataset, String table, + TableSchema schema, CreateDisposition createDisposition, + WriteDisposition writeDisposition) { + checkWriteObjectWithValidate( + bound, project, dataset, table, schema, createDisposition, writeDisposition, true); + } + + private void checkWriteObjectWithValidate( + BigQueryIO.Write.Bound bound, String project, String dataset, String table, + TableSchema schema, CreateDisposition createDisposition, + WriteDisposition writeDisposition, boolean validate) { + assertEquals(project, bound.table.getProjectId()); + assertEquals(dataset, bound.table.getDatasetId()); + assertEquals(table, bound.table.getTableId()); + assertEquals(schema, bound.schema); + assertEquals(createDisposition, bound.createDisposition); + assertEquals(writeDisposition, bound.writeDisposition); + assertEquals(validate, bound.validate); + } + + @Before + public void setUp() { + BigQueryOptions options = PipelineOptionsFactory.as(BigQueryOptions.class); + options.setProject("defaultProject"); + } + + @Test + public void testBuildTableBasedSource() { + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from("foo.com:project:somedataset.sometable"); + checkReadTableObject(bound, "foo.com:project", "somedataset", "sometable"); + } + + @Test + public void testBuildQueryBasedSource() { + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyQuery") + .fromQuery("foo_query"); + checkReadQueryObject(bound, "foo_query"); + } + + @Test + public void testBuildTableBasedSourceWithoutValidation() { + // This test just checks that using withoutValidation will not trigger object + // construction errors. + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from("foo.com:project:somedataset.sometable").withoutValidation(); + checkReadTableObjectWithValidate(bound, "foo.com:project", "somedataset", "sometable", false); + } + + @Test + public void testBuildQueryBasedSourceWithoutValidation() { + // This test just checks that using withoutValidation will not trigger object + // construction errors. + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .fromQuery("some_query").withoutValidation(); + checkReadQueryObjectWithValidate(bound, "some_query", false); + } + + @Test + public void testBuildTableBasedSourceWithDefaultProject() { + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from("somedataset.sometable"); + checkReadTableObject(bound, null, "somedataset", "sometable"); + } + + @Test + public void testBuildSourceWithTableReference() { + TableReference table = new TableReference() + .setProjectId("foo.com:project") + .setDatasetId("somedataset") + .setTableId("sometable"); + BigQueryIO.Read.Bound bound = BigQueryIO.Read.named("ReadMyTable") + .from(table); + checkReadTableObject(bound, "foo.com:project", "somedataset", "sometable"); + } + + @Test + public void testValidateReadSetsDefaultProject() { + BigQueryOptions options = PipelineOptionsFactory.as(BigQueryOptions.class); + options.setProject("someproject"); + + Pipeline p = Pipeline.create(options); + + TableReference tableRef = new TableReference(); + tableRef.setDatasetId("somedataset"); + tableRef.setTableId("sometable"); + + thrown.expect(RuntimeException.class); + // Message will be one of following depending on the execution environment. + thrown.expectMessage( + Matchers.either(Matchers.containsString("Unable to confirm BigQuery dataset presence")) + .or(Matchers.containsString("BigQuery dataset not found for table"))); + try { + p.apply(BigQueryIO.Read.named("ReadMyTable").from(tableRef)); + } finally { + Assert.assertEquals("someproject", tableRef.getProjectId()); + } + } + + @Test + @Category(RunnableOnService.class) + public void testBuildSourceWithoutTableOrQuery() { + Pipeline p = TestPipeline.create(); + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "Invalid BigQuery read operation, either table reference or query has to be set"); + p.apply(BigQueryIO.Read.named("ReadMyTable")); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testBuildSourceWithTableAndQuery() { + Pipeline p = TestPipeline.create(); + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "Invalid BigQuery read operation. Specifies both a query and a table, only one of these" + + " should be provided"); + p.apply( + BigQueryIO.Read.named("ReadMyTable") + .from("foo.com:project:somedataset.sometable") + .fromQuery("query")); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testBuildSourceWithTableAndFlatten() { + Pipeline p = TestPipeline.create(); + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "Invalid BigQuery read operation. Specifies a" + + " table with a result flattening preference, which is not configurable"); + p.apply( + BigQueryIO.Read.named("ReadMyTable") + .from("foo.com:project:somedataset.sometable") + .withoutResultFlattening()); + p.run(); + } + + @Test + public void testBuildSink() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable"); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkwithoutValidation() { + // This test just checks that using withoutValidation will not trigger object + // construction errors. + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable").withoutValidation(); + checkWriteObjectWithValidate( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY, false); + } + + @Test + public void testBuildSinkDefaultProject() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("somedataset.sometable"); + checkWriteObject( + bound, null, "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithTableReference() { + TableReference table = new TableReference() + .setProjectId("foo.com:project") + .setDatasetId("somedataset") + .setTableId("sometable"); + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to(table); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + @Category(RunnableOnService.class) + public void testBuildSinkWithoutTable() { + Pipeline p = TestPipeline.create(); + thrown.expect(IllegalStateException.class); + thrown.expectMessage("must set the table reference"); + p.apply(Create.of().withCoder(TableRowJsonCoder.of())) + .apply(BigQueryIO.Write.named("WriteMyTable")); + } + + @Test + public void testBuildSinkWithSchema() { + TableSchema schema = new TableSchema(); + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable").withSchema(schema); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + schema, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithCreateDispositionNever() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withCreateDisposition(CreateDisposition.CREATE_NEVER); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_NEVER, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithCreateDispositionIfNeeded() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + @Test + public void testBuildSinkWithWriteDispositionTruncate() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_TRUNCATE); + } + + @Test + public void testBuildSinkWithWriteDispositionAppend() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withWriteDisposition(WriteDisposition.WRITE_APPEND); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_APPEND); + } + + @Test + public void testBuildSinkWithWriteDispositionEmpty() { + BigQueryIO.Write.Bound bound = BigQueryIO.Write.named("WriteMyTable") + .to("foo.com:project:somedataset.sometable") + .withWriteDisposition(WriteDisposition.WRITE_EMPTY); + checkWriteObject( + bound, "foo.com:project", "somedataset", "sometable", + null, CreateDisposition.CREATE_IF_NEEDED, WriteDisposition.WRITE_EMPTY); + } + + + private void testWriteValidatesDataset(boolean streaming) { + BigQueryOptions options = PipelineOptionsFactory.as(BigQueryOptions.class); + options.setProject("someproject"); + options.setStreaming(streaming); + + Pipeline p = Pipeline.create(options); + + TableReference tableRef = new TableReference(); + tableRef.setDatasetId("somedataset"); + tableRef.setTableId("sometable"); + + thrown.expect(RuntimeException.class); + // Message will be one of following depending on the execution environment. + thrown.expectMessage( + Matchers.either(Matchers.containsString("Unable to confirm BigQuery dataset presence")) + .or(Matchers.containsString("BigQuery dataset not found for table"))); + try { + p.apply(Create.of().withCoder(TableRowJsonCoder.of())) + .apply(BigQueryIO.Write.named("WriteMyTable") + .to(tableRef) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withSchema(new TableSchema())); + } finally { + Assert.assertEquals("someproject", tableRef.getProjectId()); + } + } + + @Test + public void testWriteValidatesDatasetBatch() { + testWriteValidatesDataset(false); + } + + @Test + public void testWriteValidatesDatasetStreaming() { + testWriteValidatesDataset(true); + } + + @Test + public void testTableParsing() { + TableReference ref = BigQueryIO + .parseTableSpec("my-project:data_set.table_name"); + Assert.assertEquals("my-project", ref.getProjectId()); + Assert.assertEquals("data_set", ref.getDatasetId()); + Assert.assertEquals("table_name", ref.getTableId()); + } + + @Test + public void testTableParsing_validPatterns() { + BigQueryIO.parseTableSpec("a123-456:foo_bar.d"); + BigQueryIO.parseTableSpec("a12345:b.c"); + BigQueryIO.parseTableSpec("b12345.c"); + } + + @Test + public void testTableParsing_noProjectId() { + TableReference ref = BigQueryIO + .parseTableSpec("data_set.table_name"); + Assert.assertEquals(null, ref.getProjectId()); + Assert.assertEquals("data_set", ref.getDatasetId()); + Assert.assertEquals("table_name", ref.getTableId()); + } + + @Test + public void testTableParsingError() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec("0123456:foo.bar"); + } + + @Test + public void testTableParsingError_2() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec("myproject:.bar"); + } + + @Test + public void testTableParsingError_3() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec(":a.b"); + } + + @Test + public void testTableParsingError_slash() { + thrown.expect(IllegalArgumentException.class); + BigQueryIO.parseTableSpec("a\\b12345:c.d"); + } + + // Test that BigQuery's special null placeholder objects can be encoded. + @Test + public void testCoder_nullCell() throws CoderException { + TableRow row = new TableRow(); + row.set("temperature", Data.nullOf(Object.class)); + row.set("max_temperature", Data.nullOf(Object.class)); + + byte[] bytes = CoderUtils.encodeToByteArray(TableRowJsonCoder.of(), row); + + TableRow newRow = CoderUtils.decodeFromByteArray(TableRowJsonCoder.of(), bytes); + byte[] newBytes = CoderUtils.encodeToByteArray(TableRowJsonCoder.of(), newRow); + + Assert.assertArrayEquals(bytes, newBytes); + } + + @Test + public void testBigQueryIOGetName() { + assertEquals("BigQueryIO.Read", BigQueryIO.Read.from("somedataset.sometable").getName()); + assertEquals("BigQueryIO.Write", BigQueryIO.Write.to("somedataset.sometable").getName()); + assertEquals("ReadMyTable", BigQueryIO.Read.named("ReadMyTable").getName()); + assertEquals("WriteMyTable", BigQueryIO.Write.named("WriteMyTable").getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java new file mode 100644 index 000000000000..d01c03634460 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.dataflow.TestCountingSource; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Unit tests for {@link BoundedReadFromUnboundedSource}. */ +@RunWith(JUnit4.class) +public class BoundedReadFromUnboundedSourceTest { + private static final int NUM_RECORDS = 100; + private static List finalizeTracker = null; + + @Test + @Category(RunnableOnService.class) + public void testNoDedup() throws Exception { + test(false, false); + } + + @Test + @Category(RunnableOnService.class) + public void testDedup() throws Exception { + test(true, false); + } + + @Test + @Category(RunnableOnService.class) + public void testTimeBound() throws Exception { + test(false, true); + } + + private static class Checker + implements SerializableFunction>, Void> { + private final boolean dedup; + private final boolean timeBound; + + Checker(boolean dedup, boolean timeBound) { + this.dedup = dedup; + this.timeBound = timeBound; + } + + @Override + public Void apply(Iterable> input) { + List values = new ArrayList<>(); + for (KV kv : input) { + assertEquals(0, (int) kv.getKey()); + values.add(kv.getValue()); + } + if (timeBound) { + assertTrue(values.size() > 2); + } else if (dedup) { + // Verify that at least some data came through. The chance of 90% of the input + // being duplicates is essentially zero. + assertTrue(values.size() > NUM_RECORDS / 10 && values.size() <= NUM_RECORDS); + } else { + assertEquals(NUM_RECORDS, values.size()); + } + Collections.sort(values); + for (int i = 0; i < values.size(); i++) { + assertEquals(i, (int) values.get(i)); + } + if (finalizeTracker != null) { + assertThat(finalizeTracker, containsInAnyOrder(values.size() - 1)); + } + return null; + } + } + + private void test(boolean dedup, boolean timeBound) throws Exception { + Pipeline p = TestPipeline.create(); + + if (p.getOptions().getRunner() == DirectPipelineRunner.class) { + finalizeTracker = new ArrayList<>(); + TestCountingSource.setFinalizeTracker(finalizeTracker); + } + TestCountingSource source = new TestCountingSource(Integer.MAX_VALUE); + if (dedup) { + source = source.withDedup(); + } + PCollection> output = + timeBound + ? p.apply(Read.from(source).withMaxReadTime(Duration.millis(200))) + : p.apply(Read.from(source).withMaxNumRecords(NUM_RECORDS)); + + List> expectedOutput = new ArrayList<>(); + for (int i = 0; i < NUM_RECORDS; i++) { + expectedOutput.add(KV.of(0, i)); + } + + // Because some of the NUM_RECORDS elements read are dupes, the final output + // will only have output from 0 to n where n < NUM_RECORDS. + DataflowAssert.that(output).satisfies(new Checker(dedup, timeBound)); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CompressedSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CompressedSourceTest.java new file mode 100644 index 000000000000..14c8fe9acad9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CompressedSourceTest.java @@ -0,0 +1,430 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.io.CompressedSource.CompressionMode; +import com.google.cloud.dataflow.sdk.io.CompressedSource.DecompressingChannelFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.SourceTestUtils; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.io.Files; +import com.google.common.primitives.Bytes; + +import org.apache.commons.compress.compressors.bzip2.BZip2CompressorOutputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorOutputStream; +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Random; + +import javax.annotation.Nullable; + +/** + * Tests for CompressedSource. + */ +@RunWith(JUnit4.class) +public class CompressedSourceTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + /** + * Test reading nonempty input with gzip. + */ + @Test + public void testReadGzip() throws Exception { + byte[] input = generateInput(5000); + runReadTest(input, CompressionMode.GZIP); + } + + /** + * Test reading nonempty input with bzip2. + */ + @Test + public void testReadBzip2() throws Exception { + byte[] input = generateInput(5000); + runReadTest(input, CompressionMode.BZIP2); + } + + /** + * Test reading empty input with gzip. + */ + @Test + public void testEmptyReadGzip() throws Exception { + byte[] input = generateInput(0); + runReadTest(input, CompressionMode.GZIP); + } + + /** + * Test reading empty input with bzip2. + */ + @Test + public void testCompressedReadBzip2() throws Exception { + byte[] input = generateInput(0); + runReadTest(input, CompressionMode.BZIP2); + } + + /** + * Test reading according to filepattern when the file is bzipped. + */ + @Test + public void testCompressedAccordingToFilepatternGzip() throws Exception { + byte[] input = generateInput(100); + File tmpFile = tmpFolder.newFile("test.gz"); + writeFile(tmpFile, input, CompressionMode.GZIP); + verifyReadContents(input, tmpFile, null /* default auto decompression factory */); + } + + /** + * Test reading according to filepattern when the file is gzipped. + */ + @Test + public void testCompressedAccordingToFilepatternBzip2() throws Exception { + byte[] input = generateInput(100); + File tmpFile = tmpFolder.newFile("test.bz2"); + writeFile(tmpFile, input, CompressionMode.BZIP2); + verifyReadContents(input, tmpFile, null /* default auto decompression factory */); + } + + /** + * Test reading multiple files with different compression. + */ + @Test + public void testHeterogeneousCompression() throws Exception { + String baseName = "test-input"; + + // Expected data + byte[] generated = generateInput(1000); + List expected = new ArrayList<>(); + + // Every sort of compression + File uncompressedFile = tmpFolder.newFile(baseName + ".bin"); + generated = generateInput(1000); + Files.write(generated, uncompressedFile); + expected.addAll(Bytes.asList(generated)); + + File gzipFile = tmpFolder.newFile(baseName + ".gz"); + generated = generateInput(1000); + writeFile(gzipFile, generated, CompressionMode.GZIP); + expected.addAll(Bytes.asList(generated)); + + File bzip2File = tmpFolder.newFile(baseName + ".bz2"); + generated = generateInput(1000); + writeFile(bzip2File, generateInput(1000), CompressionMode.BZIP2); + expected.addAll(Bytes.asList(generated)); + + String filePattern = new File(tmpFolder.getRoot().toString(), baseName + ".*").toString(); + + Pipeline p = TestPipeline.create(); + + CompressedSource source = + CompressedSource.from(new ByteSource(filePattern, 1)); + PCollection output = p.apply(Read.from(source)); + + DataflowAssert.that(output).containsInAnyOrder(expected); + p.run(); + } + + @Test + public void testUncompressedFileIsSplittable() throws Exception { + String baseName = "test-input"; + + File uncompressedFile = tmpFolder.newFile(baseName + ".bin"); + Files.write(generateInput(10), uncompressedFile); + + CompressedSource source = + CompressedSource.from(new ByteSource(uncompressedFile.getPath(), 1)); + assertTrue(source.isSplittable()); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testGzipFileIsNotSplittable() throws Exception { + String baseName = "test-input"; + + File compressedFile = tmpFolder.newFile(baseName + ".gz"); + writeFile(compressedFile, generateInput(10), CompressionMode.GZIP); + + CompressedSource source = + CompressedSource.from(new ByteSource(compressedFile.getPath(), 1)); + assertFalse(source.isSplittable()); + } + + @Test + public void testBzip2FileIsNotSplittable() throws Exception { + String baseName = "test-input"; + + File compressedFile = tmpFolder.newFile(baseName + ".bz2"); + writeFile(compressedFile, generateInput(10), CompressionMode.BZIP2); + + CompressedSource source = + CompressedSource.from(new ByteSource(compressedFile.getPath(), 1)); + assertFalse(source.isSplittable()); + } + + /** + * Test reading an uncompressed file with {@link CompressionMode#GZIP}, since we must support + * this due to properties of services that we read from. + */ + @Test + public void testFalseGzipStream() throws Exception { + byte[] input = generateInput(1000); + File tmpFile = tmpFolder.newFile("test.gz"); + Files.write(input, tmpFile); + verifyReadContents(input, tmpFile, CompressionMode.GZIP); + } + + /** + * Test reading an uncompressed file with {@link CompressionMode#BZIP2}, and show that + * we fail. + */ + @Test + public void testFalseBzip2Stream() throws Exception { + byte[] input = generateInput(1000); + File tmpFile = tmpFolder.newFile("test.bz2"); + Files.write(input, tmpFile); + thrown.expectCause(Matchers.allOf( + instanceOf(IOException.class), + ThrowableMessageMatcher.hasMessage( + containsString("Stream is not in the BZip2 format")))); + verifyReadContents(input, tmpFile, CompressionMode.BZIP2); + } + + /** + * Test reading an empty input file with gzip; it must be interpreted as uncompressed because + * the gzip header is two bytes. + */ + @Test + public void testEmptyReadGzipUncompressed() throws Exception { + byte[] input = generateInput(0); + File tmpFile = tmpFolder.newFile("test.gz"); + Files.write(input, tmpFile); + verifyReadContents(input, tmpFile, CompressionMode.GZIP); + } + + /** + * Test reading single byte input with gzip; it must be interpreted as uncompressed because + * the gzip header is two bytes. + */ + @Test + public void testOneByteReadGzipUncompressed() throws Exception { + byte[] input = generateInput(1); + File tmpFile = tmpFolder.newFile("test.gz"); + Files.write(input, tmpFile); + verifyReadContents(input, tmpFile, CompressionMode.GZIP); + } + + /** + * Test reading multiple files. + */ + @Test + public void testCompressedReadMultipleFiles() throws Exception { + int numFiles = 10; + String baseName = "test_input-"; + String filePattern = new File(tmpFolder.getRoot().toString(), baseName + "*").toString(); + List expected = new ArrayList<>(); + + for (int i = 0; i < numFiles; i++) { + byte[] generated = generateInput(1000); + File tmpFile = tmpFolder.newFile(baseName + i); + writeFile(tmpFile, generated, CompressionMode.GZIP); + expected.addAll(Bytes.asList(generated)); + } + + Pipeline p = TestPipeline.create(); + + CompressedSource source = + CompressedSource.from(new ByteSource(filePattern, 1)) + .withDecompression(CompressionMode.GZIP); + PCollection output = p.apply(Read.from(source)); + + DataflowAssert.that(output).containsInAnyOrder(expected); + p.run(); + } + + /** + * Generate byte array of given size. + */ + private byte[] generateInput(int size) { + // Arbitrary but fixed seed + Random random = new Random(285930); + byte[] buff = new byte[size]; + for (int i = 0; i < size; i++) { + buff[i] = (byte) (random.nextInt() % Byte.MAX_VALUE); + } + return buff; + } + + /** + * Get a compressing stream for a given compression mode. + */ + private OutputStream getOutputStreamForMode(CompressionMode mode, OutputStream stream) + throws IOException { + switch (mode) { + case GZIP: + return new GzipCompressorOutputStream(stream); + case BZIP2: + return new BZip2CompressorOutputStream(stream); + default: + throw new RuntimeException("Unexpected compression mode"); + } + } + + /** + * Writes a single output file. + */ + private void writeFile(File file, byte[] input, CompressionMode mode) throws IOException { + try (OutputStream os = getOutputStreamForMode(mode, new FileOutputStream(file))) { + os.write(input); + } + } + + /** + * Run a single read test, writing and reading back input with the given compression mode. + */ + private void runReadTest(byte[] input, + CompressionMode inputCompressionMode, + @Nullable DecompressingChannelFactory decompressionFactory) + throws IOException { + File tmpFile = tmpFolder.newFile(); + writeFile(tmpFile, input, inputCompressionMode); + verifyReadContents(input, tmpFile, decompressionFactory); + } + + private void verifyReadContents(byte[] expected, File inputFile, + @Nullable DecompressingChannelFactory decompressionFactory) { + Pipeline p = TestPipeline.create(); + CompressedSource source = + CompressedSource.from(new ByteSource(inputFile.toPath().toString(), 1)); + if (decompressionFactory != null) { + source = source.withDecompression(decompressionFactory); + } + PCollection output = p.apply(Read.from(source)); + DataflowAssert.that(output).containsInAnyOrder(Bytes.asList(expected)); + p.run(); + } + + /** + * Run a single read test, writing and reading back input with the given compression mode. + */ + private void runReadTest(byte[] input, CompressionMode mode) throws IOException { + runReadTest(input, mode, mode); + } + + /** + * Dummy source for use in tests. + */ + private static class ByteSource extends FileBasedSource { + public ByteSource(String fileOrPatternSpec, long minBundleSize) { + super(fileOrPatternSpec, minBundleSize); + } + + public ByteSource(String fileName, long minBundleSize, long startOffset, long endOffset) { + super(fileName, minBundleSize, startOffset, endOffset); + } + + @Override + protected FileBasedSource createForSubrangeOfFile(String fileName, long start, long end) { + return new ByteSource(fileName, getMinBundleSize(), start, end); + } + + @Override + protected FileBasedReader createSingleFileReader(PipelineOptions options) { + return new ByteReader(this); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public Coder getDefaultOutputCoder() { + return SerializableCoder.of(Byte.class); + } + + private static class ByteReader extends FileBasedReader { + ByteBuffer buff = ByteBuffer.allocate(1); + Byte current; + long offset = -1; + ReadableByteChannel channel; + + public ByteReader(ByteSource source) { + super(source); + } + + @Override + public Byte getCurrent() throws NoSuchElementException { + return current; + } + + @Override + protected boolean isAtSplitPoint() { + return true; + } + + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + this.channel = channel; + } + + @Override + protected boolean readNextRecord() throws IOException { + buff.clear(); + if (channel.read(buff) != 1) { + return false; + } + current = new Byte(buff.get(0)); + offset += 1; + return true; + } + + @Override + protected long getCurrentOffset() { + return offset; + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java new file mode 100644 index 000000000000..cc9db7978f78 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java @@ -0,0 +1,216 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.CountingSource.CounterMark; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** + * Tests of {@link CountingSource}. + */ +@RunWith(JUnit4.class) +public class CountingSourceTest { + + public static void addCountingAsserts(PCollection input, long numElements) { + // Count == numElements + DataflowAssert + .thatSingleton(input.apply("Count", Count.globally())) + .isEqualTo(numElements); + // Unique count == numElements + DataflowAssert + .thatSingleton(input.apply(RemoveDuplicates.create()) + .apply("UniqueCount", Count.globally())) + .isEqualTo(numElements); + // Min == 0 + DataflowAssert + .thatSingleton(input.apply("Min", Min.globally())) + .isEqualTo(0L); + // Max == numElements-1 + DataflowAssert + .thatSingleton(input.apply("Max", Max.globally())) + .isEqualTo(numElements - 1); + } + + @Test + @Category(RunnableOnService.class) + public void testBoundedSource() { + Pipeline p = TestPipeline.create(); + long numElements = 1000; + PCollection input = p.apply(Read.from(CountingSource.upTo(numElements))); + + addCountingAsserts(input, numElements); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testBoundedSourceSplits() throws Exception { + Pipeline p = TestPipeline.create(); + long numElements = 1000; + long numSplits = 10; + long splitSizeBytes = numElements * 8 / numSplits; // 8 bytes per long element. + + BoundedSource initial = CountingSource.upTo(numElements); + List> splits = + initial.splitIntoBundles(splitSizeBytes, p.getOptions()); + assertEquals("Expected exact splitting", numSplits, splits.size()); + + // Assemble all the splits into one flattened PCollection, also verify their sizes. + PCollectionList pcollections = PCollectionList.empty(p); + for (int i = 0; i < splits.size(); ++i) { + BoundedSource split = splits.get(i); + pcollections = pcollections.and(p.apply("split" + i, Read.from(split))); + assertEquals("Expected even splitting", + splitSizeBytes, split.getEstimatedSizeBytes(p.getOptions())); + } + PCollection input = pcollections.apply(Flatten.pCollections()); + + addCountingAsserts(input, numElements); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSource() { + Pipeline p = TestPipeline.create(); + long numElements = 1000; + + PCollection input = p + .apply(Read.from(CountingSource.unbounded()).withMaxNumRecords(numElements)); + + addCountingAsserts(input, numElements); + p.run(); + } + + private static class ElementValueDiff extends DoFn { + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element() - c.timestamp().getMillis()); + } + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSourceTimestamps() { + Pipeline p = TestPipeline.create(); + long numElements = 1000; + + PCollection input = p.apply( + Read.from(CountingSource.unboundedWithTimestampFn(new ValueAsTimestampFn())) + .withMaxNumRecords(numElements)); + addCountingAsserts(input, numElements); + + PCollection diffs = input + .apply("TimestampDiff", ParDo.of(new ElementValueDiff())) + .apply("RemoveDuplicateTimestamps", RemoveDuplicates.create()); + // This assert also confirms that diffs only has one unique value. + DataflowAssert.thatSingleton(diffs).isEqualTo(0L); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSourceSplits() throws Exception { + Pipeline p = TestPipeline.create(); + long numElements = 1000; + int numSplits = 10; + + UnboundedSource initial = CountingSource.unbounded(); + List> splits = + initial.generateInitialSplits(numSplits, p.getOptions()); + assertEquals("Expected exact splitting", numSplits, splits.size()); + + long elementsPerSplit = numElements / numSplits; + assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits); + PCollectionList pcollections = PCollectionList.empty(p); + for (int i = 0; i < splits.size(); ++i) { + pcollections = pcollections.and( + p.apply("split" + i, Read.from(splits.get(i)).withMaxNumRecords(elementsPerSplit))); + } + PCollection input = pcollections.apply(Flatten.pCollections()); + + addCountingAsserts(input, numElements); + p.run(); + } + + /** + * A timestamp function that uses the given value as the timestamp. Because the input values will + * not wrap, this function is non-decreasing and meets the timestamp function criteria laid out + * in {@link CountingSource#unboundedWithTimestampFn(SerializableFunction)}. + */ + private static class ValueAsTimestampFn implements SerializableFunction { + @Override + public Instant apply(Long input) { + return new Instant(input); + } + } + + @Test + public void testUnboundedSourceCheckpointMark() throws Exception { + UnboundedSource source = + CountingSource.unboundedWithTimestampFn(new ValueAsTimestampFn()); + UnboundedReader reader = source.createReader(null, null); + final long numToSkip = 3; + assertTrue(reader.start()); + + // Advance the source numToSkip elements and manually save state. + for (long l = 0; l < numToSkip; ++l) { + reader.advance(); + } + + // Confirm that we get the expected element in sequence before checkpointing. + assertEquals(numToSkip, (long) reader.getCurrent()); + assertEquals(numToSkip, reader.getCurrentTimestamp().getMillis()); + + // Checkpoint and restart, and confirm that the source continues correctly. + CounterMark mark = CoderUtils.clone( + source.getCheckpointMarkCoder(), (CounterMark) reader.getCheckpointMark()); + reader = source.createReader(null, mark); + assertTrue(reader.start()); + + // Confirm that we get the next element in sequence. + assertEquals(numToSkip + 1, (long) reader.getCurrent()); + assertEquals(numToSkip + 1, reader.getCurrentTimestamp().getMillis()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/DatastoreIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/DatastoreIOTest.java new file mode 100644 index 000000000000..4cc3ace1b578 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/DatastoreIOTest.java @@ -0,0 +1,631 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.api.services.datastore.client.DatastoreHelper.makeKey; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.services.datastore.DatastoreV1.Entity; +import com.google.api.services.datastore.DatastoreV1.EntityResult; +import com.google.api.services.datastore.DatastoreV1.Key; +import com.google.api.services.datastore.DatastoreV1.KindExpression; +import com.google.api.services.datastore.DatastoreV1.PartitionId; +import com.google.api.services.datastore.DatastoreV1.PropertyFilter; +import com.google.api.services.datastore.DatastoreV1.Query; +import com.google.api.services.datastore.DatastoreV1.QueryResultBatch; +import com.google.api.services.datastore.DatastoreV1.RunQueryRequest; +import com.google.api.services.datastore.DatastoreV1.RunQueryResponse; +import com.google.api.services.datastore.DatastoreV1.Value; +import com.google.api.services.datastore.client.Datastore; +import com.google.api.services.datastore.client.DatastoreHelper; +import com.google.api.services.datastore.client.QuerySplitter; +import com.google.cloud.dataflow.sdk.io.DatastoreIO.DatastoreReader; +import com.google.cloud.dataflow.sdk.io.DatastoreIO.DatastoreWriter; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.common.collect.Lists; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +/** + * Tests for {@link DatastoreIO}. + */ +@RunWith(JUnit4.class) +public class DatastoreIOTest { + private static final String HOST = "testHost"; + private static final String DATASET = "testDataset"; + private static final String NAMESPACE = "testNamespace"; + private static final String KIND = "testKind"; + private static final Query QUERY; + static { + Query.Builder q = Query.newBuilder(); + q.addKindBuilder().setName(KIND); + QUERY = q.build(); + } + private DatastoreIO.Source initialSource; + + @Mock + Datastore mockDatastore; + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + @Rule public final ExpectedLogs logged = ExpectedLogs.none(DatastoreIO.Source.class); + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + initialSource = DatastoreIO.source() + .withHost(HOST).withDataset(DATASET).withQuery(QUERY).withNamespace(NAMESPACE); + } + + /** + * Helper function to create a test {@code DataflowPipelineOptions}. + */ + static final DataflowPipelineOptions testPipelineOptions(@Nullable Integer numWorkers) { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + if (numWorkers != null) { + options.setNumWorkers(numWorkers); + } + return options; + } + + @Test + public void testBuildSource() throws Exception { + DatastoreIO.Source source = DatastoreIO.source() + .withHost(HOST).withDataset(DATASET).withQuery(QUERY).withNamespace(NAMESPACE); + assertEquals(QUERY, source.getQuery()); + assertEquals(DATASET, source.getDataset()); + assertEquals(HOST, source.getHost()); + assertEquals(NAMESPACE, source.getNamespace()); + } + + /** + * {@link #testBuildSource} but constructed in a different order. + */ + @Test + public void testBuildSourceAlt() throws Exception { + DatastoreIO.Source source = DatastoreIO.source() + .withDataset(DATASET).withNamespace(NAMESPACE).withQuery(QUERY).withHost(HOST); + assertEquals(QUERY, source.getQuery()); + assertEquals(DATASET, source.getDataset()); + assertEquals(HOST, source.getHost()); + assertEquals(NAMESPACE, source.getNamespace()); + } + + @Test + public void testSourceValidationFailsHost() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage("host"); + + DatastoreIO.Source source = initialSource.withHost(null); + source.validate(); + } + + @Test + public void testSourceValidationFailsDataset() throws Exception { + DatastoreIO.Source source = DatastoreIO.source().withQuery(QUERY); + thrown.expect(NullPointerException.class); + thrown.expectMessage("dataset"); + source.validate(); + } + + @Test + public void testSourceValidationFailsQuery() throws Exception { + DatastoreIO.Source source = DatastoreIO.source().withDataset(DATASET); + thrown.expect(NullPointerException.class); + thrown.expectMessage("query"); + source.validate(); + } + + @Test + public void testSourceValidationFailsQueryLimitZero() throws Exception { + Query invalidLimit = Query.newBuilder().setLimit(0).build(); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Invalid query limit 0"); + + DatastoreIO.source().withQuery(invalidLimit); + } + + @Test + public void testSourceValidationFailsQueryLimitNegative() throws Exception { + Query invalidLimit = Query.newBuilder().setLimit(-5).build(); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Invalid query limit -5"); + + DatastoreIO.source().withQuery(invalidLimit); + } + + @Test + public void testSourceValidationSucceedsNamespace() throws Exception { + DatastoreIO.Source source = DatastoreIO.source().withDataset(DATASET).withQuery(QUERY); + /* Should succeed, as a null namespace is fine. */ + source.validate(); + } + + @Test + public void testSinkDoesNotAllowNullHost() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage("host"); + + DatastoreIO.sink().withDataset(DATASET).withHost(null); + } + + @Test + public void testSinkDoesNotAllowNullDataset() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage("datasetId"); + + DatastoreIO.sink().withDataset(null); + } + + @Test + public void testSinkValidationFailsWithNoDataset() throws Exception { + DatastoreIO.Sink sink = DatastoreIO.sink(); + + thrown.expect(NullPointerException.class); + thrown.expectMessage("Dataset"); + + sink.validate(testPipelineOptions(null)); + } + + @Test + public void testSinkValidationSucceedsWithDataset() throws Exception { + DatastoreIO.Sink sink = DatastoreIO.sink().withDataset(DATASET); + sink.validate(testPipelineOptions(null)); + } + + @Test + public void testQuerySplitBasic() throws Exception { + KindExpression mykind = KindExpression.newBuilder().setName("mykind").build(); + Query query = Query.newBuilder().addKind(mykind).build(); + + List mockSplits = new ArrayList<>(); + for (int i = 0; i < 8; ++i) { + mockSplits.add( + Query.newBuilder() + .addKind(mykind) + .setFilter( + DatastoreHelper.makeFilter("foo", PropertyFilter.Operator.EQUAL, + Value.newBuilder().setIntegerValue(i).build())) + .build()); + } + + QuerySplitter splitter = mock(QuerySplitter.class); + /* No namespace */ + PartitionId partition = PartitionId.newBuilder().build(); + when(splitter.getSplits(any(Query.class), eq(partition), eq(8), any(Datastore.class))) + .thenReturn(mockSplits); + + DatastoreIO.Source io = initialSource + .withNamespace(null) + .withQuery(query) + .withMockSplitter(splitter) + .withMockEstimateSizeBytes(8 * 1024L); + + List bundles = io.splitIntoBundles(1024, testPipelineOptions(null)); + assertEquals(8, bundles.size()); + for (int i = 0; i < 8; ++i) { + DatastoreIO.Source bundle = bundles.get(i); + Query bundleQuery = bundle.getQuery(); + assertEquals("mykind", bundleQuery.getKind(0).getName()); + assertEquals(i, bundleQuery.getFilter().getPropertyFilter().getValue().getIntegerValue()); + } + } + + /** + * Verifies that when namespace is set in the source, the split request includes the namespace. + */ + @Test + public void testSourceWithNamespace() throws Exception { + QuerySplitter splitter = mock(QuerySplitter.class); + DatastoreIO.Source io = initialSource + .withMockSplitter(splitter) + .withMockEstimateSizeBytes(8 * 1024L); + + io.splitIntoBundles(1024, testPipelineOptions(null)); + + PartitionId partition = PartitionId.newBuilder().setNamespace(NAMESPACE).build(); + verify(splitter).getSplits(eq(QUERY), eq(partition), eq(8), any(Datastore.class)); + verifyNoMoreInteractions(splitter); + } + + @Test + public void testQuerySplitWithZeroSize() throws Exception { + KindExpression mykind = KindExpression.newBuilder().setName("mykind").build(); + Query query = Query.newBuilder().addKind(mykind).build(); + + List mockSplits = Lists.newArrayList( + Query.newBuilder() + .addKind(mykind) + .build()); + + QuerySplitter splitter = mock(QuerySplitter.class); + when(splitter.getSplits(any(Query.class), any(PartitionId.class), eq(1), any(Datastore.class))) + .thenReturn(mockSplits); + + DatastoreIO.Source io = initialSource + .withQuery(query) + .withMockSplitter(splitter) + .withMockEstimateSizeBytes(0L); + + List bundles = io.splitIntoBundles(1024, testPipelineOptions(null)); + assertEquals(1, bundles.size()); + verify(splitter, never()) + .getSplits(any(Query.class), any(PartitionId.class), eq(1), any(Datastore.class)); + DatastoreIO.Source bundle = bundles.get(0); + Query bundleQuery = bundle.getQuery(); + assertEquals("mykind", bundleQuery.getKind(0).getName()); + assertFalse(bundleQuery.hasFilter()); + } + + /** + * Tests that a query with a user-provided limit field does not split, and does not even + * interact with a query splitter. + */ + @Test + public void testQueryDoesNotSplitWithLimitSet() throws Exception { + // Minimal query with a limit + Query query = Query.newBuilder().setLimit(5).build(); + + // Mock query splitter, should not be invoked. + QuerySplitter splitter = mock(QuerySplitter.class); + when(splitter.getSplits(any(Query.class), any(PartitionId.class), eq(2), any(Datastore.class))) + .thenThrow(new AssertionError("Splitter should not be invoked")); + + List bundles = + initialSource + .withQuery(query) + .withMockSplitter(splitter) + .splitIntoBundles(1024, testPipelineOptions(null)); + + assertEquals(1, bundles.size()); + assertEquals(query, bundles.get(0).getQuery()); + verifyNoMoreInteractions(splitter); + } + + /** + * Tests that when {@link QuerySplitter} cannot split a query, {@link DatastoreIO} falls back to + * a single split. + */ + @Test + public void testQuerySplitterThrows() throws Exception { + // Mock query splitter that throws IllegalArgumentException + IllegalArgumentException exception = + new IllegalArgumentException("query not supported by splitter"); + QuerySplitter splitter = mock(QuerySplitter.class); + when( + splitter.getSplits( + any(Query.class), any(PartitionId.class), any(Integer.class), any(Datastore.class))) + .thenThrow(exception); + + Query query = Query.newBuilder().addKind(KindExpression.newBuilder().setName("myKind")).build(); + List bundles = + initialSource + .withQuery(query) + .withMockSplitter(splitter) + .withMockEstimateSizeBytes(10240L) + .splitIntoBundles(1024, testPipelineOptions(null)); + + assertEquals(1, bundles.size()); + assertEquals(query, bundles.get(0).getQuery()); + verify(splitter, times(1)) + .getSplits( + any(Query.class), any(PartitionId.class), any(Integer.class), any(Datastore.class)); + logged.verifyWarn("Unable to parallelize the given query", exception); + } + + @Test + public void testQuerySplitSizeUnavailable() throws Exception { + KindExpression mykind = KindExpression.newBuilder().setName("mykind").build(); + Query query = Query.newBuilder().addKind(mykind).build(); + + List mockSplits = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + mockSplits.add( + Query.newBuilder() + .addKind(mykind) + .setFilter( + DatastoreHelper.makeFilter("foo", PropertyFilter.Operator.EQUAL, + Value.newBuilder().setIntegerValue(i).build())) + .build()); + } + + QuerySplitter splitter = mock(QuerySplitter.class); + when(splitter.getSplits(any(Query.class), any(PartitionId.class), eq(2), any(Datastore.class))) + .thenReturn(mockSplits); + + DatastoreIO.Source io = initialSource + .withQuery(query) + .withMockSplitter(splitter) + .withMockEstimateSizeBytes(8 * 1024L); + + DatastoreIO.Source spiedIo = spy(io); + when(spiedIo.getEstimatedSizeBytes(any(PipelineOptions.class))).thenThrow(new IOException()); + + List bundles = spiedIo.splitIntoBundles(1024, testPipelineOptions(2)); + assertEquals(2, bundles.size()); + for (int i = 0; i < 2; ++i) { + DatastoreIO.Source bundle = bundles.get(i); + Query bundleQuery = bundle.getQuery(); + assertEquals("mykind", bundleQuery.getKind(0).getName()); + assertEquals(i, bundleQuery.getFilter().getPropertyFilter().getValue().getIntegerValue()); + } + } + + @Test + public void testQuerySplitNoWorkers() throws Exception { + KindExpression mykind = KindExpression.newBuilder().setName("mykind").build(); + Query query = Query.newBuilder().addKind(mykind).build(); + + List mockSplits = Lists.newArrayList(Query.newBuilder().addKind(mykind).build()); + + QuerySplitter splitter = mock(QuerySplitter.class); + when(splitter.getSplits(any(Query.class), any(PartitionId.class), eq(12), any(Datastore.class))) + .thenReturn(mockSplits); + + DatastoreIO.Source io = initialSource + .withQuery(query) + .withMockSplitter(splitter) + .withMockEstimateSizeBytes(8 * 1024L); + + DatastoreIO.Source spiedIo = spy(io); + when(spiedIo.getEstimatedSizeBytes(any(PipelineOptions.class))) + .thenThrow(new NoSuchElementException()); + + List bundles = spiedIo.splitIntoBundles(1024, testPipelineOptions(0)); + assertEquals(1, bundles.size()); + verify(splitter, never()) + .getSplits(any(Query.class), any(PartitionId.class), eq(1), any(Datastore.class)); + DatastoreIO.Source bundle = bundles.get(0); + Query bundleQuery = bundle.getQuery(); + assertEquals("mykind", bundleQuery.getKind(0).getName()); + assertFalse(bundleQuery.hasFilter()); + } + + /** + * Test building a Sink using builder methods. + */ + @Test + public void testBuildSink() throws Exception { + DatastoreIO.Sink sink = DatastoreIO.sink().withDataset(DATASET).withHost(HOST); + assertEquals(HOST, sink.host); + assertEquals(DATASET, sink.datasetId); + + sink = DatastoreIO.sink().withHost(HOST).withDataset(DATASET); + assertEquals(HOST, sink.host); + assertEquals(DATASET, sink.datasetId); + + sink = DatastoreIO.sink().withDataset(DATASET).withHost(HOST); + assertEquals(HOST, sink.host); + assertEquals(DATASET, sink.datasetId); + } + + /** + * Test building a sink using the default host. + */ + @Test + public void testBuildSinkDefaults() throws Exception { + DatastoreIO.Sink sink = DatastoreIO.sink().withDataset(DATASET); + assertEquals(DatastoreIO.DEFAULT_HOST, sink.host); + assertEquals(DATASET, sink.datasetId); + + sink = DatastoreIO.sink().withDataset(DATASET); + assertEquals(DatastoreIO.DEFAULT_HOST, sink.host); + assertEquals(DATASET, sink.datasetId); + } + + /** + * Test the detection of complete and incomplete keys. + */ + @Test + public void testHasNameOrId() { + Key key; + // Complete with name, no ancestor + key = DatastoreHelper.makeKey("bird", "finch").build(); + assertTrue(DatastoreWriter.isValidKey(key)); + + // Complete with id, no ancestor + key = DatastoreHelper.makeKey("bird", 123).build(); + assertTrue(DatastoreWriter.isValidKey(key)); + + // Incomplete, no ancestor + key = DatastoreHelper.makeKey("bird").build(); + assertFalse(DatastoreWriter.isValidKey(key)); + + // Complete with name and ancestor + key = DatastoreHelper.makeKey("bird", "owl").build(); + key = DatastoreHelper.makeKey(key, "bird", "horned").build(); + assertTrue(DatastoreWriter.isValidKey(key)); + + // Complete with id and ancestor + key = DatastoreHelper.makeKey("bird", "owl").build(); + key = DatastoreHelper.makeKey(key, "bird", 123).build(); + assertTrue(DatastoreWriter.isValidKey(key)); + + // Incomplete with ancestor + key = DatastoreHelper.makeKey("bird", "owl").build(); + key = DatastoreHelper.makeKey(key, "bird").build(); + assertFalse(DatastoreWriter.isValidKey(key)); + + key = DatastoreHelper.makeKey().build(); + assertFalse(DatastoreWriter.isValidKey(key)); + } + + /** + * Test that entities with incomplete keys cannot be updated. + */ + @Test + public void testAddEntitiesWithIncompleteKeys() throws Exception { + Key key = DatastoreHelper.makeKey("bird").build(); + Entity entity = Entity.newBuilder().setKey(key).build(); + DatastoreWriter writer = new DatastoreIO.DatastoreWriter(null, mockDatastore); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Entities to be written to the Datastore must have complete keys"); + + writer.write(entity); + } + + /** + * Test that entities are added to the batch to update. + */ + @Test + public void testAddingEntities() throws Exception { + List expected = Lists.newArrayList( + Entity.newBuilder().setKey(DatastoreHelper.makeKey("bird", "jay").build()).build(), + Entity.newBuilder().setKey(DatastoreHelper.makeKey("bird", "condor").build()).build(), + Entity.newBuilder().setKey(DatastoreHelper.makeKey("bird", "robin").build()).build()); + + List allEntities = Lists.newArrayList(expected); + Collections.shuffle(allEntities); + + DatastoreWriter writer = new DatastoreIO.DatastoreWriter(null, mockDatastore); + writer.open("test_id"); + for (Entity entity : allEntities) { + writer.write(entity); + } + + assertEquals(expected.size(), writer.entities.size()); + assertThat(writer.entities, containsInAnyOrder(expected.toArray())); + } + + /** Datastore batch API limit in number of records per query. */ + private static final int DATASTORE_QUERY_BATCH_LIMIT = 500; + + /** + * A helper function that creates mock {@link Entity} results in response to a query. Always + * indicates that more results are available, unless the batch is limited to fewer than + * {@link #DATASTORE_QUERY_BATCH_LIMIT} results. + */ + private static RunQueryResponse mockResponseForQuery(Query q) { + // Every query DatastoreIO sends should have a limit. + assertTrue(q.hasLimit()); + + // The limit should be in the range [1, DATASTORE_QUERY_BATCH_LIMIT] + int limit = q.getLimit(); + assertThat(limit, greaterThanOrEqualTo(1)); + assertThat(limit, lessThanOrEqualTo(DATASTORE_QUERY_BATCH_LIMIT)); + + // Create the requested number of entities. + List entities = new ArrayList<>(limit); + for (int i = 0; i < limit; ++i) { + entities.add( + EntityResult.newBuilder() + .setEntity(Entity.newBuilder().setKey(makeKey("key" + i, i + 1))) + .build()); + } + + // Fill out the other parameters on the returned result batch. + RunQueryResponse.Builder ret = RunQueryResponse.newBuilder(); + ret.getBatchBuilder() + .addAllEntityResult(entities) + .setEntityResultType(EntityResult.ResultType.FULL) + .setMoreResults( + limit == DATASTORE_QUERY_BATCH_LIMIT + ? QueryResultBatch.MoreResultsType.NOT_FINISHED + : QueryResultBatch.MoreResultsType.NO_MORE_RESULTS); + + return ret.build(); + } + + /** Helper function to run a test reading from a limited-result query. */ + private void runQueryLimitReadTest(int numEntities) throws Exception { + // An empty query to read entities. + Query query = Query.newBuilder().setLimit(numEntities).build(); + DatastoreIO.Source source = DatastoreIO.source().withQuery(query).withDataset("mockDataset"); + + // Use mockResponseForQuery to generate results. + when(mockDatastore.runQuery(any(RunQueryRequest.class))) + .thenAnswer( + new Answer() { + @Override + public RunQueryResponse answer(InvocationOnMock invocation) throws Throwable { + Query q = ((RunQueryRequest) invocation.getArguments()[0]).getQuery(); + return mockResponseForQuery(q); + } + }); + + // Actually instantiate the reader. + DatastoreReader reader = new DatastoreReader(source, mockDatastore); + + // Simply count the number of results returned by the reader. + assertTrue(reader.start()); + int resultCount = 1; + while (reader.advance()) { + resultCount++; + } + reader.close(); + + // Validate the number of results. + assertEquals(numEntities, resultCount); + } + + /** Tests reading with a query limit less than one batch. */ + @Test + public void testReadingWithLimitOneBatch() throws Exception { + runQueryLimitReadTest(5); + } + + /** Tests reading with a query limit more than one batch, and not a multiple. */ + @Test + public void testReadingWithLimitMultipleBatches() throws Exception { + runQueryLimitReadTest(DATASTORE_QUERY_BATCH_LIMIT + 5); + } + + /** Tests reading several batches, using an exact multiple of batch size results. */ + @Test + public void testReadingWithLimitMultipleBatchesExactMultiple() throws Exception { + runQueryLimitReadTest(5 * DATASTORE_QUERY_BATCH_LIMIT); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java new file mode 100644 index 000000000000..da23f3a56229 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java @@ -0,0 +1,512 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriteOperation; +import com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriteOperation.TemporaryFileRetention; +import com.google.cloud.dataflow.sdk.io.FileBasedSink.FileResult; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.FileReader; +import java.io.PrintWriter; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for FileBasedSink. + */ +@RunWith(JUnit4.class) +public class FileBasedSinkTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private String baseOutputFilename = "output"; + private String baseTemporaryFilename = "temp"; + + private String appendToTempFolder(String filename) { + return Paths.get(tmpFolder.getRoot().getPath(), filename).toString(); + } + + private String getBaseOutputFilename() { + return appendToTempFolder(baseOutputFilename); + } + + private String getBaseTempFilename() { + return appendToTempFolder(baseTemporaryFilename); + } + + /** + * FileBasedWriter opens the correct file, writes the header, footer, and elements in the + * correct order, and returns the correct filename. + */ + @Test + public void testWriter() throws Exception { + String testUid = "testId"; + String expectedFilename = + getBaseTempFilename() + FileBasedWriteOperation.TEMPORARY_FILENAME_SEPARATOR + testUid; + SimpleSink.SimpleWriter writer = buildWriter(); + + List values = Arrays.asList("sympathetic vulture", "boresome hummingbird"); + List expected = new ArrayList<>(); + expected.add(SimpleSink.SimpleWriter.HEADER); + expected.addAll(values); + expected.add(SimpleSink.SimpleWriter.FOOTER); + + writer.open(testUid); + for (String value : values) { + writer.write(value); + } + FileResult result = writer.close(); + + assertEquals(expectedFilename, result.getFilename()); + assertFileContains(expected, expectedFilename); + } + + /** + * Assert that a file contains the lines provided, in the same order as expected. + */ + private void assertFileContains(List expected, String filename) throws Exception { + try (BufferedReader reader = new BufferedReader(new FileReader(filename))) { + List actual = new ArrayList<>(); + for (;;) { + String line = reader.readLine(); + if (line == null) { + break; + } + actual.add(line); + } + assertEquals(expected, actual); + } + } + + /** + * Write lines to a file. + */ + private void writeFile(List lines, File file) throws Exception { + try (PrintWriter writer = new PrintWriter(new FileOutputStream(file))) { + for (String line : lines) { + writer.println(line); + } + } + } + + /** + * Removes temporary files when temporary and output filenames differ. + */ + @Test + public void testRemoveWithTempFilename() throws Exception { + testRemoveTemporaryFiles(3, baseTemporaryFilename); + } + + /** + * Removes only temporary files, even if temporary and output files share the same base filename. + */ + @Test + public void testRemoveWithSameFilename() throws Exception { + testRemoveTemporaryFiles(3, baseOutputFilename); + } + + /** + * Finalize copies temporary files to output files and removes any temporary files. + */ + @Test + public void testFinalizeWithNoRetention() throws Exception { + List files = generateTemporaryFilesForFinalize(3); + boolean retainTemporaryFiles = false; + runFinalize(buildWriteOperationForFinalize(retainTemporaryFiles), files, retainTemporaryFiles); + } + + /** + * Finalize retains temporary files when requested. + */ + @Test + public void testFinalizeWithRetention() throws Exception { + List files = generateTemporaryFilesForFinalize(3); + boolean retainTemporaryFiles = true; + runFinalize(buildWriteOperationForFinalize(retainTemporaryFiles), files, retainTemporaryFiles); + } + + /** + * Finalize can be called repeatedly. + */ + @Test + public void testFinalizeMultipleCalls() throws Exception { + List files = generateTemporaryFilesForFinalize(3); + SimpleSink.SimpleWriteOperation writeOp = buildWriteOperationForFinalize(false); + runFinalize(writeOp, files, false); + runFinalize(writeOp, files, false); + } + + /** + * Finalize can be called when some temporary files do not exist and output files exist. + */ + @Test + public void testFinalizeWithIntermediateState() throws Exception { + List files = generateTemporaryFilesForFinalize(3); + SimpleSink.SimpleWriteOperation writeOp = buildWriteOperationForFinalize(false); + runFinalize(writeOp, files, false); + + // create a temporary file + tmpFolder.newFile( + baseTemporaryFilename + FileBasedWriteOperation.TEMPORARY_FILENAME_SEPARATOR + "1"); + + runFinalize(writeOp, files, false); + } + + /** + * Build a SimpleWriteOperation with default values and the specified retention policy. + */ + private SimpleSink.SimpleWriteOperation buildWriteOperationForFinalize( + boolean retainTemporaryFiles) throws Exception { + TemporaryFileRetention retentionPolicy = + retainTemporaryFiles ? TemporaryFileRetention.KEEP : TemporaryFileRetention.REMOVE; + return buildWriteOperation(retentionPolicy); + } + + /** + * Generate n temporary files using the temporary file pattern of FileBasedWriter. + */ + private List generateTemporaryFilesForFinalize(int numFiles) throws Exception { + List temporaryFiles = new ArrayList<>(); + for (int i = 0; i < numFiles; i++) { + String temporaryFilename = + FileBasedWriteOperation.buildTemporaryFilename(baseTemporaryFilename, "" + i); + File tmpFile = tmpFolder.newFile(temporaryFilename); + temporaryFiles.add(tmpFile); + } + + return temporaryFiles; + } + + /** + * Finalize and verify that files are copied and temporary files are optionally removed. + */ + private void runFinalize(SimpleSink.SimpleWriteOperation writeOp, List temporaryFiles, + boolean retainTemporaryFiles) throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + + int numFiles = temporaryFiles.size(); + + List outputFiles = new ArrayList<>(); + List fileResults = new ArrayList<>(); + List outputFilenames = writeOp.generateDestinationFilenames(numFiles); + + // Create temporary output bundles and output File objects + for (int i = 0; i < numFiles; i++) { + fileResults.add(new FileResult(temporaryFiles.get(i).toString())); + outputFiles.add(new File(outputFilenames.get(i))); + } + + writeOp.finalize(fileResults, options); + + for (int i = 0; i < numFiles; i++) { + assertTrue(outputFiles.get(i).exists()); + assertEquals(retainTemporaryFiles, temporaryFiles.get(i).exists()); + } + } + + /** + * Create n temporary and output files and verify that removeTemporaryFiles only + * removes temporary files. + */ + private void testRemoveTemporaryFiles(int numFiles, String baseTemporaryFilename) + throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + SimpleSink.SimpleWriteOperation writeOp = buildWriteOperation(baseTemporaryFilename); + + List temporaryFiles = new ArrayList<>(); + List outputFiles = new ArrayList<>(); + for (int i = 0; i < numFiles; i++) { + File tmpFile = tmpFolder.newFile( + FileBasedWriteOperation.buildTemporaryFilename(baseTemporaryFilename, "" + i)); + temporaryFiles.add(tmpFile); + File outputFile = tmpFolder.newFile(baseOutputFilename + i); + outputFiles.add(outputFile); + } + + writeOp.removeTemporaryFiles(options); + + for (int i = 0; i < numFiles; i++) { + assertFalse(temporaryFiles.get(i).exists()); + assertTrue(outputFiles.get(i).exists()); + } + } + + /** + * Output files are copied to the destination location with the correct names and contents. + */ + @Test + public void testCopyToOutputFiles() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + SimpleSink.SimpleWriteOperation writeOp = buildWriteOperation(); + + List inputFilenames = Arrays.asList("input-3", "input-2", "input-1"); + List inputContents = Arrays.asList("3", "2", "1"); + List expectedOutputFilenames = Arrays.asList( + "output-00002-of-00003.test", "output-00001-of-00003.test", "output-00000-of-00003.test"); + + List inputFilePaths = new ArrayList<>(); + List expectedOutputPaths = new ArrayList<>(); + + for (int i = 0; i < inputFilenames.size(); i++) { + // Generate output paths. + File outputFile = tmpFolder.newFile(expectedOutputFilenames.get(i)); + expectedOutputPaths.add(outputFile.toString()); + + // Generate and write to input paths. + File inputTmpFile = tmpFolder.newFile(inputFilenames.get(i)); + List lines = Arrays.asList(inputContents.get(i)); + writeFile(lines, inputTmpFile); + inputFilePaths.add(inputTmpFile.toString()); + } + + // Copy input files to output files. + List actual = writeOp.copyToOutputFiles(inputFilePaths, options); + + // Assert that the expected paths are returned. + assertThat(expectedOutputPaths, containsInAnyOrder(actual.toArray())); + + // Assert that the contents were copied. + for (int i = 0; i < expectedOutputPaths.size(); i++) { + assertFileContains(Arrays.asList(inputContents.get(i)), expectedOutputPaths.get(i)); + } + } + + /** + * Output filenames use the supplied naming template. + */ + @Test + public void testGenerateOutputFilenamesWithTemplate() { + List expected; + List actual; + SimpleSink sink = new SimpleSink(getBaseOutputFilename(), "test", ".SS.of.NN"); + SimpleSink.SimpleWriteOperation writeOp = new SimpleSink.SimpleWriteOperation(sink); + + expected = Arrays.asList(appendToTempFolder("output.00.of.03.test"), + appendToTempFolder("output.01.of.03.test"), appendToTempFolder("output.02.of.03.test")); + actual = writeOp.generateDestinationFilenames(3); + assertEquals(expected, actual); + + expected = Arrays.asList(appendToTempFolder("output.00.of.01.test")); + actual = writeOp.generateDestinationFilenames(1); + assertEquals(expected, actual); + + expected = new ArrayList<>(); + actual = writeOp.generateDestinationFilenames(0); + assertEquals(expected, actual); + + // Also validate that we handle the case where the user specified "." that we do + // not prefix an additional "." making "..test" + sink = new SimpleSink(getBaseOutputFilename(), ".test", ".SS.of.NN"); + writeOp = new SimpleSink.SimpleWriteOperation(sink); + expected = Arrays.asList(appendToTempFolder("output.00.of.03.test"), + appendToTempFolder("output.01.of.03.test"), appendToTempFolder("output.02.of.03.test")); + actual = writeOp.generateDestinationFilenames(3); + assertEquals(expected, actual); + + expected = Arrays.asList(appendToTempFolder("output.00.of.01.test")); + actual = writeOp.generateDestinationFilenames(1); + assertEquals(expected, actual); + + expected = new ArrayList<>(); + actual = writeOp.generateDestinationFilenames(0); + assertEquals(expected, actual); + } + + /** + * Output filenames are generated correctly when an extension is supplied. + */ + @Test + public void testGenerateOutputFilenamesWithExtension() { + List expected; + List actual; + SimpleSink.SimpleWriteOperation writeOp = buildWriteOperation(); + + expected = Arrays.asList( + appendToTempFolder("output-00000-of-00003.test"), + appendToTempFolder("output-00001-of-00003.test"), + appendToTempFolder("output-00002-of-00003.test")); + actual = writeOp.generateDestinationFilenames(3); + assertEquals(expected, actual); + + expected = Arrays.asList(appendToTempFolder("output-00000-of-00001.test")); + actual = writeOp.generateDestinationFilenames(1); + assertEquals(expected, actual); + + expected = new ArrayList<>(); + actual = writeOp.generateDestinationFilenames(0); + assertEquals(expected, actual); + } + + /** + * Output filenames are generated correctly when an extension is not supplied. + */ + @Test + public void testGenerateOutputFilenamesWithoutExtension() { + List expected; + List actual; + SimpleSink sink = new SimpleSink(appendToTempFolder(baseOutputFilename), ""); + SimpleSink.SimpleWriteOperation writeOp = new SimpleSink.SimpleWriteOperation(sink); + + expected = Arrays.asList(appendToTempFolder("output-00000-of-00003"), + appendToTempFolder("output-00001-of-00003"), appendToTempFolder("output-00002-of-00003")); + actual = writeOp.generateDestinationFilenames(3); + assertEquals(expected, actual); + + expected = Arrays.asList(appendToTempFolder("output-00000-of-00001")); + actual = writeOp.generateDestinationFilenames(1); + assertEquals(expected, actual); + + expected = new ArrayList<>(); + actual = writeOp.generateDestinationFilenames(0); + assertEquals(expected, actual); + } + + /** + * A simple FileBasedSink that writes String values as lines with header and footer lines. + */ + private static final class SimpleSink extends FileBasedSink { + public SimpleSink(String baseOutputFilename, String extension) { + super(baseOutputFilename, extension); + } + + public SimpleSink(String baseOutputFilename, String extension, String fileNamingTemplate) { + super(baseOutputFilename, extension, fileNamingTemplate); + } + + @Override + public SimpleWriteOperation createWriteOperation(PipelineOptions options) { + return new SimpleWriteOperation(this); + } + + private static final class SimpleWriteOperation extends FileBasedWriteOperation { + public SimpleWriteOperation( + SimpleSink sink, String tempOutputFilename, TemporaryFileRetention retentionPolicy) { + super(sink, tempOutputFilename, retentionPolicy); + } + + public SimpleWriteOperation(SimpleSink sink, String tempOutputFilename) { + super(sink, tempOutputFilename); + } + + public SimpleWriteOperation(SimpleSink sink) { + super(sink); + } + + @Override + public SimpleWriter createWriter(PipelineOptions options) throws Exception { + return new SimpleWriter(this); + } + } + + private static final class SimpleWriter extends FileBasedWriter { + static final String HEADER = "header"; + static final String FOOTER = "footer"; + + private WritableByteChannel channel; + + public SimpleWriter(SimpleWriteOperation writeOperation) { + super(writeOperation); + } + + private static ByteBuffer wrap(String value) throws Exception { + return ByteBuffer.wrap((value + "\n").getBytes("UTF-8")); + } + + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + this.channel = channel; + } + + @Override + protected void writeHeader() throws Exception { + channel.write(wrap(HEADER)); + } + + @Override + protected void writeFooter() throws Exception { + channel.write(wrap(FOOTER)); + } + + @Override + public void write(String value) throws Exception { + channel.write(wrap(value)); + } + } + } + + /** + * Build a SimpleSink with default options. + */ + private SimpleSink buildSink() { + return new SimpleSink(getBaseOutputFilename(), "test"); + } + + /** + * Build a SimpleWriteOperation with default options and the given file retention policy. + */ + private SimpleSink.SimpleWriteOperation buildWriteOperation( + TemporaryFileRetention fileRetention) { + SimpleSink sink = buildSink(); + return new SimpleSink.SimpleWriteOperation(sink, getBaseTempFilename(), fileRetention); + } + + /** + * Build a SimpleWriteOperation with default options and the given base temporary filename. + */ + private SimpleSink.SimpleWriteOperation buildWriteOperation(String baseTemporaryFilename) { + SimpleSink sink = buildSink(); + return new SimpleSink.SimpleWriteOperation(sink, appendToTempFolder(baseTemporaryFilename)); + } + + /** + * Build a write operation with the default options for it and its parent sink. + */ + private SimpleSink.SimpleWriteOperation buildWriteOperation() { + SimpleSink sink = buildSink(); + return new SimpleSink.SimpleWriteOperation( + sink, getBaseTempFilename(), TemporaryFileRetention.REMOVE); + } + + /** + * Build a writer with the default options for its parent write operation and sink. + */ + private SimpleSink.SimpleWriter buildWriter() { + SimpleSink.SimpleWriteOperation writeOp = buildWriteOperation(TemporaryFileRetention.REMOVE); + return new SimpleSink.SimpleWriter(writeOp); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSourceTest.java new file mode 100644 index 000000000000..7cf4398393d9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSourceTest.java @@ -0,0 +1,914 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionFails; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.readFromSource; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.FileBasedSource.FileBasedReader; +import com.google.cloud.dataflow.sdk.io.Source.Reader; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.IOChannelFactory; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Random; + +/** + * Tests code common to all file-based sources. + */ +@RunWith(JUnit4.class) +public class FileBasedSourceTest { + + Random random = new Random(0L); + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + /** + * If {@code splitHeader} is null, this is just a simple line-based reader. Otherwise, the file is + * considered to consist of blocks beginning with {@code splitHeader}. The header itself is not + * returned as a record. The first record after the header is considered to be a split point. + * + *

    E.g., if {@code splitHeader} is "h" and the lines of the file are: h, a, b, h, h, c, then + * the records in this source are a,b,c, and records a and c are split points. + */ + static class TestFileBasedSource extends FileBasedSource { + + final String splitHeader; + + public TestFileBasedSource(String fileOrPattern, long minBundleSize, String splitHeader) { + super(fileOrPattern, minBundleSize); + this.splitHeader = splitHeader; + } + + public TestFileBasedSource( + String fileOrPattern, + long minBundleSize, + long startOffset, + long endOffset, + String splitHeader) { + super(fileOrPattern, minBundleSize, startOffset, endOffset); + this.splitHeader = splitHeader; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return StringUtf8Coder.of(); + } + + @Override + protected FileBasedSource createForSubrangeOfFile( + String fileName, long start, long end) { + return new TestFileBasedSource(fileName, getMinBundleSize(), start, end, splitHeader); + } + + @Override + protected FileBasedReader createSingleFileReader(PipelineOptions options) { + if (splitHeader == null) { + return new TestReader(this); + } else { + return new TestReaderWithSplits(this); + } + } + } + + /** + * A utility class that starts reading lines from a given offset in a file until EOF. + */ + private static class LineReader { + private ReadableByteChannel channel = null; + private long nextLineStart = 0; + private long currentLineStart = 0; + private final ByteBuffer buf; + private static final int BUF_SIZE = 1024; + private String currentValue = null; + + public LineReader(ReadableByteChannel channel) throws IOException { + buf = ByteBuffer.allocate(BUF_SIZE); + buf.flip(); + + boolean removeLine = false; + // If we are not at the beginning of a line, we should ignore the current line. + if (channel instanceof SeekableByteChannel) { + SeekableByteChannel seekChannel = (SeekableByteChannel) channel; + if (seekChannel.position() > 0) { + // Start from one character back and read till we find a new line. + seekChannel.position(seekChannel.position() - 1); + removeLine = true; + } + nextLineStart = seekChannel.position(); + } + this.channel = channel; + if (removeLine) { + nextLineStart += readNextLine(new ByteArrayOutputStream()); + } + } + + private int readNextLine(ByteArrayOutputStream out) throws IOException { + int byteCount = 0; + while (true) { + if (!buf.hasRemaining()) { + buf.clear(); + int read = channel.read(buf); + if (read < 0) { + break; + } + buf.flip(); + } + byte b = buf.get(); + byteCount++; + if (b == '\n') { + break; + } + out.write(b); + } + return byteCount; + } + + public boolean readNextLine() throws IOException { + currentLineStart = nextLineStart; + + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + int offsetAdjustment = readNextLine(buf); + if (offsetAdjustment == 0) { + // EOF + return false; + } + nextLineStart += offsetAdjustment; + // When running on Windows, each line obtained from 'readNextLine()' will end with a '\r' + // since we use '\n' as the line boundary of the reader. So we trim it off here. + currentValue = CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), buf.toByteArray()).trim(); + return true; + } + + public String getCurrent() { + return currentValue; + } + + public long getCurrentLineStart() { + return currentLineStart; + } + } + + /** + * A reader that can read lines of text from a {@link TestFileBasedSource}. This reader does not + * consider {@code splitHeader} defined by {@code TestFileBasedSource} hence every line can be the + * first line of a split. + */ + private static class TestReader extends FileBasedReader { + private LineReader lineReader = null; + + public TestReader(TestFileBasedSource source) { + super(source); + } + + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + this.lineReader = new LineReader(channel); + } + + @Override + protected boolean readNextRecord() throws IOException { + return lineReader.readNextLine(); + } + + @Override + protected boolean isAtSplitPoint() { + return true; + } + + @Override + protected long getCurrentOffset() { + return lineReader.getCurrentLineStart(); + } + + @Override + public String getCurrent() throws NoSuchElementException { + return lineReader.getCurrent(); + } + } + + /** + * A reader that can read lines of text from a {@link TestFileBasedSource}. This reader considers + * {@code splitHeader} defined by {@code TestFileBasedSource} hence only lines that immediately + * follow a {@code splitHeader} are split points. + */ + private static class TestReaderWithSplits extends FileBasedReader { + private LineReader lineReader; + private final String splitHeader; + private boolean foundFirstSplitPoint = false; + private boolean isAtSplitPoint = false; + private long currentOffset; + + public TestReaderWithSplits(TestFileBasedSource source) { + super(source); + this.splitHeader = source.splitHeader; + } + + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + this.lineReader = new LineReader(channel); + } + + @Override + protected boolean readNextRecord() throws IOException { + if (!foundFirstSplitPoint) { + while (!isAtSplitPoint) { + if (!readNextRecordInternal()) { + return false; + } + } + foundFirstSplitPoint = true; + return true; + } + return readNextRecordInternal(); + } + + private boolean readNextRecordInternal() throws IOException { + isAtSplitPoint = false; + if (!lineReader.readNextLine()) { + return false; + } + currentOffset = lineReader.getCurrentLineStart(); + while (getCurrent().equals(splitHeader)) { + currentOffset = lineReader.getCurrentLineStart(); + if (!lineReader.readNextLine()) { + return false; + } + isAtSplitPoint = true; + } + return true; + } + + @Override + protected boolean isAtSplitPoint() { + return isAtSplitPoint; + } + + @Override + protected long getCurrentOffset() { + return currentOffset; + } + + @Override + public String getCurrent() throws NoSuchElementException { + return lineReader.getCurrent(); + } + } + + public File createFileWithData(String fileName, List data) throws IOException { + File file = tempFolder.newFile(fileName); + Files.write(file.toPath(), data, StandardCharsets.UTF_8); + return file; + } + + private String createRandomString(int length) { + char[] chars = "abcdefghijklmnopqrstuvwxyz".toCharArray(); + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < length; i++) { + builder.append(chars[random.nextInt(chars.length)]); + } + return builder.toString(); + } + + public List createStringDataset(int dataItemLength, int numItems) { + List list = new ArrayList(); + for (int i = 0; i < numItems; i++) { + list.add(createRandomString(dataItemLength)); + } + return list; + } + + @Test + public void testFullyReadSingleFile() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List data = createStringDataset(3, 50); + + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 64, null); + assertEquals(data, readFromSource(source, options)); + } + + @Test + public void testFullyReadFilePattern() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List data1 = createStringDataset(3, 50); + File file1 = createFileWithData("file1", data1); + + List data2 = createStringDataset(3, 50); + createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 50); + createFileWithData("file3", data3); + + List data4 = createStringDataset(3, 50); + createFileWithData("otherfile", data4); + + TestFileBasedSource source = + new TestFileBasedSource(new File(file1.getParent(), "file*").getPath(), 64, null); + List expectedResults = new ArrayList(); + expectedResults.addAll(data1); + expectedResults.addAll(data2); + expectedResults.addAll(data3); + assertThat(expectedResults, containsInAnyOrder(readFromSource(source, options).toArray())); + } + + @Test + public void testCloseUnstartedFilePatternReader() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List data1 = createStringDataset(3, 50); + File file1 = createFileWithData("file1", data1); + + List data2 = createStringDataset(3, 50); + createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 50); + createFileWithData("file3", data3); + + List data4 = createStringDataset(3, 50); + createFileWithData("otherfile", data4); + + TestFileBasedSource source = + new TestFileBasedSource(new File(file1.getParent(), "file*").getPath(), 64, null); + Reader reader = source.createReader(options); + // Closing an unstarted FilePatternReader should not throw an exception. + try { + reader.close(); + } catch (Exception e) { + fail("Closing an unstarted FilePatternReader should not throw an exception"); + } + } + + @Test + public void testSplittingUsingFullThreadPool() throws Exception { + int numFiles = FileBasedSource.THREAD_POOL_SIZE * 5; + File file0 = null; + for (int i = 0; i < numFiles; i++) { + List data = createStringDataset(3, 1000); + File file = createFileWithData("file" + i, data); + if (i == 0) { + file0 = file; + } + } + + TestFileBasedSource source = + new TestFileBasedSource(file0.getParent() + "/" + "file*", Long.MAX_VALUE, null); + List> splits = source.splitIntoBundles(Long.MAX_VALUE, null); + assertEquals(numFiles, splits.size()); + } + + @Test + public void testFractionConsumedWhenReadingFilepattern() throws IOException { + List data1 = createStringDataset(3, 1000); + File file1 = createFileWithData("file1", data1); + + List data2 = createStringDataset(3, 1000); + createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 1000); + createFileWithData("file3", data3); + + TestFileBasedSource source = + new TestFileBasedSource(file1.getParent() + "/" + "file*", 1024, null); + try (BoundedSource.BoundedReader reader = source.createReader(null)) { + double lastFractionConsumed = 0.0; + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertTrue(reader.start()); + assertTrue(reader.advance()); + assertTrue(reader.advance()); + // We're inside the first file. Should be in [0, 1/3). + assertTrue(reader.getFractionConsumed() > 0.0); + assertTrue(reader.getFractionConsumed() < 1.0 / 3.0); + while (reader.advance()) { + double fractionConsumed = reader.getFractionConsumed(); + assertTrue(fractionConsumed > lastFractionConsumed); + lastFractionConsumed = fractionConsumed; + } + assertTrue(reader.getFractionConsumed() < 1.0); + } + } + + @Test + public void testFullyReadFilePatternFirstRecordEmpty() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + File file1 = createFileWithData("file1", new ArrayList()); + + IOChannelFactory mockIOFactory = Mockito.mock(IOChannelFactory.class); + String parent = file1.getParent(); + String pattern = "mocked://test"; + when(mockIOFactory.match(pattern)) + .thenReturn( + ImmutableList.of( + new File(parent, "file1").getPath(), + new File(parent, "file2").getPath(), + new File(parent, "file3").getPath())); + IOChannelUtils.setIOFactory("mocked", mockIOFactory); + + List data2 = createStringDataset(3, 50); + createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 50); + createFileWithData("file3", data3); + + List data4 = createStringDataset(3, 50); + createFileWithData("otherfile", data4); + + TestFileBasedSource source = new TestFileBasedSource(pattern, 64, null); + + List expectedResults = new ArrayList(); + expectedResults.addAll(data2); + expectedResults.addAll(data3); + assertThat(expectedResults, containsInAnyOrder(readFromSource(source, options).toArray())); + } + + @Test + public void testReadRangeAtStart() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List data = createStringDataset(3, 50); + + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source1 = new TestFileBasedSource(file.getPath(), 64, 0, 25, null); + TestFileBasedSource source2 = + new TestFileBasedSource(file.getPath(), 64, 25, Long.MAX_VALUE, null); + + List results = new ArrayList(); + results.addAll(readFromSource(source1, options)); + results.addAll(readFromSource(source2, options)); + assertThat(data, containsInAnyOrder(results.toArray())); + } + + @Test + public void testReadEverythingFromFileWithSplits() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + String header = ""; + List data = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + data.add(header); + data.addAll(createStringDataset(3, 9)); + } + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 64, header); + + List expectedResults = new ArrayList(); + expectedResults.addAll(data); + // Remove all occurrences of header from expected results. + expectedResults.removeAll(Arrays.asList(header)); + + assertEquals(expectedResults, readFromSource(source, options)); + } + + @Test + public void testReadRangeFromFileWithSplitsFromStart() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + String header = ""; + List data = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + data.add(header); + data.addAll(createStringDataset(3, 9)); + } + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source1 = new TestFileBasedSource(file.getPath(), 64, 0, 60, header); + TestFileBasedSource source2 = + new TestFileBasedSource(file.getPath(), 64, 60, Long.MAX_VALUE, header); + + List expectedResults = new ArrayList(); + expectedResults.addAll(data); + // Remove all occurrences of header from expected results. + expectedResults.removeAll(Arrays.asList(header)); + + List results = new ArrayList<>(); + results.addAll(readFromSource(source1, options)); + results.addAll(readFromSource(source2, options)); + + assertThat(expectedResults, containsInAnyOrder(results.toArray())); + } + + @Test + public void testReadRangeFromFileWithSplitsFromMiddle() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + String header = ""; + List data = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + data.add(header); + data.addAll(createStringDataset(3, 9)); + } + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source1 = new TestFileBasedSource(file.getPath(), 64, 0, 42, header); + TestFileBasedSource source2 = new TestFileBasedSource(file.getPath(), 64, 42, 112, header); + TestFileBasedSource source3 = + new TestFileBasedSource(file.getPath(), 64, 112, Long.MAX_VALUE, header); + + List expectedResults = new ArrayList(); + + expectedResults.addAll(data); + // Remove all occurrences of header from expected results. + expectedResults.removeAll(Arrays.asList(header)); + + List results = new ArrayList<>(); + results.addAll(readFromSource(source1, options)); + results.addAll(readFromSource(source2, options)); + results.addAll(readFromSource(source3, options)); + + assertThat(expectedResults, containsInAnyOrder(results.toArray())); + } + + @Test + public void testReadFileWithSplitsWithEmptyRange() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + String header = ""; + List data = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + data.add(header); + data.addAll(createStringDataset(3, 9)); + } + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source1 = new TestFileBasedSource(file.getPath(), 64, 0, 42, header); + TestFileBasedSource source2 = new TestFileBasedSource(file.getPath(), 64, 42, 62, header); + TestFileBasedSource source3 = + new TestFileBasedSource(file.getPath(), 64, 62, Long.MAX_VALUE, header); + + List expectedResults = new ArrayList(); + + expectedResults.addAll(data); + // Remove all occurrences of header from expected results. + expectedResults.removeAll(Arrays.asList(header)); + + List results = new ArrayList<>(); + results.addAll(readFromSource(source1, options)); + results.addAll(readFromSource(source2, options)); + results.addAll(readFromSource(source3, options)); + + assertThat(expectedResults, containsInAnyOrder(results.toArray())); + } + + @Test + public void testReadRangeFromFileWithSplitsFromMiddleOfHeader() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + String header = ""; + List data = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + data.add(header); + data.addAll(createStringDataset(3, 9)); + } + String fileName = "file"; + File file = createFileWithData(fileName, data); + + List expectedResults = new ArrayList(); + expectedResults.addAll(data.subList(10, data.size())); + // Remove all occurrences of header from expected results. + expectedResults.removeAll(Arrays.asList(header)); + + // Split starts after "<" of the header + TestFileBasedSource source = + new TestFileBasedSource(file.getPath(), 64, 1, Long.MAX_VALUE, header); + assertThat(expectedResults, containsInAnyOrder(readFromSource(source, options).toArray())); + + // Split starts after "" of the header + source = new TestFileBasedSource(file.getPath(), 64, 3, Long.MAX_VALUE, header); + assertThat(expectedResults, containsInAnyOrder(readFromSource(source, options).toArray())); + } + + @Test + public void testReadRangeAtMiddle() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List data = createStringDataset(3, 50); + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source1 = new TestFileBasedSource(file.getPath(), 64, 0, 52, null); + TestFileBasedSource source2 = new TestFileBasedSource(file.getPath(), 64, 52, 72, null); + TestFileBasedSource source3 = + new TestFileBasedSource(file.getPath(), 64, 72, Long.MAX_VALUE, null); + + List results = new ArrayList<>(); + results.addAll(readFromSource(source1, options)); + results.addAll(readFromSource(source2, options)); + results.addAll(readFromSource(source3, options)); + + assertThat(data, containsInAnyOrder(results.toArray())); + } + + @Test + public void testReadRangeAtEnd() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + List data = createStringDataset(3, 50); + + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source1 = new TestFileBasedSource(file.getPath(), 64, 0, 162, null); + TestFileBasedSource source2 = + new TestFileBasedSource(file.getPath(), 1024, 162, Long.MAX_VALUE, null); + + List results = new ArrayList<>(); + results.addAll(readFromSource(source1, options)); + results.addAll(readFromSource(source2, options)); + + assertThat(data, containsInAnyOrder(results.toArray())); + } + + @Test + public void testReadAllSplitsOfSingleFile() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + List data = createStringDataset(3, 50); + + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 16, null); + + List> sources = source.splitIntoBundles(32, null); + + // Not a trivial split. + assertTrue(sources.size() > 1); + + List results = new ArrayList(); + for (BoundedSource split : sources) { + results.addAll(readFromSource(split, options)); + } + + assertThat(data, containsInAnyOrder(results.toArray())); + } + + @Test + public void testDataflowFile() throws IOException { + Pipeline p = TestPipeline.create(); + List data = createStringDataset(3, 50); + + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 64, null); + PCollection output = p.apply(Read.from(source).named("ReadFileData")); + + DataflowAssert.that(output).containsInAnyOrder(data); + p.run(); + } + + @Test + public void testDataflowFilePattern() throws IOException { + Pipeline p = TestPipeline.create(); + + List data1 = createStringDataset(3, 50); + File file1 = createFileWithData("file1", data1); + + List data2 = createStringDataset(3, 50); + createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 50); + createFileWithData("file3", data3); + + List data4 = createStringDataset(3, 50); + createFileWithData("otherfile", data4); + + TestFileBasedSource source = + new TestFileBasedSource(new File(file1.getParent(), "file*").getPath(), 64, null); + + PCollection output = p.apply(Read.from(source).named("ReadFileData")); + + List expectedResults = new ArrayList(); + expectedResults.addAll(data1); + expectedResults.addAll(data2); + expectedResults.addAll(data3); + + DataflowAssert.that(output).containsInAnyOrder(expectedResults); + p.run(); + } + + @Test + public void testEstimatedSizeOfFile() throws Exception { + List data = createStringDataset(3, 50); + String fileName = "file"; + File file = createFileWithData(fileName, data); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 64, null); + assertEquals(file.length(), source.getEstimatedSizeBytes(null)); + } + + @Test + public void testEstimatedSizeOfFilePattern() throws Exception { + List data1 = createStringDataset(3, 20); + File file1 = createFileWithData("file1", data1); + + List data2 = createStringDataset(3, 40); + File file2 = createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 30); + File file3 = createFileWithData("file3", data3); + + List data4 = createStringDataset(3, 45); + createFileWithData("otherfile", data4); + + List data5 = createStringDataset(3, 53); + createFileWithData("anotherfile", data5); + + TestFileBasedSource source = + new TestFileBasedSource(new File(file1.getParent(), "file*").getPath(), 64, null); + + // Estimated size of the file pattern based source should be the total size of files that the + // corresponding pattern is expanded into. + assertEquals( + file1.length() + file2.length() + file3.length(), source.getEstimatedSizeBytes(null)); + } + + @Test + public void testEstimatedSizeOfFilePatternAllThreads() throws Exception { + File file0 = null; + int numFiles = FileBasedSource.THREAD_POOL_SIZE * 5; + long totalSize = 0; + for (int i = 0; i < numFiles; i++) { + List data = createStringDataset(3, 20); + File file = createFileWithData("file" + i, data); + if (i == 0) { + file0 = file; + } + totalSize += file.length(); + } + + TestFileBasedSource source = + new TestFileBasedSource(new File(file0.getParent(), "file*").getPath(), 64, null); + + // Since all files are of equal size, sampling should produce the exact result. + assertEquals(totalSize, source.getEstimatedSizeBytes(null)); + } + + @Test + public void testEstimatedSizeOfFilePatternThroughSamplingEqualSize() throws Exception { + // When all files are of equal size, we should get the exact size. + int numFilesToTest = FileBasedSource.MAX_NUMBER_OF_FILES_FOR_AN_EXACT_STAT * 2; + File file0 = null; + for (int i = 0; i < numFilesToTest; i++) { + List data = createStringDataset(3, 20); + File file = createFileWithData("file" + i, data); + if (i == 0) { + file0 = file; + } + } + + long actualTotalSize = file0.length() * numFilesToTest; + TestFileBasedSource source = + new TestFileBasedSource(new File(file0.getParent(), "file*").getPath(), 64, null); + assertEquals(actualTotalSize, source.getEstimatedSizeBytes(null)); + } + + @Test + public void testEstimatedSizeOfFilePatternThroughSamplingDifferentSizes() throws Exception { + float tolerableError = 0.2f; + int numFilesToTest = FileBasedSource.MAX_NUMBER_OF_FILES_FOR_AN_EXACT_STAT * 2; + File file0 = null; + + // Keeping sizes of files close to each other to make sure that the test passes reliably. + Random rand = new Random(System.currentTimeMillis()); + int dataSizeBase = 100; + int dataSizeDelta = 10; + + long actualTotalSize = 0; + for (int i = 0; i < numFilesToTest; i++) { + List data = createStringDataset( + 3, (int) (dataSizeBase + rand.nextFloat() * dataSizeDelta * 2 - dataSizeDelta)); + File file = createFileWithData("file" + i, data); + if (i == 0) { + file0 = file; + } + actualTotalSize += file.length(); + } + + TestFileBasedSource source = + new TestFileBasedSource(new File(file0.getParent(), "file*").getPath(), 64, null); + assertEquals((double) actualTotalSize, (double) source.getEstimatedSizeBytes(null), + actualTotalSize * tolerableError); + } + + @Test + public void testReadAllSplitsOfFilePattern() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + List data1 = createStringDataset(3, 50); + File file1 = createFileWithData("file1", data1); + + List data2 = createStringDataset(3, 50); + createFileWithData("file2", data2); + + List data3 = createStringDataset(3, 50); + createFileWithData("file3", data3); + + List data4 = createStringDataset(3, 50); + createFileWithData("otherfile", data4); + + TestFileBasedSource source = + new TestFileBasedSource(new File(file1.getParent(), "file*").getPath(), 64, null); + List> sources = source.splitIntoBundles(512, null); + + // Not a trivial split. + assertTrue(sources.size() > 1); + + List results = new ArrayList(); + for (BoundedSource split : sources) { + results.addAll(readFromSource(split, options)); + } + + List expectedResults = new ArrayList(); + expectedResults.addAll(data1); + expectedResults.addAll(data2); + expectedResults.addAll(data3); + + assertThat(expectedResults, containsInAnyOrder(results.toArray())); + } + + @Test + public void testSplitAtFraction() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + File file = createFileWithData("file", createStringDataset(3, 100)); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 1, 0, file.length(), null); + // Shouldn't be able to split while unstarted. + assertSplitAtFractionFails(source, 0, 0.7, options); + assertSplitAtFractionSucceedsAndConsistent(source, 1, 0.7, options); + assertSplitAtFractionSucceedsAndConsistent(source, 30, 0.7, options); + assertSplitAtFractionFails(source, 0, 0.0, options); + assertSplitAtFractionFails(source, 70, 0.3, options); + assertSplitAtFractionFails(source, 100, 1.0, options); + assertSplitAtFractionFails(source, 100, 0.99, options); + assertSplitAtFractionSucceedsAndConsistent(source, 100, 0.995, options); + } + + @Test + public void testSplitAtFractionExhaustive() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + // Smaller file for exhaustive testing. + File file = createFileWithData("file", createStringDataset(3, 20)); + + TestFileBasedSource source = new TestFileBasedSource(file.getPath(), 1, 0, file.length(), null); + assertSplitAtFractionExhaustive(source, options); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/OffsetBasedSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/OffsetBasedSourceTest.java new file mode 100644 index 000000000000..bdbd00ecb3fa --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/OffsetBasedSourceTest.java @@ -0,0 +1,278 @@ +/* + * Copyright (C) 2014 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.readFromSource; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Tests code common to all offset-based sources. + */ +@RunWith(JUnit4.class) +public class OffsetBasedSourceTest { + // An offset-based source with 4 bytes per offset that yields its own current offset + // and rounds the start and end offset to the nearest multiple of a given number, + // e.g. reading [13, 48) with granularity 10 gives records with values [20, 50). + private static class CoarseRangeSource extends OffsetBasedSource { + private long granularity; + + public CoarseRangeSource( + long startOffset, long endOffset, long minBundleSize, long granularity) { + super(startOffset, endOffset, minBundleSize); + this.granularity = granularity; + } + + @Override + public OffsetBasedSource createSourceForSubrange(long start, long end) { + return new CoarseRangeSource(start, end, getMinBundleSize(), granularity); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return BigEndianIntegerCoder.of(); + } + + @Override + public long getBytesPerOffset() { + return 4; + } + + @Override + public long getMaxEndOffset(PipelineOptions options) { + return getEndOffset(); + } + + @Override + public BoundedReader createReader(PipelineOptions options) throws IOException { + return new CoarseRangeReader(this); + } + } + + private static class CoarseRangeReader + extends OffsetBasedSource.OffsetBasedReader { + private long current = -1; + private long granularity; + + public CoarseRangeReader(CoarseRangeSource source) { + super(source); + this.granularity = source.granularity; + } + + @Override + protected long getCurrentOffset() { + return current; + } + + @Override + public boolean startImpl() throws IOException { + current = getCurrentSource().getStartOffset(); + while (current % granularity != 0) { + ++current; + } + return true; + } + + @Override + public boolean advanceImpl() throws IOException { + ++current; + return true; + } + + @Override + public Integer getCurrent() throws NoSuchElementException { + return (int) current; + } + + @Override + public boolean isAtSplitPoint() { + return current % granularity == 0; + } + + @Override + public void close() throws IOException { } + } + + public static void assertSplitsAre(List> splits, + long[] expectedBoundaries) { + assertEquals(splits.size(), expectedBoundaries.length - 1); + int i = 0; + for (OffsetBasedSource split : splits) { + assertEquals(split.getStartOffset(), expectedBoundaries[i]); + assertEquals(split.getEndOffset(), expectedBoundaries[i + 1]); + i++; + } + } + + @Test + public void testSplitPositionsZeroStart() throws Exception { + long start = 0; + long end = 1000; + long minBundleSize = 50; + CoarseRangeSource testSource = new CoarseRangeSource(start, end, minBundleSize, 1); + long[] boundaries = {0, 150, 300, 450, 600, 750, 900, 1000}; + assertSplitsAre( + testSource.splitIntoBundles(150 * testSource.getBytesPerOffset(), null), + boundaries); + } + + @Test + public void testSplitPositionsNonZeroStart() throws Exception { + long start = 300; + long end = 1000; + long minBundleSize = 50; + CoarseRangeSource testSource = new CoarseRangeSource(start, end, minBundleSize, 1); + long[] boundaries = {300, 450, 600, 750, 900, 1000}; + assertSplitsAre( + testSource.splitIntoBundles(150 * testSource.getBytesPerOffset(), null), + boundaries); + } + + @Test + public void testEstimatedSizeBytes() throws Exception { + long start = 300; + long end = 1000; + long minBundleSize = 150; + CoarseRangeSource testSource = new CoarseRangeSource(start, end, minBundleSize, 1); + PipelineOptions options = PipelineOptionsFactory.create(); + assertEquals( + (end - start) * testSource.getBytesPerOffset(), testSource.getEstimatedSizeBytes(options)); + } + + @Test + public void testMinBundleSize() throws Exception { + long start = 300; + long end = 1000; + long minBundleSize = 150; + CoarseRangeSource testSource = new CoarseRangeSource(start, end, minBundleSize, 1); + long[] boundaries = {300, 450, 600, 750, 1000}; + assertSplitsAre( + testSource.splitIntoBundles(100 * testSource.getBytesPerOffset(), null), + boundaries); + } + + @Test + public void testSplitPositionsCollapseEndBundle() throws Exception { + long start = 0; + long end = 1000; + long minBundleSize = 50; + CoarseRangeSource testSource = new CoarseRangeSource(start, end, minBundleSize, 1); + // Last 10 bytes should collapse to the previous bundle. + long[] boundaries = {0, 110, 220, 330, 440, 550, 660, 770, 880, 1000}; + assertSplitsAre( + testSource.splitIntoBundles(110 * testSource.getBytesPerOffset(), null), + boundaries); + } + + @Test + public void testReadingGranularityAndFractionConsumed() throws IOException { + // Tests that the reader correctly snaps to multiples of the given granularity + // (note: this is testing test code), and that getFractionConsumed works sensibly + // in the face of that. + PipelineOptions options = PipelineOptionsFactory.create(); + CoarseRangeSource source = new CoarseRangeSource(13, 35, 1, 10); + try (BoundedSource.BoundedReader reader = source.createReader(options)) { + List items = new ArrayList<>(); + + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertTrue(reader.start()); + do { + Double fraction = reader.getFractionConsumed(); + assertNotNull(fraction); + assertTrue(fraction.toString(), fraction > 0.0); + assertTrue(fraction.toString(), fraction <= 1.0); + items.add(reader.getCurrent()); + } while (reader.advance()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + + assertEquals(20, items.size()); + assertEquals(20, items.get(0).intValue()); + assertEquals(39, items.get(items.size() - 1).intValue()); + + source = new CoarseRangeSource(13, 17, 1, 10); + } + try (BoundedSource.BoundedReader reader = source.createReader(options)) { + assertFalse(reader.start()); + } + } + + @Test + public void testSplitAtFraction() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + CoarseRangeSource source = new CoarseRangeSource(13, 35, 1, 10); + try (CoarseRangeReader reader = (CoarseRangeReader) source.createReader(options)) { + List originalItems = new ArrayList<>(); + assertTrue(reader.start()); + originalItems.add(reader.getCurrent()); + assertTrue(reader.advance()); + originalItems.add(reader.getCurrent()); + assertTrue(reader.advance()); + originalItems.add(reader.getCurrent()); + assertTrue(reader.advance()); + originalItems.add(reader.getCurrent()); + assertNull(reader.splitAtFraction(0.0)); + assertNull(reader.splitAtFraction(reader.getFractionConsumed() - 0.1)); + + BoundedSource residual = reader.splitAtFraction(reader.getFractionConsumed() + 0.1); + BoundedSource primary = reader.getCurrentSource(); + List primaryItems = readFromSource(primary, options); + List residualItems = readFromSource(residual, options); + for (Integer item : residualItems) { + assertTrue(item > reader.getCurrentOffset()); + } + assertFalse(primaryItems.isEmpty()); + assertFalse(residualItems.isEmpty()); + assertTrue(primaryItems.get(primaryItems.size() - 1) <= residualItems.get(0)); + + while (reader.advance()) { + originalItems.add(reader.getCurrent()); + } + assertEquals(originalItems, primaryItems); + } + } + + @Test + public void testSplitAtFractionExhaustive() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + CoarseRangeSource original = new CoarseRangeSource(13, 35, 1, 10); + assertSplitAtFractionExhaustive(original, options); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java new file mode 100644 index 000000000000..8e7ad29bbe24 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.junit.Assert.assertEquals; + +import com.google.api.client.testing.http.FixedClock; +import com.google.api.client.util.Clock; +import com.google.api.services.pubsub.model.PubsubMessage; + +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.HashMap; + +import javax.annotation.Nullable; + +/** + * Tests for PubsubIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +public class PubsubIOTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testPubsubIOGetName() { + assertEquals("PubsubIO.Read", + PubsubIO.Read.topic("projects/myproject/topics/mytopic").getName()); + assertEquals("PubsubIO.Write", + PubsubIO.Write.topic("projects/myproject/topics/mytopic").getName()); + assertEquals("ReadMyTopic", + PubsubIO.Read.named("ReadMyTopic").topic("projects/myproject/topics/mytopic").getName()); + assertEquals("WriteMyTopic", + PubsubIO.Write.named("WriteMyTopic").topic("projects/myproject/topics/mytopic").getName()); + } + + @Test + public void testTopicValidationSuccess() throws Exception { + PubsubIO.Read.topic("projects/my-project/topics/abc"); + PubsubIO.Read.topic("projects/my-project/topics/ABC"); + PubsubIO.Read.topic("projects/my-project/topics/AbC-DeF"); + PubsubIO.Read.topic("projects/my-project/topics/AbC-1234"); + PubsubIO.Read.topic("projects/my-project/topics/AbC-1234-_.~%+-_.~%+-_.~%+-abc"); + PubsubIO.Read.topic(new StringBuilder().append("projects/my-project/topics/A-really-long-one-") + .append("111111111111111111111111111111111111111111111111111111111111111111111111111111111") + .append("111111111111111111111111111111111111111111111111111111111111111111111111111111111") + .append("11111111111111111111111111111111111111111111111111111111111111111111111111") + .toString()); + } + + @Test + public void testTopicValidationBadCharacter() throws Exception { + thrown.expect(IllegalArgumentException.class); + PubsubIO.Read.topic("projects/my-project/topics/abc-*-abc"); + } + + @Test + public void testTopicValidationTooLong() throws Exception { + thrown.expect(IllegalArgumentException.class); + PubsubIO.Read.topic(new StringBuilder().append("projects/my-project/topics/A-really-long-one-") + .append("111111111111111111111111111111111111111111111111111111111111111111111111111111111") + .append("111111111111111111111111111111111111111111111111111111111111111111111111111111111") + .append("1111111111111111111111111111111111111111111111111111111111111111111111111111") + .toString()); + } + + /** + * Helper function that creates a {@link PubsubMessage} with the given timestamp registered as + * an attribute with the specified label. + * + *

    If {@code label} is {@code null}, then the attributes are {@code null}. + * + *

    Else, if {@code timestamp} is {@code null}, then attributes are present but have no key for + * the label. + */ + private static PubsubMessage messageWithTimestamp( + @Nullable String label, @Nullable String timestamp) { + PubsubMessage message = new PubsubMessage(); + if (label == null) { + message.setAttributes(null); + return message; + } + + message.setAttributes(new HashMap()); + + if (timestamp == null) { + return message; + } + + message.getAttributes().put(label, timestamp); + return message; + } + + /** + * Helper function that parses the given string to a timestamp through the PubSubIO plumbing. + */ + private static Instant parseTimestamp(@Nullable String timestamp) { + PubsubMessage message = messageWithTimestamp("mylabel", timestamp); + return PubsubIO.assignMessageTimestamp(message, "mylabel", Clock.SYSTEM); + } + + @Test + public void noTimestampLabelReturnsNow() { + final long time = 987654321L; + Instant timestamp = PubsubIO.assignMessageTimestamp( + messageWithTimestamp(null, null), null, new FixedClock(time)); + + assertEquals(new Instant(time), timestamp); + } + + @Test + public void timestampLabelWithNullAttributesThrowsError() { + PubsubMessage message = messageWithTimestamp(null, null); + thrown.expect(RuntimeException.class); + thrown.expectMessage("PubSub message is missing a timestamp in label: myLabel"); + + PubsubIO.assignMessageTimestamp(message, "myLabel", Clock.SYSTEM); + } + + @Test + public void timestampLabelSetWithMissingAttributeThrowsError() { + PubsubMessage message = messageWithTimestamp("notMyLabel", "ignored"); + thrown.expect(RuntimeException.class); + thrown.expectMessage("PubSub message is missing a timestamp in label: myLabel"); + + PubsubIO.assignMessageTimestamp(message, "myLabel", Clock.SYSTEM); + } + + @Test + public void timestampLabelParsesMillisecondsSinceEpoch() { + Long millis = 1446162101123L; + assertEquals(new Instant(millis), parseTimestamp(millis.toString())); + } + + @Test + public void timestampLabelParsesRfc3339Seconds() { + String rfc3339 = "2015-10-29T23:41:41Z"; + assertEquals(Instant.parse(rfc3339), parseTimestamp(rfc3339)); + } + + @Test + public void timestampLabelParsesRfc3339Tenths() { + String rfc3339tenths = "2015-10-29T23:41:41.1Z"; + assertEquals(Instant.parse(rfc3339tenths), parseTimestamp(rfc3339tenths)); + } + + @Test + public void timestampLabelParsesRfc3339Hundredths() { + String rfc3339hundredths = "2015-10-29T23:41:41.12Z"; + assertEquals(Instant.parse(rfc3339hundredths), parseTimestamp(rfc3339hundredths)); + } + + @Test + public void timestampLabelParsesRfc3339Millis() { + String rfc3339millis = "2015-10-29T23:41:41.123Z"; + assertEquals(Instant.parse(rfc3339millis), parseTimestamp(rfc3339millis)); + } + + @Test + public void timestampLabelParsesRfc3339Micros() { + String rfc3339micros = "2015-10-29T23:41:41.123456Z"; + assertEquals(Instant.parse(rfc3339micros), parseTimestamp(rfc3339micros)); + // Note: micros part 456/1000 is dropped. + assertEquals(Instant.parse("2015-10-29T23:41:41.123Z"), parseTimestamp(rfc3339micros)); + } + + @Test + public void timestampLabelParsesRfc3339MicrosRounding() { + String rfc3339micros = "2015-10-29T23:41:41.123999Z"; + assertEquals(Instant.parse(rfc3339micros), parseTimestamp(rfc3339micros)); + // Note: micros part 999/1000 is dropped, not rounded up. + assertEquals(Instant.parse("2015-10-29T23:41:41.123Z"), parseTimestamp(rfc3339micros)); + } + + @Test + public void timestampLabelWithInvalidFormatThrowsError() { + thrown.expect(NumberFormatException.class); + parseTimestamp("not-a-timestamp"); + } + + @Test + public void timestampLabelWithInvalidFormat2ThrowsError() { + thrown.expect(NumberFormatException.class); + parseTimestamp("null"); + } + + @Test + public void timestampLabelWithInvalidFormat3ThrowsError() { + thrown.expect(NumberFormatException.class); + parseTimestamp("2015-10"); + } + + @Test + public void timestampLabelParsesRfc3339WithSmallYear() { + // Google and JodaTime agree on dates after 1582-10-15, when the Gregorian Calendar was adopted + // This is therefore a "small year" until this difference is reconciled. + String rfc3339SmallYear = "1582-10-15T01:23:45.123Z"; + assertEquals(Instant.parse(rfc3339SmallYear), parseTimestamp(rfc3339SmallYear)); + } + + @Test + public void timestampLabelParsesRfc3339WithLargeYear() { + // Year 9999 in range. + String rfc3339LargeYear = "9999-10-29T23:41:41.123999Z"; + assertEquals(Instant.parse(rfc3339LargeYear), parseTimestamp(rfc3339LargeYear)); + } + + @Test + public void timestampLabelRfc3339WithTooLargeYearThrowsError() { + thrown.expect(NumberFormatException.class); + // Year 10000 out of range. + parseTimestamp("10000-10-29T23:41:41.123999Z"); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/ReadTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/ReadTest.java new file mode 100644 index 000000000000..8dc517a8e8f6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/ReadTest.java @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.CheckpointMark; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.List; + +import javax.annotation.Nullable; + +/** + * Tests for {@link Read}. + */ +@RunWith(JUnit4.class) +public class ReadTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void failsWhenCustomBoundedSourceIsNotSerializable() { + thrown.expect(IllegalArgumentException.class); + Read.from(new NotSerializableBoundedSource()); + } + + @Test + public void succeedsWhenCustomBoundedSourceIsSerializable() { + Read.from(new SerializableBoundedSource()); + } + + @Test + public void failsWhenCustomUnboundedSourceIsNotSerializable() { + thrown.expect(IllegalArgumentException.class); + Read.from(new NotSerializableUnboundedSource()); + } + + @Test + public void succeedsWhenCustomUnboundedSourceIsSerializable() { + Read.from(new SerializableUnboundedSource()); + } + + private abstract static class CustomBoundedSource extends BoundedSource { + @Override + public List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + return null; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 0; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedReader createReader(PipelineOptions options) throws IOException { + return null; + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return null; + } + } + + private static class NotSerializableBoundedSource extends CustomBoundedSource { + @SuppressWarnings("unused") + private final NotSerializableClass notSerializableClass = new NotSerializableClass(); + } + + private static class SerializableBoundedSource extends CustomBoundedSource {} + + private abstract static class CustomUnboundedSource + extends UnboundedSource { + @Override + public List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception { + return null; + } + + @Override + public UnboundedReader createReader( + PipelineOptions options, NoOpCheckpointMark checkpointMark) { + return null; + } + + @Override + @Nullable + public Coder getCheckpointMarkCoder() { + return null; + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return null; + } + } + + private static class NoOpCheckpointMark implements CheckpointMark { + @Override + public void finalizeCheckpoint() throws IOException {} + } + + private static class NotSerializableUnboundedSource extends CustomUnboundedSource { + @SuppressWarnings("unused") + private final NotSerializableClass notSerializableClass = new NotSerializableClass(); + } + + private static class SerializableUnboundedSource extends CustomUnboundedSource {} + + private static class NotSerializableClass {} +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java new file mode 100644 index 000000000000..0a8e3811085f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java @@ -0,0 +1,562 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.TestUtils.INTS_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_INTS_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES_ARRAY; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.TextualIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.TextIO.CompressionType; +import com.google.cloud.dataflow.sdk.io.TextIO.TextSource; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.SourceTestUtils; +import com.google.cloud.dataflow.sdk.testing.TestDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.channels.FileChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.zip.GZIPOutputStream; + +/** + * Tests for TextIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +@SuppressWarnings("unchecked") +public class TextIOTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule public ExpectedException expectedException = ExpectedException.none(); + + private GcsUtil buildMockGcsUtil() throws IOException { + GcsUtil mockGcsUtil = Mockito.mock(GcsUtil.class); + + // Any request to open gets a new bogus channel + Mockito + .when(mockGcsUtil.open(Mockito.any(GcsPath.class))) + .then(new Answer() { + @Override + public SeekableByteChannel answer(InvocationOnMock invocation) throws Throwable { + return FileChannel.open( + Files.createTempFile("channel-", ".tmp"), + StandardOpenOption.CREATE, StandardOpenOption.DELETE_ON_CLOSE); + } + }); + + // Any request for expansion returns a list containing the original GcsPath + // This is required to pass validation that occurs in TextIO during apply() + Mockito + .when(mockGcsUtil.expand(Mockito.any(GcsPath.class))) + .then(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return ImmutableList.of((GcsPath) invocation.getArguments()[0]); + } + }); + + return mockGcsUtil; + } + + private TestDataflowPipelineOptions buildTestPipelineOptions() { + TestDataflowPipelineOptions options = + PipelineOptionsFactory.as(TestDataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + return options; + } + + void runTestRead(T[] expected, Coder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + try (PrintStream writer = new PrintStream(new FileOutputStream(tmpFile))) { + for (T elem : expected) { + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, elem); + String line = new String(encodedElem); + writer.println(line); + } + } + + Pipeline p = TestPipeline.create(); + + TextIO.Read.Bound read; + if (coder.equals(StringUtf8Coder.of())) { + TextIO.Read.Bound readStrings = TextIO.Read.from(filename); + // T==String + read = (TextIO.Read.Bound) readStrings; + } else { + read = TextIO.Read.from(filename).withCoder(coder); + } + + PCollection output = p.apply(read); + + DataflowAssert.that(output).containsInAnyOrder(expected); + p.run(); + } + + @Test + public void testReadStrings() throws Exception { + runTestRead(LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testReadEmptyStrings() throws Exception { + runTestRead(NO_LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testReadInts() throws Exception { + runTestRead(INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testReadEmptyInts() throws Exception { + runTestRead(NO_INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testReadNulls() throws Exception { + runTestRead(new Void[]{ null, null, null }, VoidCoder.of()); + } + + @Test + public void testReadNamed() throws Exception { + String file = tmpFolder.newFile().getAbsolutePath(); + Pipeline p = TestPipeline.create(); + + { + PCollection output1 = + p.apply(TextIO.Read.from(file)); + assertEquals("TextIO.Read/Read.out", output1.getName()); + } + + { + PCollection output2 = + p.apply(TextIO.Read.named("MyRead").from(file)); + assertEquals("MyRead/Read.out", output2.getName()); + } + + { + PCollection output3 = + p.apply(TextIO.Read.from(file).named("HerRead")); + assertEquals("HerRead/Read.out", output3.getName()); + } + } + + void runTestWrite(T[] elems, Coder coder) throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(elems)).withCoder(coder)); + + TextIO.Write.Bound write; + if (coder.equals(StringUtf8Coder.of())) { + TextIO.Write.Bound writeStrings = + TextIO.Write.to(filename).withoutSharding(); + // T==String + write = (TextIO.Write.Bound) writeStrings; + } else { + write = TextIO.Write.to(filename).withCoder(coder).withoutSharding(); + } + + input.apply(write); + + p.run(); + + List actual = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new FileReader(tmpFile))) { + for (;;) { + String line = reader.readLine(); + if (line == null) { + break; + } + actual.add(line); + } + } + + String[] expected = new String[elems.length]; + for (int i = 0; i < elems.length; i++) { + T elem = elems[i]; + byte[] encodedElem = CoderUtils.encodeToByteArray(coder, elem); + String line = new String(encodedElem); + expected[i] = line; + } + + assertThat(actual, + containsInAnyOrder(expected)); + } + + @Test + public void testWriteStrings() throws Exception { + runTestWrite(LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testWriteEmptyStrings() throws Exception { + runTestWrite(NO_LINES_ARRAY, StringUtf8Coder.of()); + } + + @Test + public void testWriteInts() throws Exception { + runTestWrite(INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testWriteEmptyInts() throws Exception { + runTestWrite(NO_INTS_ARRAY, TextualIntegerCoder.of()); + } + + @Test + public void testWriteNamed() { + { + PTransform, PDone> transform1 = + TextIO.Write.to("/tmp/file.txt"); + assertEquals("TextIO.Write", transform1.getName()); + } + + { + PTransform, PDone> transform2 = + TextIO.Write.named("MyWrite").to("/tmp/file.txt"); + assertEquals("MyWrite", transform2.getName()); + } + + { + PTransform, PDone> transform3 = + TextIO.Write.to("/tmp/file.txt").named("HerWrite"); + assertEquals("HerWrite", transform3.getName()); + } + } + + @Test + public void testUnsupportedFilePattern() throws IOException { + File outFolder = tmpFolder.newFolder(); + // Windows doesn't like resolving paths with * in them. + String filename = outFolder.toPath().resolve("output@5").toString(); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(LINES_ARRAY)) + .withCoder(StringUtf8Coder.of())); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Output name components are not allowed to contain"); + input.apply(TextIO.Write.to(filename)); + } + + /** + * This tests a few corner cases that should not crash. + */ + @Test + public void testGoodWildcards() throws Exception { + TestDataflowPipelineOptions options = buildTestPipelineOptions(); + options.setGcsUtil(buildMockGcsUtil()); + + Pipeline pipeline = Pipeline.create(options); + + applyRead(pipeline, "gs://bucket/foo"); + applyRead(pipeline, "gs://bucket/foo/"); + applyRead(pipeline, "gs://bucket/foo/*"); + applyRead(pipeline, "gs://bucket/foo/?"); + applyRead(pipeline, "gs://bucket/foo/[0-9]"); + applyRead(pipeline, "gs://bucket/foo/*baz*"); + applyRead(pipeline, "gs://bucket/foo/*baz?"); + applyRead(pipeline, "gs://bucket/foo/[0-9]baz?"); + applyRead(pipeline, "gs://bucket/foo/baz/*"); + applyRead(pipeline, "gs://bucket/foo/baz/*wonka*"); + applyRead(pipeline, "gs://bucket/foo/*baz/wonka*"); + applyRead(pipeline, "gs://bucket/foo*/baz"); + applyRead(pipeline, "gs://bucket/foo?/baz"); + applyRead(pipeline, "gs://bucket/foo[0-9]/baz"); + + // Check that running doesn't fail. + pipeline.run(); + } + + private void applyRead(Pipeline pipeline, String path) { + pipeline.apply("Read(" + path + ")", TextIO.Read.from(path)); + } + + /** + * Recursive wildcards are not supported. + * This tests "**". + */ + @Test + public void testBadWildcardRecursive() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + // Check that applying does fail. + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("wildcard"); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo**/baz")); + } + + @Test + public void testReadWithoutValidationFlag() throws Exception { + TextIO.Read.Bound read = TextIO.Read.from("gs://bucket/foo*/baz"); + assertTrue(read.needsValidation()); + assertFalse(read.withoutValidation().needsValidation()); + } + + @Test + public void testWriteWithoutValidationFlag() throws Exception { + TextIO.Write.Bound write = TextIO.Write.to("gs://bucket/foo/baz"); + assertTrue(write.needsValidation()); + assertFalse(write.withoutValidation().needsValidation()); + } + + @Test + public void testCompressionTypeIsSet() throws Exception { + TextIO.Read.Bound read = TextIO.Read.from("gs://bucket/test"); + assertEquals(CompressionType.AUTO, read.getCompressionType()); + read = TextIO.Read.from("gs://bucket/test").withCompressionType(CompressionType.GZIP); + assertEquals(CompressionType.GZIP, read.getCompressionType()); + } + + @Test + public void testCompressedRead() throws Exception { + String[] lines = {"Irritable eagle", "Optimistic jay", "Fanciful hawk"}; + File tmpFile = tmpFolder.newFile(); + String filename = tmpFile.getPath(); + + List expected = new ArrayList<>(); + try (PrintStream writer = + new PrintStream(new GZIPOutputStream(new FileOutputStream(tmpFile)))) { + for (String line : lines) { + writer.println(line); + expected.add(line); + } + } + + Pipeline p = TestPipeline.create(); + + TextIO.Read.Bound read = + TextIO.Read.from(filename).withCompressionType(CompressionType.GZIP); + PCollection output = p.apply(read); + + DataflowAssert.that(output).containsInAnyOrder(expected); + p.run(); + } + + @Test + public void testGZIPReadWhenUncompressed() throws Exception { + String[] lines = {"Meritorious condor", "Obnoxious duck"}; + File tmpFile = tmpFolder.newFile(); + String filename = tmpFile.getPath(); + + List expected = new ArrayList<>(); + try (PrintStream writer = new PrintStream(new FileOutputStream(tmpFile))) { + for (String line : lines) { + writer.println(line); + expected.add(line); + } + } + + Pipeline p = TestPipeline.create(); + TextIO.Read.Bound read = + TextIO.Read.from(filename).withCompressionType(CompressionType.GZIP); + PCollection output = p.apply(read); + + DataflowAssert.that(output).containsInAnyOrder(expected); + p.run(); + } + + @Test + public void testTextIOGetName() { + assertEquals("TextIO.Read", TextIO.Read.from("somefile").getName()); + assertEquals("TextIO.Write", TextIO.Write.to("somefile").getName()); + assertEquals("ReadMyFile", TextIO.Read.named("ReadMyFile").from("somefile").getName()); + assertEquals("WriteMyFile", TextIO.Write.named("WriteMyFile").to("somefile").getName()); + + assertEquals("TextIO.Read", TextIO.Read.from("somefile").toString()); + assertEquals( + "ReadMyFile [TextIO.Read]", TextIO.Read.named("ReadMyFile").from("somefile").toString()); + } + + @Test + public void testReadEmptyLines() throws Exception { + runTestReadWithData("\n\n\n".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("", "", "")); + } + + @Test + public void testReadFileWithLineFeedDelimiter() throws Exception { + runTestReadWithData("asdf\nhjkl\nxyz\n".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithCarriageReturnDelimiter() throws Exception { + runTestReadWithData("asdf\rhjkl\rxyz\r".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithCarriageReturnAndLineFeedDelimiter() throws Exception { + runTestReadWithData("asdf\r\nhjkl\r\nxyz\r\n".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithMixedDelimiters() throws Exception { + runTestReadWithData("asdf\rhjkl\r\nxyz\n".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithLineFeedDelimiterAndNonEmptyBytesAtEnd() throws Exception { + runTestReadWithData("asdf\nhjkl\nxyz".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithCarriageReturnDelimiterAndNonEmptyBytesAtEnd() throws Exception { + runTestReadWithData("asdf\rhjkl\rxyz".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithCarriageReturnAndLineFeedDelimiterAndNonEmptyBytesAtEnd() + throws Exception { + runTestReadWithData("asdf\r\nhjkl\r\nxyz".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + @Test + public void testReadFileWithMixedDelimitersAndNonEmptyBytesAtEnd() throws Exception { + runTestReadWithData("asdf\rhjkl\r\nxyz".getBytes(StandardCharsets.UTF_8), + ImmutableList.of("asdf", "hjkl", "xyz")); + } + + private void runTestReadWithData(byte[] data, List expectedResults) throws Exception { + TextSource source = prepareSource(data); + List actual = SourceTestUtils.readFromSource(source, PipelineOptionsFactory.create()); + assertThat(actual, containsInAnyOrder(new ArrayList<>(expectedResults).toArray(new String[0]))); + } + + @Test + public void testSplittingSourceWithEmptyLines() throws Exception { + TextSource source = prepareSource("\n\n\n".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithLineFeedDelimiter() throws Exception { + TextSource source = prepareSource("asdf\nhjkl\nxyz\n".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithCarriageReturnDelimiter() throws Exception { + TextSource source = prepareSource("asdf\rhjkl\rxyz\r".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithCarriageReturnAndLineFeedDelimiter() throws Exception { + TextSource source = prepareSource( + "asdf\r\nhjkl\r\nxyz\r\n".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithMixedDelimiters() throws Exception { + TextSource source = prepareSource( + "asdf\rhjkl\r\nxyz\n".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithLineFeedDelimiterAndNonEmptyBytesAtEnd() throws Exception { + TextSource source = prepareSource("asdf\nhjkl\nxyz".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithCarriageReturnDelimiterAndNonEmptyBytesAtEnd() + throws Exception { + TextSource source = prepareSource("asdf\rhjkl\rxyz".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithCarriageReturnAndLineFeedDelimiterAndNonEmptyBytesAtEnd() + throws Exception { + TextSource source = prepareSource( + "asdf\r\nhjkl\r\nxyz".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + @Test + public void testSplittingSourceWithMixedDelimitersAndNonEmptyBytesAtEnd() throws Exception { + TextSource source = prepareSource("asdf\rhjkl\r\nxyz".getBytes(StandardCharsets.UTF_8)); + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } + + private TextSource prepareSource(byte[] data) throws IOException { + File file = tmpFolder.newFile(); + Files.write(file.toPath(), data); + + TextSource source = new TextSource<>(file.toPath().toString(), StringUtf8Coder.of()); + + return source; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java new file mode 100644 index 000000000000..b92f2f18ea15 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java @@ -0,0 +1,341 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.Sink.WriteOperation; +import com.google.cloud.dataflow.sdk.io.Sink.Writer; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest.TestPipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.MoreObjects; + +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; + +/** + * Tests for the Write PTransform. + */ +@RunWith(JUnit4.class) +public class WriteTest { + // Static store that can be accessed within the writer + static List sinkContents = new ArrayList<>(); + + /** + * Test a Write transform with a PCollection of elements. + */ + @Test + public void testWrite() { + List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", + "Intimidating pigeon", "Pedantic gull", "Frisky finch"); + runWrite(inputs, /* not windowed */ false); + } + + /** + * Test a Write transform with an empty PCollection. + */ + @Test + public void testWriteWithEmptyPCollection() { + List inputs = new ArrayList<>(); + runWrite(inputs, /* not windowed */ false); + } + + /** + * Test a Write with a windowed PCollection. + */ + @Test + public void testWriteWindowed() { + List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", + "Intimidating pigeon", "Pedantic gull", "Frisky finch"); + runWrite(inputs, /* windowed */ true); + } + + /** + * Performs a Write transform and verifies the Write transform calls the appropriate methods on + * a test sink in the correct order, as well as verifies that the elements of a PCollection are + * written to the sink. + */ + public void runWrite(List inputs, boolean windowed) { + // Flag to validate that the pipeline options are passed to the Sink + String[] args = {"--testFlag=test_value"}; + PipelineOptions options = PipelineOptionsFactory.fromArgs(args).as(WriteOptions.class); + Pipeline p = Pipeline.create(options); + + // Clear the sink's contents. + sinkContents.clear(); + + // Construct the input PCollection and test Sink. + PCollection input; + if (windowed) { + List timestamps = new ArrayList<>(); + for (long i = 0; i < inputs.size(); i++) { + timestamps.add(i + 1); + } + input = p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of())) + .apply(Window.into(FixedWindows.of(new Duration(2)))); + } else { + input = p.apply(Create.of(inputs).withCoder(StringUtf8Coder.of())); + } + TestSink sink = new TestSink(); + + input.apply(Write.to(sink)); + + p.run(); + assertThat(sinkContents, containsInAnyOrder(inputs.toArray())); + assertTrue(sink.hasCorrectState()); + } + + // Test sink and associated write operation and writer. TestSink, TestWriteOperation, and + // TestWriter each verify that the sequence of method calls is consistent with the specification + // of the Write PTransform. + private static class TestSink extends Sink { + private boolean createCalled = false; + private boolean validateCalled = false; + + @Override + public WriteOperation createWriteOperation(PipelineOptions options) { + assertTrue(validateCalled); + assertTestFlagPresent(options); + createCalled = true; + return new TestSinkWriteOperation(this); + } + + @Override + public void validate(PipelineOptions options) { + assertTestFlagPresent(options); + validateCalled = true; + } + + private void assertTestFlagPresent(PipelineOptions options) { + assertEquals("test_value", options.as(WriteOptions.class).getTestFlag()); + } + + private boolean hasCorrectState() { + return validateCalled && createCalled; + } + + /** + * Implementation of equals() that indicates all test sinks are equal. + */ + @Override + public boolean equals(Object other) { + if (!(other instanceof TestSink)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + return Objects.hash(getClass()); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("createCalled", createCalled) + .add("validateCalled", validateCalled) + .toString(); + } + } + + private static class TestSinkWriteOperation extends WriteOperation { + private enum State { + INITIAL, + INITIALIZED, + FINALIZED + } + + // Must be static in case the WriteOperation is serialized before the its coder is obtained. + // If this occurs, the value will be modified but not reflected in the WriteOperation that is + // executed by the runner, and the finalize method will fail. + private static volatile boolean coderCalled = false; + + private State state = State.INITIAL; + + private final TestSink sink; + private final UUID id = UUID.randomUUID(); + + public TestSinkWriteOperation(TestSink sink) { + this.sink = sink; + } + + @Override + public TestSink getSink() { + return sink; + } + + @Override + public void initialize(PipelineOptions options) throws Exception { + assertEquals("test_value", options.as(WriteOptions.class).getTestFlag()); + assertThat(state, anyOf(equalTo(State.INITIAL), equalTo(State.INITIALIZED))); + state = State.INITIALIZED; + } + + @Override + public void finalize(Iterable bundleResults, PipelineOptions options) + throws Exception { + assertEquals("test_value", options.as(WriteOptions.class).getTestFlag()); + assertEquals(State.INITIALIZED, state); + // The coder for the test writer results should've been called. + assertTrue(coderCalled); + Set idSet = new HashSet<>(); + int resultCount = 0; + state = State.FINALIZED; + for (TestWriterResult result : bundleResults) { + resultCount += 1; + idSet.add(result.uId); + // Add the elements that were written to the sink's contents. + sinkContents.addAll(result.elementsWritten); + } + // Each result came from a unique id. + assertEquals(resultCount, idSet.size()); + } + + @Override + public Writer createWriter(PipelineOptions options) { + return new TestSinkWriter(this); + } + + @Override + public Coder getWriterResultCoder() { + coderCalled = true; + return SerializableCoder.of(TestWriterResult.class); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("id", id) + .add("sink", sink) + .add("state", state) + .add("coderCalled", coderCalled) + .toString(); + } + + /** + * Implementation of equals() that does not depend on the state of the write operation, + * but only its specification. In general, write operations will have interesting + * specifications, but for a {@link TestSinkWriteOperation}, it is not the case. Instead, + * a unique identifier (that is serialized along with it) is used to simulate such a + * specification. + */ + @Override + public boolean equals(Object other) { + if (!(other instanceof TestSinkWriteOperation)) { + return false; + } + TestSinkWriteOperation otherOperation = (TestSinkWriteOperation) other; + return sink.equals(otherOperation.sink) + && id.equals(otherOperation.id); + } + + @Override + public int hashCode() { + return Objects.hash(id, sink); + } + } + + private static class TestWriterResult implements Serializable { + String uId; + List elementsWritten; + + public TestWriterResult(String uId, List elementsWritten) { + this.uId = uId; + this.elementsWritten = elementsWritten; + } + } + + private static class TestSinkWriter extends Writer { + private enum State { + INITIAL, + OPENED, + WRITING, + CLOSED + } + + private State state = State.INITIAL; + private List elementsWritten = new ArrayList<>(); + private String uId; + + private final TestSinkWriteOperation writeOperation; + + public TestSinkWriter(TestSinkWriteOperation writeOperation) { + this.writeOperation = writeOperation; + } + + @Override + public TestSinkWriteOperation getWriteOperation() { + return writeOperation; + } + + @Override + public void open(String uId) throws Exception { + this.uId = uId; + assertEquals(State.INITIAL, state); + state = State.OPENED; + } + + @Override + public void write(String value) throws Exception { + assertThat(state, anyOf(equalTo(State.OPENED), equalTo(State.WRITING))); + state = State.WRITING; + elementsWritten.add(value); + } + + @Override + public TestWriterResult close() throws Exception { + assertThat(state, anyOf(equalTo(State.OPENED), equalTo(State.WRITING))); + state = State.CLOSED; + return new TestWriterResult(uId, elementsWritten); + } + } + + /** + * Options for test, exposed for PipelineOptionsFactory. + */ + public static interface WriteOptions extends TestPipelineOptions { + @Description("Test flag and value") + String getTestFlag(); + + void setTestFlag(String value); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/XmlSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/XmlSinkTest.java new file mode 100644 index 000000000000..4b8380d89e53 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/XmlSinkTest.java @@ -0,0 +1,235 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import com.google.cloud.dataflow.sdk.io.XmlSink.XmlWriteOperation; +import com.google.cloud.dataflow.sdk.io.XmlSink.XmlWriter; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.common.collect.Lists; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.FileReader; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlType; + +/** + * Tests for XmlSink. + */ +@RunWith(JUnit4.class) +public class XmlSinkTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private Class testClass = Bird.class; + private String testRootElement = "testElement"; + private String testFilePrefix = "testPrefix"; + + /** + * An XmlWriter correctly writes objects as Xml elements with an enclosing root element. + */ + @Test + public void testXmlWriter() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + XmlWriteOperation writeOp = + XmlSink.writeOf(Bird.class, "birds", testFilePrefix).createWriteOperation(options); + XmlWriter writer = writeOp.createWriter(options); + + List bundle = + Lists.newArrayList(new Bird("bemused", "robin"), new Bird("evasive", "goose")); + List lines = Arrays.asList("", "", "robin", + "bemused", "", "", "goose", + "evasive", "", ""); + runTestWrite(writer, bundle, lines); + } + + /** + * Builder methods correctly initialize an XML Sink. + */ + @Test + public void testBuildXmlSink() { + XmlSink.Bound sink = + XmlSink.write() + .toFilenamePrefix(testFilePrefix) + .ofRecordClass(testClass) + .withRootElement(testRootElement); + assertEquals(testClass, sink.classToBind); + assertEquals(testRootElement, sink.rootElementName); + assertEquals(testFilePrefix, sink.baseOutputFilename); + } + + /** + * Alternate builder method correctly initializes an XML Sink. + */ + @Test + public void testBuildXmlSinkDirect() { + XmlSink.Bound sink = + XmlSink.writeOf(Bird.class, testRootElement, testFilePrefix); + assertEquals(testClass, sink.classToBind); + assertEquals(testRootElement, sink.rootElementName); + assertEquals(testFilePrefix, sink.baseOutputFilename); + } + + /** + * Validation ensures no fields are missing. + */ + @Test + public void testValidateXmlSinkMissingFields() { + XmlSink.Bound sink; + sink = XmlSink.writeOf(null, testRootElement, testFilePrefix); + validateAndFailIfSucceeds(sink, NullPointerException.class); + sink = XmlSink.writeOf(testClass, null, testFilePrefix); + validateAndFailIfSucceeds(sink, NullPointerException.class); + sink = XmlSink.writeOf(testClass, testRootElement, null); + validateAndFailIfSucceeds(sink, NullPointerException.class); + } + + /** + * Call validate and fail if validation does not throw the expected exception. + */ + private void validateAndFailIfSucceeds( + XmlSink.Bound sink, Class expected) { + thrown.expect(expected); + PipelineOptions options = PipelineOptionsFactory.create(); + sink.validate(options); + } + + /** + * An XML Sink correctly creates an XmlWriteOperation. + */ + @Test + public void testCreateWriteOperations() { + PipelineOptions options = PipelineOptionsFactory.create(); + XmlSink.Bound sink = + XmlSink.writeOf(testClass, testRootElement, testFilePrefix); + XmlWriteOperation writeOp = sink.createWriteOperation(options); + assertEquals(testClass, writeOp.getSink().classToBind); + assertEquals(testFilePrefix, writeOp.getSink().baseOutputFilename); + assertEquals(testRootElement, writeOp.getSink().rootElementName); + assertEquals(XmlSink.XML_EXTENSION, writeOp.getSink().extension); + assertEquals(testFilePrefix, writeOp.baseTemporaryFilename); + } + + /** + * An XmlWriteOperation correctly creates an XmlWriter. + */ + @Test + public void testCreateWriter() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + XmlWriteOperation writeOp = + XmlSink.writeOf(testClass, testRootElement, testFilePrefix) + .createWriteOperation(options); + XmlWriter writer = writeOp.createWriter(options); + assertEquals(testFilePrefix, writer.getWriteOperation().baseTemporaryFilename); + assertEquals(testRootElement, writer.getWriteOperation().getSink().rootElementName); + assertNotNull(writer.marshaller); + } + + /** + * Write a bundle with an XmlWriter and verify the output is expected. + */ + private void runTestWrite(XmlWriter writer, List bundle, List expected) + throws Exception { + File tmpFile = tmpFolder.newFile("foo.txt"); + try (FileOutputStream fileOutputStream = new FileOutputStream(tmpFile)) { + writeBundle(writer, bundle, fileOutputStream.getChannel()); + } + List lines = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new FileReader(tmpFile))) { + for (;;) { + String line = reader.readLine(); + if (line == null) { + break; + } + line = line.trim(); + if (line.length() > 0) { + lines.add(line); + } + } + assertEquals(expected, lines); + } + } + + /** + * Write a bundle with an XmlWriter. + */ + private void writeBundle(XmlWriter writer, List elements, WritableByteChannel channel) + throws Exception { + writer.prepareWrite(channel); + writer.writeHeader(); + for (T elem : elements) { + writer.write(elem); + } + writer.writeFooter(); + } + + /** + * Test JAXB annotated class. + */ + @SuppressWarnings("unused") + @XmlRootElement(name = "bird") + @XmlType(propOrder = {"name", "adjective"}) + private static final class Bird { + private String name; + private String adjective; + + @XmlElement(name = "species") + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getAdjective() { + return adjective; + } + + public void setAdjective(String adjective) { + this.adjective = adjective; + } + + public Bird() {} + + public Bird(String adjective, String name) { + this.adjective = adjective; + this.name = name; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/XmlSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/XmlSourceTest.java new file mode 100644 index 000000000000..5618ec7a10c0 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/XmlSourceTest.java @@ -0,0 +1,822 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io; + +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionFails; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.Source.Reader; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import javax.xml.bind.annotation.XmlAttribute; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * Tests XmlSource. + */ +@RunWith(JUnit4.class) +public class XmlSourceTest { + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Rule + public ExpectedException exception = ExpectedException.none(); + + String tinyXML = + "ThomasHenry" + + "James"; + + String xmlWithMultiByteElementName = + "<දුම්රියන්><දුම්රිය>Thomas<දුම්රිය>Henry" + + "<දුම්රිය>James"; + + String xmlWithMultiByteChars = + "Thomas¥Hen¶ry" + + "Jamßes"; + + String trainXML = + "" + + "Thomas1blue" + + "Henry3green" + + "Toby7brown" + + "Gordon4blue" + + "Emily-1red" + + "Percy6green" + + ""; + + String trainXMLWithEmptyTags = + "" + + "" + + "Thomas1blue" + + "Henry3green" + + "" + + "Toby7brown" + + "Gordon4blue" + + "Emily-1red" + + "Percy6green" + + ""; + + String trainXMLWithAttributes = + "" + + "Thomas1blue" + + "Henry3green" + + "Toby7brown" + + "Gordon4blue" + + "Emily-1red" + + "Percy6green" + + ""; + + String trainXMLWithSpaces = + "" + + "Thomas 1blue" + + "Henry3green\n" + + "Toby7 brown " + + "Gordon 4blue\n\t" + + "Emily-1\tred" + + "\nPercy 6 green" + + ""; + + String trainXMLWithAllFeaturesMultiByte = + "<දුම්රියන්>" + + "<දුම්රිය/>" + + "<දුම්රිය size=\"small\"> Thomas¥1blue" + + "" + + "<දුම්රිය size=\"big\">He nry3green" + + "<දුම්රිය size=\"small\">Toby 7br¶own" + + "" + + "<දුම්රිය/>" + + "<දුම්රිය size=\"big\">Gordon4 blue" + + "<දුම්රිය size=\"small\">Emily-1red" + + "<දුම්රිය size=\"small\">Percy6green" + + "" + + ""; + + String trainXMLWithAllFeaturesSingleByte = + "" + + "" + + " Thomas1blue" + + "" + + "He nry3green" + + "Toby 7brown" + + "" + + "" + + "Gordon4 blue" + + "Emily-1red" + + "Percy6green" + + "" + + ""; + + @XmlRootElement + static class Train { + public static final int TRAIN_NUMBER_UNDEFINED = -1; + public String name = null; + public String color = null; + public int number = TRAIN_NUMBER_UNDEFINED; + + @XmlAttribute(name = "size") + public String size = null; + + public Train() {} + + public Train(String name, int number, String color, String size) { + this.name = name; + this.number = number; + this.color = color; + this.size = size; + } + + @Override + public int hashCode() { + int hashCode = 1; + hashCode = 31 * hashCode + (name == null ? 0 : name.hashCode()); + hashCode = 31 * hashCode + number; + hashCode = 31 * hashCode + (color == null ? 0 : name.hashCode()); + hashCode = 31 * hashCode + (size == null ? 0 : name.hashCode()); + return hashCode; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Train)) { + return false; + } + + Train other = (Train) obj; + return (name == null || name.equals(other.name)) && (number == other.number) + && (color == null || color.equals(other.color)) + && (size == null || size.equals(other.size)); + } + + @Override + public String toString() { + String str = "Train["; + boolean first = true; + if (name != null) { + str = str + "name=" + name; + first = false; + } + if (number != Integer.MIN_VALUE) { + if (!first) { + str = str + ","; + } + str = str + "number=" + number; + first = false; + } + if (color != null) { + if (!first) { + str = str + ","; + } + str = str + "color=" + color; + first = false; + } + if (size != null) { + if (!first) { + str = str + ","; + } + str = str + "size=" + size; + } + str = str + "]"; + return str; + } + } + + private List generateRandomTrainList(int size) { + String[] names = {"Thomas", "Henry", "Gordon", "Emily", "Toby", "Percy", "Mavis", "Edward", + "Bertie", "Harold", "Hiro", "Terence", "Salty", "Trevor"}; + int[] numbers = {-1, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + String[] colors = {"red", "blue", "green", "orange", "brown", "black", "white"}; + String[] sizes = {"small", "medium", "big"}; + + Random random = new Random(System.currentTimeMillis()); + + List trains = new ArrayList<>(); + for (int i = 0; i < size; i++) { + trains.add(new Train(names[random.nextInt(names.length - 1)], + numbers[random.nextInt(numbers.length - 1)], colors[random.nextInt(colors.length - 1)], + sizes[random.nextInt(sizes.length - 1)])); + } + + return trains; + } + + private String trainToXMLElement(Train train) { + return "" + train.name + "" + + train.number + "" + train.color + ""; + } + + private File createRandomTrainXML(String fileName, List trains) throws IOException { + File file = tempFolder.newFile(fileName); + try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) { + writer.write(""); + writer.newLine(); + for (Train train : trains) { + String str = trainToXMLElement(train); + writer.write(str); + writer.newLine(); + } + writer.write(""); + writer.newLine(); + } + return file; + } + + private List readEverythingFromReader(Reader reader) throws IOException { + List results = new ArrayList<>(); + for (boolean available = reader.start(); available; available = reader.advance()) { + Train train = reader.getCurrent(); + results.add(train); + } + return results; + } + + @Test + public void testReadXMLTiny() throws IOException { + File file = tempFolder.newFile("trainXMLTiny"); + Files.write(file.toPath(), tinyXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = ImmutableList.of( + new Train("Thomas", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("Henry", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("James", Train.TRAIN_NUMBER_UNDEFINED, null, null)); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testReadXMLWithMultiByteChars() throws IOException { + File file = tempFolder.newFile("trainXMLTiny"); + Files.write(file.toPath(), xmlWithMultiByteChars.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = ImmutableList.of( + new Train("Thomas¥", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("Hen¶ry", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("Jamßes", Train.TRAIN_NUMBER_UNDEFINED, null, null)); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + @Ignore( + "Multi-byte characters in XML are not supported because the parser " + + "currently does not correctly report byte offsets") + public void testReadXMLWithMultiByteElementName() throws IOException { + File file = tempFolder.newFile("trainXMLTiny"); + Files.write(file.toPath(), xmlWithMultiByteElementName.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("දුම්රියන්") + .withRecordElement("දුම්රිය") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = ImmutableList.of( + new Train("Thomas", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("Henry", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("James", Train.TRAIN_NUMBER_UNDEFINED, null, null)); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testSplitWithEmptyBundleAtEnd() throws Exception { + File file = tempFolder.newFile("trainXMLTiny"); + Files.write(file.toPath(), tinyXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(10); + List> splits = source.splitIntoBundles(50, null); + + assertTrue(splits.size() > 2); + + List results = new ArrayList<>(); + for (FileBasedSource split : splits) { + results.addAll(readEverythingFromReader(split.createReader(null))); + } + + List expectedResults = ImmutableList.of( + new Train("Thomas", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("Henry", Train.TRAIN_NUMBER_UNDEFINED, null, null), + new Train("James", Train.TRAIN_NUMBER_UNDEFINED, null, null)); + + assertThat( + trainsToStrings(expectedResults), containsInAnyOrder(trainsToStrings(results).toArray())); + } + + List trainsToStrings(List input) { + List strings = new ArrayList<>(); + for (Object data : input) { + strings.add(data.toString()); + } + return strings; + } + + @Test + public void testReadXMLSmall() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = + ImmutableList.of(new Train("Thomas", 1, "blue", null), new Train("Henry", 3, "green", null), + new Train("Toby", 7, "brown", null), new Train("Gordon", 4, "blue", null), + new Train("Emily", -1, "red", null), new Train("Percy", 6, "green", null)); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testReadXMLNoRootElement() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRecordElement("train") + .withRecordClass(Train.class); + + exception.expect(NullPointerException.class); + exception.expectMessage( + "rootElement is null. Use builder method withRootElement() to set this."); + readEverythingFromReader(source.createReader(null)); + } + + @Test + public void testReadXMLNoRecordElement() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordClass(Train.class); + + exception.expect(NullPointerException.class); + exception.expectMessage( + "recordElement is null. Use builder method withRecordElement() to set this."); + readEverythingFromReader(source.createReader(null)); + } + + @Test + public void testReadXMLNoRecordClass() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train"); + + exception.expect(NullPointerException.class); + exception.expectMessage( + "recordClass is null. Use builder method withRecordClass() to set this."); + readEverythingFromReader(source.createReader(null)); + } + + @Test + public void testReadXMLIncorrectRootElement() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("something") + .withRecordElement("train") + .withRecordClass(Train.class); + + exception.expectMessage("Unexpected close tag ; expected ."); + readEverythingFromReader(source.createReader(null)); + } + + @Test + public void testReadXMLIncorrectRecordElement() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("something") + .withRecordClass(Train.class); + + assertEquals(readEverythingFromReader(source.createReader(null)), new ArrayList()); + } + + @XmlRootElement + private static class WrongTrainType { + @SuppressWarnings("unused") + public String something; + } + + @Test + public void testReadXMLInvalidRecordClass() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(WrongTrainType.class); + + exception.expect(RuntimeException.class); + + // JAXB internationalizes the error message. So this is all we can match for. + exception.expectMessage(both(containsString("name")).and(Matchers.containsString("something"))); + try (Reader reader = source.createReader(null)) { + + List results = new ArrayList<>(); + for (boolean available = reader.start(); available; available = reader.advance()) { + WrongTrainType train = reader.getCurrent(); + results.add(train); + } + } + } + + @Test + public void testReadXMLNoBundleSize() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class); + + List expectedResults = + ImmutableList.of(new Train("Thomas", 1, "blue", null), new Train("Henry", 3, "green", null), + new Train("Toby", 7, "brown", null), new Train("Gordon", 4, "blue", null), + new Train("Emily", -1, "red", null), new Train("Percy", 6, "green", null)); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + + @Test + public void testReadXMLWithEmptyTags() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXMLWithEmptyTags.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = ImmutableList.of(new Train("Thomas", 1, "blue", null), + new Train("Henry", 3, "green", null), new Train("Toby", 7, "brown", null), + new Train("Gordon", 4, "blue", null), new Train("Emily", -1, "red", null), + new Train("Percy", 6, "green", null), new Train(), new Train()); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testReadXMLSmallDataflow() throws IOException { + Pipeline p = TestPipeline.create(); + + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXML.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + PCollection output = p.apply(Read.from(source).named("ReadFileData")); + + List expectedResults = + ImmutableList.of(new Train("Thomas", 1, "blue", null), new Train("Henry", 3, "green", null), + new Train("Toby", 7, "brown", null), new Train("Gordon", 4, "blue", null), + new Train("Emily", -1, "red", null), new Train("Percy", 6, "green", null)); + + DataflowAssert.that(output).containsInAnyOrder(expectedResults); + p.run(); + } + + @Test + public void testReadXMLWithAttributes() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXMLWithAttributes.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = ImmutableList.of(new Train("Thomas", 1, "blue", "small"), + new Train("Henry", 3, "green", "big"), new Train("Toby", 7, "brown", "small"), + new Train("Gordon", 4, "blue", "big"), new Train("Emily", -1, "red", "small"), + new Train("Percy", 6, "green", "small")); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testReadXMLWithWhitespaces() throws IOException { + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXMLWithSpaces.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + List expectedResults = ImmutableList.of(new Train("Thomas ", 1, "blue", null), + new Train("Henry", 3, "green", null), new Train("Toby", 7, " brown ", null), + new Train("Gordon", 4, "blue", null), new Train("Emily", -1, "red", null), + new Train("Percy", 6, "green", null)); + + assertThat( + trainsToStrings(expectedResults), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testReadXMLLarge() throws IOException { + String fileName = "temp.xml"; + List trains = generateRandomTrainList(100); + File file = createRandomTrainXML(fileName, trains); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + + assertThat( + trainsToStrings(trains), + containsInAnyOrder( + trainsToStrings(readEverythingFromReader(source.createReader(null))).toArray())); + } + + @Test + public void testReadXMLLargeDataflow() throws IOException { + String fileName = "temp.xml"; + List trains = generateRandomTrainList(100); + File file = createRandomTrainXML(fileName, trains); + + Pipeline p = TestPipeline.create(); + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + PCollection output = p.apply(Read.from(source).named("ReadFileData")); + + DataflowAssert.that(output).containsInAnyOrder(trains); + p.run(); + } + + @Test + public void testSplitWithEmptyBundles() throws Exception { + String fileName = "temp.xml"; + List trains = generateRandomTrainList(10); + File file = createRandomTrainXML(fileName, trains); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(10); + List> splits = source.splitIntoBundles(100, null); + + assertTrue(splits.size() > 2); + + List results = new ArrayList<>(); + for (FileBasedSource split : splits) { + results.addAll(readEverythingFromReader(split.createReader(null))); + } + + assertThat(trainsToStrings(trains), containsInAnyOrder(trainsToStrings(results).toArray())); + } + + @Test + public void testXMLWithSplits() throws Exception { + String fileName = "temp.xml"; + List trains = generateRandomTrainList(100); + File file = createRandomTrainXML(fileName, trains); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(10); + List> splits = source.splitIntoBundles(256, null); + + // Not a trivial split + assertTrue(splits.size() > 2); + + List results = new ArrayList<>(); + for (FileBasedSource split : splits) { + results.addAll(readEverythingFromReader(split.createReader(null))); + } + assertThat(trainsToStrings(trains), containsInAnyOrder(trainsToStrings(results).toArray())); + } + + @Test + public void testSplitAtFraction() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + String fileName = "temp.xml"; + List trains = generateRandomTrainList(100); + File file = createRandomTrainXML(fileName, trains); + + XmlSource fileSource = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(10); + + List> splits = + fileSource.splitIntoBundles(file.length() / 3, null); + for (BoundedSource splitSource : splits) { + int numItems = readEverythingFromReader(splitSource.createReader(null)).size(); + // Should not split while unstarted. + assertSplitAtFractionFails(splitSource, 0, 0.7, options); + assertSplitAtFractionSucceedsAndConsistent(splitSource, 1, 0.7, options); + assertSplitAtFractionSucceedsAndConsistent(splitSource, 15, 0.7, options); + assertSplitAtFractionFails(splitSource, 0, 0.0, options); + assertSplitAtFractionFails(splitSource, 20, 0.3, options); + assertSplitAtFractionFails(splitSource, numItems, 1.0, options); + + // After reading 100 elements we will be approximately at position + // 0.99 * (endOffset - startOffset) hence trying to split at fraction 0.9 will be + // unsuccessful. + assertSplitAtFractionFails(splitSource, numItems, 0.9, options); + + // Following passes since we can always find a fraction that is extremely close to 1 such that + // the position suggested by the fraction will be larger than the position the reader is at + // after reading "items - 1" elements. + // This also passes for "numItemsToReadBeforeSplit = items" if the position at suggested + // fraction is larger than the position the reader is at after reading all "items" elements + // (i.e., the start position of the last element). This is true for most cases but will not + // be true if reader position is only one less than the end position. (i.e., the last element + // of the bundle start at the last byte that belongs to the bundle). + assertSplitAtFractionSucceedsAndConsistent(splitSource, numItems - 1, 0.999, options); + } + } + + @Test + public void testSplitAtFractionExhaustiveSingleByte() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXMLWithAllFeaturesSingleByte.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class); + assertSplitAtFractionExhaustive(source, options); + } + + @Test + @Ignore( + "Multi-byte characters in XML are not supported because the parser " + + "currently does not correctly report byte offsets") + public void testSplitAtFractionExhaustiveMultiByte() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + File file = tempFolder.newFile("trainXMLSmall"); + Files.write(file.toPath(), trainXMLWithAllFeaturesMultiByte.getBytes(StandardCharsets.UTF_8)); + + XmlSource source = + XmlSource.from(file.toPath().toString()) + .withRootElement("දුම්රියන්") + .withRecordElement("දුම්රිය") + .withRecordClass(Train.class); + assertSplitAtFractionExhaustive(source, options); + } + + @Test + public void testReadXMLFilePattern() throws IOException { + List trains1 = generateRandomTrainList(20); + File file = createRandomTrainXML("temp1.xml", trains1); + List trains2 = generateRandomTrainList(10); + createRandomTrainXML("temp2.xml", trains2); + List trains3 = generateRandomTrainList(15); + createRandomTrainXML("temp3.xml", trains3); + generateRandomTrainList(8); + createRandomTrainXML("otherfile.xml", trains1); + + Pipeline p = TestPipeline.create(); + + XmlSource source = XmlSource.from(file.getParent() + "/" + + "temp*.xml") + .withRootElement("trains") + .withRecordElement("train") + .withRecordClass(Train.class) + .withMinBundleSize(1024); + PCollection output = p.apply(Read.from(source).named("ReadFileData")); + + List expectedResults = new ArrayList<>(); + expectedResults.addAll(trains1); + expectedResults.addAll(trains2); + expectedResults.addAll(trains3); + + DataflowAssert.that(output).containsInAnyOrder(expectedResults); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIOTest.java new file mode 100644 index 000000000000..0afac13e2962 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIOTest.java @@ -0,0 +1,688 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.bigtable; + +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSourcesEqualReferenceSource; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionFails; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verifyNotNull; +import static org.hamcrest.Matchers.hasSize; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; + +import com.google.bigtable.v1.Cell; +import com.google.bigtable.v1.Column; +import com.google.bigtable.v1.Family; +import com.google.bigtable.v1.Mutation; +import com.google.bigtable.v1.Mutation.SetCell; +import com.google.bigtable.v1.Row; +import com.google.bigtable.v1.RowFilter; +import com.google.bigtable.v1.SampleRowKeysResponse; +import com.google.cloud.bigtable.config.BigtableOptions; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.bigtable.BigtableIO.BigtableSource; +import com.google.cloud.dataflow.sdk.io.range.ByteKey; +import com.google.cloud.dataflow.sdk.io.range.ByteKeyRange; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Predicate; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nullable; + +/** + * Unit tests for {@link BigtableIO}. + */ +@RunWith(JUnit4.class) +public class BigtableIOTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + @Rule public ExpectedLogs logged = ExpectedLogs.none(BigtableIO.class); + + /** + * These tests requires a static instance of the {@link FakeBigtableService} because the writers + * go through a serialization step when executing the test and would not affect passed-in objects + * otherwise. + */ + private static FakeBigtableService service; + private static final BigtableOptions BIGTABLE_OPTIONS = + new BigtableOptions.Builder() + .setProjectId("project") + .setClusterId("cluster") + .setZoneId("zone") + .build(); + private static BigtableIO.Read defaultRead = + BigtableIO.read().withBigtableOptions(BIGTABLE_OPTIONS); + private static BigtableIO.Write defaultWrite = + BigtableIO.write().withBigtableOptions(BIGTABLE_OPTIONS); + private Coder>> bigtableCoder; + private static final TypeDescriptor>> BIGTABLE_WRITE_TYPE = + new TypeDescriptor>>() {}; + + @Before + public void setup() throws Exception { + service = new FakeBigtableService(); + defaultRead = defaultRead.withBigtableService(service); + defaultWrite = defaultWrite.withBigtableService(service); + bigtableCoder = TestPipeline.create().getCoderRegistry().getCoder(BIGTABLE_WRITE_TYPE); + } + + @Test + public void testReadBuildsCorrectly() { + BigtableIO.Read read = + BigtableIO.read().withBigtableOptions(BIGTABLE_OPTIONS).withTableId("table"); + assertEquals("project", read.getBigtableOptions().getProjectId()); + assertEquals("cluster", read.getBigtableOptions().getClusterId()); + assertEquals("zone", read.getBigtableOptions().getZoneId()); + assertEquals("table", read.getTableId()); + } + + @Test + public void testReadBuildsCorrectlyInDifferentOrder() { + BigtableIO.Read read = + BigtableIO.read().withTableId("table").withBigtableOptions(BIGTABLE_OPTIONS); + assertEquals("project", read.getBigtableOptions().getProjectId()); + assertEquals("cluster", read.getBigtableOptions().getClusterId()); + assertEquals("zone", read.getBigtableOptions().getZoneId()); + assertEquals("table", read.getTableId()); + } + + @Test + public void testWriteBuildsCorrectly() { + BigtableIO.Write write = + BigtableIO.write().withBigtableOptions(BIGTABLE_OPTIONS).withTableId("table"); + assertEquals("table", write.getTableId()); + assertEquals("project", write.getBigtableOptions().getProjectId()); + assertEquals("zone", write.getBigtableOptions().getZoneId()); + assertEquals("cluster", write.getBigtableOptions().getClusterId()); + } + + @Test + public void testWriteBuildsCorrectlyInDifferentOrder() { + BigtableIO.Write write = + BigtableIO.write().withTableId("table").withBigtableOptions(BIGTABLE_OPTIONS); + assertEquals("cluster", write.getBigtableOptions().getClusterId()); + assertEquals("project", write.getBigtableOptions().getProjectId()); + assertEquals("zone", write.getBigtableOptions().getZoneId()); + assertEquals("table", write.getTableId()); + } + + @Test + public void testWriteValidationFailsMissingTable() { + BigtableIO.Write write = BigtableIO.write().withBigtableOptions(BIGTABLE_OPTIONS); + + thrown.expect(IllegalArgumentException.class); + + write.validate(null); + } + + @Test + public void testWriteValidationFailsMissingOptions() { + BigtableIO.Write write = BigtableIO.write().withTableId("table"); + + thrown.expect(IllegalArgumentException.class); + + write.validate(null); + } + + /** Helper function to make a single row mutation to be written. */ + private static KV> makeWrite(String key, String value) { + ByteString rowKey = ByteString.copyFromUtf8(key); + Iterable mutations = + ImmutableList.of( + Mutation.newBuilder() + .setSetCell(SetCell.newBuilder().setValue(ByteString.copyFromUtf8(value))) + .build()); + return KV.of(rowKey, mutations); + } + + /** Helper function to make a single bad row mutation (no set cell). */ + private static KV> makeBadWrite(String key) { + Iterable mutations = ImmutableList.of(Mutation.newBuilder().build()); + return KV.of(ByteString.copyFromUtf8(key), mutations); + } + + /** Tests that when reading from a non-existent table, the read fails. */ + @Test + public void testReadingFailsTableDoesNotExist() throws Exception { + final String table = "TEST-TABLE"; + + BigtableIO.Read read = + BigtableIO.read() + .withBigtableOptions(BIGTABLE_OPTIONS) + .withTableId(table) + .withBigtableService(service); + + // Exception will be thrown by read.validate() when read is applied. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(String.format("Table %s does not exist", table)); + + TestPipeline.create().apply(read); + } + + /** Tests that when reading from an empty table, the read succeeds. */ + @Test + public void testReadingEmptyTable() throws Exception { + final String table = "TEST-EMPTY-TABLE"; + service.createTable(table); + + TestPipeline p = TestPipeline.create(); + PCollection rows = p.apply(defaultRead.withTableId(table)); + DataflowAssert.that(rows).empty(); + + p.run(); + logged.verifyInfo(String.format("Closing reader after reading 0 records.")); + } + + /** Tests reading all rows from a table. */ + @Test + public void testReading() throws Exception { + final String table = "TEST-MANY-ROWS-TABLE"; + final int numRows = 1001; + List testRows = makeTableData(table, numRows); + + TestPipeline p = TestPipeline.create(); + PCollection rows = p.apply(defaultRead.withTableId(table)); + DataflowAssert.that(rows).containsInAnyOrder(testRows); + + p.run(); + logged.verifyInfo(String.format("Closing reader after reading %d records.", numRows)); + } + + /** A {@link Predicate} that a {@link Row Row's} key matches the given regex. */ + private static class KeyMatchesRegex implements Predicate { + private final String regex; + + public KeyMatchesRegex(String regex) { + this.regex = regex; + } + + @Override + public boolean apply(@Nullable ByteString input) { + verifyNotNull(input, "input"); + return input.toStringUtf8().matches(regex); + } + } + + /** Tests reading all rows using a filter. */ + @Test + public void testReadingWithFilter() throws Exception { + final String table = "TEST-FILTER-TABLE"; + final int numRows = 1001; + List testRows = makeTableData(table, numRows); + String regex = ".*17.*"; + final KeyMatchesRegex keyPredicate = new KeyMatchesRegex(regex); + Iterable filteredRows = + Iterables.filter( + testRows, + new Predicate() { + @Override + public boolean apply(@Nullable Row input) { + verifyNotNull(input, "input"); + return keyPredicate.apply(input.getKey()); + } + }); + + RowFilter filter = + RowFilter.newBuilder().setRowKeyRegexFilter(ByteString.copyFromUtf8(regex)).build(); + + TestPipeline p = TestPipeline.create(); + PCollection rows = p.apply(defaultRead.withTableId(table).withRowFilter(filter)); + DataflowAssert.that(rows).containsInAnyOrder(filteredRows); + + p.run(); + } + + /** + * Tests dynamic work rebalancing exhaustively. + * + *

    Because this test runs so slowly, it is disabled by default. Re-run when changing the + * {@link BigtableIO.Read} implementation. + */ + @Ignore("Slow. Rerun when changing the implementation.") + @Test + public void testReadingSplitAtFractionExhaustive() throws Exception { + final String table = "TEST-FEW-ROWS-SPLIT-EXHAUSTIVE-TABLE"; + final int numRows = 10; + final int numSamples = 1; + final long bytesPerRow = 1L; + makeTableData(table, numRows); + service.setupSampleRowKeys(table, numSamples, bytesPerRow); + + BigtableSource source = + new BigtableSource(service, table, null, service.getTableRange(table), null); + assertSplitAtFractionExhaustive(source, null); + } + + /** + * Unit tests of splitAtFraction. + */ + @Test + public void testReadingSplitAtFraction() throws Exception { + final String table = "TEST-SPLIT-AT-FRACTION"; + final int numRows = 10; + final int numSamples = 1; + final long bytesPerRow = 1L; + makeTableData(table, numRows); + service.setupSampleRowKeys(table, numSamples, bytesPerRow); + + BigtableSource source = + new BigtableSource(service, table, null, service.getTableRange(table), null); + // With 0 items read, all split requests will fail. + assertSplitAtFractionFails(source, 0, 0.1, null /* options */); + assertSplitAtFractionFails(source, 0, 1.0, null /* options */); + // With 1 items read, all split requests past 1/10th will succeed. + assertSplitAtFractionSucceedsAndConsistent(source, 1, 0.333, null /* options */); + assertSplitAtFractionSucceedsAndConsistent(source, 1, 0.666, null /* options */); + // With 3 items read, all split requests past 3/10ths will succeed. + assertSplitAtFractionFails(source, 3, 0.2, null /* options */); + assertSplitAtFractionSucceedsAndConsistent(source, 3, 0.571, null /* options */); + assertSplitAtFractionSucceedsAndConsistent(source, 3, 0.9, null /* options */); + // With 6 items read, all split requests past 6/10ths will succeed. + assertSplitAtFractionFails(source, 6, 0.5, null /* options */); + assertSplitAtFractionSucceedsAndConsistent(source, 6, 0.7, null /* options */); + } + + /** Tests reading all rows from a split table. */ + @Test + public void testReadingWithSplits() throws Exception { + final String table = "TEST-MANY-ROWS-SPLITS-TABLE"; + final int numRows = 1500; + final int numSamples = 10; + final long bytesPerRow = 100L; + + // Set up test table data and sample row keys for size estimation and splitting. + makeTableData(table, numRows); + service.setupSampleRowKeys(table, numSamples, bytesPerRow); + + // Generate source and split it. + BigtableSource source = + new BigtableSource(service, table, null /*filter*/, ByteKeyRange.ALL_KEYS, null /*size*/); + List splits = + source.splitIntoBundles(numRows * bytesPerRow / numSamples, null /* options */); + + // Test num splits and split equality. + assertThat(splits, hasSize(numSamples)); + assertSourcesEqualReferenceSource(source, splits, null /* options */); + } + + /** Tests reading all rows from a sub-split table. */ + @Test + public void testReadingWithSubSplits() throws Exception { + final String table = "TEST-MANY-ROWS-SPLITS-TABLE"; + final int numRows = 1000; + final int numSamples = 10; + final int numSplits = 20; + final long bytesPerRow = 100L; + + // Set up test table data and sample row keys for size estimation and splitting. + makeTableData(table, numRows); + service.setupSampleRowKeys(table, numSamples, bytesPerRow); + + // Generate source and split it. + BigtableSource source = + new BigtableSource(service, table, null /*filter*/, ByteKeyRange.ALL_KEYS, null /*size*/); + List splits = source.splitIntoBundles(numRows * bytesPerRow / numSplits, null); + + // Test num splits and split equality. + assertThat(splits, hasSize(numSplits)); + assertSourcesEqualReferenceSource(source, splits, null /* options */); + } + + /** Tests reading all rows from a sub-split table. */ + @Test + public void testReadingWithFilterAndSubSplits() throws Exception { + final String table = "TEST-FILTER-SUB-SPLITS"; + final int numRows = 1700; + final int numSamples = 10; + final int numSplits = 20; + final long bytesPerRow = 100L; + + // Set up test table data and sample row keys for size estimation and splitting. + makeTableData(table, numRows); + service.setupSampleRowKeys(table, numSamples, bytesPerRow); + + // Generate source and split it. + RowFilter filter = + RowFilter.newBuilder().setRowKeyRegexFilter(ByteString.copyFromUtf8(".*17.*")).build(); + BigtableSource source = + new BigtableSource(service, table, filter, ByteKeyRange.ALL_KEYS, null /*size*/); + List splits = source.splitIntoBundles(numRows * bytesPerRow / numSplits, null); + + // Test num splits and split equality. + assertThat(splits, hasSize(numSplits)); + assertSourcesEqualReferenceSource(source, splits, null /* options */); + } + + /** Tests that a record gets written to the service and messages are logged. */ + @Test + public void testWriting() throws Exception { + final String table = "table"; + final String key = "key"; + final String value = "value"; + + service.createTable(table); + + TestPipeline p = TestPipeline.create(); + p.apply("single row", Create.of(makeWrite(key, value)).withCoder(bigtableCoder)) + .apply("write", defaultWrite.withTableId(table)); + p.run(); + + logged.verifyInfo("Wrote 1 records"); + + assertEquals(1, service.tables.size()); + assertNotNull(service.getTable(table)); + Map rows = service.getTable(table); + assertEquals(1, rows.size()); + assertEquals(ByteString.copyFromUtf8(value), rows.get(ByteString.copyFromUtf8(key))); + } + + /** Tests that when writing to a non-existent table, the write fails. */ + @Test + public void testWritingFailsTableDoesNotExist() throws Exception { + final String table = "TEST-TABLE"; + + PCollection>> emptyInput = + TestPipeline.create().apply(Create.>>of()); + + // Exception will be thrown by write.validate() when write is applied. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(String.format("Table %s does not exist", table)); + + emptyInput.apply("write", defaultWrite.withTableId(table)); + } + + /** Tests that when writing an element fails, the write fails. */ + @Test + public void testWritingFailsBadElement() throws Exception { + final String table = "TEST-TABLE"; + final String key = "KEY"; + service.createTable(table); + + TestPipeline p = TestPipeline.create(); + p.apply(Create.of(makeBadWrite(key)).withCoder(bigtableCoder)) + .apply(defaultWrite.withTableId(table)); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(Matchers.instanceOf(IOException.class)); + thrown.expectMessage("At least 1 errors occurred writing to Bigtable. First 1 errors:"); + thrown.expectMessage("Error mutating row " + key + " with mutations []: cell value missing"); + p.run(); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + private static final String COLUMN_FAMILY_NAME = "family"; + private static final ByteString COLUMN_NAME = ByteString.copyFromUtf8("column"); + private static final Column TEST_COLUMN = Column.newBuilder().setQualifier(COLUMN_NAME).build(); + private static final Family TEST_FAMILY = Family.newBuilder().setName(COLUMN_FAMILY_NAME).build(); + + /** Helper function that builds a {@link Row} in a test table that could be returned by read. */ + private static Row makeRow(ByteString key, ByteString value) { + // Build the currentRow and return true. + Column.Builder newColumn = TEST_COLUMN.toBuilder().addCells(Cell.newBuilder().setValue(value)); + return Row.newBuilder() + .setKey(key) + .addFamilies(TEST_FAMILY.toBuilder().addColumns(newColumn)) + .build(); + } + + /** Helper function to create a table and return the rows that it created. */ + private static List makeTableData(String tableId, int numRows) { + service.createTable(tableId); + Map testData = service.getTable(tableId); + + List testRows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; ++i) { + ByteString key = ByteString.copyFromUtf8(String.format("key%09d", i)); + ByteString value = ByteString.copyFromUtf8(String.format("value%09d", i)); + testData.put(key, value); + testRows.add(makeRow(key, value)); + } + + return testRows; + } + + + /** + * A {@link BigtableService} implementation that stores tables and their contents in memory. + */ + private static class FakeBigtableService implements BigtableService { + private final Map> tables = new HashMap<>(); + private final Map> sampleRowKeys = new HashMap<>(); + + @Nullable + public SortedMap getTable(String tableId) { + return tables.get(tableId); + } + + public ByteKeyRange getTableRange(String tableId) { + verifyTableExists(tableId); + SortedMap data = tables.get(tableId); + return ByteKeyRange.of(ByteKey.of(data.firstKey()), ByteKey.of(data.lastKey())); + } + + public void createTable(String tableId) { + tables.put(tableId, new TreeMap(new ByteStringComparator())); + } + + @Override + public boolean tableExists(String tableId) { + return tables.containsKey(tableId); + } + + public void verifyTableExists(String tableId) { + checkArgument(tableExists(tableId), "Table %s does not exist", tableId); + } + + @Override + public FakeBigtableReader createReader(BigtableSource source) { + return new FakeBigtableReader(source); + } + + @Override + public FakeBigtableWriter openForWriting(String tableId) { + return new FakeBigtableWriter(tableId); + } + + @Override + public List getSampleRowKeys(BigtableSource source) { + List samples = sampleRowKeys.get(source.getTableId()); + checkArgument(samples != null, "No samples found for table %s", source.getTableId()); + return samples; + } + + /** Sets up the sample row keys for the specified table. */ + void setupSampleRowKeys(String tableId, int numSamples, long bytesPerRow) { + verifyTableExists(tableId); + checkArgument(numSamples > 0, "Number of samples must be positive: %s", numSamples); + checkArgument(bytesPerRow > 0, "Bytes/Row must be positive: %s", bytesPerRow); + + ImmutableList.Builder ret = ImmutableList.builder(); + SortedMap rows = getTable(tableId); + int currentSample = 1; + int rowsSoFar = 0; + for (Map.Entry entry : rows.entrySet()) { + if (((double) rowsSoFar) / rows.size() >= ((double) currentSample) / numSamples) { + // add the sample with the total number of bytes in the table before this key. + ret.add( + SampleRowKeysResponse.newBuilder() + .setRowKey(entry.getKey()) + .setOffsetBytes(rowsSoFar * bytesPerRow) + .build()); + // Move on to next sample + currentSample++; + } + ++rowsSoFar; + } + + // Add the last sample indicating the end of the table, with all rows before it. + ret.add(SampleRowKeysResponse.newBuilder().setOffsetBytes(rows.size() * bytesPerRow).build()); + sampleRowKeys.put(tableId, ret.build()); + } + } + + /** + * A {@link BigtableService.Reader} implementation that reads from the static instance of + * {@link FakeBigtableService} stored in {@link #service}. + * + *

    This reader does not support {@link RowFilter} objects. + */ + private static class FakeBigtableReader implements BigtableService.Reader { + private final BigtableSource source; + private Iterator> rows; + private Row currentRow; + private Predicate filter; + + public FakeBigtableReader(BigtableSource source) { + this.source = source; + if (source.getRowFilter() == null) { + filter = Predicates.alwaysTrue(); + } else { + ByteString keyRegex = source.getRowFilter().getRowKeyRegexFilter(); + checkArgument(!keyRegex.isEmpty(), "Only RowKeyRegexFilter is supported"); + filter = new KeyMatchesRegex(keyRegex.toStringUtf8()); + } + service.verifyTableExists(source.getTableId()); + } + + @Override + public boolean start() { + rows = service.tables.get(source.getTableId()).entrySet().iterator(); + return advance(); + } + + @Override + public boolean advance() { + // Loop until we find a row in range, or reach the end of the iterator. + Map.Entry entry = null; + while (rows.hasNext()) { + entry = rows.next(); + if (!filter.apply(entry.getKey()) + || !source.getRange().containsKey(ByteKey.of(entry.getKey()))) { + // Does not match row filter or does not match source range. Skip. + entry = null; + continue; + } + // Found a row inside this source's key range, stop. + break; + } + + // Return false if no more rows. + if (entry == null) { + currentRow = null; + return false; + } + + // Set the current row and return true. + currentRow = makeRow(entry.getKey(), entry.getValue()); + return true; + } + + @Override + public Row getCurrentRow() { + if (currentRow == null) { + throw new NoSuchElementException(); + } + return currentRow; + } + + @Override + public void close() { + rows = null; + currentRow = null; + } + } + + /** + * A {@link BigtableService.Writer} implementation that writes to the static instance of + * {@link FakeBigtableService} stored in {@link #service}. + * + *

    This writer only supports {@link Mutation Mutations} that consist only of {@link SetCell} + * entries. The column family in the {@link SetCell} is ignored; only the value is used. + * + *

    When no {@link SetCell} is provided, the write will fail and this will be exposed via an + * exception on the returned {@link ListenableFuture}. + */ + private static class FakeBigtableWriter implements BigtableService.Writer { + private final String tableId; + + public FakeBigtableWriter(String tableId) { + this.tableId = tableId; + } + + @Override + public ListenableFuture writeRecord(KV> record) { + service.verifyTableExists(tableId); + Map table = service.getTable(tableId); + ByteString key = record.getKey(); + for (Mutation m : record.getValue()) { + SetCell cell = m.getSetCell(); + if (cell.getValue().isEmpty()) { + return Futures.immediateFailedCheckedFuture(new IOException("cell value missing")); + } + table.put(key, cell.getValue()); + } + return Futures.immediateFuture(Empty.getDefaultInstance()); + } + + @Override + public void close() {} + } + + /** A serializable comparator for ByteString. Used to make row samples. */ + private static final class ByteStringComparator implements Comparator, Serializable { + @Override + public int compare(ByteString o1, ByteString o2) { + return ByteKey.of(o1).compareTo(ByteKey.of(o2)); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeEstimateFractionTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeEstimateFractionTest.java new file mode 100644 index 000000000000..71928f400eaa --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeEstimateFractionTest.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.assertThat; + +import com.google.common.collect.ImmutableList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +/** + * A combinatorial test of {@link ByteKeyRange#estimateFractionForKey(ByteKey)}. + */ +@RunWith(Parameterized.class) +public class ByteKeyRangeEstimateFractionTest { + private static final ByteKey[] TEST_KEYS = ByteKeyRangeTest.RANGE_TEST_KEYS; + + @Parameters(name = "{index}: i={0}, k={1}") + public static Iterable data() { + ImmutableList.Builder ret = ImmutableList.builder(); + for (int i = 0; i < TEST_KEYS.length; ++i) { + for (int k = i + 1; k < TEST_KEYS.length; ++k) { + ret.add(new Object[] {i, k}); + } + } + return ret.build(); + } + + @Parameter(0) + public int i; + + @Parameter(1) + public int k; + + @Test + public void testEstimateFractionForKey() { + double last = 0.0; + ByteKeyRange range = ByteKeyRange.of(TEST_KEYS[i], TEST_KEYS[k]); + for (int j = i; j < k; ++j) { + ByteKey key = TEST_KEYS[j]; + if (key.isEmpty()) { + // Cannot compute progress for unspecified key + continue; + } + double fraction = range.estimateFractionForKey(key); + assertThat(fraction, greaterThanOrEqualTo(last)); + last = fraction; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeInterpolateKeyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeInterpolateKeyTest.java new file mode 100644 index 000000000000..c41f9054553e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeInterpolateKeyTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.assertThat; + +import com.google.common.collect.ImmutableList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +/** + * Combinatorial tests for {@link ByteKeyRange#interpolateKey}, which also checks + * {@link ByteKeyRange#estimateFractionForKey} by converting the interpolated keys back to + * fractions. + */ +@RunWith(Parameterized.class) +public class ByteKeyRangeInterpolateKeyTest { + private static final ByteKey[] TEST_KEYS = ByteKeyRangeTest.RANGE_TEST_KEYS; + + @Parameters(name = "{index}: {0}") + public static Iterable data() { + ImmutableList.Builder ret = ImmutableList.builder(); + for (int i = 0; i < TEST_KEYS.length; ++i) { + for (int j = i + 1; j < TEST_KEYS.length; ++j) { + ret.add(new Object[] {ByteKeyRange.of(TEST_KEYS[i], TEST_KEYS[j])}); + } + } + return ret.build(); + } + + @Parameter public ByteKeyRange range; + + @Test + public void testInterpolateKeyAndEstimateFraction() { + double delta = 0.0000001; + double[] testFractions = + new double[] {0.01, 0.1, 0.123, 0.2, 0.3, 0.45738, 0.5, 0.6, 0.7182, 0.8, 0.95, 0.97, 0.99}; + ByteKey last = range.getStartKey(); + for (double fraction : testFractions) { + String message = Double.toString(fraction); + try { + ByteKey key = range.interpolateKey(fraction); + assertThat(message, key, greaterThanOrEqualTo(last)); + assertThat(message, range.estimateFractionForKey(key), closeTo(fraction, delta)); + last = key; + } catch (IllegalStateException e) { + assertThat(message, e.getMessage(), containsString("near-empty ByteKeyRange")); + continue; + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTest.java new file mode 100644 index 000000000000..f5c5d67d7b9d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTest.java @@ -0,0 +1,396 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for {@link ByteKeyRange}. + */ +@RunWith(JUnit4.class) +public class ByteKeyRangeTest { + // A set of ranges for testing. + private static final ByteKeyRange RANGE_1_10 = ByteKeyRange.of(ByteKey.of(1), ByteKey.of(10)); + private static final ByteKeyRange RANGE_5_10 = ByteKeyRange.of(ByteKey.of(5), ByteKey.of(10)); + private static final ByteKeyRange RANGE_5_50 = ByteKeyRange.of(ByteKey.of(5), ByteKey.of(50)); + private static final ByteKeyRange RANGE_10_50 = ByteKeyRange.of(ByteKey.of(10), ByteKey.of(50)); + private static final ByteKeyRange UP_TO_1 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(1)); + private static final ByteKeyRange UP_TO_5 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(5)); + private static final ByteKeyRange UP_TO_10 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(10)); + private static final ByteKeyRange UP_TO_50 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(50)); + private static final ByteKeyRange AFTER_1 = ByteKeyRange.of(ByteKey.of(1), ByteKey.EMPTY); + private static final ByteKeyRange AFTER_5 = ByteKeyRange.of(ByteKey.of(5), ByteKey.EMPTY); + private static final ByteKeyRange AFTER_10 = ByteKeyRange.of(ByteKey.of(10), ByteKey.EMPTY); + private static final ByteKeyRange[] TEST_RANGES = + new ByteKeyRange[] { + ByteKeyRange.ALL_KEYS, + RANGE_1_10, + RANGE_5_10, + RANGE_5_50, + RANGE_10_50, + UP_TO_1, + UP_TO_5, + UP_TO_10, + UP_TO_50, + AFTER_1, + AFTER_5, + AFTER_10, + }; + + static final ByteKey[] RANGE_TEST_KEYS = + ImmutableList.builder() + .addAll(Arrays.asList(ByteKeyTest.TEST_KEYS)) + .add(ByteKey.EMPTY) + .build() + .toArray(ByteKeyTest.TEST_KEYS); + + /** + * Tests that the two ranges do not overlap, passing each in as the first range in the comparison. + */ + private static void bidirectionalNonOverlap(ByteKeyRange left, ByteKeyRange right) { + bidirectionalOverlapHelper(left, right, false); + } + + /** + * Tests that the two ranges overlap, passing each in as the first range in the comparison. + */ + private static void bidirectionalOverlap(ByteKeyRange left, ByteKeyRange right) { + bidirectionalOverlapHelper(left, right, true); + } + + /** + * Helper function for tests with a good error message. + */ + private static void bidirectionalOverlapHelper( + ByteKeyRange left, ByteKeyRange right, boolean result) { + assertEquals(String.format("%s overlaps %s", left, right), result, left.overlaps(right)); + assertEquals(String.format("%s overlaps %s", right, left), result, right.overlaps(left)); + } + + /** + * Tests of {@link ByteKeyRange#overlaps(ByteKeyRange)} with cases that should return true. + */ + @Test + public void testOverlappingRanges() { + bidirectionalOverlap(ByteKeyRange.ALL_KEYS, ByteKeyRange.ALL_KEYS); + bidirectionalOverlap(ByteKeyRange.ALL_KEYS, RANGE_1_10); + bidirectionalOverlap(UP_TO_1, UP_TO_1); + bidirectionalOverlap(UP_TO_1, UP_TO_5); + bidirectionalOverlap(UP_TO_50, AFTER_10); + bidirectionalOverlap(UP_TO_50, RANGE_1_10); + bidirectionalOverlap(UP_TO_10, UP_TO_50); + bidirectionalOverlap(RANGE_1_10, RANGE_5_50); + bidirectionalOverlap(AFTER_1, AFTER_5); + bidirectionalOverlap(RANGE_5_10, RANGE_1_10); + bidirectionalOverlap(RANGE_5_10, RANGE_5_50); + } + + /** + * Tests of {@link ByteKeyRange#overlaps(ByteKeyRange)} with cases that should return false. + */ + @Test + public void testNonOverlappingRanges() { + bidirectionalNonOverlap(UP_TO_1, AFTER_1); + bidirectionalNonOverlap(UP_TO_1, AFTER_5); + bidirectionalNonOverlap(RANGE_5_10, RANGE_10_50); + } + + /** + * Verifies that all keys in the given list are strictly ordered by size. + */ + private static void ensureOrderedKeys(List keys) { + for (int i = 0; i < keys.size() - 1; ++i) { + // This will throw if these two keys do not form a valid range. + ByteKeyRange.of(keys.get(i), keys.get(i + 1)); + // Also, a key is only allowed empty if it is the first key. + if (i > 0 && keys.get(i).isEmpty()) { + fail(String.format("Intermediate key %s/%s may not be empty", i, keys.size())); + } + } + } + + /** Tests for {@link ByteKeyRange#split(int)} with invalid inputs. */ + @Test + public void testRejectsInvalidSplit() { + try { + fail(String.format("%s.split(0) should fail: %s", RANGE_1_10, RANGE_1_10.split(0))); + } catch (IllegalArgumentException expected) { + // pass + } + + try { + fail(String.format("%s.split(-3) should fail: %s", RANGE_1_10, RANGE_1_10.split(-3))); + } catch (IllegalArgumentException expected) { + // pass + } + } + + /** Tests for {@link ByteKeyRange#split(int)} with weird inputs. */ + @Test + public void testSplitSpecialInputs() { + // Range split by 1 returns list of its keys. + assertEquals( + "Split 1 should return input", + ImmutableList.of(RANGE_1_10.getStartKey(), RANGE_1_10.getEndKey()), + RANGE_1_10.split(1)); + + // Unsplittable range returns list of its keys. + ByteKeyRange unsplittable = ByteKeyRange.of(ByteKey.of(), ByteKey.of(0, 0, 0, 0)); + assertEquals( + "Unsplittable should return input", + ImmutableList.of(unsplittable.getStartKey(), unsplittable.getEndKey()), + unsplittable.split(5)); + } + + /** Tests for {@link ByteKeyRange#split(int)}. */ + @Test + public void testSplitKeysCombinatorial() { + List sizes = ImmutableList.of(1, 2, 5, 10, 25, 32, 64); + for (int i = 0; i < RANGE_TEST_KEYS.length; ++i) { + for (int j = i + 1; j < RANGE_TEST_KEYS.length; ++j) { + ByteKeyRange range = ByteKeyRange.of(RANGE_TEST_KEYS[i], RANGE_TEST_KEYS[j]); + for (int s : sizes) { + List splits = range.split(s); + ensureOrderedKeys(splits); + assertThat("At least two entries in splits", splits.size(), greaterThanOrEqualTo(2)); + assertEquals("First split equals start of range", splits.get(0), RANGE_TEST_KEYS[i]); + assertEquals( + "Last split equals end of range", splits.get(splits.size() - 1), RANGE_TEST_KEYS[j]); + } + } + } + } + + /** Manual tests for {@link ByteKeyRange#estimateFractionForKey}. */ + @Test + public void testEstimateFractionForKey() { + final double delta = 0.0000001; + + /* 0x80 is halfway between [] and [] */ + assertEquals(0.5, ByteKeyRange.ALL_KEYS.estimateFractionForKey(ByteKey.of(0x80)), delta); + + /* 0x80 is halfway between [00] and [] */ + ByteKeyRange after0 = ByteKeyRange.of(ByteKey.of(0), ByteKey.EMPTY); + assertEquals(0.5, after0.estimateFractionForKey(ByteKey.of(0x80)), delta); + + /* 0x80 is halfway between [0000] and [] */ + ByteKeyRange after00 = ByteKeyRange.of(ByteKey.of(0, 0), ByteKey.EMPTY); + assertEquals(0.5, after00.estimateFractionForKey(ByteKey.of(0x80)), delta); + + /* 0x7f is halfway between [] and [fe] */ + ByteKeyRange upToFE = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(0xfe)); + assertEquals(0.5, upToFE.estimateFractionForKey(ByteKey.of(0x7f)), delta); + + /* 0x40 is one-quarter of the way between [] and [] */ + assertEquals(0.25, ByteKeyRange.ALL_KEYS.estimateFractionForKey(ByteKey.of(0x40)), delta); + + /* 0x40 is one-half of the way between [] and [0x80] */ + ByteKeyRange upTo80 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(0x80)); + assertEquals(0.50, upTo80.estimateFractionForKey(ByteKey.of(0x40)), delta); + + /* 0x40 is one-half of the way between [0x30] and [0x50] */ + ByteKeyRange range30to50 = ByteKeyRange.of(ByteKey.of(0x30), ByteKey.of(0x50)); + assertEquals(0.50, range30to50.estimateFractionForKey(ByteKey.of(0x40)), delta); + + /* 0x40 is one-half of the way between [0x30, 0, 1] and [0x4f, 0xff, 0xff, 0, 0] */ + ByteKeyRange range31to4f = + ByteKeyRange.of(ByteKey.of(0x30, 0, 1), ByteKey.of(0x4f, 0xff, 0xff, 0, 0)); + assertEquals(0.50, range31to4f.estimateFractionForKey(ByteKey.of(0x40)), delta); + + /* Exact fractions from 0 to 47 for a prime range. */ + ByteKeyRange upTo47 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(47)); + for (int i = 0; i <= 47; ++i) { + assertEquals("i=" + i, i / 47.0, upTo47.estimateFractionForKey(ByteKey.of(i)), delta); + } + + /* Exact fractions from 0 to 83 for a prime range. */ + ByteKeyRange rangeFDECtoFDEC83 = + ByteKeyRange.of(ByteKey.of(0xfd, 0xec), ByteKey.of(0xfd, 0xec, 83)); + for (int i = 0; i <= 83; ++i) { + assertEquals( + "i=" + i, + i / 83.0, + rangeFDECtoFDEC83.estimateFractionForKey(ByteKey.of(0xfd, 0xec, i)), + delta); + } + } + + /** Manual tests for {@link ByteKeyRange#interpolateKey}. */ + @Test + public void testInterpolateKey() { + /* 0x80 is halfway between [] and [] */ + assertEqualExceptPadding(ByteKey.of(0x80), ByteKeyRange.ALL_KEYS.interpolateKey(0.5)); + + /* 0x80 is halfway between [00] and [] */ + ByteKeyRange after0 = ByteKeyRange.of(ByteKey.of(0), ByteKey.EMPTY); + assertEqualExceptPadding(ByteKey.of(0x80), after0.interpolateKey(0.5)); + + /* 0x80 is halfway between [0000] and [] -- padding to longest key */ + ByteKeyRange after00 = ByteKeyRange.of(ByteKey.of(0, 0), ByteKey.EMPTY); + assertEqualExceptPadding(ByteKey.of(0x80), after00.interpolateKey(0.5)); + + /* 0x7f is halfway between [] and [fe] */ + ByteKeyRange upToFE = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(0xfe)); + assertEqualExceptPadding(ByteKey.of(0x7f), upToFE.interpolateKey(0.5)); + + /* 0x40 is one-quarter of the way between [] and [] */ + assertEqualExceptPadding(ByteKey.of(0x40), ByteKeyRange.ALL_KEYS.interpolateKey(0.25)); + + /* 0x40 is halfway between [] and [0x80] */ + ByteKeyRange upTo80 = ByteKeyRange.of(ByteKey.EMPTY, ByteKey.of(0x80)); + assertEqualExceptPadding(ByteKey.of(0x40), upTo80.interpolateKey(0.5)); + + /* 0x40 is halfway between [0x30] and [0x50] */ + ByteKeyRange range30to50 = ByteKeyRange.of(ByteKey.of(0x30), ByteKey.of(0x50)); + assertEqualExceptPadding(ByteKey.of(0x40), range30to50.interpolateKey(0.5)); + + /* 0x40 is halfway between [0x30, 0, 1] and [0x4f, 0xff, 0xff, 0, 0] */ + ByteKeyRange range31to4f = + ByteKeyRange.of(ByteKey.of(0x30, 0, 1), ByteKey.of(0x4f, 0xff, 0xff, 0, 0)); + assertEqualExceptPadding(ByteKey.of(0x40), range31to4f.interpolateKey(0.5)); + } + + /** Tests that {@link ByteKeyRange#interpolateKey} does not return the empty key. */ + @Test + public void testInterpolateKeyIsNotEmpty() { + String fmt = "Interpolating %s at fraction 0.0 should not return the empty key"; + for (ByteKeyRange range : TEST_RANGES) { + range = ByteKeyRange.ALL_KEYS; + assertFalse(String.format(fmt, range), range.interpolateKey(0.0).isEmpty()); + } + } + + /** Test {@link ByteKeyRange} getters. */ + @Test + public void testKeyGetters() { + // [1,) + assertEquals(AFTER_1.getStartKey(), ByteKey.of(1)); + assertEquals(AFTER_1.getEndKey(), ByteKey.EMPTY); + // [1, 10) + assertEquals(RANGE_1_10.getStartKey(), ByteKey.of(1)); + assertEquals(RANGE_1_10.getEndKey(), ByteKey.of(10)); + // [, 10) + assertEquals(UP_TO_10.getStartKey(), ByteKey.EMPTY); + assertEquals(UP_TO_10.getEndKey(), ByteKey.of(10)); + } + + /** Test {@link ByteKeyRange#toString}. */ + @Test + public void testToString() { + assertEquals("ByteKeyRange{startKey=[], endKey=[0a]}", UP_TO_10.toString()); + } + + /** Test {@link ByteKeyRange#equals}. */ + @Test + public void testEquals() { + // Verify that the comparison gives the correct result for all values in both directions. + for (int i = 0; i < TEST_RANGES.length; ++i) { + for (int j = 0; j < TEST_RANGES.length; ++j) { + ByteKeyRange left = TEST_RANGES[i]; + ByteKeyRange right = TEST_RANGES[j]; + boolean eq = left.equals(right); + if (i == j) { + assertTrue(String.format("Expected that %s is equal to itself.", left), eq); + assertTrue( + String.format("Expected that %s is equal to a copy of itself.", left), + left.equals(ByteKeyRange.of(right.getStartKey(), right.getEndKey()))); + } else { + assertFalse(String.format("Expected that %s is not equal to %s", left, right), eq); + } + } + } + } + + /** Test that {@link ByteKeyRange#of} rejects invalid ranges. */ + @Test + public void testRejectsInvalidRanges() { + ByteKey[] testKeys = ByteKeyTest.TEST_KEYS; + for (int i = 0; i < testKeys.length; ++i) { + for (int j = i; j < testKeys.length; ++j) { + if (testKeys[i].isEmpty() || testKeys[j].isEmpty()) { + continue; // these are valid ranges. + } + try { + ByteKeyRange range = ByteKeyRange.of(testKeys[j], testKeys[i]); + fail(String.format("Expected failure constructing %s", range)); + } catch (IllegalArgumentException expected) { + // pass + } + } + } + } + + /** Test {@link ByteKeyRange#hashCode}. */ + @Test + public void testHashCode() { + // Verify that the hashCode is equal when i==j, and usually not equal otherwise. + int collisions = 0; + for (int i = 0; i < TEST_RANGES.length; ++i) { + ByteKeyRange current = TEST_RANGES[i]; + int left = current.hashCode(); + int leftClone = ByteKeyRange.of(current.getStartKey(), current.getEndKey()).hashCode(); + assertEquals( + String.format("Expected same hash code for %s and a copy of itself", current), + left, + leftClone); + for (int j = i + 1; j < TEST_RANGES.length; ++j) { + int right = TEST_RANGES[j].hashCode(); + if (left == right) { + ++collisions; + } + } + } + int totalUnequalTests = TEST_RANGES.length * (TEST_RANGES.length - 1) / 2; + assertThat("Too many hash collisions", collisions, lessThan(totalUnequalTests / 2)); + } + + /** Asserts the two keys are equal except trailing zeros. */ + private static void assertEqualExceptPadding(ByteKey expected, ByteKey key) { + ByteString shortKey = expected.getValue(); + ByteString longKey = key.getValue(); + if (shortKey.size() > longKey.size()) { + shortKey = key.getValue(); + longKey = expected.getValue(); + } + for (int i = 0; i < shortKey.size(); ++i) { + if (shortKey.byteAt(i) != longKey.byteAt(i)) { + fail(String.format("Expected %s (up to trailing zeros), got %s", expected, key)); + } + } + for (int j = shortKey.size(); j < longKey.size(); ++j) { + if (longKey.byteAt(j) != 0) { + fail(String.format("Expected %s (up to trailing zeros), got %s", expected, key)); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTrackerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTrackerTest.java new file mode 100644 index 000000000000..234e6e9d9547 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyRangeTrackerTest.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.io.range; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ByteKeyRangeTracker}. */ +@RunWith(JUnit4.class) +public class ByteKeyRangeTrackerTest { + private static final ByteKey START_KEY = ByteKey.of(0x12); + private static final ByteKey MIDDLE_KEY = ByteKey.of(0x23); + private static final ByteKey BEFORE_END_KEY = ByteKey.of(0x33); + private static final ByteKey END_KEY = ByteKey.of(0x34); + private static final double RANGE_SIZE = 0x34 - 0x12; + private static final ByteKeyRange RANGE = ByteKeyRange.of(START_KEY, END_KEY); + + /** Tests for {@link ByteKeyRangeTracker#toString}. */ + @Test + public void testToString() { + ByteKeyRangeTracker tracker = ByteKeyRangeTracker.of(RANGE); + String expected = String.format("ByteKeyRangeTracker{range=%s, position=null}", RANGE); + assertEquals(expected, tracker.toString()); + + tracker.tryReturnRecordAt(true, MIDDLE_KEY); + expected = String.format("ByteKeyRangeTracker{range=%s, position=%s}", RANGE, MIDDLE_KEY); + assertEquals(expected, tracker.toString()); + } + + /** Tests for {@link ByteKeyRangeTracker#of}. */ + @Test + public void testBuilding() { + ByteKeyRangeTracker tracker = ByteKeyRangeTracker.of(RANGE); + + assertEquals(START_KEY, tracker.getStartPosition()); + assertEquals(END_KEY, tracker.getStopPosition()); + } + + /** Tests for {@link ByteKeyRangeTracker#getFractionConsumed()}. */ + @Test + public void testGetFractionConsumed() { + ByteKeyRangeTracker tracker = ByteKeyRangeTracker.of(RANGE); + double delta = 0.00001; + + assertEquals(0.0, tracker.getFractionConsumed(), delta); + + tracker.tryReturnRecordAt(true, START_KEY); + assertEquals(0.0, tracker.getFractionConsumed(), delta); + + tracker.tryReturnRecordAt(true, MIDDLE_KEY); + assertEquals(0.5, tracker.getFractionConsumed(), delta); + + tracker.tryReturnRecordAt(true, BEFORE_END_KEY); + assertEquals(1 - 1 / RANGE_SIZE, tracker.getFractionConsumed(), delta); + } + + /** Tests for {@link ByteKeyRangeTracker#tryReturnRecordAt}. */ + @Test + public void testTryReturnRecordAt() { + ByteKeyRangeTracker tracker = ByteKeyRangeTracker.of(RANGE); + + // Should be able to emit at the same key twice, should that happen. + // Should be able to emit within range (in order, but system guarantees won't try out of order). + // Should not be able to emit past end of range. + + assertTrue(tracker.tryReturnRecordAt(true, START_KEY)); + assertTrue(tracker.tryReturnRecordAt(true, START_KEY)); + + assertTrue(tracker.tryReturnRecordAt(true, MIDDLE_KEY)); + assertTrue(tracker.tryReturnRecordAt(true, MIDDLE_KEY)); + + assertTrue(tracker.tryReturnRecordAt(true, BEFORE_END_KEY)); + + assertFalse(tracker.tryReturnRecordAt(true, END_KEY)); // after end + + assertTrue(tracker.tryReturnRecordAt(true, BEFORE_END_KEY)); // still succeeds + } + + /** Tests for {@link ByteKeyRangeTracker#trySplitAtPosition}. */ + @Test + public void testSplitAtPosition() { + ByteKeyRangeTracker tracker = ByteKeyRangeTracker.of(RANGE); + + // Unstarted, should not split. + assertFalse(tracker.trySplitAtPosition(MIDDLE_KEY)); + + // Start it, split it before the end. + assertTrue(tracker.tryReturnRecordAt(true, START_KEY)); + assertTrue(tracker.trySplitAtPosition(BEFORE_END_KEY)); + assertEquals(BEFORE_END_KEY, tracker.getStopPosition()); + + // Should not be able to split it after the end. + assertFalse(tracker.trySplitAtPosition(END_KEY)); + + // Should not be able to split after emitting. + assertTrue(tracker.tryReturnRecordAt(true, MIDDLE_KEY)); + assertFalse(tracker.trySplitAtPosition(MIDDLE_KEY)); + assertTrue(tracker.tryReturnRecordAt(true, MIDDLE_KEY)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyTest.java new file mode 100644 index 000000000000..922ac5b80371 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/ByteKeyTest.java @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.io.range; + +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests of {@link ByteKey}. + */ +@RunWith(JUnit4.class) +public class ByteKeyTest { + /* A big list of byte[] keys, in ascending sorted order. */ + static final ByteKey[] TEST_KEYS = + new ByteKey[] { + ByteKey.EMPTY, + ByteKey.of(0), + ByteKey.of(0, 1), + ByteKey.of(0, 1, 1), + ByteKey.of(0, 1, 2), + ByteKey.of(0, 1, 2, 0xfe), + ByteKey.of(0, 1, 3, 0xfe), + ByteKey.of(0, 0xfe, 0xfe, 0xfe), + ByteKey.of(0, 0xfe, 0xfe, 0xff), + ByteKey.of(0, 0xfe, 0xff, 0), + ByteKey.of(0, 0xff, 0xff, 0), + ByteKey.of(0, 0xff, 0xff, 1), + ByteKey.of(0, 0xff, 0xff, 0xfe), + ByteKey.of(0, 0xff, 0xff, 0xff), + ByteKey.of(1), + ByteKey.of(1, 2), + ByteKey.of(1, 2, 3), + ByteKey.of(3), + ByteKey.of(0xdd), + ByteKey.of(0xfe), + ByteKey.of(0xfe, 0xfe), + ByteKey.of(0xfe, 0xff), + ByteKey.of(0xff), + ByteKey.of(0xff, 0), + ByteKey.of(0xff, 0xfe), + ByteKey.of(0xff, 0xff), + ByteKey.of(0xff, 0xff, 0xff), + ByteKey.of(0xff, 0xff, 0xff, 0xff), + }; + + /** + * Tests {@link ByteKey#compareTo(ByteKey)} using exhaustive testing within a large sorted list + * of keys. + */ + @Test + public void testCompareToExhaustive() { + // Verify that the comparison gives the correct result for all values in both directions. + for (int i = 0; i < TEST_KEYS.length; ++i) { + for (int j = 0; j < TEST_KEYS.length; ++j) { + ByteKey left = TEST_KEYS[i]; + ByteKey right = TEST_KEYS[j]; + int cmp = left.compareTo(right); + if (i < j && !(cmp < 0)) { + fail( + String.format( + "Expected that cmp(%s, %s) < 0, got %d [i=%d, j=%d]", left, right, cmp, i, j)); + } else if (i == j && !(cmp == 0)) { + fail( + String.format( + "Expected that cmp(%s, %s) == 0, got %d [i=%d, j=%d]", left, right, cmp, i, j)); + } else if (i > j && !(cmp > 0)) { + fail( + String.format( + "Expected that cmp(%s, %s) > 0, got %d [i=%d, j=%d]", left, right, cmp, i, j)); + } + } + } + } + + /** + * Tests {@link ByteKey#equals}. + */ + @Test + public void testEquals() { + // Verify that the comparison gives the correct result for all values in both directions. + for (int i = 0; i < TEST_KEYS.length; ++i) { + for (int j = 0; j < TEST_KEYS.length; ++j) { + ByteKey left = TEST_KEYS[i]; + ByteKey right = TEST_KEYS[j]; + boolean eq = left.equals(right); + if (i == j) { + assertTrue(String.format("Expected that %s is equal to itself.", left), eq); + assertTrue( + String.format("Expected that %s is equal to a copy of itself.", left), + left.equals(ByteKey.of(right.getValue()))); + } else { + assertFalse(String.format("Expected that %s is not equal to %s", left, right), eq); + } + } + } + } + + /** + * Tests {@link ByteKey#hashCode}. + */ + @Test + public void testHashCode() { + // Verify that the hashCode is equal when i==j, and usually not equal otherwise. + int collisions = 0; + for (int i = 0; i < TEST_KEYS.length; ++i) { + int left = TEST_KEYS[i].hashCode(); + int leftClone = ByteKey.of(TEST_KEYS[i].getValue()).hashCode(); + assertEquals( + String.format("Expected same hash code for %s and a copy of itself", TEST_KEYS[i]), + left, + leftClone); + for (int j = i + 1; j < TEST_KEYS.length; ++j) { + int right = TEST_KEYS[j].hashCode(); + if (left == right) { + ++collisions; + } + } + } + int totalUnequalTests = TEST_KEYS.length * (TEST_KEYS.length - 1) / 2; + assertThat("Too many hash collisions", collisions, lessThan(totalUnequalTests / 2)); + } + + /** + * Tests {@link ByteKey#toString}. + */ + @Test + public void testToString() { + assertEquals("[]", ByteKey.EMPTY.toString()); + assertEquals("[00]", ByteKey.of(0).toString()); + assertEquals("[0000]", ByteKey.of(0x00, 0x00).toString()); + assertEquals( + "[0123456789abcdef]", + ByteKey.of(0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef).toString()); + } + + /** + * Tests {@link ByteKey#isEmpty}. + */ + @Test + public void testIsEmpty() { + assertTrue("[] is empty", ByteKey.EMPTY.isEmpty()); + assertFalse("[00]", ByteKey.of(0).isEmpty()); + } + + /** + * Tests {@link ByteKey#getBytes}. + */ + @Test + public void testGetBytes() { + assertTrue("[] equal after getBytes", Arrays.equals(new byte[] {}, ByteKey.EMPTY.getBytes())); + assertTrue( + "[00] equal after getBytes", Arrays.equals(new byte[] {0x00}, ByteKey.of(0x00).getBytes())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/OffsetRangeTrackerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/OffsetRangeTrackerTest.java new file mode 100644 index 000000000000..b1f1070ba062 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/range/OffsetRangeTrackerTest.java @@ -0,0 +1,186 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * icensed under the Apache icense, Version 2.0 (the "icense"); you may not + * use this file except in compliance with the icense. You may obtain a copy of + * the icense at + * + * http://www.apache.org/licenses/ICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the icense is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * icense for the specific language governing permissions and limitations under + * the icense. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.io.range; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link OffsetRangeTracker}. + */ +@RunWith(JUnit4.class) +public class OffsetRangeTrackerTest { + @Rule public final ExpectedException expected = ExpectedException.none(); + + @Test + public void testTryReturnRecordSimpleSparse() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(100, 200); + assertTrue(tracker.tryReturnRecordAt(true, 110)); + assertTrue(tracker.tryReturnRecordAt(true, 140)); + assertTrue(tracker.tryReturnRecordAt(true, 183)); + assertFalse(tracker.tryReturnRecordAt(true, 210)); + } + + @Test + public void testTryReturnRecordSimpleDense() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(3, 6); + assertTrue(tracker.tryReturnRecordAt(true, 3)); + assertTrue(tracker.tryReturnRecordAt(true, 4)); + assertTrue(tracker.tryReturnRecordAt(true, 5)); + assertFalse(tracker.tryReturnRecordAt(true, 6)); + } + + @Test + public void testTryReturnRecordContinuesUntilSplitPoint() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(9, 18); + // Return records with gaps of 2; every 3rd record is a split point. + assertTrue(tracker.tryReturnRecordAt(true, 10)); + assertTrue(tracker.tryReturnRecordAt(false, 12)); + assertTrue(tracker.tryReturnRecordAt(false, 14)); + assertTrue(tracker.tryReturnRecordAt(true, 16)); + // Out of range, but not a split point... + assertTrue(tracker.tryReturnRecordAt(false, 18)); + assertTrue(tracker.tryReturnRecordAt(false, 20)); + // Out of range AND a split point. + assertFalse(tracker.tryReturnRecordAt(true, 22)); + } + + @Test + public void testSplitAtOffsetFailsIfUnstarted() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(100, 200); + assertFalse(tracker.trySplitAtPosition(150)); + } + + @Test + public void testSplitAtOffset() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(100, 200); + assertTrue(tracker.tryReturnRecordAt(true, 110)); + // Example positions we shouldn't split at, when last record is [110, 130]: + assertFalse(tracker.trySplitAtPosition(109)); + assertFalse(tracker.trySplitAtPosition(110)); + assertFalse(tracker.trySplitAtPosition(200)); + assertFalse(tracker.trySplitAtPosition(210)); + // Example positions we *should* split at: + assertTrue(tracker.copy().trySplitAtPosition(111)); + assertTrue(tracker.copy().trySplitAtPosition(129)); + assertTrue(tracker.copy().trySplitAtPosition(130)); + assertTrue(tracker.copy().trySplitAtPosition(131)); + assertTrue(tracker.copy().trySplitAtPosition(150)); + assertTrue(tracker.copy().trySplitAtPosition(199)); + + // If we split at 170 and then at 150: + assertTrue(tracker.trySplitAtPosition(170)); + assertTrue(tracker.trySplitAtPosition(150)); + // Should be able to return a record starting before the new stop offset. + // Returning records starting at the same offset is ok. + assertTrue(tracker.copy().tryReturnRecordAt(true, 135)); + assertTrue(tracker.copy().tryReturnRecordAt(true, 135)); + // Should be able to return a record starting right before the new stop offset. + assertTrue(tracker.copy().tryReturnRecordAt(true, 149)); + // Should not be able to return a record starting at or after the new stop offset + assertFalse(tracker.tryReturnRecordAt(true, 150)); + assertFalse(tracker.tryReturnRecordAt(true, 151)); + // Should accept non-splitpoint records starting after stop offset. + assertTrue(tracker.tryReturnRecordAt(false, 135)); + assertTrue(tracker.tryReturnRecordAt(false, 152)); + assertTrue(tracker.tryReturnRecordAt(false, 160)); + assertFalse(tracker.tryReturnRecordAt(true, 171)); + } + + @Test + public void testGetPositionForFractionDense() throws Exception { + // Represents positions 3, 4, 5. + OffsetRangeTracker tracker = new OffsetRangeTracker(3, 6); + // [3, 3) represents 0.0 of [3, 6) + assertEquals(3, tracker.getPositionForFractionConsumed(0.0)); + // [3, 4) represents up to 1/3 of [3, 6) + assertEquals(4, tracker.getPositionForFractionConsumed(1.0 / 6)); + assertEquals(4, tracker.getPositionForFractionConsumed(0.333)); + // [3, 5) represents up to 2/3 of [3, 6) + assertEquals(5, tracker.getPositionForFractionConsumed(0.334)); + assertEquals(5, tracker.getPositionForFractionConsumed(0.666)); + // any fraction consumed over 2/3 means the whole [3, 6) has been consumed. + assertEquals(6, tracker.getPositionForFractionConsumed(0.667)); + } + + @Test + public void testGetFractionConsumedDense() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(3, 6); + assertEquals(0, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(true, 3)); + assertEquals(1.0 / 3, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(true, 4)); + assertEquals(2.0 / 3, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(true, 5)); + assertEquals(1.0, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(false /* non-split-point */, 6)); + assertEquals(1.0, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(false /* non-split-point */, 7)); + assertEquals(1.0, tracker.getFractionConsumed(), 1e-6); + assertFalse(tracker.tryReturnRecordAt(true, 7)); + } + + @Test + public void testGetFractionConsumedSparse() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(100, 200); + assertEquals(0, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(true, 110)); + // Consumed positions through 110 = total 11 positions of 100. + assertEquals(0.11, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(true, 150)); + assertEquals(0.51, tracker.getFractionConsumed(), 1e-6); + assertTrue(tracker.tryReturnRecordAt(true, 195)); + assertEquals(0.96, tracker.getFractionConsumed(), 1e-6); + } + + @Test + public void testEverythingWithUnboundedRange() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(100, Long.MAX_VALUE); + assertTrue(tracker.tryReturnRecordAt(true, 150)); + assertTrue(tracker.tryReturnRecordAt(true, 250)); + assertEquals(0.0, tracker.getFractionConsumed(), 1e-6); + assertFalse(tracker.trySplitAtPosition(1000)); + try { + tracker.getPositionForFractionConsumed(0.5); + fail("getPositionForFractionConsumed should fail for an unbounded range"); + } catch (IllegalArgumentException e) { + // Expected. + } + } + + @Test + public void testTryReturnFirstRecordNotSplitPoint() throws Exception { + expected.expect(IllegalStateException.class); + new OffsetRangeTracker(100, 200).tryReturnRecordAt(false, 120); + } + + @Test + public void testTryReturnRecordNonMonotonic() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(100, 200); + expected.expect(IllegalStateException.class); + tracker.tryReturnRecordAt(true, 120); + tracker.tryReturnRecordAt(true, 110); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/user.avsc b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/user.avsc new file mode 100644 index 000000000000..0cd9cee69027 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/user.avsc @@ -0,0 +1,10 @@ +{ + "namespace": "com.google.cloud.dataflow.sdk.io", + "type": "record", + "name": "AvroGeneratedUser", + "fields": [ + { "name": "name", "type": "string"}, + { "name": "favorite_number", "type": ["int", "null"]}, + { "name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptionsTest.java new file mode 100644 index 000000000000..9fea09705055 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineDebugOptionsTest.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.hamcrest.Matchers.hasEntry; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DataflowPipelineDebugOptions}. */ +@RunWith(JUnit4.class) +public class DataflowPipelineDebugOptionsTest { + @Test + public void testTransformNameMapping() throws Exception { + DataflowPipelineDebugOptions options = PipelineOptionsFactory + .fromArgs(new String[]{"--transformNameMapping={\"a\":\"b\",\"foo\":\"\",\"bar\":\"baz\"}"}) + .as(DataflowPipelineDebugOptions.class); + assertEquals(3, options.getTransformNameMapping().size()); + assertThat(options.getTransformNameMapping(), hasEntry("a", "b")); + assertThat(options.getTransformNameMapping(), hasEntry("foo", "")); + assertThat(options.getTransformNameMapping(), hasEntry("bar", "baz")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptionsTest.java new file mode 100644 index 000000000000..0207ba4b07c2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowPipelineOptionsTest.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.testing.ResetDateTimeProvider; +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DataflowPipelineOptions}. */ +@RunWith(JUnit4.class) +public class DataflowPipelineOptionsTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + @Rule public ResetDateTimeProvider resetDateTimeProviderRule = new ResetDateTimeProvider(); + + @Test + public void testJobNameIsSet() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setJobName("TestJobName"); + assertEquals("TestJobName", options.getJobName()); + } + + @Test + public void testUserNameIsNotSet() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().remove("user.name"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("TestApplication"); + assertEquals("testapplication--1208190706", options.getJobName()); + assertTrue(options.getJobName().length() <= 40); + } + + @Test + public void testAppNameAndUserNameAreLong() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "abcdeabcdeabcdeabcdeabcdeabcde"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("1234567890123456789012345678901234567890"); + assertEquals( + "a234567890123456789012345678901234567890-abcdeabcdeabcdeabcdeabcdeabcde-1208190706", + options.getJobName()); + } + + @Test + public void testAppNameIsLong() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "abcde"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("1234567890123456789012345678901234567890"); + assertEquals("a234567890123456789012345678901234567890-abcde-1208190706", options.getJobName()); + } + + @Test + public void testUserNameIsLong() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "abcdeabcdeabcdeabcdeabcdeabcde"); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("1234567890"); + assertEquals("a234567890-abcdeabcdeabcdeabcdeabcdeabcde-1208190706", options.getJobName()); + } + + @Test + public void testUtf8UserNameAndApplicationNameIsNormalized() { + resetDateTimeProviderRule.setDateTimeFixed("2014-12-08T19:07:06.698Z"); + System.getProperties().put("user.name", "ði ıntəˈnæʃənəl "); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setAppName("fəˈnɛtık əsoʊsiˈeıʃn"); + assertEquals("f00n0t0k00so0si0e00n-0i00nt00n000n0l0-1208190706", options.getJobName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowProfilingOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowProfilingOptionsTest.java new file mode 100644 index 000000000000..8b89abd2b1d3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowProfilingOptionsTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.hamcrest.Matchers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link DataflowProfilingOptions}. + */ +@RunWith(JUnit4.class) +public class DataflowProfilingOptionsTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + public void testOptionsObject() throws Exception { + DataflowPipelineOptions options = PipelineOptionsFactory.fromArgs(new String[] { + "--enableProfilingAgent", "--profilingAgentConfiguration={\"interval\": 21}"}) + .as(DataflowPipelineOptions.class); + assertTrue(options.getEnableProfilingAgent()); + + String json = MAPPER.writeValueAsString(options); + assertThat(json, Matchers.containsString( + "\"profilingAgentConfiguration\":{\"interval\":21}")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerLoggingOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerLoggingOptionsTest.java new file mode 100644 index 000000000000..82ec48a5f5ba --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/DataflowWorkerLoggingOptionsTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import static com.google.cloud.dataflow.sdk.options.DataflowWorkerLoggingOptions.Level.WARN; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.options.DataflowWorkerLoggingOptions.WorkerLogLevelOverrides; +import com.google.common.collect.ImmutableMap; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DataflowWorkerLoggingOptions}. */ +@RunWith(JUnit4.class) +public class DataflowWorkerLoggingOptionsTest { + private static final ObjectMapper MAPPER = new ObjectMapper(); + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testWorkerLogLevelOverrideWithInvalidLogLevel() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Unsupported log level"); + WorkerLogLevelOverrides.from(ImmutableMap.of("Name", "FakeLevel")); + } + + @Test + public void testWorkerLogLevelOverrideForClass() throws Exception { + assertEquals("{\"org.junit.Test\":\"WARN\"}", + MAPPER.writeValueAsString( + new WorkerLogLevelOverrides().addOverrideForClass(Test.class, WARN))); + } + + @Test + public void testWorkerLogLevelOverrideForPackage() throws Exception { + assertEquals("{\"org.junit\":\"WARN\"}", + MAPPER.writeValueAsString( + new WorkerLogLevelOverrides().addOverrideForPackage(Test.class.getPackage(), WARN))); + } + + @Test + public void testWorkerLogLevelOverrideForName() throws Exception { + assertEquals("{\"A\":\"WARN\"}", + MAPPER.writeValueAsString( + new WorkerLogLevelOverrides().addOverrideForName("A", WARN))); + } + + @Test + public void testSerializationAndDeserializationOf() throws Exception { + String testValue = "{\"A\":\"WARN\"}"; + assertEquals(testValue, + MAPPER.writeValueAsString( + MAPPER.readValue(testValue, WorkerLogLevelOverrides.class))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/GcpOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/GcpOptionsTest.java new file mode 100644 index 000000000000..40024d082b9b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/GcpOptionsTest.java @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.options.GcpOptions.DefaultProjectFactory; +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Files; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** Tests for {@link GcpOptions}. */ +@RunWith(JUnit4.class) +public class GcpOptionsTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testGetProjectFromCloudSdkConfigEnv() throws Exception { + Map environment = + ImmutableMap.of("CLOUDSDK_CONFIG", tmpFolder.getRoot().getAbsolutePath()); + assertEquals("test-project", + runGetProjectTest(tmpFolder.newFile("properties"), environment)); + } + + @Test + public void testGetProjectFromAppDataEnv() throws Exception { + Map environment = + ImmutableMap.of("APPDATA", tmpFolder.getRoot().getAbsolutePath()); + System.setProperty("os.name", "windows"); + assertEquals("test-project", + runGetProjectTest(new File(tmpFolder.newFolder("gcloud"), "properties"), + environment)); + } + + @Test + public void testGetProjectFromUserHomeEnvOld() throws Exception { + Map environment = ImmutableMap.of(); + System.setProperty("user.home", tmpFolder.getRoot().getAbsolutePath()); + assertEquals("test-project", + runGetProjectTest( + new File(tmpFolder.newFolder(".config", "gcloud"), "properties"), + environment)); + } + + @Test + public void testGetProjectFromUserHomeEnv() throws Exception { + Map environment = ImmutableMap.of(); + System.setProperty("user.home", tmpFolder.getRoot().getAbsolutePath()); + assertEquals("test-project", + runGetProjectTest( + new File(tmpFolder.newFolder(".config", "gcloud", "configurations"), "config_default"), + environment)); + } + + @Test + public void testGetProjectFromUserHomeOldAndNewPrefersNew() throws Exception { + Map environment = ImmutableMap.of(); + System.setProperty("user.home", tmpFolder.getRoot().getAbsolutePath()); + makePropertiesFileWithProject(new File(tmpFolder.newFolder(".config", "gcloud"), "properties"), + "old-project"); + assertEquals("test-project", + runGetProjectTest( + new File(tmpFolder.newFolder(".config", "gcloud", "configurations"), "config_default"), + environment)); + } + + @Test + public void testUnableToGetDefaultProject() throws Exception { + System.setProperty("user.home", tmpFolder.getRoot().getAbsolutePath()); + DefaultProjectFactory projectFactory = spy(new DefaultProjectFactory()); + when(projectFactory.getEnvironment()).thenReturn(ImmutableMap.of()); + assertNull(projectFactory.create(PipelineOptionsFactory.create())); + } + + private static void makePropertiesFileWithProject(File path, String projectId) + throws IOException { + String properties = String.format("[core]%n" + + "account = test-account@google.com%n" + + "project = %s%n" + + "%n" + + "[dataflow]%n" + + "magic = true%n", projectId); + Files.write(properties, path, StandardCharsets.UTF_8); + } + + private static String runGetProjectTest(File path, Map environment) + throws Exception { + makePropertiesFileWithProject(path, "test-project"); + DefaultProjectFactory projectFactory = spy(new DefaultProjectFactory()); + when(projectFactory.getEnvironment()).thenReturn(environment); + return projectFactory.create(PipelineOptionsFactory.create()); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/GoogleApiDebugOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/GoogleApiDebugOptionsTest.java new file mode 100644 index 000000000000..3a16cf5dee25 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/GoogleApiDebugOptionsTest.java @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import com.google.api.services.bigquery.Bigquery.Datasets.Delete; +import com.google.api.services.dataflow.Dataflow.Projects.Jobs.Create; +import com.google.api.services.dataflow.Dataflow.Projects.Jobs.Get; +import com.google.cloud.dataflow.sdk.options.GoogleApiDebugOptions.GoogleApiTracer; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.Transport; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link GoogleApiDebugOptions}. */ +@RunWith(JUnit4.class) +public class GoogleApiDebugOptionsTest { + @Test + public void testWhenTracingMatches() throws Exception { + String[] args = + new String[] {"--googleApiTrace={\"Projects.Jobs.Get\":\"GetTraceDestination\"}"}; + DataflowPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + + assertNotNull(options.getGoogleApiTrace()); + + Get request = + options.getDataflowClient().projects().jobs().get("testProjectId", "testJobId"); + assertEquals("GetTraceDestination", request.get("$trace")); + } + + @Test + public void testWhenTracingDoesNotMatch() throws Exception { + String[] args = new String[] {"--googleApiTrace={\"Projects.Jobs.Create\":\"testToken\"}"}; + DataflowPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + + assertNotNull(options.getGoogleApiTrace()); + + Get request = + options.getDataflowClient().projects().jobs().get("testProjectId", "testJobId"); + assertNull(request.get("$trace")); + } + + @Test + public void testWithMultipleTraces() throws Exception { + String[] args = new String[] { + "--googleApiTrace={\"Projects.Jobs.Create\":\"CreateTraceDestination\"," + + "\"Projects.Jobs.Get\":\"GetTraceDestination\"}"}; + DataflowPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + + assertNotNull(options.getGoogleApiTrace()); + + Get getRequest = + options.getDataflowClient().projects().jobs().get("testProjectId", "testJobId"); + assertEquals("GetTraceDestination", getRequest.get("$trace")); + + Create createRequest = + options.getDataflowClient().projects().jobs().create("testProjectId", null); + assertEquals("CreateTraceDestination", createRequest.get("$trace")); + } + + @Test + public void testMatchingAllDataflowCalls() throws Exception { + String[] args = new String[] {"--googleApiTrace={\"Dataflow\":\"TraceDestination\"}"}; + DataflowPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + + assertNotNull(options.getGoogleApiTrace()); + + Get getRequest = + options.getDataflowClient().projects().jobs().get("testProjectId", "testJobId"); + assertEquals("TraceDestination", getRequest.get("$trace")); + + Create createRequest = + options.getDataflowClient().projects().jobs().create("testProjectId", null); + assertEquals("TraceDestination", createRequest.get("$trace")); + } + + @Test + public void testMatchingAgainstClient() throws Exception { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + options.setGoogleApiTrace(new GoogleApiTracer().addTraceFor( + Transport.newDataflowClient(options).build(), "TraceDestination")); + + Get getRequest = + options.getDataflowClient().projects().jobs().get("testProjectId", "testJobId"); + assertEquals("TraceDestination", getRequest.get("$trace")); + + Delete deleteRequest = Transport.newBigQueryClient(options).build().datasets() + .delete("testProjectId", "testDatasetId"); + assertNull(deleteRequest.get("$trace")); + } + + @Test + public void testMatchingAgainstRequestType() throws Exception { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + options.setGoogleApiTrace(new GoogleApiTracer().addTraceFor( + Transport.newDataflowClient(options).build().projects().jobs() + .get("aProjectId", "aJobId"), "TraceDestination")); + + Get getRequest = + options.getDataflowClient().projects().jobs().get("testProjectId", "testJobId"); + assertEquals("TraceDestination", getRequest.get("$trace")); + + Create createRequest = + options.getDataflowClient().projects().jobs().create("testProjectId", null); + assertNull(createRequest.get("$trace")); + } + + @Test + public void testDeserializationAndSerializationOfGoogleApiTracer() throws Exception { + String serializedValue = "{\"Api\":\"Token\"}"; + ObjectMapper objectMapper = new ObjectMapper(); + assertEquals(serializedValue, + objectMapper.writeValueAsString( + objectMapper.readValue(serializedValue, GoogleApiTracer.class))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java new file mode 100644 index 000000000000..e687f2798946 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java @@ -0,0 +1,1101 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.util.List; +import java.util.Map; + +/** Tests for {@link PipelineOptionsFactory}. */ +@RunWith(JUnit4.class) +public class PipelineOptionsFactoryTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(PipelineOptionsFactory.class); + + @Test + public void testAutomaticRegistrationOfPipelineOptions() { + assertTrue(PipelineOptionsFactory.getRegisteredOptions().contains(DirectPipelineOptions.class)); + } + + @Test + public void testAutomaticRegistrationOfRunners() { + assertEquals(DirectPipelineRunner.class, + PipelineOptionsFactory.getRegisteredRunners().get("DirectPipelineRunner")); + } + + @Test + public void testCreationFromSystemProperties() throws Exception { + System.getProperties().putAll(ImmutableMap + .builder() + .put("worker_id", "test_worker_id") + .put("job_id", "test_job_id") + // Set a non-default value for testing + .put("sdk_pipeline_options", "{\"options\":{\"numWorkers\":999}}") + .build()); + + @SuppressWarnings("deprecation") // testing deprecated functionality + DataflowWorkerHarnessOptions options = PipelineOptionsFactory.createFromSystemProperties(); + assertEquals("test_worker_id", options.getWorkerId()); + assertEquals("test_job_id", options.getJobId()); + assertEquals(999, options.getNumWorkers()); + } + + @Test + public void testAppNameIsSet() { + ApplicationNameOptions options = PipelineOptionsFactory.as(ApplicationNameOptions.class); + assertEquals(PipelineOptionsFactoryTest.class.getSimpleName(), options.getAppName()); + } + + /** A simple test interface. */ + public static interface TestPipelineOptions extends PipelineOptions { + String getTestPipelineOption(); + void setTestPipelineOption(String value); + } + + @Test + public void testAppNameIsSetWhenUsingAs() { + TestPipelineOptions options = PipelineOptionsFactory.as(TestPipelineOptions.class); + assertEquals(PipelineOptionsFactoryTest.class.getSimpleName(), + options.as(ApplicationNameOptions.class).getAppName()); + } + + @Test + public void testManualRegistration() { + assertFalse(PipelineOptionsFactory.getRegisteredOptions().contains(TestPipelineOptions.class)); + PipelineOptionsFactory.register(TestPipelineOptions.class); + assertTrue(PipelineOptionsFactory.getRegisteredOptions().contains(TestPipelineOptions.class)); + } + + @Test + public void testDefaultRegistration() { + assertTrue(PipelineOptionsFactory.getRegisteredOptions().contains(PipelineOptions.class)); + } + + /** A test interface missing a getter. */ + public static interface MissingGetter extends PipelineOptions { + void setObject(Object value); + } + + @Test + public void testMissingGetterThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected getter for property [object] of type [java.lang.Object] on " + + "[com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingGetter]."); + + PipelineOptionsFactory.as(MissingGetter.class); + } + + /** A test interface missing multiple getters. */ + public static interface MissingMultipleGetters extends MissingGetter { + void setOtherObject(Object value); + } + + @Test + public void testMultipleMissingGettersThrows() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "missing property methods on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MissingMultipleGetters]"); + expectedException.expectMessage("getter for property [object] of type [java.lang.Object]"); + expectedException.expectMessage("getter for property [otherObject] of type [java.lang.Object]"); + + PipelineOptionsFactory.as(MissingMultipleGetters.class); + } + + /** A test interface missing a setter. */ + public static interface MissingSetter extends PipelineOptions { + Object getObject(); + } + + @Test + public void testMissingSetterThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected setter for property [object] of type [java.lang.Object] on " + + "[com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingSetter]."); + + PipelineOptionsFactory.as(MissingSetter.class); + } + + /** A test interface missing multiple setters. */ + public static interface MissingMultipleSetters extends MissingSetter { + Object getOtherObject(); + } + + @Test + public void testMissingMultipleSettersThrows() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "missing property methods on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MissingMultipleSetters]"); + expectedException.expectMessage("setter for property [object] of type [java.lang.Object]"); + expectedException.expectMessage("setter for property [otherObject] of type [java.lang.Object]"); + + PipelineOptionsFactory.as(MissingMultipleSetters.class); + } + + /** A test interface missing a setter and a getter. */ + public static interface MissingGettersAndSetters extends MissingGetter { + Object getOtherObject(); + } + + @Test + public void testMissingGettersAndSettersThrows() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "missing property methods on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MissingGettersAndSetters]"); + expectedException.expectMessage("getter for property [object] of type [java.lang.Object]"); + expectedException.expectMessage("setter for property [otherObject] of type [java.lang.Object]"); + + PipelineOptionsFactory.as(MissingGettersAndSetters.class); + } + + /** A test interface with a type mismatch between the getter and setter. */ + public static interface GetterSetterTypeMismatch extends PipelineOptions { + boolean getValue(); + void setValue(int value); + } + + @Test + public void testGetterSetterTypeMismatchThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Type mismatch between getter and setter methods for property [value]. Getter is of type " + + "[boolean] whereas setter is of type [int]."); + + PipelineOptionsFactory.as(GetterSetterTypeMismatch.class); + } + + /** A test interface with multiple type mismatches between getters and setters. */ + public static interface MultiGetterSetterTypeMismatch extends GetterSetterTypeMismatch { + long getOther(); + void setOther(String other); + } + + @Test + public void testMultiGetterSetterTypeMismatchThrows() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Type mismatches between getters and setters detected:"); + expectedException.expectMessage("Property [value]: Getter is of type " + + "[boolean] whereas setter is of type [int]."); + expectedException.expectMessage("Property [other]: Getter is of type [long] " + + "whereas setter is of type [java.lang.String]."); + + PipelineOptionsFactory.as(MultiGetterSetterTypeMismatch.class); + } + + /** A test interface representing a composite interface. */ + public static interface CombinedObject extends MissingGetter, MissingSetter { + } + + @Test + public void testHavingSettersGettersFromSeparateInterfacesIsValid() { + PipelineOptionsFactory.as(CombinedObject.class); + } + + /** A test interface that contains a non-bean style method. */ + public static interface ExtraneousMethod extends PipelineOptions { + public String extraneousMethod(int value, String otherValue); + } + + @Test + public void testHavingExtraneousMethodThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Methods [extraneousMethod(int, String)] on " + + "[com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$ExtraneousMethod] " + + "do not conform to being bean properties."); + + PipelineOptionsFactory.as(ExtraneousMethod.class); + } + + /** A test interface that has a conflicting return type with its parent. */ + public static interface ReturnTypeConflict extends CombinedObject { + @Override + String getObject(); + void setObject(String value); + } + + @Test + public void testReturnTypeConflictThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Method [getObject] has multiple definitions [public abstract java.lang.Object " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingSetter" + + ".getObject(), public abstract java.lang.String " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$ReturnTypeConflict" + + ".getObject()] with different return types for [" + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$ReturnTypeConflict]."); + PipelineOptionsFactory.as(ReturnTypeConflict.class); + } + + /** An interface to provide multiple methods with return type conflicts. */ + public static interface MultiReturnTypeConflictBase extends CombinedObject { + Object getOther(); + void setOther(Object object); + } + + /** A test interface that has multiple conflicting return types with its parent. */ + public static interface MultiReturnTypeConflict extends MultiReturnTypeConflictBase { + @Override + String getObject(); + void setObject(String value); + + @Override + Long getOther(); + void setOther(Long other); + } + + @Test + public void testMultipleReturnTypeConflictsThrows() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("[com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultiReturnTypeConflict]"); + expectedException.expectMessage( + "Methods with multiple definitions with different return types"); + expectedException.expectMessage("Method [getObject] has multiple definitions"); + expectedException.expectMessage("public abstract java.lang.Object " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$" + + "MissingSetter.getObject()"); + expectedException.expectMessage( + "public abstract java.lang.String com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultiReturnTypeConflict.getObject()"); + expectedException.expectMessage("Method [getOther] has multiple definitions"); + expectedException.expectMessage("public abstract java.lang.Object " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$" + + "MultiReturnTypeConflictBase.getOther()"); + expectedException.expectMessage( + "public abstract java.lang.Long com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultiReturnTypeConflict.getOther()"); + + PipelineOptionsFactory.as(MultiReturnTypeConflict.class); + } + + /** Test interface that has {@link JsonIgnore @JsonIgnore} on a setter for a property. */ + public static interface SetterWithJsonIgnore extends PipelineOptions { + String getValue(); + @JsonIgnore + void setValue(String value); + } + + @Test + public void testSetterAnnotatedWithJsonIgnore() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected setter for property [value] to not be marked with @JsonIgnore on [com." + + "google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$SetterWithJsonIgnore]"); + PipelineOptionsFactory.as(SetterWithJsonIgnore.class); + } + + /** Test interface that has {@link JsonIgnore @JsonIgnore} on multiple setters. */ + public static interface MultiSetterWithJsonIgnore extends SetterWithJsonIgnore { + Integer getOther(); + @JsonIgnore + void setOther(Integer other); + } + + @Test + public void testMultipleSettersAnnotatedWithJsonIgnore() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Found setters marked with @JsonIgnore:"); + expectedException.expectMessage( + "property [other] should not be marked with @JsonIgnore on [com" + + ".google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultiSetterWithJsonIgnore]"); + expectedException.expectMessage( + "property [value] should not be marked with @JsonIgnore on [com." + + "google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$SetterWithJsonIgnore]"); + PipelineOptionsFactory.as(MultiSetterWithJsonIgnore.class); + } + + /** + * This class is has a conflicting field with {@link CombinedObject} that doesn't have + * {@link JsonIgnore @JsonIgnore}. + */ + public static interface GetterWithJsonIgnore extends PipelineOptions { + @JsonIgnore + Object getObject(); + void setObject(Object value); + } + + @Test + public void testNotAllGettersAnnotatedWithJsonIgnore() throws Exception { + // Initial construction is valid. + GetterWithJsonIgnore options = PipelineOptionsFactory.as(GetterWithJsonIgnore.class); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Expected getter for property [object] to be marked with @JsonIgnore on all [" + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$GetterWithJsonIgnore, " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MissingSetter], " + + "found only on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$GetterWithJsonIgnore]"); + + // When we attempt to convert, we should error at this moment. + options.as(CombinedObject.class); + } + + private static interface MultiGetters extends PipelineOptions { + Object getObject(); + void setObject(Object value); + + @JsonIgnore + Integer getOther(); + void setOther(Integer value); + + Void getConsistent(); + void setConsistent(Void consistent); + } + + private static interface MultipleGettersWithInconsistentJsonIgnore extends PipelineOptions { + @JsonIgnore + Object getObject(); + void setObject(Object value); + + Integer getOther(); + void setOther(Integer value); + + Void getConsistent(); + void setConsistent(Void consistent); + } + + @Test + public void testMultipleGettersWithInconsistentJsonIgnore() { + // Initial construction is valid. + MultiGetters options = PipelineOptionsFactory.as(MultiGetters.class); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Property getters are inconsistently marked with @JsonIgnore:"); + expectedException.expectMessage( + "property [object] to be marked on all"); + expectedException.expectMessage("found only on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultiGetters]"); + expectedException.expectMessage( + "property [other] to be marked on all"); + expectedException.expectMessage("found only on [com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultipleGettersWithInconsistentJsonIgnore]"); + + expectedException.expectMessage(Matchers.anyOf( + containsString(java.util.Arrays.toString(new String[] + {"com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultipleGettersWithInconsistentJsonIgnore", + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MultiGetters"})), + containsString(java.util.Arrays.toString(new String[] + {"com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryTest$MultiGetters", + "com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$MultipleGettersWithInconsistentJsonIgnore"})))); + expectedException.expectMessage(not(containsString("property [consistent]"))); + + // When we attempt to convert, we should error immediately + options.as(MultipleGettersWithInconsistentJsonIgnore.class); + } + + @Test + public void testAppNameIsNotOverriddenWhenPassedInViaCommandLine() { + ApplicationNameOptions options = PipelineOptionsFactory + .fromArgs(new String[]{ "--appName=testAppName" }) + .as(ApplicationNameOptions.class); + assertEquals("testAppName", options.getAppName()); + } + + @Test + public void testPropertyIsSetOnRegisteredPipelineOptionNotPartOfOriginalInterface() { + PipelineOptions options = PipelineOptionsFactory + .fromArgs(new String[]{ "--project=testProject" }) + .create(); + assertEquals("testProject", options.as(GcpOptions.class).getProject()); + } + + /** A test interface containing all the primitives. */ + public static interface Primitives extends PipelineOptions { + boolean getBoolean(); + void setBoolean(boolean value); + char getChar(); + void setChar(char value); + byte getByte(); + void setByte(byte value); + short getShort(); + void setShort(short value); + int getInt(); + void setInt(int value); + long getLong(); + void setLong(long value); + float getFloat(); + void setFloat(float value); + double getDouble(); + void setDouble(double value); + } + + @Test + public void testPrimitives() { + String[] args = new String[] { + "--boolean=true", + "--char=d", + "--byte=12", + "--short=300", + "--int=100000", + "--long=123890123890", + "--float=55.5", + "--double=12.3"}; + + Primitives options = PipelineOptionsFactory.fromArgs(args).as(Primitives.class); + assertTrue(options.getBoolean()); + assertEquals('d', options.getChar()); + assertEquals((byte) 12, options.getByte()); + assertEquals((short) 300, options.getShort()); + assertEquals(100000, options.getInt()); + assertEquals(123890123890L, options.getLong()); + assertEquals(55.5f, options.getFloat(), 0.0f); + assertEquals(12.3, options.getDouble(), 0.0); + } + + @Test + public void testBooleanShorthandArgument() { + String[] args = new String[] {"--boolean"}; + + Primitives options = PipelineOptionsFactory.fromArgs(args).as(Primitives.class); + assertTrue(options.getBoolean()); + } + + @Test + public void testEmptyValueNotAllowed() { + String[] args = new String[] { + "--byte="}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Empty argument value is only allowed for String, String Array, and Collection"); + PipelineOptionsFactory.fromArgs(args).as(Primitives.class); + } + + /** Enum used for testing PipelineOptions CLI parsing. */ + public enum TestEnum { + Value, Value2 + } + + /** A test interface containing all supported objects. */ + public static interface Objects extends PipelineOptions { + Boolean getBoolean(); + void setBoolean(Boolean value); + Character getChar(); + void setChar(Character value); + Byte getByte(); + void setByte(Byte value); + Short getShort(); + void setShort(Short value); + Integer getInt(); + void setInt(Integer value); + Long getLong(); + void setLong(Long value); + Float getFloat(); + void setFloat(Float value); + Double getDouble(); + void setDouble(Double value); + String getString(); + void setString(String value); + String getEmptyString(); + void setEmptyString(String value); + Class getClassValue(); + void setClassValue(Class value); + TestEnum getEnum(); + void setEnum(TestEnum value); + } + + @Test + public void testObjects() { + String[] args = new String[] { + "--boolean=true", + "--char=d", + "--byte=12", + "--short=300", + "--int=100000", + "--long=123890123890", + "--float=55.5", + "--double=12.3", + "--string=stringValue", + "--emptyString=", + "--classValue=" + PipelineOptionsFactoryTest.class.getName(), + "--enum=" + TestEnum.Value}; + + Objects options = PipelineOptionsFactory.fromArgs(args).as(Objects.class); + assertTrue(options.getBoolean()); + assertEquals(Character.valueOf('d'), options.getChar()); + assertEquals(Byte.valueOf((byte) 12), options.getByte()); + assertEquals(Short.valueOf((short) 300), options.getShort()); + assertEquals(Integer.valueOf(100000), options.getInt()); + assertEquals(Long.valueOf(123890123890L), options.getLong()); + assertEquals(Float.valueOf(55.5f), options.getFloat(), 0.0f); + assertEquals(Double.valueOf(12.3), options.getDouble(), 0.0); + assertEquals("stringValue", options.getString()); + assertTrue(options.getEmptyString().isEmpty()); + assertEquals(PipelineOptionsFactoryTest.class, options.getClassValue()); + assertEquals(TestEnum.Value, options.getEnum()); + } + + /** A test class for verifying JSON -> Object conversion. */ + public static class ComplexType { + String value; + String value2; + public ComplexType(@JsonProperty("key") String value, @JsonProperty("key2") String value2) { + this.value = value; + this.value2 = value2; + } + } + + /** A test interface for verifying JSON -> complex type conversion. */ + interface ComplexTypes extends PipelineOptions { + Map getMap(); + void setMap(Map value); + + ComplexType getObject(); + void setObject(ComplexType value); + } + + @Test + public void testComplexTypes() { + String[] args = new String[] { + "--map={\"key\":\"value\",\"key2\":\"value2\"}", + "--object={\"key\":\"value\",\"key2\":\"value2\"}"}; + ComplexTypes options = PipelineOptionsFactory.fromArgs(args).as(ComplexTypes.class); + assertEquals(ImmutableMap.of("key", "value", "key2", "value2"), options.getMap()); + assertEquals("value", options.getObject().value); + assertEquals("value2", options.getObject().value2); + } + + @Test + public void testMissingArgument() { + String[] args = new String[] {}; + + Objects options = PipelineOptionsFactory.fromArgs(args).as(Objects.class); + assertNull(options.getString()); + } + + /** A test interface containing all supported array return types. */ + public static interface Arrays extends PipelineOptions { + boolean[] getBoolean(); + void setBoolean(boolean[] value); + char[] getChar(); + void setChar(char[] value); + short[] getShort(); + void setShort(short[] value); + int[] getInt(); + void setInt(int[] value); + long[] getLong(); + void setLong(long[] value); + float[] getFloat(); + void setFloat(float[] value); + double[] getDouble(); + void setDouble(double[] value); + String[] getString(); + void setString(String[] value); + Class[] getClassValue(); + void setClassValue(Class[] value); + TestEnum[] getEnum(); + void setEnum(TestEnum[] value); + } + + @Test + @SuppressWarnings("rawtypes") + public void testArrays() { + String[] args = new String[] { + "--boolean=true", + "--boolean=true", + "--boolean=false", + "--char=d", + "--char=e", + "--char=f", + "--short=300", + "--short=301", + "--short=302", + "--int=100000", + "--int=100001", + "--int=100002", + "--long=123890123890", + "--long=123890123891", + "--long=123890123892", + "--float=55.5", + "--float=55.6", + "--float=55.7", + "--double=12.3", + "--double=12.4", + "--double=12.5", + "--string=stringValue1", + "--string=stringValue2", + "--string=stringValue3", + "--classValue=" + PipelineOptionsFactory.class.getName(), + "--classValue=" + PipelineOptionsFactoryTest.class.getName(), + "--enum=" + TestEnum.Value, + "--enum=" + TestEnum.Value2}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + boolean[] bools = options.getBoolean(); + assertTrue(bools[0] && bools[1] && !bools[2]); + assertArrayEquals(new char[] {'d', 'e', 'f'}, options.getChar()); + assertArrayEquals(new short[] {300, 301, 302}, options.getShort()); + assertArrayEquals(new int[] {100000, 100001, 100002}, options.getInt()); + assertArrayEquals(new long[] {123890123890L, 123890123891L, 123890123892L}, options.getLong()); + assertArrayEquals(new float[] {55.5f, 55.6f, 55.7f}, options.getFloat(), 0.0f); + assertArrayEquals(new double[] {12.3, 12.4, 12.5}, options.getDouble(), 0.0); + assertArrayEquals(new String[] {"stringValue1", "stringValue2", "stringValue3"}, + options.getString()); + assertArrayEquals(new Class[] {PipelineOptionsFactory.class, + PipelineOptionsFactoryTest.class}, + options.getClassValue()); + assertArrayEquals(new TestEnum[] {TestEnum.Value, TestEnum.Value2}, options.getEnum()); + } + + @Test + @SuppressWarnings("rawtypes") + public void testEmptyInStringArrays() { + String[] args = new String[] { + "--string=", + "--string=", + "--string="}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + assertArrayEquals(new String[] {"", "", ""}, + options.getString()); + } + + @Test + @SuppressWarnings("rawtypes") + public void testEmptyInStringArraysWithCommaList() { + String[] args = new String[] { + "--string=a,,b"}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + assertArrayEquals(new String[] {"a", "", "b"}, + options.getString()); + } + + @Test + public void testEmptyInNonStringArrays() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Empty argument value is only allowed for String, String Array, and Collection"); + + String[] args = new String[] { + "--boolean=true", + "--boolean=", + "--boolean=false"}; + + PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + } + + @Test + public void testEmptyInNonStringArraysWithCommaList() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Empty argument value is only allowed for String, String Array, and Collection"); + + String[] args = new String[] { + "--int=1,,9"}; + PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + } + + @Test + public void testOutOfOrderArrays() { + String[] args = new String[] { + "--char=d", + "--boolean=true", + "--boolean=true", + "--char=e", + "--char=f", + "--boolean=false"}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + boolean[] bools = options.getBoolean(); + assertTrue(bools[0] && bools[1] && !bools[2]); + assertArrayEquals(new char[] {'d', 'e', 'f'}, options.getChar()); + } + + /** A test interface containing all supported List return types. */ + public static interface Lists extends PipelineOptions { + List getString(); + void setString(List value); + } + + @Test + public void testList() { + String[] args = + new String[] {"--string=stringValue1", "--string=stringValue2", "--string=stringValue3"}; + + Lists options = PipelineOptionsFactory.fromArgs(args).as(Lists.class); + assertEquals(ImmutableList.of("stringValue1", "stringValue2", "stringValue3"), + options.getString()); + } + + @Test + public void testListShorthand() { + String[] args = new String[] {"--string=stringValue1,stringValue2,stringValue3"}; + + Lists options = PipelineOptionsFactory.fromArgs(args).as(Lists.class); + assertEquals(ImmutableList.of("stringValue1", "stringValue2", "stringValue3"), + options.getString()); + } + + @Test + public void testMixedShorthandAndLongStyleList() { + String[] args = new String[] { + "--char=d", + "--char=e", + "--char=f", + "--char=g,h,i", + "--char=j", + "--char=k", + "--char=l", + "--char=m,n,o"}; + + Arrays options = PipelineOptionsFactory.fromArgs(args).as(Arrays.class); + assertArrayEquals(new char[] {'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o'}, + options.getChar()); + } + + @Test + public void testSetASingularAttributeUsingAListThrowsAnError() { + String[] args = new String[] { + "--diskSizeGb=100", + "--diskSizeGb=200"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("expected one element but was"); + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testSetASingularAttributeUsingAListIsIgnoredWithoutStrictParsing() { + String[] args = new String[] { + "--diskSizeGb=100", + "--diskSizeGb=200"}; + PipelineOptionsFactory.fromArgs(args).withoutStrictParsing().create(); + expectedLogs.verifyWarn("Strict parsing is disabled, ignoring option"); + } + + @Test + public void testSettingRunner() { + String[] args = new String[] {"--runner=BlockingDataflowPipelineRunner"}; + + PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(BlockingDataflowPipelineRunner.class, options.getRunner()); + } + + @Test + public void testSettingUnknownRunner() { + String[] args = new String[] {"--runner=UnknownRunner"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Unknown 'runner' specified 'UnknownRunner', supported " + + "pipeline runners [BlockingDataflowPipelineRunner, DataflowPipelineRunner, " + + "DirectPipelineRunner]"); + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testUsingArgumentWithUnknownPropertyIsNotAllowed() { + String[] args = new String[] {"--unknownProperty=value"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("missing a property named 'unknownProperty'"); + PipelineOptionsFactory.fromArgs(args).create(); + } + + interface SuggestedOptions extends PipelineOptions { + String getAbc(); + void setAbc(String value); + + String getAbcdefg(); + void setAbcdefg(String value); + } + + @Test + public void testUsingArgumentWithMisspelledPropertyGivesASuggestion() { + String[] args = new String[] {"--ab=value"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("missing a property named 'ab'. Did you mean 'abc'?"); + PipelineOptionsFactory.fromArgs(args).as(SuggestedOptions.class); + } + + @Test + public void testUsingArgumentWithMisspelledPropertyGivesMultipleSuggestions() { + String[] args = new String[] {"--abcde=value"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "missing a property named 'abcde'. Did you mean one of [abc, abcdefg]?"); + PipelineOptionsFactory.fromArgs(args).as(SuggestedOptions.class); + } + + @Test + public void testUsingArgumentWithUnknownPropertyIsIgnoredWithoutStrictParsing() { + String[] args = new String[] {"--unknownProperty=value"}; + PipelineOptionsFactory.fromArgs(args).withoutStrictParsing().create(); + expectedLogs.verifyWarn("missing a property named 'unknownProperty'"); + } + + @Test + public void testUsingArgumentStartingWithIllegalCharacterIsNotAllowed() { + String[] args = new String[] {" --diskSizeGb=100"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Argument ' --diskSizeGb=100' does not begin with '--'"); + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testUsingArgumentStartingWithIllegalCharacterIsIgnoredWithoutStrictParsing() { + String[] args = new String[] {" --diskSizeGb=100"}; + PipelineOptionsFactory.fromArgs(args).withoutStrictParsing().create(); + expectedLogs.verifyWarn("Strict parsing is disabled, ignoring option"); + } + + @Test + public void testEmptyArgumentIsIgnored() { + String[] args = new String[] {"", "--diskSizeGb=100", "", "", "--runner=DirectPipelineRunner"}; + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testNullArgumentIsIgnored() { + String[] args = new String[] {"--diskSizeGb=100", null, null, "--runner=DirectPipelineRunner"}; + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testUsingArgumentWithInvalidNameIsNotAllowed() { + String[] args = new String[] {"--=100"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Argument '--=100' starts with '--='"); + PipelineOptionsFactory.fromArgs(args).create(); + } + + @Test + public void testUsingArgumentWithInvalidNameIsIgnoredWithoutStrictParsing() { + String[] args = new String[] {"--=100"}; + PipelineOptionsFactory.fromArgs(args).withoutStrictParsing().create(); + expectedLogs.verifyWarn("Strict parsing is disabled, ignoring option"); + } + + @Test + public void testWhenNoHelpIsRequested() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + assertFalse(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertEquals("", output); + } + + @Test + public void testDefaultHelpAsArgument() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "true"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("The set of registered options are:")); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + assertThat(output, containsString("Use --help= for detailed help.")); + } + + @Test + public void testSpecificHelpAsArgument() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "com.google.cloud.dataflow.sdk.options.PipelineOptions"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + assertThat(output, containsString("--runner")); + assertThat(output, containsString("Default: DirectPipelineRunner")); + assertThat(output, + containsString("The pipeline runner that will be used to execute the pipeline.")); + } + + @Test + public void testSpecificHelpAsArgumentWithSimpleClassName() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "PipelineOptions"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + assertThat(output, containsString("--runner")); + assertThat(output, containsString("Default: DirectPipelineRunner")); + assertThat(output, + containsString("The pipeline runner that will be used to execute the pipeline.")); + } + + @Test + public void testSpecificHelpAsArgumentWithClassNameSuffix() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "options.PipelineOptions"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + assertThat(output, containsString("--runner")); + assertThat(output, containsString("Default: DirectPipelineRunner")); + assertThat(output, + containsString("The pipeline runner that will be used to execute the pipeline.")); + } + + /** Used for a name collision test with the other DataflowPipelineOptions. */ + private interface DataflowPipelineOptions extends PipelineOptions { + } + + @Test + public void testShortnameSpecificHelpHasMultipleMatches() { + PipelineOptionsFactory.register(DataflowPipelineOptions.class); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "DataflowPipelineOptions"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("Multiple matches found for DataflowPipelineOptions")); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options." + + "PipelineOptionsFactoryTest$DataflowPipelineOptions")); + assertThat(output, containsString("The set of registered options are:")); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + } + + @Test + public void testHelpWithOptionThatOutputsValidEnumTypes() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "com.google.cloud.dataflow.sdk.options.DataflowWorkerLoggingOptions"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("")); + } + + @Test + public void testHelpWithBadOptionNameAsArgument() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "com.google.cloud.dataflow.sdk.Pipeline"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + assertThat(output, + containsString("Unable to find option com.google.cloud.dataflow.sdk.Pipeline")); + assertThat(output, containsString("The set of registered options are:")); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + } + + @Test + public void testHelpWithHiddenMethodAndInterface() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ListMultimap arguments = ArrayListMultimap.create(); + arguments.put("help", "com.google.cloud.dataflow.sdk.option.DataflowPipelineOptions"); + assertTrue(PipelineOptionsFactory.printHelpUsageAndExitIfNeeded( + arguments, new PrintStream(baos), false /* exit */)); + String output = new String(baos.toByteArray()); + // A hidden interface. + assertThat(output, not( + containsString("com.google.cloud.dataflow.sdk.options.DataflowPipelineDebugOptions"))); + // A hidden option. + assertThat(output, not(containsString("--gcpCredential"))); + } + + @Test + public void testProgrammaticPrintHelp() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + PipelineOptionsFactory.printHelp(new PrintStream(baos)); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("The set of registered options are:")); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + } + + @Test + public void testProgrammaticPrintHelpForSpecificType() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + PipelineOptionsFactory.printHelp(new PrintStream(baos), PipelineOptions.class); + String output = new String(baos.toByteArray()); + assertThat(output, containsString("com.google.cloud.dataflow.sdk.options.PipelineOptions")); + assertThat(output, containsString("--runner")); + assertThat(output, containsString("Default: DirectPipelineRunner")); + assertThat(output, + containsString("The pipeline runner that will be used to execute the pipeline.")); + } + + @Test + public void testFindProperClassLoaderIfContextClassLoaderIsNull() throws InterruptedException { + final ClassLoader[] classLoader = new ClassLoader[1]; + Thread thread = new Thread(new Runnable() { + + @Override + public void run() { + classLoader[0] = PipelineOptionsFactory.findClassLoader(); + } + }); + thread.setContextClassLoader(null); + thread.start(); + thread.join(); + assertEquals(PipelineOptionsFactory.class.getClassLoader(), classLoader[0]); + } + + @Test + public void testFindProperClassLoaderIfContextClassLoaderIsAvailable() + throws InterruptedException { + final ClassLoader[] classLoader = new ClassLoader[1]; + Thread thread = new Thread(new Runnable() { + + @Override + public void run() { + classLoader[0] = PipelineOptionsFactory.findClassLoader(); + } + }); + ClassLoader cl = new ClassLoader() {}; + thread.setContextClassLoader(cl); + thread.start(); + thread.join(); + assertEquals(cl, classLoader[0]); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsTest.java new file mode 100644 index 000000000000..98e83980e2b9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.List; +import java.util.Set; + +/** Unit tests for {@link PipelineOptions}. */ +@RunWith(JUnit4.class) +public class PipelineOptionsTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + /** Interfaces used for testing that {@link PipelineOptions#as(Class)} functions. */ + private static interface DerivedTestOptions extends BaseTestOptions { + int getDerivedValue(); + void setDerivedValue(int derivedValue); + + @Override + @JsonIgnore + Set getIgnoredValue(); + @Override + void setIgnoredValue(Set ignoredValue); + } + + private static interface ConflictedTestOptions extends BaseTestOptions { + String getDerivedValue(); + void setDerivedValue(String derivedValue); + + @Override + @JsonIgnore + Set getIgnoredValue(); + @Override + void setIgnoredValue(Set ignoredValue); + } + + private static interface BaseTestOptions extends PipelineOptions { + List getBaseValue(); + void setBaseValue(List baseValue); + + @JsonIgnore + Set getIgnoredValue(); + void setIgnoredValue(Set ignoredValue); + } + + @Test + public void testDynamicAs() { + BaseTestOptions options = PipelineOptionsFactory.create().as(BaseTestOptions.class); + assertNotNull(options); + } + + @Test + public void testDefaultRunnerIsSet() { + assertEquals(DirectPipelineRunner.class, PipelineOptionsFactory.create().getRunner()); + } + + @Test + public void testCloneAs() throws IOException { + DerivedTestOptions options = PipelineOptionsFactory.create().as(DerivedTestOptions.class); + options.setBaseValue(Lists.newArrayList()); + options.setIgnoredValue(Sets.newHashSet()); + options.getIgnoredValue().add("ignoredString"); + options.setDerivedValue(0); + + BaseTestOptions clonedOptions = options.cloneAs(BaseTestOptions.class); + assertNotSame(clonedOptions, options); + assertNotSame(clonedOptions.getBaseValue(), options.getBaseValue()); + + clonedOptions.getBaseValue().add(true); + assertFalse(clonedOptions.getBaseValue().isEmpty()); + assertTrue(options.getBaseValue().isEmpty()); + + assertNull(clonedOptions.getIgnoredValue()); + + ObjectMapper mapper = new ObjectMapper(); + mapper.readValue(mapper.writeValueAsBytes(clonedOptions), PipelineOptions.class); + } + + @Test + public void testCloneAsConflicted() throws Exception { + DerivedTestOptions options = PipelineOptionsFactory.create().as(DerivedTestOptions.class); + options.setBaseValue(Lists.newArrayList()); + options.setIgnoredValue(Sets.newHashSet()); + options.getIgnoredValue().add("ignoredString"); + options.setDerivedValue(0); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("incompatible return types"); + options.cloneAs(ConflictedTestOptions.class); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidatorTest.java new file mode 100644 index 000000000000..d6a3c189af41 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsValidatorTest.java @@ -0,0 +1,310 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PipelineOptionsValidator}. */ +@RunWith(JUnit4.class) +public class PipelineOptionsValidatorTest { + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + /** A test interface with an {@link Validation.Required} annotation. */ + public static interface Required extends PipelineOptions { + @Validation.Required + @Description("Fake Description") + String getObject(); + void setObject(String value); + } + + @Test + public void testWhenRequiredOptionIsSet() { + Required required = PipelineOptionsFactory.as(Required.class); + required.setObject("blah"); + PipelineOptionsValidator.validate(Required.class, required); + } + + @Test + public void testWhenRequiredOptionIsSetAndCleared() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject(), \"Fake Description\"]."); + + Required required = PipelineOptionsFactory.as(Required.class); + required.setObject("blah"); + required.setObject(null); + PipelineOptionsValidator.validate(Required.class, required); + } + + @Test + public void testWhenRequiredOptionIsNeverSet() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject(), \"Fake Description\"]."); + + Required required = PipelineOptionsFactory.as(Required.class); + PipelineOptionsValidator.validate(Required.class, required); + } + + @Test + public void testWhenRequiredOptionIsNeverSetOnSuperInterface() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject(), \"Fake Description\"]."); + + PipelineOptions options = PipelineOptionsFactory.create(); + PipelineOptionsValidator.validate(Required.class, options); + } + + /** A test interface that overrides the parent's method. */ + public static interface SubClassValidation extends Required { + @Override + String getObject(); + @Override + void setObject(String value); + } + + @Test + public void testValidationOnOverriddenMethods() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[public abstract java.lang.String com.google.cloud.dataflow." + + "sdk.options.PipelineOptionsValidatorTest$Required.getObject(), \"Fake Description\"]."); + + SubClassValidation required = PipelineOptionsFactory.as(SubClassValidation.class); + PipelineOptionsValidator.validate(Required.class, required); + } + + /** A test interface with a required group. */ + public static interface GroupRequired extends PipelineOptions { + @Validation.Required(groups = {"ham"}) + String getFoo(); + void setFoo(String foo); + + @Validation.Required(groups = {"ham"}) + String getBar(); + void setBar(String bar); + } + + @Test + public void testWhenOneOfRequiredGroupIsSetIsValid() { + GroupRequired groupRequired = PipelineOptionsFactory.as(GroupRequired.class); + groupRequired.setFoo("foo"); + groupRequired.setBar(null); + + PipelineOptionsValidator.validate(GroupRequired.class, groupRequired); + + // Symmetric + groupRequired.setFoo(null); + groupRequired.setBar("bar"); + PipelineOptionsValidator.validate(GroupRequired.class, groupRequired); + } + + @Test + public void testWhenNoneOfRequiredGroupIsSetThrowsException() { + GroupRequired groupRequired = PipelineOptionsFactory.as(GroupRequired.class); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for group [ham]"); + expectedException.expectMessage("properties"); + expectedException.expectMessage("getFoo"); + expectedException.expectMessage("getBar"); + + PipelineOptionsValidator.validate(GroupRequired.class, groupRequired); + } + + /** A test interface with a member in multiple required groups. */ + public static interface MultiGroupRequired extends PipelineOptions { + @Validation.Required(groups = {"spam", "ham"}) + String getFoo(); + void setFoo(String foo); + + @Validation.Required(groups = {"spam"}) + String getBar(); + void setBar(String bar); + + @Validation.Required(groups = {"ham"}) + String getBaz(); + void setBaz(String baz); + } + + @Test + public void testWhenOneOfMultipleRequiredGroupsIsSetIsValid() { + MultiGroupRequired multiGroupRequired = PipelineOptionsFactory.as(MultiGroupRequired.class); + + multiGroupRequired.setFoo("eggs"); + + PipelineOptionsValidator.validate(MultiGroupRequired.class, multiGroupRequired); + } + + private static interface LeftOptions extends PipelineOptions { + @Validation.Required(groups = {"left"}) + String getFoo(); + void setFoo(String foo); + + @Validation.Required(groups = {"left"}) + String getLeft(); + void setLeft(String left); + + @Validation.Required(groups = {"both"}) + String getBoth(); + void setBoth(String both); + } + + private static interface RightOptions extends PipelineOptions { + @Validation.Required(groups = {"right"}) + String getFoo(); + void setFoo(String foo); + + @Validation.Required(groups = {"right"}) + String getRight(); + void setRight(String right); + + @Validation.Required(groups = {"both"}) + String getBoth(); + void setBoth(String both); + } + + private static interface JoinedOptions extends LeftOptions, RightOptions {} + + @Test + public void testWhenOptionIsDefinedInMultipleSuperInterfacesAndIsNotPresentFailsRequirement() { + RightOptions rightOptions = PipelineOptionsFactory.as(RightOptions.class); + rightOptions.setBoth("foo"); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for group"); + expectedException.expectMessage("getFoo"); + + PipelineOptionsValidator.validate(JoinedOptions.class, rightOptions); + } + + @Test + public void testWhenOptionIsDefinedInMultipleSuperInterfacesMeetsGroupRequirement() { + RightOptions rightOpts = PipelineOptionsFactory.as(RightOptions.class); + rightOpts.setFoo("true"); + rightOpts.setBoth("bar"); + + LeftOptions leftOpts = PipelineOptionsFactory.as(LeftOptions.class); + leftOpts.setFoo("Untrue"); + leftOpts.setBoth("Raise the"); + + PipelineOptionsValidator.validate(JoinedOptions.class, rightOpts); + PipelineOptionsValidator.validate(JoinedOptions.class, leftOpts); + } + + @Test + public void testWhenOptionIsDefinedOnOtherOptionsClassMeetsGroupRequirement() { + RightOptions rightOpts = PipelineOptionsFactory.as(RightOptions.class); + rightOpts.setFoo("true"); + rightOpts.setBoth("bar"); + + LeftOptions leftOpts = PipelineOptionsFactory.as(LeftOptions.class); + leftOpts.setFoo("Untrue"); + leftOpts.setBoth("Raise the"); + + PipelineOptionsValidator.validate(RightOptions.class, leftOpts); + PipelineOptionsValidator.validate(LeftOptions.class, rightOpts); + } + + @Test + public void testWhenOptionIsDefinedOnMultipleInterfacesOnlyListedOnceWhenNotPresent() { + JoinedOptions options = PipelineOptionsFactory.as(JoinedOptions.class); + options.setFoo("Hello"); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("required value for group [both]"); + expectedException.expectMessage("properties [getBoth()]"); + + PipelineOptionsValidator.validate(JoinedOptions.class, options); + } + + private static interface SuperOptions extends PipelineOptions { + @Validation.Required(groups = {"super"}) + String getFoo(); + void setFoo(String foo); + + @Validation.Required(groups = {"sub"}) + String getBar(); + void setBar(String bar); + + @Validation.Required(groups = {"otherSuper"}) + String getSuperclassObj(); + void setSuperclassObj(String sup); + } + + private static interface SubOptions extends SuperOptions { + @Override + @Validation.Required(groups = {"sub"}) + String getFoo(); + @Override + void setFoo(String foo); + + @Override + String getSuperclassObj(); + @Override + void setSuperclassObj(String sup); + } + + @Test + public void testSuperInterfaceRequiredOptionsAlsoRequiredInSubInterface() { + SubOptions subOpts = PipelineOptionsFactory.as(SubOptions.class); + subOpts.setFoo("Bar"); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("otherSuper"); + expectedException.expectMessage("Missing required value"); + expectedException.expectMessage("getSuperclassObj"); + + PipelineOptionsValidator.validate(SubOptions.class, subOpts); + } + + @Test + public void + testSuperInterfaceGroupIsInAdditionToSubInterfaceGroupOnlyWhenValidatingSuperInterface() { + SubOptions opts = PipelineOptionsFactory.as(SubOptions.class); + opts.setFoo("Foo"); + opts.setSuperclassObj("Hello world"); + + // Valid SubOptions, but invalid SuperOptions + PipelineOptionsValidator.validate(SubOptions.class, opts); + + expectedException.expectMessage("sub"); + expectedException.expectMessage("Missing required value"); + expectedException.expectMessage("getBar"); + PipelineOptionsValidator.validate(SuperOptions.class, opts); + } + + @Test + public void testSuperInterfaceRequiredOptionsSatisfiedBySubInterface() { + SubOptions subOpts = PipelineOptionsFactory.as(SubOptions.class); + subOpts.setFoo("bar"); + subOpts.setBar("bar"); + subOpts.setSuperclassObj("SuperDuper"); + + PipelineOptionsValidator.validate(SubOptions.class, subOpts); + PipelineOptionsValidator.validate(SuperOptions.class, subOpts); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandlerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandlerTest.java new file mode 100644 index 000000000000..7eb5822b436e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/ProxyInvocationHandlerTest.java @@ -0,0 +1,691 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.options; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** Tests for {@link ProxyInvocationHandler}. */ +@RunWith(JUnit4.class) +public class ProxyInvocationHandlerTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + /** A test interface with some primitives and objects. */ + public static interface Simple extends PipelineOptions { + boolean isOptionEnabled(); + void setOptionEnabled(boolean value); + int getPrimitive(); + void setPrimitive(int value); + String getString(); + void setString(String value); + } + + @Test + public void testPropertySettingAndGetting() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + proxy.setString("OBJECT"); + proxy.setOptionEnabled(true); + proxy.setPrimitive(4); + assertEquals("OBJECT", proxy.getString()); + assertTrue(proxy.isOptionEnabled()); + assertEquals(4, proxy.getPrimitive()); + } + + /** A test interface containing all the JLS default values. */ + public static interface JLSDefaults extends PipelineOptions { + boolean getBoolean(); + void setBoolean(boolean value); + char getChar(); + void setChar(char value); + byte getByte(); + void setByte(byte value); + short getShort(); + void setShort(short value); + int getInt(); + void setInt(int value); + long getLong(); + void setLong(long value); + float getFloat(); + void setFloat(float value); + double getDouble(); + void setDouble(double value); + Object getObject(); + void setObject(Object value); + } + + @Test + public void testGettingJLSDefaults() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + JLSDefaults proxy = handler.as(JLSDefaults.class); + assertFalse(proxy.getBoolean()); + assertEquals('\0', proxy.getChar()); + assertEquals((byte) 0, proxy.getByte()); + assertEquals((short) 0, proxy.getShort()); + assertEquals(0, proxy.getInt()); + assertEquals(0L, proxy.getLong()); + assertEquals(0f, proxy.getFloat(), 0f); + assertEquals(0d, proxy.getDouble(), 0d); + assertNull(proxy.getObject()); + } + + /** A {@link DefaultValueFactory} that is used for testing. */ + public static class TestOptionFactory implements DefaultValueFactory { + @Override + public String create(PipelineOptions options) { + return "testOptionFactory"; + } + } + + /** A test enum for testing {@link Default.Enum @Default.Enum}. */ + public enum EnumType { + MyEnum("MyTestEnum"); + + private final String value; + private EnumType(String value) { + this.value = value; + } + + @Override + public String toString() { + return value; + } + } + + /** A test interface containing all the {@link Default} annotations. */ + public static interface DefaultAnnotations extends PipelineOptions { + @Default.Boolean(true) + boolean getBoolean(); + void setBoolean(boolean value); + @Default.Character('a') + char getChar(); + void setChar(char value); + @Default.Byte((byte) 4) + byte getByte(); + void setByte(byte value); + @Default.Short((short) 5) + short getShort(); + void setShort(short value); + @Default.Integer(6) + int getInt(); + void setInt(int value); + @Default.Long(7L) + long getLong(); + void setLong(long value); + @Default.Float(8f) + float getFloat(); + void setFloat(float value); + @Default.Double(9d) + double getDouble(); + void setDouble(double value); + @Default.String("testString") + String getString(); + void setString(String value); + @Default.Class(DefaultAnnotations.class) + Class getClassOption(); + void setClassOption(Class value); + @Default.Enum("MyEnum") + EnumType getEnum(); + void setEnum(EnumType value); + @Default.InstanceFactory(TestOptionFactory.class) + String getComplex(); + void setComplex(String value); + } + + @Test + public void testAnnotationDefaults() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + DefaultAnnotations proxy = handler.as(DefaultAnnotations.class); + assertTrue(proxy.getBoolean()); + assertEquals('a', proxy.getChar()); + assertEquals((byte) 4, proxy.getByte()); + assertEquals((short) 5, proxy.getShort()); + assertEquals(6, proxy.getInt()); + assertEquals(7, proxy.getLong()); + assertEquals(8f, proxy.getFloat(), 0f); + assertEquals(9d, proxy.getDouble(), 0d); + assertEquals("testString", proxy.getString()); + assertEquals(DefaultAnnotations.class, proxy.getClassOption()); + assertEquals(EnumType.MyEnum, proxy.getEnum()); + assertEquals("testOptionFactory", proxy.getComplex()); + } + + @Test + public void testEquals() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + JLSDefaults sameAsProxy = proxy.as(JLSDefaults.class); + ProxyInvocationHandler handler2 = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy2 = handler2.as(Simple.class); + JLSDefaults sameAsProxy2 = proxy2.as(JLSDefaults.class); + assertTrue(handler.equals(proxy)); + assertTrue(proxy.equals(proxy)); + assertTrue(proxy.equals(sameAsProxy)); + assertFalse(handler.equals(handler2)); + assertFalse(proxy.equals(proxy2)); + assertFalse(proxy.equals(sameAsProxy2)); + } + + @Test + public void testHashCode() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + JLSDefaults sameAsProxy = proxy.as(JLSDefaults.class); + + ProxyInvocationHandler handler2 = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy2 = handler2.as(Simple.class); + JLSDefaults sameAsProxy2 = proxy2.as(JLSDefaults.class); + + // Hashcode comparisons below depend on random numbers, so could fail if seed changes. + assertTrue(handler.hashCode() == proxy.hashCode()); + assertTrue(proxy.hashCode() == sameAsProxy.hashCode()); + assertFalse(handler.hashCode() == handler2.hashCode()); + assertFalse(proxy.hashCode() == proxy2.hashCode()); + assertFalse(proxy.hashCode() == sameAsProxy2.hashCode()); + } + + @Test + public void testToString() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + proxy.setString("stringValue"); + DefaultAnnotations proxy2 = proxy.as(DefaultAnnotations.class); + proxy2.setLong(57L); + assertEquals("Current Settings:\n" + + " long: 57\n" + + " string: stringValue\n", + proxy.toString()); + } + + @Test + public void testToStringAfterDeserializationContainsJsonEntries() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + proxy.setString("stringValue"); + DefaultAnnotations proxy2 = proxy.as(DefaultAnnotations.class); + proxy2.setLong(57L); + assertEquals("Current Settings:\n" + + " long: 57\n" + + " string: \"stringValue\"\n", + serializeDeserialize(PipelineOptions.class, proxy2).toString()); + } + + @Test + public void testToStringAfterDeserializationContainsOverriddenEntries() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + Simple proxy = handler.as(Simple.class); + proxy.setString("stringValue"); + DefaultAnnotations proxy2 = proxy.as(DefaultAnnotations.class); + proxy2.setLong(57L); + Simple deserializedOptions = serializeDeserialize(Simple.class, proxy2); + deserializedOptions.setString("overriddenValue"); + assertEquals("Current Settings:\n" + + " long: 57\n" + + " string: overriddenValue\n", + deserializedOptions.toString()); + } + + /** A test interface containing an unknown method. */ + public static interface UnknownMethod { + void unknownMethod(); + } + + @Test + public void testInvokeWithUnknownMethod() throws Exception { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Unknown method [public abstract void com.google.cloud." + + "dataflow.sdk.options.ProxyInvocationHandlerTest$UnknownMethod.unknownMethod()] " + + "invoked with args [null]."); + + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + handler.invoke(handler, UnknownMethod.class.getMethod("unknownMethod"), null); + } + + /** A test interface that extends another interface. */ + public static interface SubClass extends Simple { + String getExtended(); + void setExtended(String value); + } + + @Test + public void testSubClassStoresSuperInterfaceValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setString("parentValue"); + assertEquals("parentValue", extended.getString()); + } + + @Test + public void testUpCastRetainsSuperInterfaceValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setString("parentValue"); + Simple simple = extended.as(Simple.class); + assertEquals("parentValue", simple.getString()); + } + + @Test + public void testUpCastRetainsSubClassValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setExtended("subClassValue"); + SubClass extended2 = extended.as(Simple.class).as(SubClass.class); + assertEquals("subClassValue", extended2.getExtended()); + } + + /** A test interface that is a sibling to {@link SubClass}. */ + public static interface Sibling extends Simple { + String getSibling(); + void setSibling(String value); + } + + @Test + public void testAsSiblingRetainsSuperInterfaceValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SubClass extended = handler.as(SubClass.class); + + extended.setString("parentValue"); + Sibling sibling = extended.as(Sibling.class); + assertEquals("parentValue", sibling.getString()); + } + + /** A test interface that has the same methods as the parent. */ + public static interface MethodConflict extends Simple { + @Override + String getString(); + @Override + void setString(String value); + } + + @Test + public void testMethodConflictProvidesSameValue() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + MethodConflict methodConflict = handler.as(MethodConflict.class); + + methodConflict.setString("conflictValue"); + assertEquals("conflictValue", methodConflict.getString()); + assertEquals("conflictValue", methodConflict.as(Simple.class).getString()); + } + + /** A test interface that has the same methods as its parent and grandparent. */ + public static interface DeepMethodConflict extends MethodConflict { + @Override + String getString(); + @Override + void setString(String value); + @Override + int getPrimitive(); + @Override + void setPrimitive(int value); + } + + @Test + public void testDeepMethodConflictProvidesSameValue() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + DeepMethodConflict deepMethodConflict = handler.as(DeepMethodConflict.class); + + // Tests overriding an already overridden method + deepMethodConflict.setString("conflictValue"); + assertEquals("conflictValue", deepMethodConflict.getString()); + assertEquals("conflictValue", deepMethodConflict.as(MethodConflict.class).getString()); + assertEquals("conflictValue", deepMethodConflict.as(Simple.class).getString()); + + // Tests overriding a method from an ancestor class + deepMethodConflict.setPrimitive(5); + assertEquals(5, deepMethodConflict.getPrimitive()); + assertEquals(5, deepMethodConflict.as(MethodConflict.class).getPrimitive()); + assertEquals(5, deepMethodConflict.as(Simple.class).getPrimitive()); + } + + /** A test interface that shares the same methods as {@link Sibling}. */ + public static interface SimpleSibling extends PipelineOptions { + String getString(); + void setString(String value); + } + + @Test + public void testDisjointSiblingsShareValues() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SimpleSibling proxy = handler.as(SimpleSibling.class); + proxy.setString("siblingValue"); + assertEquals("siblingValue", proxy.getString()); + assertEquals("siblingValue", proxy.as(Simple.class).getString()); + } + + /** A test interface that joins two sibling interfaces that have conflicting methods. */ + public static interface SiblingMethodConflict extends Simple, SimpleSibling { + } + + @Test + public void testSiblingMethodConflict() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + SiblingMethodConflict siblingMethodConflict = handler.as(SiblingMethodConflict.class); + siblingMethodConflict.setString("siblingValue"); + assertEquals("siblingValue", siblingMethodConflict.getString()); + assertEquals("siblingValue", siblingMethodConflict.as(Simple.class).getString()); + assertEquals("siblingValue", siblingMethodConflict.as(SimpleSibling.class).getString()); + } + + /** A test interface that has only the getter and only a setter overriden. */ + public static interface PartialMethodConflict extends Simple { + @Override + String getString(); + @Override + void setPrimitive(int value); + } + + @Test + public void testPartialMethodConflictProvidesSameValue() throws Exception { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + PartialMethodConflict partialMethodConflict = handler.as(PartialMethodConflict.class); + + // Tests overriding a getter property that is only partially bound + partialMethodConflict.setString("conflictValue"); + assertEquals("conflictValue", partialMethodConflict.getString()); + assertEquals("conflictValue", partialMethodConflict.as(Simple.class).getString()); + + // Tests overriding a setter property that is only partially bound + partialMethodConflict.setPrimitive(5); + assertEquals(5, partialMethodConflict.getPrimitive()); + assertEquals(5, partialMethodConflict.as(Simple.class).getPrimitive()); + } + + @Test + public void testJsonConversionForDefault() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + assertNotNull(serializeDeserialize(PipelineOptions.class, options)); + } + + /** Test interface for JSON conversion of simple types. */ + private static interface SimpleTypes extends PipelineOptions { + int getInteger(); + void setInteger(int value); + String getString(); + void setString(String value); + } + + @Test + public void testJsonConversionForSimpleTypes() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setString("TestValue"); + options.setInteger(5); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, options); + assertEquals(5, options2.getInteger()); + assertEquals("TestValue", options2.getString()); + } + + @Test + public void testJsonConversionOfAJsonConvertedType() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setString("TestValue"); + options.setInteger(5); + // It is important here that our first serialization goes to our most basic + // type so that we handle the case when we don't know the types of certain + // properties because the intermediate instance of PipelineOptions never + // saw their interface. + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, + serializeDeserialize(PipelineOptions.class, options)); + assertEquals(5, options2.getInteger()); + assertEquals("TestValue", options2.getString()); + } + + @Test + public void testJsonConversionForPartiallySerializedValues() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setInteger(5); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, options); + options2.setString("TestValue"); + SimpleTypes options3 = serializeDeserialize(SimpleTypes.class, options2); + assertEquals(5, options3.getInteger()); + assertEquals("TestValue", options3.getString()); + } + + @Test + public void testJsonConversionForOverriddenSerializedValues() throws Exception { + SimpleTypes options = PipelineOptionsFactory.as(SimpleTypes.class); + options.setInteger(-5); + options.setString("NeedsToBeOverridden"); + SimpleTypes options2 = serializeDeserialize(SimpleTypes.class, options); + options2.setInteger(5); + options2.setString("TestValue"); + SimpleTypes options3 = serializeDeserialize(SimpleTypes.class, options2); + assertEquals(5, options3.getInteger()); + assertEquals("TestValue", options3.getString()); + } + + /** Test interface for JSON conversion of container types. */ + private static interface ContainerTypes extends PipelineOptions { + List getList(); + void setList(List values); + Map getMap(); + void setMap(Map values); + Set getSet(); + void setSet(Set values); + } + + @Test + public void testJsonConversionForContainerTypes() throws Exception { + List list = ImmutableList.of("a", "b", "c"); + Map map = ImmutableMap.of("d", "x", "e", "y", "f", "z"); + Set set = ImmutableSet.of("g", "h", "i"); + ContainerTypes options = PipelineOptionsFactory.as(ContainerTypes.class); + options.setList(list); + options.setMap(map); + options.setSet(set); + ContainerTypes options2 = serializeDeserialize(ContainerTypes.class, options); + assertEquals(list, options2.getList()); + assertEquals(map, options2.getMap()); + assertEquals(set, options2.getSet()); + } + + /** Test interface for conversion of inner types. */ + private static class InnerType { + public double doubleField; + + static InnerType of(double value) { + InnerType rval = new InnerType(); + rval.doubleField = value; + return rval; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + return obj != null + && getClass().equals(obj.getClass()) + && Objects.equals(doubleField, ((InnerType) obj).doubleField); + } + } + + /** Test interface for conversion of generics and inner types. */ + private static class ComplexType { + public String stringField; + public Integer intField; + public List genericType; + public InnerType innerType; + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + return obj != null + && getClass().equals(obj.getClass()) + && Objects.equals(stringField, ((ComplexType) obj).stringField) + && Objects.equals(intField, ((ComplexType) obj).intField) + && Objects.equals(genericType, ((ComplexType) obj).genericType) + && Objects.equals(innerType, ((ComplexType) obj).innerType); + } + } + + private static interface ComplexTypes extends PipelineOptions { + ComplexType getComplexType(); + void setComplexType(ComplexType value); + } + + @Test + public void testJsonConversionForComplexType() throws Exception { + ComplexType complexType = new ComplexType(); + complexType.stringField = "stringField"; + complexType.intField = 12; + complexType.innerType = InnerType.of(12); + complexType.genericType = ImmutableList.of(InnerType.of(16234), InnerType.of(24)); + + ComplexTypes options = PipelineOptionsFactory.as(ComplexTypes.class); + options.setComplexType(complexType); + ComplexTypes options2 = serializeDeserialize(ComplexTypes.class, options); + assertEquals(complexType, options2.getComplexType()); + } + + /** Test interface for testing ignored properties during serialization. */ + private static interface IgnoredProperty extends PipelineOptions { + @JsonIgnore + String getValue(); + void setValue(String value); + } + + @Test + public void testJsonConversionOfIgnoredProperty() throws Exception { + IgnoredProperty options = PipelineOptionsFactory.as(IgnoredProperty.class); + options.setValue("TestValue"); + + IgnoredProperty options2 = serializeDeserialize(IgnoredProperty.class, options); + assertNull(options2.getValue()); + } + + /** Test class that is not serializable by Jackson. */ + public static class NotSerializable { + private String value; + public NotSerializable(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** Test interface containing a class that is not serializable by Jackson. */ + private static interface NotSerializableProperty extends PipelineOptions { + NotSerializable getValue(); + void setValue(NotSerializable value); + } + + @Test + public void testJsonConversionOfNotSerializableProperty() throws Exception { + NotSerializableProperty options = PipelineOptionsFactory.as(NotSerializableProperty.class); + options.setValue(new NotSerializable("TestString")); + + expectedException.expect(JsonMappingException.class); + expectedException.expectMessage("Failed to serialize and deserialize property 'value'"); + serializeDeserialize(NotSerializableProperty.class, options); + } + + /** + * Test interface that has {@link JsonIgnore @JsonIgnore} on a property that Jackson + * can't serialize. + */ + private static interface IgnoredNotSerializableProperty extends PipelineOptions { + @JsonIgnore + NotSerializable getValue(); + void setValue(NotSerializable value); + } + + @Test + public void testJsonConversionOfIgnoredNotSerializableProperty() throws Exception { + IgnoredNotSerializableProperty options = + PipelineOptionsFactory.as(IgnoredNotSerializableProperty.class); + options.setValue(new NotSerializable("TestString")); + + IgnoredNotSerializableProperty options2 = + serializeDeserialize(IgnoredNotSerializableProperty.class, options); + assertNull(options2.getValue()); + } + + /** Test class that is only serializable by Jackson with the added metadata. */ + public static class SerializableWithMetadata { + private String value; + public SerializableWithMetadata(@JsonProperty("value") String value) { + this.value = value; + } + + @JsonProperty("value") + public String getValue() { + return value; + } + } + + /** + * Test interface containing a property that is serializable by Jackson only with + * the additional metadata. + */ + private static interface SerializableWithMetadataProperty extends PipelineOptions { + SerializableWithMetadata getValue(); + void setValue(SerializableWithMetadata value); + } + + @Test + public void testJsonConversionOfSerializableWithMetadataProperty() throws Exception { + SerializableWithMetadataProperty options = + PipelineOptionsFactory.as(SerializableWithMetadataProperty.class); + options.setValue(new SerializableWithMetadata("TestString")); + + SerializableWithMetadataProperty options2 = + serializeDeserialize(SerializableWithMetadataProperty.class, options); + assertEquals("TestString", options2.getValue().getValue()); + } + + private T serializeDeserialize(Class kls, PipelineOptions options) + throws Exception { + ObjectMapper mapper = new ObjectMapper(); + String value = mapper.writeValueAsString(options); + return mapper.readValue(value, PipelineOptions.class).as(kls); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/AggregatorPipelineExtractorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/AggregatorPipelineExtractorTest.java new file mode 100644 index 000000000000..04c7edef944e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/AggregatorPipelineExtractorTest.java @@ -0,0 +1,228 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * Tests for {@link AggregatorPipelineExtractor}. + */ +@RunWith(JUnit4.class) +public class AggregatorPipelineExtractorTest { + @Mock + private Pipeline p; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithParDoBoundExtractsSteps() { + @SuppressWarnings("rawtypes") + ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); + AggregatorProvidingDoFn fn = new AggregatorProvidingDoFn<>(); + when(bound.getFn()).thenReturn(fn); + + Aggregator aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); + Aggregator aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn()); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map, Collection>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals(ImmutableSet.>of(bound), aggregatorSteps.get(aggregatorOne)); + assertEquals(ImmutableSet.>of(bound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(aggregatorSteps.size(), 2); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithParDoBoundMultiExtractsSteps() { + @SuppressWarnings("rawtypes") + ParDo.BoundMulti bound = mock(ParDo.BoundMulti.class, "BoundMulti"); + AggregatorProvidingDoFn fn = new AggregatorProvidingDoFn<>(); + when(bound.getFn()).thenReturn(fn); + + Aggregator aggregatorOne = fn.addAggregator(new Max.MaxLongFn()); + Aggregator aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map, Collection>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals(ImmutableSet.>of(bound), aggregatorSteps.get(aggregatorOne)); + assertEquals(ImmutableSet.>of(bound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(2, aggregatorSteps.size()); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() { + @SuppressWarnings("rawtypes") + ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); + @SuppressWarnings("rawtypes") + ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); + AggregatorProvidingDoFn fn = new AggregatorProvidingDoFn<>(); + when(bound.getFn()).thenReturn(fn); + when(otherBound.getFn()).thenReturn(fn); + + Aggregator aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); + Aggregator aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); + when(otherTransformNode.getTransform()).thenReturn(otherBound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map, Collection>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals( + ImmutableSet.>of(bound, otherBound), aggregatorSteps.get(aggregatorOne)); + assertEquals( + ImmutableSet.>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(2, aggregatorSteps.size()); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithDifferentStepsAddsSteps() { + @SuppressWarnings("rawtypes") + ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); + + AggregatorProvidingDoFn fn = new AggregatorProvidingDoFn<>(); + Aggregator aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); + + when(bound.getFn()).thenReturn(fn); + + @SuppressWarnings("rawtypes") + ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); + + AggregatorProvidingDoFn otherFn = new AggregatorProvidingDoFn<>(); + Aggregator aggregatorTwo = otherFn.addAggregator(new Sum.SumDoubleFn()); + + when(otherBound.getFn()).thenReturn(otherFn); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); + when(otherTransformNode.getTransform()).thenReturn(otherBound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map, Collection>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals(ImmutableSet.>of(bound), aggregatorSteps.get(aggregatorOne)); + assertEquals(ImmutableSet.>of(otherBound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(2, aggregatorSteps.size()); + } + + private static class VisitNodesAnswer implements Answer { + private final List nodes; + + public VisitNodesAnswer(List nodes) { + this.nodes = nodes; + } + + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0]; + for (TransformTreeNode node : nodes) { + visitor.visitTransform(node); + } + return null; + } + } + + private static class AggregatorProvidingDoFn extends DoFn { + public Aggregator addAggregator( + CombineFn combiner) { + return createAggregator(randomName(), combiner); + } + + private String randomName() { + return UUID.randomUUID().toString(); + } + + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + fail(); + } + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunnerTest.java new file mode 100644 index 000000000000..5b1748c32560 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/BlockingDataflowPipelineRunnerTest.java @@ -0,0 +1,301 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.testing.TestDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.NoopPathValidator; +import com.google.cloud.dataflow.sdk.util.TestCredential; + +import org.hamcrest.Description; +import org.hamcrest.Factory; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.concurrent.TimeUnit; + +/** + * Tests for BlockingDataflowPipelineRunner. + */ +@RunWith(JUnit4.class) +public class BlockingDataflowPipelineRunnerTest { + + @Rule + public ExpectedLogs expectedLogs = ExpectedLogs.none(BlockingDataflowPipelineRunner.class); + + @Rule + public ExpectedException expectedThrown = ExpectedException.none(); + + /** + * A {@link Matcher} for a {@link DataflowJobException} that applies an underlying {@link Matcher} + * to the {@link DataflowPipelineJob} returned by {@link DataflowJobException#getJob()}. + */ + private static class DataflowJobExceptionMatcher + extends TypeSafeMatcher { + + private final Matcher matcher; + + public DataflowJobExceptionMatcher(Matcher matcher) { + this.matcher = matcher; + } + + @Override + public boolean matchesSafely(T ex) { + return matcher.matches(ex.getJob()); + } + + @Override + protected void describeMismatchSafely(T item, Description description) { + description.appendText("job "); + matcher.describeMismatch(item.getMessage(), description); + } + + @Override + public void describeTo(Description description) { + description.appendText("exception with job matching "); + description.appendDescriptionOf(matcher); + } + + @Factory + public static Matcher expectJob( + Matcher matcher) { + return new DataflowJobExceptionMatcher(matcher); + } + } + + /** + * A {@link Matcher} for a {@link DataflowPipelineJob} that applies an underlying {@link Matcher} + * to the return value of {@link DataflowPipelineJob#getJobId()}. + */ + private static class JobIdMatcher extends TypeSafeMatcher { + + private final Matcher matcher; + + public JobIdMatcher(Matcher matcher) { + this.matcher = matcher; + } + + @Override + public boolean matchesSafely(T job) { + return matcher.matches(job.getJobId()); + } + + @Override + protected void describeMismatchSafely(T item, Description description) { + description.appendText("jobId "); + matcher.describeMismatch(item.getJobId(), description); + } + + @Override + public void describeTo(Description description) { + description.appendText("job with jobId "); + description.appendDescriptionOf(matcher); + } + + @Factory + public static Matcher expectJobId(final String jobId) { + return new JobIdMatcher(equalTo(jobId)); + } + } + + /** + * A {@link Matcher} for a {@link DataflowJobUpdatedException} that applies an underlying + * {@link Matcher} to the {@link DataflowPipelineJob} returned by + * {@link DataflowJobUpdatedException#getReplacedByJob()}. + */ + private static class ReplacedByJobMatcher + extends TypeSafeMatcher { + + private final Matcher matcher; + + public ReplacedByJobMatcher(Matcher matcher) { + this.matcher = matcher; + } + + @Override + public boolean matchesSafely(T ex) { + return matcher.matches(ex.getReplacedByJob()); + } + + @Override + protected void describeMismatchSafely(T item, Description description) { + description.appendText("job "); + matcher.describeMismatch(item.getMessage(), description); + } + + @Override + public void describeTo(Description description) { + description.appendText("exception with replacedByJob() "); + description.appendDescriptionOf(matcher); + } + + @Factory + public static Matcher expectReplacedBy( + Matcher matcher) { + return new ReplacedByJobMatcher(matcher); + } + } + + /** + * Creates a mocked {@link DataflowPipelineJob} with the given {@code projectId} and {@code jobId} + * that will immediately terminate in the provided {@code terminalState}. + * + *

    The return value may be further mocked. + */ + private DataflowPipelineJob createMockJob( + String projectId, String jobId, State terminalState) throws Exception { + DataflowPipelineJob mockJob = mock(DataflowPipelineJob.class); + when(mockJob.getProjectId()).thenReturn(projectId); + when(mockJob.getJobId()).thenReturn(jobId); + when(mockJob.waitToFinish( + anyLong(), isA(TimeUnit.class), isA(MonitoringUtil.JobMessagesHandler.class))) + .thenReturn(terminalState); + return mockJob; + } + + /** + * Returns a {@link BlockingDataflowPipelineRunner} that will return the provided a job to return. + * Some {@link PipelineOptions} will be extracted from the job, such as the project ID. + */ + private BlockingDataflowPipelineRunner createMockRunner(DataflowPipelineJob job) + throws Exception { + DataflowPipelineRunner mockRunner = mock(DataflowPipelineRunner.class); + TestDataflowPipelineOptions options = + PipelineOptionsFactory.as(TestDataflowPipelineOptions.class); + options.setProject(job.getProjectId()); + + when(mockRunner.run(isA(Pipeline.class))).thenReturn(job); + + return new BlockingDataflowPipelineRunner(mockRunner, options); + } + + /** + * Tests that the {@link BlockingDataflowPipelineRunner} returns normally when a job terminates in + * the {@link State#DONE DONE} state. + */ + @Test + public void testJobDoneComplete() throws Exception { + createMockRunner(createMockJob("testJobDone-projectId", "testJobDone-jobId", State.DONE)) + .run(DirectPipeline.createForTest()); + expectedLogs.verifyInfo("Job finished with status DONE"); + } + + /** + * Tests that the {@link BlockingDataflowPipelineRunner} throws the appropriate exception + * when a job terminates in the {@link State#FAILED FAILED} state. + */ + @Test + public void testFailedJobThrowsException() throws Exception { + expectedThrown.expect(DataflowJobExecutionException.class); + expectedThrown.expect(DataflowJobExceptionMatcher.expectJob( + JobIdMatcher.expectJobId("testFailedJob-jobId"))); + createMockRunner(createMockJob("testFailedJob-projectId", "testFailedJob-jobId", State.FAILED)) + .run(DirectPipeline.createForTest()); + } + + /** + * Tests that the {@link BlockingDataflowPipelineRunner} throws the appropriate exception + * when a job terminates in the {@link State#CANCELLED CANCELLED} state. + */ + @Test + public void testCancelledJobThrowsException() throws Exception { + expectedThrown.expect(DataflowJobCancelledException.class); + expectedThrown.expect(DataflowJobExceptionMatcher.expectJob( + JobIdMatcher.expectJobId("testCancelledJob-jobId"))); + createMockRunner( + createMockJob("testCancelledJob-projectId", "testCancelledJob-jobId", State.CANCELLED)) + .run(DirectPipeline.createForTest()); + } + + /** + * Tests that the {@link BlockingDataflowPipelineRunner} throws the appropriate exception + * when a job terminates in the {@link State#UPDATED UPDATED} state. + */ + @Test + public void testUpdatedJobThrowsException() throws Exception { + expectedThrown.expect(DataflowJobUpdatedException.class); + expectedThrown.expect(DataflowJobExceptionMatcher.expectJob( + JobIdMatcher.expectJobId("testUpdatedJob-jobId"))); + expectedThrown.expect(ReplacedByJobMatcher.expectReplacedBy( + JobIdMatcher.expectJobId("testUpdatedJob-replacedByJobId"))); + DataflowPipelineJob job = + createMockJob("testUpdatedJob-projectId", "testUpdatedJob-jobId", State.UPDATED); + DataflowPipelineJob replacedByJob = + createMockJob("testUpdatedJob-projectId", "testUpdatedJob-replacedByJobId", State.DONE); + when(job.getReplacedByJob()).thenReturn(replacedByJob); + createMockRunner(job).run(DirectPipeline.createForTest()); + } + + /** + * Tests that the {@link BlockingDataflowPipelineRunner} throws the appropriate exception + * when a job terminates in the {@link State#UNKNOWN UNKNOWN} state, indicating that the + * Dataflow service returned a state that the SDK is unfamiliar with (possibly because it + * is an old SDK relative the service). + */ + @Test + public void testUnknownJobThrowsException() throws Exception { + expectedThrown.expect(IllegalStateException.class); + createMockRunner( + createMockJob("testUnknownJob-projectId", "testUnknownJob-jobId", State.UNKNOWN)) + .run(DirectPipeline.createForTest()); + } + + /** + * Tests that the {@link BlockingDataflowPipelineRunner} throws the appropriate exception + * when a job returns a {@code null} state, indicating that it failed to contact the service, + * including all of its built-in resilience logic. + */ + @Test + public void testNullJobThrowsException() throws Exception { + expectedThrown.expect(DataflowServiceException.class); + expectedThrown.expect(DataflowJobExceptionMatcher.expectJob( + JobIdMatcher.expectJobId("testNullJob-jobId"))); + createMockRunner( + createMockJob("testNullJob-projectId", "testNullJob-jobId", null)) + .run(DirectPipeline.createForTest()); + } + + @Test + public void testToString() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setJobName("TestJobName"); + options.setProject("test-project"); + options.setTempLocation("gs://test/temp/location"); + options.setGcpCredential(new TestCredential()); + options.setPathValidatorClass(NoopPathValidator.class); + assertEquals("BlockingDataflowPipelineRunner#TestJobName", + BlockingDataflowPipelineRunner.fromOptions(options).toString()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJobTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJobTest.java new file mode 100644 index 000000000000..261cecb9620b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineJobTest.java @@ -0,0 +1,603 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.Dataflow.Projects.Jobs.Get; +import com.google.api.services.dataflow.Dataflow.Projects.Jobs.GetMetrics; +import com.google.api.services.dataflow.Dataflow.Projects.Jobs.Messages; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.JobMetrics; +import com.google.api.services.dataflow.model.MetricStructuredName; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.runners.dataflow.DataflowAggregatorTransforms; +import com.google.cloud.dataflow.sdk.testing.FastNanoClockAndSleeper; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.AttemptBoundedExponentialBackOff; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSetMultimap; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.math.BigDecimal; +import java.net.SocketTimeoutException; +import java.util.concurrent.TimeUnit; + +/** + * Tests for DataflowPipelineJob. + */ +@RunWith(JUnit4.class) +public class DataflowPipelineJobTest { + private static final String PROJECT_ID = "someProject"; + private static final String JOB_ID = "1234"; + + @Mock + private Dataflow mockWorkflowClient; + @Mock + private Dataflow.Projects mockProjects; + @Mock + private Dataflow.Projects.Jobs mockJobs; + @Rule + public FastNanoClockAndSleeper fastClock = new FastNanoClockAndSleeper(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + when(mockWorkflowClient.projects()).thenReturn(mockProjects); + when(mockProjects.jobs()).thenReturn(mockJobs); + } + + /** + * Validates that a given time is valid for the total time slept by a + * AttemptBoundedExponentialBackOff given the number of retries and + * an initial polling interval. + * + * @param pollingIntervalMillis The initial polling interval given. + * @param attempts The number of attempts made + * @param timeSleptMillis The amount of time slept by the clock. This is checked + * against the valid interval. + */ + void checkValidInterval(long pollingIntervalMillis, int attempts, long timeSleptMillis) { + long highSum = 0; + long lowSum = 0; + for (int i = 1; i < attempts; i++) { + double currentInterval = + pollingIntervalMillis + * Math.pow(AttemptBoundedExponentialBackOff.DEFAULT_MULTIPLIER, i - 1); + double offset = + AttemptBoundedExponentialBackOff.DEFAULT_RANDOMIZATION_FACTOR * currentInterval; + highSum += Math.round(currentInterval + offset); + lowSum += Math.round(currentInterval - offset); + } + assertThat(timeSleptMillis, allOf(greaterThanOrEqualTo(lowSum), lessThanOrEqualTo(highSum))); + } + + @Test + public void testWaitToFinishMessagesFail() throws Exception { + Dataflow.Projects.Jobs.Get statusRequest = mock(Dataflow.Projects.Jobs.Get.class); + + Job statusResponse = new Job(); + statusResponse.setCurrentState("JOB_STATE_" + State.DONE.name()); + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(statusRequest); + when(statusRequest.execute()).thenReturn(statusResponse); + + MonitoringUtil.JobMessagesHandler jobHandler = mock(MonitoringUtil.JobMessagesHandler.class); + Dataflow.Projects.Jobs.Messages mockMessages = + mock(Dataflow.Projects.Jobs.Messages.class); + Messages.List listRequest = mock(Dataflow.Projects.Jobs.Messages.List.class); + when(mockJobs.messages()).thenReturn(mockMessages); + when(mockMessages.list(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(listRequest); + when(listRequest.execute()).thenThrow(SocketTimeoutException.class); + DataflowAggregatorTransforms dataflowAggregatorTransforms = + mock(DataflowAggregatorTransforms.class); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient, dataflowAggregatorTransforms); + + State state = job.waitToFinish(5, TimeUnit.MINUTES, jobHandler, fastClock, fastClock); + assertEquals(null, state); + } + + public State mockWaitToFinishInState(State state) throws Exception { + Dataflow.Projects.Jobs.Get statusRequest = mock(Dataflow.Projects.Jobs.Get.class); + + Job statusResponse = new Job(); + statusResponse.setCurrentState("JOB_STATE_" + state.name()); + + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(statusRequest); + when(statusRequest.execute()).thenReturn(statusResponse); + DataflowAggregatorTransforms dataflowAggregatorTransforms = + mock(DataflowAggregatorTransforms.class); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient, dataflowAggregatorTransforms); + + return job.waitToFinish(1, TimeUnit.MINUTES, null, fastClock, fastClock); + } + + /** + * Tests that the {@link DataflowPipelineJob} understands that the {@link State#DONE DONE} + * state is terminal. + */ + @Test + public void testWaitToFinishDone() throws Exception { + assertEquals(State.DONE, mockWaitToFinishInState(State.DONE)); + } + + /** + * Tests that the {@link DataflowPipelineJob} understands that the {@link State#FAILED FAILED} + * state is terminal. + */ + @Test + public void testWaitToFinishFailed() throws Exception { + assertEquals(State.FAILED, mockWaitToFinishInState(State.FAILED)); + } + + /** + * Tests that the {@link DataflowPipelineJob} understands that the {@link State#FAILED FAILED} + * state is terminal. + */ + @Test + public void testWaitToFinishCancelled() throws Exception { + assertEquals(State.CANCELLED, mockWaitToFinishInState(State.CANCELLED)); + } + + /** + * Tests that the {@link DataflowPipelineJob} understands that the {@link State#FAILED FAILED} + * state is terminal. + */ + @Test + public void testWaitToFinishUpdated() throws Exception { + assertEquals(State.UPDATED, mockWaitToFinishInState(State.UPDATED)); + } + + @Test + public void testWaitToFinishFail() throws Exception { + Dataflow.Projects.Jobs.Get statusRequest = mock(Dataflow.Projects.Jobs.Get.class); + + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(statusRequest); + when(statusRequest.execute()).thenThrow(IOException.class); + DataflowAggregatorTransforms dataflowAggregatorTransforms = + mock(DataflowAggregatorTransforms.class); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient, dataflowAggregatorTransforms); + + long startTime = fastClock.nanoTime(); + State state = job.waitToFinish(5, TimeUnit.MINUTES, null, fastClock, fastClock); + assertEquals(null, state); + long timeDiff = TimeUnit.NANOSECONDS.toMillis(fastClock.nanoTime() - startTime); + checkValidInterval(DataflowPipelineJob.MESSAGES_POLLING_INTERVAL, + DataflowPipelineJob.MESSAGES_POLLING_ATTEMPTS, timeDiff); + } + + @Test + public void testWaitToFinishTimeFail() throws Exception { + Dataflow.Projects.Jobs.Get statusRequest = mock(Dataflow.Projects.Jobs.Get.class); + + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(statusRequest); + when(statusRequest.execute()).thenThrow(IOException.class); + DataflowAggregatorTransforms dataflowAggregatorTransforms = + mock(DataflowAggregatorTransforms.class); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient, dataflowAggregatorTransforms); + long startTime = fastClock.nanoTime(); + State state = job.waitToFinish(4, TimeUnit.MILLISECONDS, null, fastClock, fastClock); + assertEquals(null, state); + long timeDiff = TimeUnit.NANOSECONDS.toMillis(fastClock.nanoTime() - startTime); + // Should only sleep for the 4 ms remaining. + assertEquals(timeDiff, 4L); + } + + @Test + public void testGetStateReturnsServiceState() throws Exception { + Dataflow.Projects.Jobs.Get statusRequest = mock(Dataflow.Projects.Jobs.Get.class); + + Job statusResponse = new Job(); + statusResponse.setCurrentState("JOB_STATE_" + State.RUNNING.name()); + + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(statusRequest); + when(statusRequest.execute()).thenReturn(statusResponse); + + DataflowAggregatorTransforms dataflowAggregatorTransforms = + mock(DataflowAggregatorTransforms.class); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient, dataflowAggregatorTransforms); + + assertEquals( + State.RUNNING, + job.getStateWithRetries(DataflowPipelineJob.STATUS_POLLING_ATTEMPTS, fastClock)); + } + + @Test + public void testGetStateWithExceptionReturnsUnknown() throws Exception { + Dataflow.Projects.Jobs.Get statusRequest = mock(Dataflow.Projects.Jobs.Get.class); + + when(mockJobs.get(eq(PROJECT_ID), eq(JOB_ID))).thenReturn(statusRequest); + when(statusRequest.execute()).thenThrow(IOException.class); + DataflowAggregatorTransforms dataflowAggregatorTransforms = + mock(DataflowAggregatorTransforms.class); + + DataflowPipelineJob job = new DataflowPipelineJob( + PROJECT_ID, JOB_ID, mockWorkflowClient, dataflowAggregatorTransforms); + + long startTime = fastClock.nanoTime(); + assertEquals( + State.UNKNOWN, + job.getStateWithRetries(DataflowPipelineJob.STATUS_POLLING_ATTEMPTS, fastClock)); + long timeDiff = TimeUnit.NANOSECONDS.toMillis(fastClock.nanoTime() - startTime); + checkValidInterval(DataflowPipelineJob.STATUS_POLLING_INTERVAL, + DataflowPipelineJob.STATUS_POLLING_ATTEMPTS, timeDiff); + } + + @Test + public void testGetAggregatorValuesWithNoMetricUpdatesReturnsEmptyValue() + throws IOException, AggregatorRetrievalException { + Aggregator aggregator = mock(Aggregator.class); + @SuppressWarnings("unchecked") + PTransform pTransform = mock(PTransform.class); + String stepName = "s1"; + String fullName = "Foo/Bar/Baz"; + AppliedPTransform appliedTransform = appliedPTransform(fullName, pTransform); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of(aggregator, pTransform).asMap(), + ImmutableMap., String>of(appliedTransform, stepName)); + + GetMetrics getMetrics = mock(GetMetrics.class); + when(mockJobs.getMetrics(PROJECT_ID, JOB_ID)).thenReturn(getMetrics); + JobMetrics jobMetrics = new JobMetrics(); + when(getMetrics.execute()).thenReturn(jobMetrics); + + jobMetrics.setMetrics(ImmutableList.of()); + + Get getState = mock(Get.class); + when(mockJobs.get(PROJECT_ID, JOB_ID)).thenReturn(getState); + Job modelJob = new Job(); + when(getState.execute()).thenReturn(modelJob); + modelJob.setCurrentState(State.RUNNING.toString()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + AggregatorValues values = job.getAggregatorValues(aggregator); + + assertThat(values.getValues(), empty()); + } + + @Test + public void testGetAggregatorValuesWithNullMetricUpdatesReturnsEmptyValue() + throws IOException, AggregatorRetrievalException { + Aggregator aggregator = mock(Aggregator.class); + @SuppressWarnings("unchecked") + PTransform pTransform = mock(PTransform.class); + String stepName = "s1"; + String fullName = "Foo/Bar/Baz"; + AppliedPTransform appliedTransform = appliedPTransform(fullName, pTransform); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of(aggregator, pTransform).asMap(), + ImmutableMap., String>of(appliedTransform, stepName)); + + GetMetrics getMetrics = mock(GetMetrics.class); + when(mockJobs.getMetrics(PROJECT_ID, JOB_ID)).thenReturn(getMetrics); + JobMetrics jobMetrics = new JobMetrics(); + when(getMetrics.execute()).thenReturn(jobMetrics); + + jobMetrics.setMetrics(null); + + Get getState = mock(Get.class); + when(mockJobs.get(PROJECT_ID, JOB_ID)).thenReturn(getState); + Job modelJob = new Job(); + when(getState.execute()).thenReturn(modelJob); + modelJob.setCurrentState(State.RUNNING.toString()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + AggregatorValues values = job.getAggregatorValues(aggregator); + + assertThat(values.getValues(), empty()); + } + + @Test + public void testGetAggregatorValuesWithSingleMetricUpdateReturnsSingletonCollection() + throws IOException, AggregatorRetrievalException { + CombineFn combineFn = new Sum.SumLongFn(); + String aggregatorName = "agg"; + Aggregator aggregator = new TestAggregator<>(combineFn, aggregatorName); + @SuppressWarnings("unchecked") + PTransform pTransform = mock(PTransform.class); + String stepName = "s1"; + String fullName = "Foo/Bar/Baz"; + AppliedPTransform appliedTransform = appliedPTransform(fullName, pTransform); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of(aggregator, pTransform).asMap(), + ImmutableMap., String>of(appliedTransform, stepName)); + + GetMetrics getMetrics = mock(GetMetrics.class); + when(mockJobs.getMetrics(PROJECT_ID, JOB_ID)).thenReturn(getMetrics); + JobMetrics jobMetrics = new JobMetrics(); + when(getMetrics.execute()).thenReturn(jobMetrics); + + MetricUpdate update = new MetricUpdate(); + long stepValue = 1234L; + update.setScalar(new BigDecimal(stepValue)); + + MetricStructuredName structuredName = new MetricStructuredName(); + structuredName.setName(aggregatorName); + structuredName.setContext(ImmutableMap.of("step", stepName)); + update.setName(structuredName); + + jobMetrics.setMetrics(ImmutableList.of(update)); + + Get getState = mock(Get.class); + when(mockJobs.get(PROJECT_ID, JOB_ID)).thenReturn(getState); + Job modelJob = new Job(); + when(getState.execute()).thenReturn(modelJob); + modelJob.setCurrentState(State.RUNNING.toString()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + AggregatorValues values = job.getAggregatorValues(aggregator); + + assertThat(values.getValuesAtSteps(), hasEntry(fullName, stepValue)); + assertThat(values.getValuesAtSteps().size(), equalTo(1)); + assertThat(values.getValues(), contains(stepValue)); + assertThat(values.getTotalValue(combineFn), equalTo(Long.valueOf(stepValue))); + } + + @Test + public void testGetAggregatorValuesWithMultipleMetricUpdatesReturnsCollection() + throws IOException, AggregatorRetrievalException { + CombineFn combineFn = new Sum.SumLongFn(); + String aggregatorName = "agg"; + Aggregator aggregator = new TestAggregator<>(combineFn, aggregatorName); + + @SuppressWarnings("unchecked") + PTransform pTransform = mock(PTransform.class); + String stepName = "s1"; + String fullName = "Foo/Bar/Baz"; + AppliedPTransform appliedTransform = appliedPTransform(fullName, pTransform); + + @SuppressWarnings("unchecked") + PTransform otherTransform = mock(PTransform.class); + String otherStepName = "s88"; + String otherFullName = "Spam/Ham/Eggs"; + AppliedPTransform otherAppliedTransform = + appliedPTransform(otherFullName, otherTransform); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of( + aggregator, pTransform, aggregator, otherTransform).asMap(), + ImmutableMap., String>of( + appliedTransform, stepName, otherAppliedTransform, otherStepName)); + + GetMetrics getMetrics = mock(GetMetrics.class); + when(mockJobs.getMetrics(PROJECT_ID, JOB_ID)).thenReturn(getMetrics); + JobMetrics jobMetrics = new JobMetrics(); + when(getMetrics.execute()).thenReturn(jobMetrics); + + MetricUpdate updateOne = new MetricUpdate(); + long stepValue = 1234L; + updateOne.setScalar(new BigDecimal(stepValue)); + + MetricStructuredName structuredNameOne = new MetricStructuredName(); + structuredNameOne.setName(aggregatorName); + structuredNameOne.setContext(ImmutableMap.of("step", stepName)); + updateOne.setName(structuredNameOne); + + MetricUpdate updateTwo = new MetricUpdate(); + long stepValueTwo = 1024L; + updateTwo.setScalar(new BigDecimal(stepValueTwo)); + + MetricStructuredName structuredNameTwo = new MetricStructuredName(); + structuredNameTwo.setName(aggregatorName); + structuredNameTwo.setContext(ImmutableMap.of("step", otherStepName)); + updateTwo.setName(structuredNameTwo); + + jobMetrics.setMetrics(ImmutableList.of(updateOne, updateTwo)); + + Get getState = mock(Get.class); + when(mockJobs.get(PROJECT_ID, JOB_ID)).thenReturn(getState); + Job modelJob = new Job(); + when(getState.execute()).thenReturn(modelJob); + modelJob.setCurrentState(State.RUNNING.toString()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + AggregatorValues values = job.getAggregatorValues(aggregator); + + assertThat(values.getValuesAtSteps(), hasEntry(fullName, stepValue)); + assertThat(values.getValuesAtSteps(), hasEntry(otherFullName, stepValueTwo)); + assertThat(values.getValuesAtSteps().size(), equalTo(2)); + assertThat(values.getValues(), containsInAnyOrder(stepValue, stepValueTwo)); + assertThat(values.getTotalValue(combineFn), equalTo(Long.valueOf(stepValue + stepValueTwo))); + } + + @Test + public void testGetAggregatorValuesWithUnrelatedMetricUpdateIgnoresUpdate() + throws IOException, AggregatorRetrievalException { + CombineFn combineFn = new Sum.SumLongFn(); + String aggregatorName = "agg"; + Aggregator aggregator = new TestAggregator<>(combineFn, aggregatorName); + @SuppressWarnings("unchecked") + PTransform pTransform = mock(PTransform.class); + String stepName = "s1"; + String fullName = "Foo/Bar/Baz"; + AppliedPTransform appliedTransform = appliedPTransform(fullName, pTransform); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of(aggregator, pTransform).asMap(), + ImmutableMap., String>of(appliedTransform, stepName)); + + GetMetrics getMetrics = mock(GetMetrics.class); + when(mockJobs.getMetrics(PROJECT_ID, JOB_ID)).thenReturn(getMetrics); + JobMetrics jobMetrics = new JobMetrics(); + when(getMetrics.execute()).thenReturn(jobMetrics); + + MetricUpdate ignoredUpdate = new MetricUpdate(); + ignoredUpdate.setScalar(null); + + MetricStructuredName ignoredName = new MetricStructuredName(); + ignoredName.setName("ignoredAggregator.elementCount.out0"); + ignoredName.setContext(null); + ignoredUpdate.setName(ignoredName); + + jobMetrics.setMetrics(ImmutableList.of(ignoredUpdate)); + + Get getState = mock(Get.class); + when(mockJobs.get(PROJECT_ID, JOB_ID)).thenReturn(getState); + Job modelJob = new Job(); + when(getState.execute()).thenReturn(modelJob); + modelJob.setCurrentState(State.RUNNING.toString()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + AggregatorValues values = job.getAggregatorValues(aggregator); + + assertThat(values.getValuesAtSteps().entrySet(), empty()); + assertThat(values.getValues(), empty()); + } + + @Test + public void testGetAggregatorValuesWithUnusedAggregatorThrowsException() + throws AggregatorRetrievalException { + Aggregator aggregator = mock(Aggregator.class); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of().asMap(), + ImmutableMap., String>of()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("not used in this pipeline"); + + job.getAggregatorValues(aggregator); + } + + @Test + public void testGetAggregatorValuesWhenClientThrowsExceptionThrowsAggregatorRetrievalException() + throws IOException, AggregatorRetrievalException { + CombineFn combineFn = new Sum.SumLongFn(); + String aggregatorName = "agg"; + Aggregator aggregator = new TestAggregator<>(combineFn, aggregatorName); + @SuppressWarnings("unchecked") + PTransform pTransform = mock(PTransform.class); + String stepName = "s1"; + String fullName = "Foo/Bar/Baz"; + AppliedPTransform appliedTransform = appliedPTransform(fullName, pTransform); + + DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms( + ImmutableSetMultimap., PTransform>of(aggregator, pTransform).asMap(), + ImmutableMap., String>of(appliedTransform, stepName)); + + GetMetrics getMetrics = mock(GetMetrics.class); + when(mockJobs.getMetrics(PROJECT_ID, JOB_ID)).thenReturn(getMetrics); + IOException cause = new IOException(); + when(getMetrics.execute()).thenThrow(cause); + + Get getState = mock(Get.class); + when(mockJobs.get(PROJECT_ID, JOB_ID)).thenReturn(getState); + Job modelJob = new Job(); + when(getState.execute()).thenReturn(modelJob); + modelJob.setCurrentState(State.RUNNING.toString()); + + DataflowPipelineJob job = + new DataflowPipelineJob(PROJECT_ID, JOB_ID, mockWorkflowClient, aggregatorTransforms); + + thrown.expect(AggregatorRetrievalException.class); + thrown.expectCause(is(cause)); + thrown.expectMessage(aggregator.toString()); + thrown.expectMessage("when retrieving Aggregator values for"); + + job.getAggregatorValues(aggregator); + } + + private static class TestAggregator implements Aggregator { + private final CombineFn combineFn; + private final String name; + + public TestAggregator(CombineFn combineFn, String name) { + this.combineFn = combineFn; + this.name = name; + } + + @Override + public void addValue(InT value) { + throw new AssertionError(); + } + + @Override + public String getName() { + return name; + } + + @Override + public CombineFn getCombineFn() { + return combineFn; + } + } + + private AppliedPTransform appliedPTransform( + String fullName, PTransform transform) { + return AppliedPTransform.of(fullName, mock(PInput.class), mock(POutput.class), transform); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRegistrarTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRegistrarTest.java new file mode 100644 index 000000000000..d5125cf5ce53 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRegistrarTest.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.options.BlockingDataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ServiceLoader; + +/** Tests for {@link DataflowPipelineRegistrar}. */ +@RunWith(JUnit4.class) +public class DataflowPipelineRegistrarTest { + @Test + public void testCorrectOptionsAreReturned() { + assertEquals(ImmutableList.of(DataflowPipelineOptions.class, + BlockingDataflowPipelineOptions.class), + new DataflowPipelineRegistrar.Options().getPipelineOptions()); + } + + @Test + public void testCorrectRunnersAreReturned() { + assertEquals(ImmutableList.of(DataflowPipelineRunner.class, + BlockingDataflowPipelineRunner.class), + new DataflowPipelineRegistrar.Runner().getPipelineRunners()); + } + + @Test + public void testServiceLoaderForOptions() { + for (PipelineOptionsRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class).iterator())) { + if (registrar instanceof DataflowPipelineRegistrar.Options) { + return; + } + } + fail("Expected to find " + DataflowPipelineRegistrar.Options.class); + } + + @Test + public void testServiceLoaderForRunner() { + for (PipelineRunnerRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class).iterator())) { + if (registrar instanceof DataflowPipelineRegistrar.Runner) { + return; + } + } + fail("Expected to find " + DataflowPipelineRegistrar.Runner.class); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java new file mode 100644 index 000000000000..c5f2d3fe71fd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java @@ -0,0 +1,1370 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.ListJobsResponse; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.AvroSource; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineDebugOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions.CheckEnabled; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.BatchViewAsList; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.BatchViewAsMap; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.BatchViewAsMultimap; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.TransformedMap; +import com.google.cloud.dataflow.sdk.runners.dataflow.TestCountingSource; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.IsmRecord; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.IsmRecordCoder; +import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.MetadataKeyCoder; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.cloud.dataflow.sdk.util.DataflowReleaseInfo; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.NoopPathValidator; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import org.hamcrest.Description; +import org.hamcrest.Matchers; +import org.hamcrest.TypeSafeMatcher; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.File; +import java.io.IOException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.channels.FileChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * Tests for DataflowPipelineRunner. + */ +@RunWith(JUnit4.class) +public class DataflowPipelineRunnerTest { + + private static final String PROJECT_ID = "some-project"; + + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule + public ExpectedException thrown = ExpectedException.none(); + + // Asserts that the given Job has all expected fields set. + private static void assertValidJob(Job job) { + assertNull(job.getId()); + assertNull(job.getCurrentState()); + } + + private DataflowPipeline buildDataflowPipeline(DataflowPipelineOptions options) { + options.setStableUniqueNames(CheckEnabled.ERROR); + DataflowPipeline p = DataflowPipeline.create(options); + + p.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/object")) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/object")); + + return p; + } + + private static Dataflow buildMockDataflow( + final ArgumentCaptor jobCaptor) throws IOException { + Dataflow mockDataflowClient = mock(Dataflow.class); + Dataflow.Projects mockProjects = mock(Dataflow.Projects.class); + Dataflow.Projects.Jobs mockJobs = mock(Dataflow.Projects.Jobs.class); + Dataflow.Projects.Jobs.Create mockRequest = + mock(Dataflow.Projects.Jobs.Create.class); + Dataflow.Projects.Jobs.List mockList = mock(Dataflow.Projects.Jobs.List.class); + + when(mockDataflowClient.projects()).thenReturn(mockProjects); + when(mockProjects.jobs()).thenReturn(mockJobs); + when(mockJobs.create(eq(PROJECT_ID), jobCaptor.capture())) + .thenReturn(mockRequest); + when(mockJobs.list(eq(PROJECT_ID))).thenReturn(mockList); + when(mockList.setPageToken(anyString())).thenReturn(mockList); + when(mockList.execute()) + .thenReturn(new ListJobsResponse().setJobs( + Arrays.asList(new Job() + .setName("oldJobName") + .setId("oldJobId") + .setCurrentState("JOB_STATE_RUNNING")))); + + Job resultJob = new Job(); + resultJob.setId("newid"); + when(mockRequest.execute()).thenReturn(resultJob); + return mockDataflowClient; + } + + private GcsUtil buildMockGcsUtil(boolean bucketExists) throws IOException { + GcsUtil mockGcsUtil = mock(GcsUtil.class); + when(mockGcsUtil.create(any(GcsPath.class), anyString())) + .then(new Answer() { + @Override + public SeekableByteChannel answer(InvocationOnMock invocation) throws Throwable { + return FileChannel.open( + Files.createTempFile("channel-", ".tmp"), + StandardOpenOption.CREATE, StandardOpenOption.DELETE_ON_CLOSE); + } + }); + + when(mockGcsUtil.isGcsPatternSupported(anyString())).thenReturn(true); + when(mockGcsUtil.expand(any(GcsPath.class))).then(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return ImmutableList.of((GcsPath) invocation.getArguments()[0]); + } + }); + when(mockGcsUtil.bucketExists(any(GcsPath.class))).thenReturn(bucketExists); + return mockGcsUtil; + } + + private DataflowPipelineOptions buildPipelineOptions() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + return buildPipelineOptions(jobCaptor); + } + + private DataflowPipelineOptions buildPipelineOptions( + ArgumentCaptor jobCaptor) throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setProject(PROJECT_ID); + options.setTempLocation("gs://somebucket/some/path"); + // Set FILES_PROPERTY to empty to prevent a default value calculated from classpath. + options.setFilesToStage(new LinkedList()); + options.setDataflowClient(buildMockDataflow(jobCaptor)); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + options.setGcpCredential(new TestCredential()); + return options; + } + + @Test + public void testRun() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + DataflowPipeline p = buildDataflowPipeline(options); + DataflowPipelineJob job = p.run(); + assertEquals("newid", job.getJobId()); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testRunReturnDifferentRequestId() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + Dataflow mockDataflowClient = options.getDataflowClient(); + Dataflow.Projects.Jobs.Create mockRequest = mock(Dataflow.Projects.Jobs.Create.class); + when(mockDataflowClient.projects().jobs().create(eq(PROJECT_ID), any(Job.class))) + .thenReturn(mockRequest); + Job resultJob = new Job(); + resultJob.setId("newid"); + // Return a different request id. + resultJob.setClientRequestId("different_request_id"); + when(mockRequest.execute()).thenReturn(resultJob); + + DataflowPipeline p = buildDataflowPipeline(options); + try { + p.run(); + fail("Expected DataflowJobAlreadyExistsException"); + } catch (DataflowJobAlreadyExistsException expected) { + assertThat(expected.getMessage(), + containsString("If you want to submit a second job, try again by setting a " + + "different name using --jobName.")); + assertEquals(expected.getJob().getJobId(), resultJob.getId()); + } + } + + @Test + public void testUpdate() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setUpdate(true); + options.setJobName("oldJobName"); + DataflowPipeline p = buildDataflowPipeline(options); + DataflowPipelineJob job = p.run(); + assertEquals("newid", job.getJobId()); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testUpdateNonExistentPipeline() throws IOException { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Could not find running job named badJobName"); + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setUpdate(true); + options.setJobName("badJobName"); + DataflowPipeline p = buildDataflowPipeline(options); + p.run(); + } + + @Test + public void testUpdateAlreadyUpdatedPipeline() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + options.setUpdate(true); + options.setJobName("oldJobName"); + Dataflow mockDataflowClient = options.getDataflowClient(); + Dataflow.Projects.Jobs.Create mockRequest = mock(Dataflow.Projects.Jobs.Create.class); + when(mockDataflowClient.projects().jobs().create(eq(PROJECT_ID), any(Job.class))) + .thenReturn(mockRequest); + final Job resultJob = new Job(); + resultJob.setId("newid"); + // Return a different request id. + resultJob.setClientRequestId("different_request_id"); + when(mockRequest.execute()).thenReturn(resultJob); + + DataflowPipeline p = buildDataflowPipeline(options); + + thrown.expect(DataflowJobAlreadyUpdatedException.class); + thrown.expect(new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText("Expected job ID: " + resultJob.getId()); + } + + @Override + protected boolean matchesSafely(DataflowJobAlreadyUpdatedException item) { + return resultJob.getId().equals(item.getJob().getJobId()); + } + }); + thrown.expectMessage("The job named oldjobname with id: oldJobId has already been updated " + + "into job id: newid and cannot be updated again."); + p.run(); + } + + @Test + public void testRunWithFiles() throws IOException { + // Test that the function DataflowPipelineRunner.stageFiles works as + // expected. + GcsUtil mockGcsUtil = buildMockGcsUtil(true /* bucket exists */); + final String gcsStaging = "gs://somebucket/some/path"; + final String gcsTemp = "gs://somebucket/some/temp/path"; + final String cloudDataflowDataset = "somedataset"; + + // Create some temporary files. + File temp1 = File.createTempFile("DataflowPipelineRunnerTest", "txt"); + temp1.deleteOnExit(); + File temp2 = File.createTempFile("DataflowPipelineRunnerTest2", "txt"); + temp2.deleteOnExit(); + + String overridePackageName = "alias.txt"; + + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setFilesToStage(ImmutableList.of( + temp1.getAbsolutePath(), + overridePackageName + "=" + temp2.getAbsolutePath())); + options.setStagingLocation(gcsStaging); + options.setTempLocation(gcsTemp); + options.setTempDatasetId(cloudDataflowDataset); + options.setProject(PROJECT_ID); + options.setJobName("job"); + options.setDataflowClient(buildMockDataflow(jobCaptor)); + options.setGcsUtil(mockGcsUtil); + options.setGcpCredential(new TestCredential()); + + DataflowPipeline p = buildDataflowPipeline(options); + + DataflowPipelineJob job = p.run(); + assertEquals("newid", job.getJobId()); + + Job workflowJob = jobCaptor.getValue(); + assertValidJob(workflowJob); + + assertEquals( + 2, + workflowJob.getEnvironment().getWorkerPools().get(0).getPackages().size()); + DataflowPackage workflowPackage1 = + workflowJob.getEnvironment().getWorkerPools().get(0).getPackages().get(0); + assertThat(workflowPackage1.getName(), startsWith(temp1.getName())); + DataflowPackage workflowPackage2 = + workflowJob.getEnvironment().getWorkerPools().get(0).getPackages().get(1); + assertEquals(overridePackageName, workflowPackage2.getName()); + + assertEquals( + "storage.googleapis.com/somebucket/some/temp/path", + workflowJob.getEnvironment().getTempStoragePrefix()); + assertEquals( + cloudDataflowDataset, + workflowJob.getEnvironment().getDataset()); + assertEquals( + DataflowReleaseInfo.getReleaseInfo().getName(), + workflowJob.getEnvironment().getUserAgent().get("name")); + assertEquals( + DataflowReleaseInfo.getReleaseInfo().getVersion(), + workflowJob.getEnvironment().getUserAgent().get("version")); + } + + @Test + public void runWithDefaultFilesToStage() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + options.setFilesToStage(null); + DataflowPipelineRunner.fromOptions(options); + assertTrue(!options.getFilesToStage().isEmpty()); + } + + @Test + public void detectClassPathResourceWithFileResources() throws Exception { + File file = tmpFolder.newFile("file"); + File file2 = tmpFolder.newFile("file2"); + URLClassLoader classLoader = new URLClassLoader(new URL[]{ + file.toURI().toURL(), + file2.toURI().toURL() + }); + + assertEquals(ImmutableList.of(file.getAbsolutePath(), file2.getAbsolutePath()), + DataflowPipelineRunner.detectClassPathResourcesToStage(classLoader)); + } + + @Test + public void detectClassPathResourcesWithUnsupportedClassLoader() { + ClassLoader mockClassLoader = Mockito.mock(ClassLoader.class); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unable to use ClassLoader to detect classpath elements."); + + DataflowPipelineRunner.detectClassPathResourcesToStage(mockClassLoader); + } + + @Test + public void detectClassPathResourceWithNonFileResources() throws Exception { + String url = "http://www.google.com/all-the-secrets.jar"; + URLClassLoader classLoader = new URLClassLoader(new URL[]{ + new URL(url) + }); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unable to convert url (" + url + ") to file."); + + DataflowPipelineRunner.detectClassPathResourcesToStage(classLoader); + } + + @Test + public void testGcsStagingLocationInitialization() throws Exception { + // Test that the staging location is initialized correctly. + String gcsTemp = "gs://somebucket/some/temp/path"; + + // Set temp location (required), and check that staging location is set. + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setTempLocation(gcsTemp); + options.setProject(PROJECT_ID); + options.setGcpCredential(new TestCredential()); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + + DataflowPipelineRunner.fromOptions(options); + + assertNotNull(options.getStagingLocation()); + } + + @Test + public void testNonGcsFilePathInReadFailure() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); + p.apply(TextIO.Read.named("ReadMyNonGcsFile").from(tmpFolder.newFile().getPath())); + + thrown.expectCause(Matchers.allOf( + instanceOf(IllegalArgumentException.class), + ThrowableMessageMatcher.hasMessage( + containsString("expected a valid 'gs://' path but was given")))); + p.run(); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testNonGcsFilePathInWriteFailure() throws IOException { + Pipeline p = buildDataflowPipeline(buildPipelineOptions()); + PCollection pc = p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString("expected a valid 'gs://' path but was given")); + pc.apply(TextIO.Write.named("WriteMyNonGcsFile").to("/tmp/file")); + } + + @Test + public void testMultiSlashGcsFileReadPath() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); + p.apply(TextIO.Read.named("ReadInvalidGcsFile") + .from("gs://bucket/tmp//file")); + + thrown.expectCause(Matchers.allOf( + instanceOf(IllegalArgumentException.class), + ThrowableMessageMatcher.hasMessage(containsString("consecutive slashes")))); + p.run(); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testMultiSlashGcsFileWritePath() throws IOException { + Pipeline p = buildDataflowPipeline(buildPipelineOptions()); + PCollection pc = p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("consecutive slashes"); + pc.apply(TextIO.Write.named("WriteInvalidGcsFile").to("gs://bucket/tmp//file")); + } + + @Test + public void testInvalidTempLocation() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setTempLocation("file://temp/location"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString("expected a valid 'gs://' path but was given")); + DataflowPipelineRunner.fromOptions(options); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testInvalidStagingLocation() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + options.setStagingLocation("file://my/staging/location"); + try { + DataflowPipelineRunner.fromOptions(options); + fail("fromOptions should have failed"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), containsString("expected a valid 'gs://' path but was given")); + } + options.setStagingLocation("my/staging/location"); + try { + DataflowPipelineRunner.fromOptions(options); + fail("fromOptions should have failed"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), containsString("expected a valid 'gs://' path but was given")); + } + } + + @Test + public void testNonExistentTempLocation() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + GcsUtil mockGcsUtil = buildMockGcsUtil(false /* bucket exists */); + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setGcsUtil(mockGcsUtil); + options.setTempLocation("gs://non-existent-bucket/location"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString( + "Output path does not exist or is not writeable: gs://non-existent-bucket/location")); + DataflowPipelineRunner.fromOptions(options); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testNonExistentStagingLocation() throws IOException { + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + GcsUtil mockGcsUtil = buildMockGcsUtil(false /* bucket exists */); + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + options.setGcsUtil(mockGcsUtil); + options.setStagingLocation("gs://non-existent-bucket/location"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(containsString( + "Output path does not exist or is not writeable: gs://non-existent-bucket/location")); + DataflowPipelineRunner.fromOptions(options); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testNoProjectFails() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + + options.setRunner(DataflowPipelineRunner.class); + // Explicitly set to null to prevent the default instance factory from reading credentials + // from a user's environment, causing this test to fail. + options.setProject(null); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Project id"); + thrown.expectMessage("when running a Dataflow in the cloud"); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testProjectId() throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("foo-12345"); + + options.setStagingLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + options.setGcpCredential(new TestCredential()); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testProjectPrefix() throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("google.com:some-project-12345"); + + options.setStagingLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + options.setGcpCredential(new TestCredential()); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testProjectNumber() throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("12345"); + + options.setStagingLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Project ID"); + thrown.expectMessage("project number"); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testProjectDescription() throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("some project"); + + options.setStagingLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Project ID"); + thrown.expectMessage("project description"); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testInvalidNumberOfWorkerHarnessThreads() throws IOException { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("foo-12345"); + + options.setStagingLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + + options.as(DataflowPipelineDebugOptions.class).setNumberOfWorkerHarnessThreads(-1); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Number of worker harness threads"); + thrown.expectMessage("Please make sure the value is non-negative."); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testNoStagingLocationAndNoTempLocationFails() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("foo-project"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Missing required value for group"); + thrown.expectMessage(DataflowPipelineOptions.DATAFLOW_STORAGE_LOCATION); + thrown.expectMessage("getStagingLocation"); + thrown.expectMessage("getTempLocation"); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testStagingLocationAndNoTempLocationSucceeds() throws Exception { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setGcpCredential(new TestCredential()); + options.setProject("foo-project"); + options.setStagingLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testTempLocationAndNoStagingLocationSucceeds() throws Exception { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setGcpCredential(new TestCredential()); + options.setProject("foo-project"); + options.setTempLocation("gs://spam/ham/eggs"); + options.setGcsUtil(buildMockGcsUtil(true /* bucket exists */)); + + DataflowPipelineRunner.fromOptions(options); + } + + @Test + public void testInvalidJobName() throws IOException { + List invalidNames = Arrays.asList( + "invalid_name", + "0invalid", + "invalid-"); + List expectedReason = Arrays.asList( + "JobName invalid", + "JobName invalid", + "JobName invalid"); + + for (int i = 0; i < invalidNames.size(); ++i) { + DataflowPipelineOptions options = buildPipelineOptions(); + options.setJobName(invalidNames.get(i)); + + try { + DataflowPipelineRunner.fromOptions(options); + fail("Expected IllegalArgumentException for jobName " + + options.getJobName()); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), + containsString(expectedReason.get(i))); + } + } + } + + @Test + public void testValidJobName() throws IOException { + List names = Arrays.asList("ok", "Ok", "A-Ok", "ok-123", + "this-one-is-fairly-long-01234567890123456789"); + + for (String name : names) { + DataflowPipelineOptions options = buildPipelineOptions(); + options.setJobName(name); + + DataflowPipelineRunner runner = DataflowPipelineRunner + .fromOptions(options); + assertNotNull(runner); + } + } + + /** + * A fake PTransform for testing. + */ + public static class TestTransform + extends PTransform, PCollection> { + public boolean translated = false; + + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + input.isBounded()); + } + + @Override + protected Coder getDefaultOutputCoder(PCollection input) { + return input.getCoder(); + } + } + + @Test + public void testTransformTranslatorMissing() throws IOException { + // Test that we throw if we don't provide a translation. + ArgumentCaptor jobCaptor = ArgumentCaptor.forClass(Job.class); + + DataflowPipelineOptions options = buildPipelineOptions(jobCaptor); + DataflowPipeline p = DataflowPipeline.create(options); + + p.apply(Create.of(Arrays.asList(1, 2, 3))) + .apply(new TestTransform()); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage(Matchers.containsString("no translator registered")); + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()); + assertValidJob(jobCaptor.getValue()); + } + + @Test + public void testTransformTranslator() throws IOException { + // Test that we can provide a custom translation + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipeline p = DataflowPipeline.create(options); + TestTransform transform = new TestTransform(); + + p.apply(Create.of(Arrays.asList(1, 2, 3)).withCoder(BigEndianIntegerCoder.of())) + .apply(transform); + + DataflowPipelineTranslator translator = DataflowPipelineRunner + .fromOptions(options).getTranslator(); + + DataflowPipelineTranslator.registerTransformTranslator( + TestTransform.class, + new DataflowPipelineTranslator.TransformTranslator() { + @SuppressWarnings("unchecked") + @Override + public void translate( + TestTransform transform, + DataflowPipelineTranslator.TranslationContext context) { + transform.translated = true; + + // Note: This is about the minimum needed to fake out a + // translation. This obviously isn't a real translation. + context.addStep(transform, "TestTranslate"); + context.addOutput("output", context.getOutput(transform)); + } + }); + + translator.translate( + p, p.getRunner(), Collections.emptyList()); + assertTrue(transform.translated); + } + + /** Records all the composite transforms visited within the Pipeline. */ + private static class CompositeTransformRecorder implements PipelineVisitor { + private List> transforms = new ArrayList<>(); + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + if (node.getTransform() != null) { + transforms.add(node.getTransform()); + } + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + } + + @Override + public void visitTransform(TransformTreeNode node) { + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + } + + public List> getCompositeTransforms() { + return transforms; + } + } + + @Test + public void testApplyIsScopedToExactClass() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipeline p = DataflowPipeline.create(options); + + Create.TimestampedValues transform = + Create.timestamped(Arrays.asList(TimestampedValue.of("TestString", Instant.now()))); + p.apply(transform); + + CompositeTransformRecorder recorder = new CompositeTransformRecorder(); + p.traverseTopologically(recorder); + + assertThat("Expected to have seen CreateTimestamped composite transform.", + recorder.getCompositeTransforms(), + Matchers.>contains(transform)); + } + + @Test + public void testToString() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setJobName("TestJobName"); + options.setProject("test-project"); + options.setTempLocation("gs://test/temp/location"); + options.setGcpCredential(new TestCredential()); + options.setPathValidatorClass(NoopPathValidator.class); + assertEquals("DataflowPipelineRunner#TestJobName", + DataflowPipelineRunner.fromOptions(options).toString()); + } + + private static PipelineOptions makeOptions(boolean streaming) { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setStreaming(streaming); + options.setJobName("TestJobName"); + options.setProject("test-project"); + options.setTempLocation("gs://test/temp/location"); + options.setGcpCredential(new TestCredential()); + options.setPathValidatorClass(NoopPathValidator.class); + return options; + } + + private void testUnsupportedSource(PTransform source, String name, boolean streaming) + throws Exception { + String mode = streaming ? "streaming" : "batch"; + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage( + "The DataflowPipelineRunner in " + mode + " mode does not support " + name); + + Pipeline p = Pipeline.create(makeOptions(streaming)); + p.apply(source); + p.run(); + } + + @Test + public void testBoundedSourceUnsupportedInStreaming() throws Exception { + testUnsupportedSource( + AvroSource.readFromFileWithClass("foo", String.class), "Read.Bounded", true); + } + + @Test + public void testBigQueryIOSourceUnsupportedInStreaming() throws Exception { + testUnsupportedSource( + BigQueryIO.Read.from("project:bar.baz").withoutValidation(), "BigQueryIO.Read", true); + } + + @Test + public void testAvroIOSourceUnsupportedInStreaming() throws Exception { + testUnsupportedSource( + AvroIO.Read.from("foo"), "AvroIO.Read", true); + } + + @Test + public void testTextIOSourceUnsupportedInStreaming() throws Exception { + testUnsupportedSource(TextIO.Read.from("foo"), "TextIO.Read", true); + } + + @Test + public void testReadBoundedSourceUnsupportedInStreaming() throws Exception { + testUnsupportedSource(Read.from(AvroSource.from("/tmp/test")), "Read.Bounded", true); + } + + @Test + public void testReadUnboundedUnsupportedInBatch() throws Exception { + testUnsupportedSource(Read.from(new TestCountingSource(1)), "Read.Unbounded", false); + } + + private void testUnsupportedSink( + PTransform, PDone> sink, String name, boolean streaming) + throws Exception { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage( + "The DataflowPipelineRunner in streaming mode does not support " + name); + + Pipeline p = Pipeline.create(makeOptions(streaming)); + p.apply(Create.of("foo")).apply(sink); + p.run(); + } + + @Test + public void testAvroIOSinkUnsupportedInStreaming() throws Exception { + testUnsupportedSink(AvroIO.Write.to("foo").withSchema(String.class), "AvroIO.Write", true); + } + + @Test + public void testTextIOSinkUnsupportedInStreaming() throws Exception { + testUnsupportedSink(TextIO.Write.to("foo"), "TextIO.Write", true); + } + + @Test + public void testBatchViewAsListToIsmRecordForGlobalWindow() throws Exception { + DoFnTester>> doFnTester = + DoFnTester.of(new BatchViewAsList.ToIsmRecordForGlobalWindowDoFn()); + + // The order of the output elements is important relative to processing order + assertThat(doFnTester.processBatch(ImmutableList.of("a", "b", "c")), contains( + IsmRecord.of(ImmutableList.of(GlobalWindow.INSTANCE, 0L), valueInGlobalWindow("a")), + IsmRecord.of(ImmutableList.of(GlobalWindow.INSTANCE, 1L), valueInGlobalWindow("b")), + IsmRecord.of(ImmutableList.of(GlobalWindow.INSTANCE, 2L), valueInGlobalWindow("c")))); + } + + @Test + public void testBatchViewAsListToIsmRecordForNonGlobalWindow() throws Exception { + DoFnTester>>>, + IsmRecord>> doFnTester = + DoFnTester.of( + new BatchViewAsList.ToIsmRecordForNonGlobalWindowDoFn( + IntervalWindow.getCoder())); + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + IntervalWindow windowB = new IntervalWindow(new Instant(10), new Instant(20)); + IntervalWindow windowC = new IntervalWindow(new Instant(20), new Instant(30)); + + Iterable>>>> inputElements = + ImmutableList.of( + KV.of(1, (Iterable>>) ImmutableList.of( + KV.of( + windowA, WindowedValue.of(110L, new Instant(1), windowA, PaneInfo.NO_FIRING)), + KV.of( + windowA, WindowedValue.of(111L, new Instant(3), windowA, PaneInfo.NO_FIRING)), + KV.of( + windowA, WindowedValue.of(112L, new Instant(4), windowA, PaneInfo.NO_FIRING)), + KV.of( + windowB, WindowedValue.of(120L, new Instant(12), windowB, PaneInfo.NO_FIRING)), + KV.of( + windowB, WindowedValue.of(121L, new Instant(14), windowB, PaneInfo.NO_FIRING)) + )), + KV.of(2, (Iterable>>) ImmutableList.of( + KV.of( + windowC, WindowedValue.of(210L, new Instant(25), windowC, PaneInfo.NO_FIRING)) + ))); + + // The order of the output elements is important relative to processing order + assertThat(doFnTester.processBatch(inputElements), contains( + IsmRecord.of(ImmutableList.of(windowA, 0L), + WindowedValue.of(110L, new Instant(1), windowA, PaneInfo.NO_FIRING)), + IsmRecord.of(ImmutableList.of(windowA, 1L), + WindowedValue.of(111L, new Instant(3), windowA, PaneInfo.NO_FIRING)), + IsmRecord.of(ImmutableList.of(windowA, 2L), + WindowedValue.of(112L, new Instant(4), windowA, PaneInfo.NO_FIRING)), + IsmRecord.of(ImmutableList.of(windowB, 0L), + WindowedValue.of(120L, new Instant(12), windowB, PaneInfo.NO_FIRING)), + IsmRecord.of(ImmutableList.of(windowB, 1L), + WindowedValue.of(121L, new Instant(14), windowB, PaneInfo.NO_FIRING)), + IsmRecord.of(ImmutableList.of(windowC, 0L), + WindowedValue.of(210L, new Instant(25), windowC, PaneInfo.NO_FIRING)))); + } + + @Test + public void testToIsmRecordForMapLikeDoFn() throws Exception { + TupleTag>> outputForSizeTag = new TupleTag<>(); + TupleTag>> outputForEntrySetTag = new TupleTag<>(); + + Coder keyCoder = VarLongCoder.of(); + Coder windowCoder = IntervalWindow.getCoder(); + + IsmRecordCoder> ismCoder = IsmRecordCoder.of( + 1, + 2, + ImmutableList.>of( + MetadataKeyCoder.of(keyCoder), + IntervalWindow.getCoder(), + BigEndianLongCoder.of()), + FullWindowedValueCoder.of(VarLongCoder.of(), windowCoder)); + + DoFnTester, WindowedValue>>>, + IsmRecord>> doFnTester = + DoFnTester.of(new BatchViewAsMultimap.ToIsmRecordForMapLikeDoFn( + outputForSizeTag, + outputForEntrySetTag, + windowCoder, + keyCoder, + ismCoder, + false /* unique keys */)); + doFnTester.setSideOutputTags(TupleTagList.of( + ImmutableList.>of(outputForSizeTag, outputForEntrySetTag))); + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + IntervalWindow windowB = new IntervalWindow(new Instant(10), new Instant(20)); + IntervalWindow windowC = new IntervalWindow(new Instant(20), new Instant(30)); + + Iterable, WindowedValue>>>> inputElements = + ImmutableList.of( + KV.of(1, (Iterable, WindowedValue>>) ImmutableList.of( + KV.of(KV.of(1L, windowA), + WindowedValue.of(110L, new Instant(1), windowA, PaneInfo.NO_FIRING)), + // same window same key as to previous + KV.of(KV.of(1L, windowA), + WindowedValue.of(111L, new Instant(2), windowA, PaneInfo.NO_FIRING)), + // same window different key as to previous + KV.of(KV.of(2L, windowA), + WindowedValue.of(120L, new Instant(3), windowA, PaneInfo.NO_FIRING)), + // different window same key as to previous + KV.of(KV.of(2L, windowB), + WindowedValue.of(210L, new Instant(11), windowB, PaneInfo.NO_FIRING)), + // different window and different key as to previous + KV.of(KV.of(3L, windowB), + WindowedValue.of(220L, new Instant(12), windowB, PaneInfo.NO_FIRING)))), + KV.of(2, (Iterable, WindowedValue>>) ImmutableList.of( + // different shard + KV.of(KV.of(4L, windowC), + WindowedValue.of(330L, new Instant(21), windowC, PaneInfo.NO_FIRING))))); + + // The order of the output elements is important relative to processing order + assertThat(doFnTester.processBatch(inputElements), contains( + IsmRecord.of( + ImmutableList.of(1L, windowA, 0L), + WindowedValue.of(110L, new Instant(1), windowA, PaneInfo.NO_FIRING)), + IsmRecord.of( + ImmutableList.of(1L, windowA, 1L), + WindowedValue.of(111L, new Instant(2), windowA, PaneInfo.NO_FIRING)), + IsmRecord.of( + ImmutableList.of(2L, windowA, 0L), + WindowedValue.of(120L, new Instant(3), windowA, PaneInfo.NO_FIRING)), + IsmRecord.of( + ImmutableList.of(2L, windowB, 0L), + WindowedValue.of(210L, new Instant(11), windowB, PaneInfo.NO_FIRING)), + IsmRecord.of( + ImmutableList.of(3L, windowB, 0L), + WindowedValue.of(220L, new Instant(12), windowB, PaneInfo.NO_FIRING)), + IsmRecord.of( + ImmutableList.of(4L, windowC, 0L), + WindowedValue.of(330L, new Instant(21), windowC, PaneInfo.NO_FIRING)))); + + // Verify the number of unique keys per window. + assertThat(doFnTester.takeSideOutputElements(outputForSizeTag), contains( + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowA)), + KV.of(windowA, 2L)), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowB)), + KV.of(windowB, 2L)), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowC)), + KV.of(windowC, 1L)) + )); + + // Verify the output for the unique keys. + assertThat(doFnTester.takeSideOutputElements(outputForEntrySetTag), contains( + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowA)), + KV.of(windowA, 1L)), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowA)), + KV.of(windowA, 2L)), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowB)), + KV.of(windowB, 2L)), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowB)), + KV.of(windowB, 3L)), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowC)), + KV.of(windowC, 4L)) + )); + } + + @Test + public void testToIsmRecordForMapLikeDoFnWithoutUniqueKeysThrowsException() throws Exception { + TupleTag>> outputForSizeTag = new TupleTag<>(); + TupleTag>> outputForEntrySetTag = new TupleTag<>(); + + Coder keyCoder = VarLongCoder.of(); + Coder windowCoder = IntervalWindow.getCoder(); + + IsmRecordCoder> ismCoder = IsmRecordCoder.of( + 1, + 2, + ImmutableList.>of( + MetadataKeyCoder.of(keyCoder), + IntervalWindow.getCoder(), + BigEndianLongCoder.of()), + FullWindowedValueCoder.of(VarLongCoder.of(), windowCoder)); + + DoFnTester, WindowedValue>>>, + IsmRecord>> doFnTester = + DoFnTester.of(new BatchViewAsMultimap.ToIsmRecordForMapLikeDoFn( + outputForSizeTag, + outputForEntrySetTag, + windowCoder, + keyCoder, + ismCoder, + true /* unique keys */)); + doFnTester.setSideOutputTags(TupleTagList.of( + ImmutableList.>of(outputForSizeTag, outputForEntrySetTag))); + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + + Iterable, WindowedValue>>>> inputElements = + ImmutableList.of( + KV.of(1, (Iterable, WindowedValue>>) ImmutableList.of( + KV.of(KV.of(1L, windowA), + WindowedValue.of(110L, new Instant(1), windowA, PaneInfo.NO_FIRING)), + // same window same key as to previous + KV.of(KV.of(1L, windowA), + WindowedValue.of(111L, new Instant(2), windowA, PaneInfo.NO_FIRING))))); + + try { + doFnTester.processBatch(inputElements); + fail("Expected UserCodeException"); + } catch (UserCodeException e) { + assertTrue(e.getCause() instanceof IllegalStateException); + IllegalStateException rootCause = (IllegalStateException) e.getCause(); + assertThat(rootCause.getMessage(), containsString("Unique keys are expected but found key")); + } + } + + @Test + public void testToIsmMetadataRecordForSizeDoFn() throws Exception { + TupleTag>> outputForSizeTag = new TupleTag<>(); + TupleTag>> outputForEntrySetTag = new TupleTag<>(); + + Coder keyCoder = VarLongCoder.of(); + Coder windowCoder = IntervalWindow.getCoder(); + + IsmRecordCoder> ismCoder = IsmRecordCoder.of( + 1, + 2, + ImmutableList.>of( + MetadataKeyCoder.of(keyCoder), + IntervalWindow.getCoder(), + BigEndianLongCoder.of()), + FullWindowedValueCoder.of(VarLongCoder.of(), windowCoder)); + + DoFnTester>>, + IsmRecord>> doFnTester = DoFnTester.of( + new BatchViewAsMultimap.ToIsmMetadataRecordForSizeDoFn( + windowCoder)); + doFnTester.setSideOutputTags(TupleTagList.of( + ImmutableList.>of(outputForSizeTag, outputForEntrySetTag))); + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + IntervalWindow windowB = new IntervalWindow(new Instant(10), new Instant(20)); + IntervalWindow windowC = new IntervalWindow(new Instant(20), new Instant(30)); + + Iterable>>> inputElements = + ImmutableList.of( + KV.of(1, + (Iterable>) ImmutableList.of( + KV.of(windowA, 2L), + KV.of(windowA, 3L), + KV.of(windowB, 7L))), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowB)), + (Iterable>) ImmutableList.of( + KV.of(windowC, 9L)))); + + // The order of the output elements is important relative to processing order + assertThat(doFnTester.processBatch(inputElements), contains( + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowA, 0L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 5L)), + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowB, 0L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 7L)), + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowC, 0L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 9L)) + )); + } + + @Test + public void testToIsmMetadataRecordForKeyDoFn() throws Exception { + TupleTag>> outputForSizeTag = new TupleTag<>(); + TupleTag>> outputForEntrySetTag = new TupleTag<>(); + + Coder keyCoder = VarLongCoder.of(); + Coder windowCoder = IntervalWindow.getCoder(); + + IsmRecordCoder> ismCoder = IsmRecordCoder.of( + 1, + 2, + ImmutableList.>of( + MetadataKeyCoder.of(keyCoder), + IntervalWindow.getCoder(), + BigEndianLongCoder.of()), + FullWindowedValueCoder.of(VarLongCoder.of(), windowCoder)); + + DoFnTester>>, + IsmRecord>> doFnTester = DoFnTester.of( + new BatchViewAsMultimap.ToIsmMetadataRecordForKeyDoFn( + keyCoder, windowCoder)); + doFnTester.setSideOutputTags(TupleTagList.of( + ImmutableList.>of(outputForSizeTag, outputForEntrySetTag))); + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + IntervalWindow windowB = new IntervalWindow(new Instant(10), new Instant(20)); + IntervalWindow windowC = new IntervalWindow(new Instant(20), new Instant(30)); + + Iterable>>> inputElements = + ImmutableList.of( + KV.of(1, + (Iterable>) ImmutableList.of( + KV.of(windowA, 2L), + // same window as previous + KV.of(windowA, 3L), + // different window as previous + KV.of(windowB, 3L))), + KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), windowB)), + (Iterable>) ImmutableList.of( + KV.of(windowC, 3L)))); + + // The order of the output elements is important relative to processing order + assertThat(doFnTester.processBatch(inputElements), contains( + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowA, 1L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 2L)), + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowA, 2L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 3L)), + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowB, 1L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 3L)), + IsmRecord.>meta( + ImmutableList.of(IsmFormat.getMetadataKey(), windowC, 1L), + CoderUtils.encodeToByteArray(VarLongCoder.of(), 3L)) + )); + } + + @Test + public void testToMapDoFn() throws Exception { + Coder windowCoder = IntervalWindow.getCoder(); + + DoFnTester>>>>, + IsmRecord, + Long>>>> doFnTester = + DoFnTester.of(new BatchViewAsMap.ToMapDoFn(windowCoder)); + + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + IntervalWindow windowB = new IntervalWindow(new Instant(10), new Instant(20)); + IntervalWindow windowC = new IntervalWindow(new Instant(20), new Instant(30)); + + Iterable>>>>> inputElements = + ImmutableList.of( + KV.of(1, + (Iterable>>>) ImmutableList.of( + KV.of(windowA, WindowedValue.of( + KV.of(1L, 11L), new Instant(3), windowA, PaneInfo.NO_FIRING)), + KV.of(windowA, WindowedValue.of( + KV.of(2L, 21L), new Instant(7), windowA, PaneInfo.NO_FIRING)), + KV.of(windowB, WindowedValue.of( + KV.of(2L, 21L), new Instant(13), windowB, PaneInfo.NO_FIRING)), + KV.of(windowB, WindowedValue.of( + KV.of(3L, 31L), new Instant(15), windowB, PaneInfo.NO_FIRING)))), + KV.of(2, + (Iterable>>>) ImmutableList.of( + KV.of(windowC, WindowedValue.of( + KV.of(4L, 41L), new Instant(25), windowC, PaneInfo.NO_FIRING))))); + + // The order of the output elements is important relative to processing order + List, + Long>>>> output = + doFnTester.processBatch(inputElements); + assertEquals(3, output.size()); + Map outputMap; + + outputMap = output.get(0).getValue().getValue(); + assertEquals(2, outputMap.size()); + assertEquals(ImmutableMap.of(1L, 11L, 2L, 21L), outputMap); + + outputMap = output.get(1).getValue().getValue(); + assertEquals(2, outputMap.size()); + assertEquals(ImmutableMap.of(2L, 21L, 3L, 31L), outputMap); + + outputMap = output.get(2).getValue().getValue(); + assertEquals(1, outputMap.size()); + assertEquals(ImmutableMap.of(4L, 41L), outputMap); + } + + @Test + public void testToMultimapDoFn() throws Exception { + Coder windowCoder = IntervalWindow.getCoder(); + + DoFnTester>>>>, + IsmRecord>, + Iterable>>>> doFnTester = + DoFnTester.of( + new BatchViewAsMultimap.ToMultimapDoFn(windowCoder)); + + + IntervalWindow windowA = new IntervalWindow(new Instant(0), new Instant(10)); + IntervalWindow windowB = new IntervalWindow(new Instant(10), new Instant(20)); + IntervalWindow windowC = new IntervalWindow(new Instant(20), new Instant(30)); + + Iterable>>>>> inputElements = + ImmutableList.of( + KV.of(1, + (Iterable>>>) ImmutableList.of( + KV.of(windowA, WindowedValue.of( + KV.of(1L, 11L), new Instant(3), windowA, PaneInfo.NO_FIRING)), + KV.of(windowA, WindowedValue.of( + KV.of(1L, 12L), new Instant(5), windowA, PaneInfo.NO_FIRING)), + KV.of(windowA, WindowedValue.of( + KV.of(2L, 21L), new Instant(7), windowA, PaneInfo.NO_FIRING)), + KV.of(windowB, WindowedValue.of( + KV.of(2L, 21L), new Instant(13), windowB, PaneInfo.NO_FIRING)), + KV.of(windowB, WindowedValue.of( + KV.of(3L, 31L), new Instant(15), windowB, PaneInfo.NO_FIRING)))), + KV.of(2, + (Iterable>>>) ImmutableList.of( + KV.of(windowC, WindowedValue.of( + KV.of(4L, 41L), new Instant(25), windowC, PaneInfo.NO_FIRING))))); + + // The order of the output elements is important relative to processing order + List>, + Iterable>>>> output = + doFnTester.processBatch(inputElements); + assertEquals(3, output.size()); + Map> outputMap; + + outputMap = output.get(0).getValue().getValue(); + assertEquals(2, outputMap.size()); + assertThat(outputMap.get(1L), containsInAnyOrder(11L, 12L)); + assertThat(outputMap.get(2L), containsInAnyOrder(21L)); + + outputMap = output.get(1).getValue().getValue(); + assertEquals(2, outputMap.size()); + assertThat(outputMap.get(2L), containsInAnyOrder(21L)); + assertThat(outputMap.get(3L), containsInAnyOrder(31L)); + + outputMap = output.get(2).getValue().getValue(); + assertEquals(1, outputMap.size()); + assertThat(outputMap.get(4L), containsInAnyOrder(41L)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTest.java new file mode 100644 index 000000000000..c34627cd23b2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTest.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.NoopPathValidator; +import com.google.cloud.dataflow.sdk.util.TestCredential; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DataflowPipeline}. */ +@RunWith(JUnit4.class) +public class DataflowPipelineTest { + @Test + public void testToString() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setJobName("TestJobName"); + options.setProject("project-id"); + options.setTempLocation("gs://test/temp/location"); + options.setGcpCredential(new TestCredential()); + options.setPathValidatorClass(NoopPathValidator.class); + assertEquals("DataflowPipeline#TestJobName", + DataflowPipeline.create(options).toString()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java new file mode 100644 index 000000000000..72090a0866a6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java @@ -0,0 +1,765 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static com.google.cloud.dataflow.sdk.util.Structs.addObject; +import static com.google.cloud.dataflow.sdk.util.Structs.getDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.api.services.dataflow.model.Job; +import com.google.api.services.dataflow.model.Step; +import com.google.api.services.dataflow.model.WorkerPool; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineWorkerPoolOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.OutputReference; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.Structs; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatcher; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * Tests for DataflowPipelineTranslator. + */ +@RunWith(JUnit4.class) +public class DataflowPipelineTranslatorTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + // A Custom Mockito matcher for an initial Job that checks that all + // expected fields are set. + private static class IsValidCreateRequest extends ArgumentMatcher { + @Override + public boolean matches(Object o) { + Job job = (Job) o; + return job.getId() == null + && job.getProjectId() == null + && job.getName() != null + && job.getType() != null + && job.getEnvironment() != null + && job.getSteps() != null + && job.getCurrentState() == null + && job.getCurrentStateTime() == null + && job.getExecutionInfo() == null + && job.getCreateTime() == null; + } + } + + private DataflowPipeline buildPipeline(DataflowPipelineOptions options) { + DataflowPipeline p = DataflowPipeline.create(options); + + p.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/object")) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/object")); + + return p; + } + + private static Dataflow buildMockDataflow( + ArgumentMatcher jobMatcher) throws IOException { + Dataflow mockDataflowClient = mock(Dataflow.class); + Dataflow.Projects mockProjects = mock(Dataflow.Projects.class); + Dataflow.Projects.Jobs mockJobs = mock(Dataflow.Projects.Jobs.class); + Dataflow.Projects.Jobs.Create mockRequest = mock( + Dataflow.Projects.Jobs.Create.class); + + when(mockDataflowClient.projects()).thenReturn(mockProjects); + when(mockProjects.jobs()).thenReturn(mockJobs); + when(mockJobs.create(eq("someProject"), argThat(jobMatcher))) + .thenReturn(mockRequest); + + Job resultJob = new Job(); + resultJob.setId("newid"); + when(mockRequest.execute()).thenReturn(resultJob); + return mockDataflowClient; + } + + private static DataflowPipelineOptions buildPipelineOptions() throws IOException { + GcsUtil mockGcsUtil = mock(GcsUtil.class); + when(mockGcsUtil.expand(any(GcsPath.class))).then(new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return ImmutableList.of((GcsPath) invocation.getArguments()[0]); + } + }); + when(mockGcsUtil.bucketExists(any(GcsPath.class))).thenReturn(true); + when(mockGcsUtil.isGcsPatternSupported(anyString())).thenCallRealMethod(); + + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + options.setJobName("some-job-name"); + options.setProject("some-project"); + options.setTempLocation(GcsPath.fromComponents("somebucket", "some/path").toString()); + options.setFilesToStage(new LinkedList()); + options.setDataflowClient(buildMockDataflow(new IsValidCreateRequest())); + options.setGcsUtil(mockGcsUtil); + return options; + } + + @Test + public void testSettingOfSdkPipelineOptions() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + options.setRunner(DataflowPipelineRunner.class); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + // Note that the contents of this materialized map may be changed by the act of reading an + // option, which will cause the default to get materialized whereas it would otherwise be + // left absent. It is permissible to simply alter this test to reflect current behavior. + Map settings = new HashMap<>(); + settings.put("appName", "DataflowPipelineTranslatorTest"); + settings.put("project", "some-project"); + settings.put("pathValidatorClass", "com.google.cloud.dataflow.sdk.util.DataflowPathValidator"); + settings.put("runner", "com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner"); + settings.put("jobName", "some-job-name"); + settings.put("tempLocation", "gs://somebucket/some/path"); + settings.put("stagingLocation", "gs://somebucket/some/path/staging"); + settings.put("stableUniqueNames", "WARNING"); + settings.put("streaming", false); + settings.put("numberOfWorkerHarnessThreads", 0); + settings.put("experiments", null); + + assertEquals(ImmutableMap.of("options", settings), + job.getEnvironment().getSdkPipelineOptions()); + } + + @Test + public void testNetworkConfig() throws IOException { + final String testNetwork = "test-network"; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setNetwork(testNetwork); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + assertEquals(testNetwork, + job.getEnvironment().getWorkerPools().get(0).getNetwork()); + } + + @Test + public void testNetworkConfigMissing() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + assertNull(job.getEnvironment().getWorkerPools().get(0).getNetwork()); + } + + @Test + public void testScalingAlgorithmMissing() throws IOException { + DataflowPipelineOptions options = buildPipelineOptions(); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + // Autoscaling settings are always set. + assertNull( + job + .getEnvironment() + .getWorkerPools() + .get(0) + .getAutoscalingSettings() + .getAlgorithm()); + assertEquals( + 0, + job + .getEnvironment() + .getWorkerPools() + .get(0) + .getAutoscalingSettings() + .getMaxNumWorkers() + .intValue()); + } + + @Test + public void testScalingAlgorithmNone() throws IOException { + final DataflowPipelineWorkerPoolOptions.AutoscalingAlgorithmType noScaling = + DataflowPipelineWorkerPoolOptions.AutoscalingAlgorithmType.NONE; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setAutoscalingAlgorithm(noScaling); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + assertEquals( + "AUTOSCALING_ALGORITHM_NONE", + job + .getEnvironment() + .getWorkerPools() + .get(0) + .getAutoscalingSettings() + .getAlgorithm()); + assertEquals( + 0, + job + .getEnvironment() + .getWorkerPools() + .get(0) + .getAutoscalingSettings() + .getMaxNumWorkers() + .intValue()); + } + + @Test + public void testMaxNumWorkersIsPassedWhenNoAlgorithmIsSet() throws IOException { + final DataflowPipelineWorkerPoolOptions.AutoscalingAlgorithmType noScaling = null; + DataflowPipelineOptions options = buildPipelineOptions(); + options.setMaxNumWorkers(42); + options.setAutoscalingAlgorithm(noScaling); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + assertNull( + job + .getEnvironment() + .getWorkerPools() + .get(0) + .getAutoscalingSettings() + .getAlgorithm()); + assertEquals( + 42, + job + .getEnvironment() + .getWorkerPools() + .get(0) + .getAutoscalingSettings() + .getMaxNumWorkers() + .intValue()); + } + + @Test + public void testZoneConfig() throws IOException { + final String testZone = "test-zone-1"; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setZone(testZone); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + assertEquals(testZone, + job.getEnvironment().getWorkerPools().get(0).getZone()); + } + + @Test + public void testWorkerMachineTypeConfig() throws IOException { + final String testMachineType = "test-machine-type"; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setWorkerMachineType(testMachineType); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + + WorkerPool workerPool = job.getEnvironment().getWorkerPools().get(0); + assertEquals(testMachineType, workerPool.getMachineType()); + } + + @Test + public void testDiskSizeGbConfig() throws IOException { + final Integer diskSizeGb = 1234; + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setDiskSizeGb(diskSizeGb); + + DataflowPipeline p = buildPipeline(options); + p.traverseTopologically(new RecordingPipelineVisitor()); + Job job = + DataflowPipelineTranslator.fromOptions(options) + .translate(p, p.getRunner(), Collections.emptyList()) + .getJob(); + + assertEquals(1, job.getEnvironment().getWorkerPools().size()); + assertEquals(diskSizeGb, + job.getEnvironment().getWorkerPools().get(0).getDiskSizeGb()); + } + + @Test + public void testPredefinedAddStep() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + DataflowPipelineTranslator.registerTransformTranslator( + EmbeddedTransform.class, new EmbeddedTranslator()); + + // Create a predefined step using another pipeline + Step predefinedStep = createPredefinedStep(); + + // Create a pipeline that the predefined step will be embedded into + DataflowPipeline pipeline = DataflowPipeline.create(options); + pipeline.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/in")) + .apply(ParDo.of(new NoOpFn())) + .apply(new EmbeddedTransform(predefinedStep.clone())) + .apply(ParDo.of(new NoOpFn())); + Job job = translator.translate( + pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); + + List steps = job.getSteps(); + assertEquals(4, steps.size()); + + // The input to the embedded step should match the output of the step before + Map step1Out = getOutputPortReference(steps.get(1)); + Map step2In = getDictionary( + steps.get(2).getProperties(), PropertyNames.PARALLEL_INPUT); + assertEquals(step1Out, step2In); + + // The output from the embedded step should match the input of the step after + Map step2Out = getOutputPortReference(steps.get(2)); + Map step3In = getDictionary( + steps.get(3).getProperties(), PropertyNames.PARALLEL_INPUT); + assertEquals(step2Out, step3In); + + // The step should not have been modified other than remapping the input + Step predefinedStepClone = predefinedStep.clone(); + Step embeddedStepClone = steps.get(2).clone(); + predefinedStepClone.getProperties().remove(PropertyNames.PARALLEL_INPUT); + embeddedStepClone.getProperties().remove(PropertyNames.PARALLEL_INPUT); + assertEquals(predefinedStepClone, embeddedStepClone); + } + + /** + * Construct a OutputReference for the output of the step. + */ + private static OutputReference getOutputPortReference(Step step) throws Exception { + // TODO: This should be done via a Structs accessor. + @SuppressWarnings("unchecked") + List> output = + (List>) step.getProperties().get(PropertyNames.OUTPUT_INFO); + String outputTagId = getString(Iterables.getOnlyElement(output), PropertyNames.OUTPUT_NAME); + return new OutputReference(step.getName(), outputTagId); + } + + /** + * Returns a Step for a DoFn by creating and translating a pipeline. + */ + private static Step createPredefinedStep() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + DataflowPipeline pipeline = DataflowPipeline.create(options); + String stepName = "DoFn1"; + pipeline.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/in")) + .apply(ParDo.of(new NoOpFn()).named(stepName)) + .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/out")); + Job job = translator.translate( + pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); + + assertEquals(13, job.getSteps().size()); + Step step = job.getSteps().get(1); + assertEquals(stepName, getString(step.getProperties(), PropertyNames.USER_NAME)); + return step; + } + + private static class NoOpFn extends DoFn{ + @Override public void processElement(ProcessContext c) throws Exception { + c.output(c.element()); + } + } + + /** + * A placeholder transform that will be used to substitute a predefined Step. + */ + private static class EmbeddedTransform + extends PTransform, PCollection> { + private final Step step; + + public EmbeddedTransform(Step step) { + this.step = step; + } + + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + input.isBounded()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return StringUtf8Coder.of(); + } + } + + /** + * A TransformTranslator that adds the predefined Step using + * {@link TranslationContext#addStep} and remaps the input port reference. + */ + private static class EmbeddedTranslator + implements DataflowPipelineTranslator.TransformTranslator { + @Override public void translate(EmbeddedTransform transform, TranslationContext context) { + addObject(transform.step.getProperties(), PropertyNames.PARALLEL_INPUT, + context.asOutputReference(context.getInput(transform))); + context.addStep(transform, transform.step); + } + } + + /** + * A composite transform that returns an output that is unrelated to + * the input. + */ + private static class UnrelatedOutputCreator + extends PTransform, PCollection> { + + @Override + public PCollection apply(PCollection input) { + // Apply an operation so that this is a composite transform. + input.apply(Count.perElement()); + + // Return a value unrelated to the input. + return input.getPipeline().apply(Create.of(1, 2, 3, 4)); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VarIntCoder.of(); + } + } + + /** + * A composite transform that returns an output that is unbound. + */ + private static class UnboundOutputCreator + extends PTransform, PDone> { + + @Override + public PDone apply(PCollection input) { + // Apply an operation so that this is a composite transform. + input.apply(Count.perElement()); + + return PDone.in(input.getPipeline()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + } + + /** + * A composite transform that returns a partially bound output. + * + *

    This is not allowed and will result in a failure. + */ + private static class PartiallyBoundOutputCreator + extends PTransform, PCollectionTuple> { + + public final TupleTag sumTag = new TupleTag<>("sum"); + public final TupleTag doneTag = new TupleTag<>("done"); + + @Override + public PCollectionTuple apply(PCollection input) { + PCollection sum = input.apply(Sum.integersGlobally()); + + // Fails here when attempting to construct a tuple with an unbound object. + return PCollectionTuple.of(sumTag, sum) + .and(doneTag, PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + input.isBounded())); + } + } + + @Test + public void testMultiGraphPipelineSerialization() throws IOException { + DataflowPipeline p = DataflowPipeline.create(buildPipelineOptions()); + + PCollection input = p.begin() + .apply(Create.of(1, 2, 3)); + + input.apply(new UnrelatedOutputCreator()); + input.apply(new UnboundOutputCreator()); + + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions( + PipelineOptionsFactory.as(DataflowPipelineOptions.class)); + + // Check that translation doesn't fail. + t.translate(p, p.getRunner(), Collections.emptyList()); + } + + @Test + public void testPartiallyBoundFailure() throws IOException { + Pipeline p = DataflowPipeline.create(buildPipelineOptions()); + + PCollection input = p.begin() + .apply(Create.of(1, 2, 3)); + + thrown.expect(IllegalStateException.class); + input.apply(new PartiallyBoundOutputCreator()); + + Assert.fail("Failure expected from use of partially bound output"); + } + + /** + * This tests a few corner cases that should not crash. + */ + @Test + public void testGoodWildcards() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipeline pipeline = DataflowPipeline.create(options); + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions(options); + + applyRead(pipeline, "gs://bucket/foo"); + applyRead(pipeline, "gs://bucket/foo/"); + applyRead(pipeline, "gs://bucket/foo/*"); + applyRead(pipeline, "gs://bucket/foo/?"); + applyRead(pipeline, "gs://bucket/foo/[0-9]"); + applyRead(pipeline, "gs://bucket/foo/*baz*"); + applyRead(pipeline, "gs://bucket/foo/*baz?"); + applyRead(pipeline, "gs://bucket/foo/[0-9]baz?"); + applyRead(pipeline, "gs://bucket/foo/baz/*"); + applyRead(pipeline, "gs://bucket/foo/baz/*wonka*"); + applyRead(pipeline, "gs://bucket/foo/*baz/wonka*"); + applyRead(pipeline, "gs://bucket/foo*/baz"); + applyRead(pipeline, "gs://bucket/foo?/baz"); + applyRead(pipeline, "gs://bucket/foo[0-9]/baz"); + + // Check that translation doesn't fail. + t.translate(pipeline, pipeline.getRunner(), Collections.emptyList()); + } + + private void applyRead(Pipeline pipeline, String path) { + pipeline.apply("Read(" + path + ")", TextIO.Read.from(path)); + } + + /** + * Recursive wildcards are not supported. + * This tests "**". + */ + @Test + public void testBadWildcardRecursive() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipeline pipeline = DataflowPipeline.create(options); + DataflowPipelineTranslator t = DataflowPipelineTranslator.fromOptions(options); + + pipeline.apply(TextIO.Read.from("gs://bucket/foo**/baz")); + + // Check that translation does fail. + thrown.expectCause(Matchers.allOf( + instanceOf(IllegalArgumentException.class), + ThrowableMessageMatcher.hasMessage(containsString("Unsupported wildcard usage")))); + t.translate(pipeline, pipeline.getRunner(), Collections.emptyList()); + } + + @Test + public void testToSingletonTranslation() throws Exception { + // A "change detector" test that makes sure the translation + // of getting a PCollectionView does not change + // in bad ways during refactor + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setExperiments(ImmutableList.of("disable_ism_side_input")); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + DataflowPipeline pipeline = DataflowPipeline.create(options); + pipeline.apply(Create.of(1)) + .apply(View.asSingleton()); + Job job = translator.translate( + pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); + + List steps = job.getSteps(); + assertEquals(2, steps.size()); + + Step createStep = steps.get(0); + assertEquals("CreateCollection", createStep.getKind()); + + Step collectionToSingletonStep = steps.get(1); + assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); + + } + + @Test + public void testToIterableTranslation() throws Exception { + // A "change detector" test that makes sure the translation + // of getting a PCollectionView> does not change + // in bad ways during refactor + + DataflowPipelineOptions options = buildPipelineOptions(); + options.setExperiments(ImmutableList.of("disable_ism_side_input")); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + DataflowPipeline pipeline = DataflowPipeline.create(options); + pipeline.apply(Create.of(1, 2, 3)) + .apply(View.asIterable()); + Job job = translator.translate( + pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); + + List steps = job.getSteps(); + assertEquals(2, steps.size()); + + Step createStep = steps.get(0); + assertEquals("CreateCollection", createStep.getKind()); + + Step collectionToSingletonStep = steps.get(1); + assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); + } + + @Test + public void testToSingletonTranslationWithIsmSideInput() throws Exception { + // A "change detector" test that makes sure the translation + // of getting a PCollectionView does not change + // in bad ways during refactor + + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + DataflowPipeline pipeline = DataflowPipeline.create(options); + pipeline.apply(Create.of(1)) + .apply(View.asSingleton()); + Job job = translator.translate( + pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); + + List steps = job.getSteps(); + assertEquals(5, steps.size()); + + @SuppressWarnings("unchecked") + List> toIsmRecordOutputs = + (List>) steps.get(3).getProperties().get(PropertyNames.OUTPUT_INFO); + assertTrue( + Structs.getBoolean(Iterables.getOnlyElement(toIsmRecordOutputs), "use_indexed_format")); + + Step collectionToSingletonStep = steps.get(4); + assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); + } + + @Test + public void testToIterableTranslationWithIsmSideInput() throws Exception { + // A "change detector" test that makes sure the translation + // of getting a PCollectionView> does not change + // in bad ways during refactor + + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + DataflowPipeline pipeline = DataflowPipeline.create(options); + pipeline.apply(Create.of(1, 2, 3)) + .apply(View.asIterable()); + Job job = translator.translate( + pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); + + List steps = job.getSteps(); + assertEquals(3, steps.size()); + + @SuppressWarnings("unchecked") + List> toIsmRecordOutputs = + (List>) steps.get(1).getProperties().get(PropertyNames.OUTPUT_INFO); + assertTrue( + Structs.getBoolean(Iterables.getOnlyElement(toIsmRecordOutputs), "use_indexed_format")); + + + Step collectionToSingletonStep = steps.get(2); + assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRegistrarTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRegistrarTest.java new file mode 100644 index 000000000000..e7b24f0b3006 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRegistrarTest.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ServiceLoader; + +/** Tests for {@link DirectPipelineRegistrar}. */ +@RunWith(JUnit4.class) +public class DirectPipelineRegistrarTest { + @Test + public void testCorrectOptionsAreReturned() { + assertEquals(ImmutableList.of(DirectPipelineOptions.class), + new DirectPipelineRegistrar.Options().getPipelineOptions()); + } + + @Test + public void testCorrectRunnersAreReturned() { + assertEquals(ImmutableList.of(DirectPipelineRunner.class), + new DirectPipelineRegistrar.Runner().getPipelineRunners()); + } + + @Test + public void testServiceLoaderForOptions() { + for (PipelineOptionsRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class).iterator())) { + if (registrar instanceof DirectPipelineRegistrar.Options) { + return; + } + } + fail("Expected to find " + DirectPipelineRegistrar.Options.class); + } + + @Test + public void testServiceLoaderForRunner() { + for (PipelineRunnerRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class).iterator())) { + if (registrar instanceof DirectPipelineRegistrar.Runner) { + return; + } + } + fail("Expected to find " + DirectPipelineRegistrar.Runner.class); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java new file mode 100644 index 000000000000..6524e144f8da --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java @@ -0,0 +1,210 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.common.collect.Iterables; +import com.google.common.io.Files; + +import org.apache.avro.file.DataFileReader; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +/** Tests for {@link DirectPipelineRunner}. */ +@RunWith(JUnit4.class) +public class DirectPipelineRunnerTest implements Serializable { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testToString() { + PipelineOptions options = PipelineOptionsFactory.create(); + DirectPipelineRunner runner = DirectPipelineRunner.fromOptions(options); + assertEquals("DirectPipelineRunner#" + runner.hashCode(), + runner.toString()); + } + + /** A {@link Coder} that fails during decoding. */ + private static class CrashingCoder extends AtomicCoder { + @Override + public void encode(T value, OutputStream stream, Context context) throws CoderException { + throw new CoderException("Called CrashingCoder.encode"); + } + + @Override + public T decode( + InputStream inStream, com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException { + throw new CoderException("Called CrashingCoder.decode"); + } + } + + /** A {@link DoFn} that outputs {@code 'hello'}. */ + private static class HelloDoFn extends DoFn { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + c.output("hello"); + } + } + + @Test + public void testCoderException() { + DirectPipeline pipeline = DirectPipeline.createForTest(); + + pipeline + .apply("CreateTestData", Create.of(42)) + .apply("CrashDuringCoding", ParDo.of(new HelloDoFn())) + .setCoder(new CrashingCoder()); + + expectedException.expect(RuntimeException.class); + expectedException.expectCause(isA(CoderException.class)); + pipeline.run(); + } + + @Test + public void testDirectPipelineOptions() { + DirectPipelineOptions options = PipelineOptionsFactory.create().as(DirectPipelineOptions.class); + assertNull(options.getDirectPipelineRunnerRandomSeed()); + } + + @Test + public void testTextIOWriteWithDefaultShardingStrategy() throws Exception { + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "output"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(TextIO.Write.to(prefix).withSuffix("txt")); + p.run(); + + String filename = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".txt", 0, 1); + List fileContents = + Files.readLines(new File(filename), StandardCharsets.UTF_8); + // Ensure that each file got at least one record + assertFalse(fileContents.isEmpty()); + + assertThat(fileContents, containsInAnyOrder(expectedElements)); + } + + @Test + public void testTextIOWriteWithLimitedNumberOfShards() throws Exception { + final int numShards = 3; + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "shardedOutput"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(TextIO.Write.to(prefix).withNumShards(numShards).withSuffix("txt")); + p.run(); + + List allContents = new ArrayList<>(); + for (int i = 0; i < numShards; ++i) { + String shardFileName = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".txt", i, 3); + List shardFileContents = + Files.readLines(new File(shardFileName), StandardCharsets.UTF_8); + + // Ensure that each file got at least one record + assertFalse(shardFileContents.isEmpty()); + + allContents.addAll(shardFileContents); + } + + assertThat(allContents, containsInAnyOrder(expectedElements)); + } + + @Test + public void testAvroIOWriteWithDefaultShardingStrategy() throws Exception { + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "output"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(AvroIO.Write.withSchema(String.class).to(prefix).withSuffix(".avro")); + p.run(); + + String filename = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".avro", 0, 1); + List fileContents = new ArrayList<>(); + Iterables.addAll(fileContents, DataFileReader.openReader( + new File(filename), AvroCoder.of(String.class).createDatumReader())); + + // Ensure that each file got at least one record + assertFalse(fileContents.isEmpty()); + + assertThat(fileContents, containsInAnyOrder(expectedElements)); + } + + @Test + public void testAvroIOWriteWithLimitedNumberOfShards() throws Exception { + final int numShards = 3; + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "shardedOutput"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(AvroIO.Write.withSchema(String.class).to(prefix) + .withNumShards(numShards).withSuffix(".avro")); + p.run(); + + List allContents = new ArrayList<>(); + for (int i = 0; i < numShards; ++i) { + String shardFileName = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".avro", i, 3); + List shardFileContents = new ArrayList<>(); + Iterables.addAll(shardFileContents, DataFileReader.openReader( + new File(shardFileName), AvroCoder.of(String.class).createDatumReader())); + + // Ensure that each file got at least one record + assertFalse(shardFileContents.isEmpty()); + + allContents.addAll(shardFileContents); + } + + assertThat(allContents, containsInAnyOrder(expectedElements)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineTest.java new file mode 100644 index 000000000000..ed4542fa13e5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineTest.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DirectPipeline}. */ +@RunWith(JUnit4.class) +public class DirectPipelineTest { + @Test + public void testToString() { + DirectPipeline pipeline = DirectPipeline.createForTest(); + assertEquals("DirectPipeline#" + pipeline.hashCode(), + pipeline.toString()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerTest.java new file mode 100644 index 000000000000..4b1f786dc6ca --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/PipelineRunnerTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.junit.Assert.assertTrue; + +import com.google.api.services.dataflow.Dataflow; +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.TestCredential; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for DataflowPipelineRunner. + */ +@RunWith(JUnit4.class) +public class PipelineRunnerTest { + + @Mock private Dataflow mockDataflow; + @Mock private GcsUtil mockGcsUtil; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testLongName() { + // Check we can create a pipeline runner using the full class name. + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setAppName("test"); + options.setProject("test"); + options.setGcsUtil(mockGcsUtil); + options.setRunner(DirectPipelineRunner.class); + options.setGcpCredential(new TestCredential()); + PipelineRunner runner = PipelineRunner.fromOptions(options); + assertTrue(runner instanceof DirectPipelineRunner); + } + + @Test + public void testShortName() { + // Check we can create a pipeline runner using the short class name. + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setAppName("test"); + options.setProject("test"); + options.setGcsUtil(mockGcsUtil); + options.setRunner(DirectPipelineRunner.class); + options.setGcpCredential(new TestCredential()); + PipelineRunner runner = PipelineRunner.fromOptions(options); + assertTrue(runner instanceof DirectPipelineRunner); + } + + @Test + public void testAppNameDefault() { + ApplicationNameOptions options = PipelineOptionsFactory.as(ApplicationNameOptions.class); + Assert.assertEquals(PipelineRunnerTest.class.getSimpleName(), + options.getAppName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java new file mode 100644 index 000000000000..68e1db1a52f2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java @@ -0,0 +1,194 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.io.Write; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.Sample; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PValue; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.util.Arrays; +import java.util.EnumSet; + +/** + * Tests for {@link TransformTreeNode} and {@link TransformHierarchy}. + */ +@RunWith(JUnit4.class) +public class TransformTreeTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + enum TransformsSeen { + READ, + WRITE, + SAMPLE_ANY + } + + /** + * INVALID TRANSFORM, DO NOT COPY. + * + *

    This is an invalid composite transform that returns unbound outputs. + * This should never happen, and is here to test that it is properly rejected. + */ + private static class InvalidCompositeTransform + extends PTransform> { + + @Override + public PCollectionList apply(PBegin b) { + // Composite transform: apply delegates to other transformations, + // here a Create transform. + PCollection result = b.apply(Create.of("hello", "world")); + + // Issue below: PCollection.createPrimitiveOutput should not be used + // from within a composite transform. + return PCollectionList.of( + Arrays.asList(result, PCollection.createPrimitiveOutputInternal( + b.getPipeline(), + WindowingStrategy.globalDefault(), + result.isBounded()))); + } + } + + /** + * A composite transform that returns an output that is unbound. + */ + private static class UnboundOutputCreator + extends PTransform, PDone> { + + @Override + public PDone apply(PCollection input) { + // Apply an operation so that this is a composite transform. + input.apply(Count.perElement()); + + return PDone.in(input.getPipeline()); + } + + @Override + protected Coder getDefaultOutputCoder() { + return VoidCoder.of(); + } + } + + // Builds a pipeline containing a composite operation (Pick), then + // visits the nodes and verifies that the hierarchy was captured. + @Test + public void testCompositeCapture() throws Exception { + File inputFile = tmpFolder.newFile(); + File outputFile = tmpFolder.newFile(); + + Pipeline p = DirectPipeline.createForTest(); + + p.apply(TextIO.Read.named("ReadMyFile").from(inputFile.getPath())) + .apply(Sample.any(10)) + .apply(TextIO.Write.named("WriteMyFile").to(outputFile.getPath())); + + final EnumSet visited = + EnumSet.noneOf(TransformsSeen.class); + final EnumSet left = + EnumSet.noneOf(TransformsSeen.class); + + p.traverseTopologically(new Pipeline.PipelineVisitor() { + @Override + public void enterCompositeTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + if (transform instanceof Sample.SampleAny) { + assertTrue(visited.add(TransformsSeen.SAMPLE_ANY)); + assertNotNull(node.getEnclosingNode()); + assertTrue(node.isCompositeNode()); + } else if (transform instanceof Write.Bound) { + assertTrue(visited.add(TransformsSeen.WRITE)); + assertNotNull(node.getEnclosingNode()); + assertTrue(node.isCompositeNode()); + } + assertThat(transform, not(instanceOf(Read.Bounded.class))); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + if (transform instanceof Sample.SampleAny) { + assertTrue(left.add(TransformsSeen.SAMPLE_ANY)); + } + } + + @Override + public void visitTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + // Pick is a composite, should not be visited here. + assertThat(transform, not(instanceOf(Sample.SampleAny.class))); + assertThat(transform, not(instanceOf(Write.Bound.class))); + if (transform instanceof Read.Bounded) { + assertTrue(visited.add(TransformsSeen.READ)); + } + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + } + }); + + assertTrue(visited.equals(EnumSet.allOf(TransformsSeen.class))); + assertTrue(left.equals(EnumSet.of(TransformsSeen.SAMPLE_ANY))); + } + + @Test(expected = IllegalStateException.class) + public void testOutputChecking() throws Exception { + Pipeline p = DirectPipeline.createForTest(); + + p.apply(new InvalidCompositeTransform()); + + p.traverseTopologically(new RecordingPipelineVisitor()); + fail("traversal should have failed with an IllegalStateException"); + } + + @Test + public void testMultiGraphSetup() { + Pipeline p = DirectPipeline.createForTest(); + + PCollection input = p.begin() + .apply(Create.of(1, 2, 3)); + + input.apply(new UnboundOutputCreator()); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/CustomSourcesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/CustomSourcesTest.java new file mode 100644 index 000000000000..e09251b6a977 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/CustomSourcesTest.java @@ -0,0 +1,273 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; +import static com.google.cloud.dataflow.sdk.testing.SourceTestUtils.readFromSource; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Sample; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Preconditions; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for {@link CustomSources}. + */ +@RunWith(JUnit4.class) +public class CustomSourcesTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + @Rule public ExpectedLogs logged = ExpectedLogs.none(CustomSources.class); + + static class TestIO { + public static Read fromRange(int from, int to) { + return new Read(from, to, false); + } + + static class Read extends BoundedSource { + final int from; + final int to; + final boolean produceTimestamps; + + Read(int from, int to, boolean produceTimestamps) { + this.from = from; + this.to = to; + this.produceTimestamps = produceTimestamps; + } + + public Read withTimestampsMillis() { + return new Read(from, to, true); + } + + @Override + public List splitIntoBundles(long desiredBundleSizeBytes, PipelineOptions options) + throws Exception { + List res = new ArrayList<>(); + DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); + float step = 1.0f * (to - from) / dataflowOptions.getNumWorkers(); + for (int i = 0; i < dataflowOptions.getNumWorkers(); ++i) { + res.add(new Read( + Math.round(from + i * step), Math.round(from + (i + 1) * step), + produceTimestamps)); + } + return res; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 8 * (to - from); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return true; + } + + @Override + public BoundedReader createReader(PipelineOptions options) throws IOException { + return new RangeReader(this); + } + + @Override + public void validate() {} + + @Override + public String toString() { + return "[" + from + ", " + to + ")"; + } + + @Override + public Coder getDefaultOutputCoder() { + return BigEndianIntegerCoder.of(); + } + + private static class RangeReader extends BoundedReader { + // To verify that BasicSerializableSourceFormat calls our methods according to protocol. + enum State { + UNSTARTED, + STARTED, + FINISHED + } + private Read source; + private int current = -1; + private State state = State.UNSTARTED; + + public RangeReader(Read source) { + this.source = source; + } + + @Override + public boolean start() throws IOException { + Preconditions.checkState(state == State.UNSTARTED); + state = State.STARTED; + current = source.from; + return (current < source.to); + } + + @Override + public boolean advance() throws IOException { + Preconditions.checkState(state == State.STARTED); + if (current == source.to - 1) { + state = State.FINISHED; + return false; + } + current++; + return true; + } + + @Override + public Integer getCurrent() { + Preconditions.checkState(state == State.STARTED); + return current; + } + + @Override + public Instant getCurrentTimestamp() { + return source.produceTimestamps + ? new Instant(current /* as millis */) : BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + @Override + public void close() throws IOException { + Preconditions.checkState(state == State.STARTED || state == State.FINISHED); + state = State.FINISHED; + } + + @Override + public Read getCurrentSource() { + return source; + } + + @Override + public Read splitAtFraction(double fraction) { + int proposedIndex = (int) (source.from + fraction * (source.to - source.from)); + if (proposedIndex <= current) { + return null; + } + Read primary = new Read(source.from, proposedIndex, source.produceTimestamps); + Read residual = new Read(proposedIndex, source.to, source.produceTimestamps); + this.source = primary; + return residual; + } + + @Override + public Double getFractionConsumed() { + return (current == -1) + ? 0.0 + : (1.0 * (1 + current - source.from) / (source.to - source.from)); + } + } + } + } + + @Test + public void testDirectPipelineWithoutTimestamps() throws Exception { + Pipeline p = TestPipeline.create(); + PCollection sum = p + .apply(Read.from(TestIO.fromRange(10, 20))) + .apply(Sum.integersGlobally()) + .apply(Sample.any(1)); + + DataflowAssert.thatSingleton(sum).isEqualTo(145); + p.run(); + } + + @Test + public void testDirectPipelineWithTimestamps() throws Exception { + Pipeline p = TestPipeline.create(); + PCollection sums = + p.apply(Read.from(TestIO.fromRange(10, 20).withTimestampsMillis())) + .apply(Window.into(FixedWindows.of(Duration.millis(3)))) + .apply(Sum.integersGlobally().withoutDefaults()); + // Should group into [10 11] [12 13 14] [15 16 17] [18 19]. + DataflowAssert.that(sums).containsInAnyOrder(21, 37, 39, 48); + p.run(); + } + + @Test + public void testRangeProgressAndSplitAtFraction() throws Exception { + // Show basic usage of getFractionConsumed and splitAtFraction. + // This test only tests TestIO itself, not BasicSerializableSourceFormat. + + DataflowPipelineOptions options = + PipelineOptionsFactory.create().as(DataflowPipelineOptions.class); + TestIO.Read source = TestIO.fromRange(10, 20); + try (BoundedSource.BoundedReader reader = source.createReader(options)) { + assertEquals(0, reader.getFractionConsumed().intValue()); + assertTrue(reader.start()); + assertEquals(0.1, reader.getFractionConsumed(), 1e-6); + assertTrue(reader.advance()); + assertEquals(0.2, reader.getFractionConsumed(), 1e-6); + // Already past 0.0 and 0.1. + assertNull(reader.splitAtFraction(0.0)); + assertNull(reader.splitAtFraction(0.1)); + + { + TestIO.Read residual = (TestIO.Read) reader.splitAtFraction(0.5); + assertNotNull(residual); + TestIO.Read primary = (TestIO.Read) reader.getCurrentSource(); + assertThat(readFromSource(primary, options), contains(10, 11, 12, 13, 14)); + assertThat(readFromSource(residual, options), contains(15, 16, 17, 18, 19)); + } + + // Range is now [10, 15) and we are at 12. + { + TestIO.Read residual = (TestIO.Read) reader.splitAtFraction(0.8); // give up 14. + assertNotNull(residual); + TestIO.Read primary = (TestIO.Read) reader.getCurrentSource(); + assertThat(readFromSource(primary, options), contains(10, 11, 12, 13)); + assertThat(readFromSource(residual, options), contains(14)); + } + + assertTrue(reader.advance()); + assertEquals(12, reader.getCurrent().intValue()); + assertTrue(reader.advance()); + assertEquals(13, reader.getCurrent().intValue()); + assertFalse(reader.advance()); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java new file mode 100644 index 000000000000..181ddcae5bcc --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java @@ -0,0 +1,212 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.encodeToByteArray; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DelegateCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.values.KV; + +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +import javax.annotation.Nullable; + +/** + * An unbounded source for testing the unbounded sources framework code. + * + *

    Each split of this sources produces records of the form KV(split_id, i), + * where i counts up from 0. Each record has a timestamp of i, and the watermark + * accurately tracks these timestamps. The reader will occasionally return false + * from {@code advance}, in order to simulate a source where not all the data is + * available immediately. + */ +public class TestCountingSource + extends UnboundedSource, TestCountingSource.CounterMark> { + private static List finalizeTracker; + private final int numMessagesPerShard; + private final int shardNumber; + private final boolean dedup; + + public static void setFinalizeTracker(List finalizeTracker) { + TestCountingSource.finalizeTracker = finalizeTracker; + } + + public TestCountingSource(int numMessagesPerShard) { + this(numMessagesPerShard, 0, false); + } + + public TestCountingSource withDedup() { + return new TestCountingSource(numMessagesPerShard, shardNumber, true); + } + + private TestCountingSource withShardNumber(int shardNumber) { + return new TestCountingSource(numMessagesPerShard, shardNumber, dedup); + } + + private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup) { + this.numMessagesPerShard = numMessagesPerShard; + this.shardNumber = shardNumber; + this.dedup = dedup; + } + + public int getShardNumber() { + return shardNumber; + } + + @Override + public List generateInitialSplits( + int desiredNumSplits, PipelineOptions options) { + List splits = new ArrayList<>(); + for (int i = 0; i < desiredNumSplits; i++) { + splits.add(withShardNumber(i)); + } + return splits; + } + + class CounterMark implements UnboundedSource.CheckpointMark { + int current; + + public CounterMark(int current) { + this.current = current; + } + + @Override + public void finalizeCheckpoint() { + if (finalizeTracker != null) { + finalizeTracker.add(current); + } + } + } + + @Override + public Coder getCheckpointMarkCoder() { + return DelegateCoder.of( + VarIntCoder.of(), + new DelegateCoder.CodingFunction() { + @Override + public Integer apply(CounterMark input) { + return input.current; + } + }, + new DelegateCoder.CodingFunction() { + @Override + public CounterMark apply(Integer input) { + return new CounterMark(input); + } + }); + } + + @Override + public boolean requiresDeduping() { + return dedup; + } + + private class CountingSourceReader extends UnboundedReader> { + private int current; + + public CountingSourceReader(int startingPoint) { + this.current = startingPoint; + } + + @Override + public boolean start() { + return true; + } + + @Override + public boolean advance() { + if (current < numMessagesPerShard - 1) { + // If testing dedup, occasionally insert a duplicate value; + if (dedup && ThreadLocalRandom.current().nextInt(5) == 0) { + return true; + } + current++; + return true; + } else { + return false; + } + } + + @Override + public KV getCurrent() { + return KV.of(shardNumber, current); + } + + @Override + public Instant getCurrentTimestamp() { + return new Instant(current); + } + + @Override + public byte[] getCurrentRecordId() { + try { + return encodeToByteArray(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), getCurrent()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() {} + + @Override + public TestCountingSource getCurrentSource() { + return TestCountingSource.this; + } + + @Override + public Instant getWatermark() { + // The watermark is a promise about future elements, and the timestamps of elements are + // strictly increasing for this source. + return new Instant(current + 1); + } + + @Override + public CheckpointMark getCheckpointMark() { + return new CounterMark(current); + } + + @Override + public long getSplitBacklogBytes() { + return 7L; + } + } + + @Override + public CountingSourceReader createReader( + PipelineOptions options, @Nullable CounterMark checkpointMark) { + return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : 0); + } + + @Override + public void validate() {} + + @Override + public Coder> getDefaultOutputCoder() { + return KvCoder.of(VarIntCoder.of(), VarIntCoder.of()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java new file mode 100644 index 000000000000..9f22fbbe9e51 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.CountingSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link BoundedReadEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class BoundedReadEvaluatorFactoryTest { + private BoundedSource source; + private PCollection longs; + private TransformEvaluatorFactory factory; + private InProcessEvaluationContext context; + + @Before + public void setup() { + source = CountingSource.upTo(10L); + TestPipeline p = TestPipeline.create(); + longs = p.apply(Read.from(source)); + + factory = new BoundedReadEvaluatorFactory(); + context = mock(InProcessEvaluationContext.class); + } + + @Test + public void boundedSourceInMemoryTransformEvaluatorProducesElements() throws Exception { + UncommittedBundle output = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(longs)).thenReturn(output); + + TransformEvaluator evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + assertThat( + output.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(), + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + } + + /** + * Demonstrate that acquiring multiple {@link TransformEvaluator TransformEvaluators} for the same + * {@link Bounded Read.Bounded} application with the same evaluation context only produces the + * elements once. + */ + @Test + public void boundedSourceInMemoryTransformEvaluatorAfterFinishIsEmpty() throws Exception { + UncommittedBundle output = + InProcessBundle.unkeyed(longs); + when(context.createRootBundle(longs)).thenReturn(output); + + TransformEvaluator evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + Iterable> outputElements = + output.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(); + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + + UncommittedBundle secondOutput = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(longs)).thenReturn(secondOutput); + TransformEvaluator secondEvaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult secondResult = secondEvaluator.finishBundle(); + assertThat(secondResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(secondResult.getOutputBundles(), emptyIterable()); + assertThat( + secondOutput.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(), emptyIterable()); + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + } + + /** + * Demonstrates that acquiring multiple evaluators from the factory are independent, but + * the elements in the source are only produced once. + */ + @Test + public void boundedSourceEvaluatorSimultaneousEvaluations() throws Exception { + UncommittedBundle output = InProcessBundle.unkeyed(longs); + UncommittedBundle secondOutput = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(longs)).thenReturn(output).thenReturn(secondOutput); + + // create both evaluators before finishing either. + TransformEvaluator evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + TransformEvaluator secondEvaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + + InProcessTransformResult secondResult = secondEvaluator.finishBundle(); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + Iterable> outputElements = + output.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(); + + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + assertThat(secondResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(secondResult.getOutputBundles(), emptyIterable()); + assertThat( + secondOutput.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(), emptyIterable()); + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + } + + private static WindowedValue gw(Long elem) { + return WindowedValue.valueInGlobalWindow(elem); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java new file mode 100644 index 000000000000..bf25970affc1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlattenEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class FlattenEvaluatorFactoryTest { + @Test + public void testFlattenInMemoryEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + PCollection left = p.apply("left", Create.of(1, 2, 4)); + PCollection right = p.apply("right", Create.of(-1, 2, -4)); + PCollectionList list = PCollectionList.of(left).and(right); + + PCollection flattened = list.apply(Flatten.pCollections()); + + CommittedBundle leftBundle = InProcessBundle.unkeyed(left).commit(Instant.now()); + CommittedBundle rightBundle = InProcessBundle.unkeyed(right).commit(Instant.now()); + + InProcessEvaluationContext context = mock(InProcessEvaluationContext.class); + + UncommittedBundle flattenedLeftBundle = InProcessBundle.unkeyed(flattened); + UncommittedBundle flattenedRightBundle = InProcessBundle.unkeyed(flattened); + + when(context.createBundle(leftBundle, flattened)).thenReturn(flattenedLeftBundle); + when(context.createBundle(rightBundle, flattened)).thenReturn(flattenedRightBundle); + + FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(); + TransformEvaluator leftSideEvaluator = + factory.forApplication(flattened.getProducingTransformInternal(), leftBundle, context); + TransformEvaluator rightSideEvaluator = + factory.forApplication( + flattened.getProducingTransformInternal(), + rightBundle, + context); + + leftSideEvaluator.processElement(WindowedValue.valueInGlobalWindow(1)); + rightSideEvaluator.processElement(WindowedValue.valueInGlobalWindow(-1)); + leftSideEvaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1024))); + leftSideEvaluator.processElement(WindowedValue.valueInEmptyWindows(4, PaneInfo.NO_FIRING)); + rightSideEvaluator.processElement( + WindowedValue.valueInEmptyWindows(2, PaneInfo.ON_TIME_AND_ONLY_FIRING)); + rightSideEvaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow(-4, new Instant(-4096))); + + InProcessTransformResult rightSideResult = rightSideEvaluator.finishBundle(); + InProcessTransformResult leftSideResult = leftSideEvaluator.finishBundle(); + + assertThat( + rightSideResult.getOutputBundles(), + Matchers.>contains(flattenedRightBundle)); + assertThat( + rightSideResult.getTransform(), + Matchers.>equalTo(flattened.getProducingTransformInternal())); + assertThat( + leftSideResult.getOutputBundles(), + Matchers.>contains(flattenedLeftBundle)); + assertThat( + leftSideResult.getTransform(), + Matchers.>equalTo(flattened.getProducingTransformInternal())); + + assertThat( + flattenedLeftBundle.commit(Instant.now()).getElements(), + containsInAnyOrder( + WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1024)), + WindowedValue.valueInEmptyWindows(4, PaneInfo.NO_FIRING), + WindowedValue.valueInGlobalWindow(1))); + assertThat( + flattenedRightBundle.commit(Instant.now()).getElements(), + containsInAnyOrder( + WindowedValue.valueInEmptyWindows(2, PaneInfo.ON_TIME_AND_ONLY_FIRING), + WindowedValue.timestampedValueInGlobalWindow(-4, new Instant(-4096)), + WindowedValue.valueInGlobalWindow(-1))); + } + + @Test + public void testFlattenInMemoryEvaluatorWithEmptyPCollectionList() throws Exception { + TestPipeline p = TestPipeline.create(); + PCollectionList list = PCollectionList.empty(p); + + PCollection flattened = list.apply(Flatten.pCollections()); + + InProcessEvaluationContext context = mock(InProcessEvaluationContext.class); + + FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(); + TransformEvaluator emptyEvaluator = + factory.forApplication(flattened.getProducingTransformInternal(), null, context); + + InProcessTransformResult leftSideResult = emptyEvaluator.finishBundle(); + + assertThat(leftSideResult.getOutputBundles(), emptyIterable()); + assertThat( + leftSideResult.getTransform(), + Matchers.>equalTo(flattened.getProducingTransformInternal())); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ForwardingPTransformTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ForwardingPTransformTest.java new file mode 100644 index 000000000000..2e283f50b88d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ForwardingPTransformTest.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link ForwardingPTransform}. + */ +@RunWith(JUnit4.class) +public class ForwardingPTransformTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Mock private PTransform, PCollection> delegate; + + private ForwardingPTransform, PCollection> forwarding; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + forwarding = + new ForwardingPTransform, PCollection>() { + @Override + protected PTransform, PCollection> delegate() { + return delegate; + } + }; + } + + @Test + public void applyDelegates() { + @SuppressWarnings("unchecked") + PCollection collection = mock(PCollection.class); + @SuppressWarnings("unchecked") + PCollection output = mock(PCollection.class); + when(delegate.apply(collection)).thenReturn(output); + PCollection result = forwarding.apply(collection); + assertThat(result, equalTo(output)); + } + + @Test + public void getNameDelegates() { + String name = "My_forwardingptransform-name;for!thisTest"; + when(delegate.getName()).thenReturn(name); + assertThat(forwarding.getName(), equalTo(name)); + } + + @Test + public void validateDelegates() { + @SuppressWarnings("unchecked") + PCollection input = mock(PCollection.class); + doThrow(RuntimeException.class).when(delegate).validate(input); + + thrown.expect(RuntimeException.class); + forwarding.validate(input); + } + + @Test + public void getDefaultOutputCoderDelegates() throws Exception { + @SuppressWarnings("unchecked") + PCollection input = mock(PCollection.class); + @SuppressWarnings("unchecked") + PCollection output = mock(PCollection.class); + @SuppressWarnings("unchecked") + Coder outputCoder = mock(Coder.class); + + when(delegate.getDefaultOutputCoder(input, output)).thenReturn(outputCoder); + assertThat(forwarding.getDefaultOutputCoder(input, output), equalTo(outputCoder)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java new file mode 100644 index 000000000000..5c9e824afe41 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItem; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItems; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multiset; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link GroupByKeyEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class GroupByKeyEvaluatorFactoryTest { + @Test + public void testInMemoryEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + KV firstFoo = KV.of("foo", -1); + KV secondFoo = KV.of("foo", 1); + KV thirdFoo = KV.of("foo", 3); + KV firstBar = KV.of("bar", 22); + KV secondBar = KV.of("bar", 12); + KV firstBaz = KV.of("baz", Integer.MAX_VALUE); + PCollection> values = + p.apply(Create.of(firstFoo, firstBar, secondFoo, firstBaz, secondBar, thirdFoo)); + PCollection>> kvs = + values.apply(new GroupByKey.ReifyTimestampsAndWindows()); + PCollection> groupedKvs = + kvs.apply(new GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly()); + + CommittedBundle>> inputBundle = + InProcessBundle.unkeyed(kvs).commit(Instant.now()); + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + + UncommittedBundle> fooBundle = + InProcessBundle.keyed(groupedKvs, "foo"); + UncommittedBundle> barBundle = + InProcessBundle.keyed(groupedKvs, "bar"); + UncommittedBundle> bazBundle = + InProcessBundle.keyed(groupedKvs, "baz"); + + when(evaluationContext.createKeyedBundle(inputBundle, "foo", groupedKvs)).thenReturn(fooBundle); + when(evaluationContext.createKeyedBundle(inputBundle, "bar", groupedKvs)).thenReturn(barBundle); + when(evaluationContext.createKeyedBundle(inputBundle, "baz", groupedKvs)).thenReturn(bazBundle); + + // The input to a GroupByKey is assumed to be a KvCoder + @SuppressWarnings("unchecked") + Coder keyCoder = + ((KvCoder>) kvs.getCoder()).getKeyCoder(); + TransformEvaluator>> evaluator = + new GroupByKeyEvaluatorFactory() + .forApplication( + groupedKvs.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(secondFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(thirdFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstBar))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(secondBar))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstBaz))); + + evaluator.finishBundle(); + + assertThat( + fooBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher( + KeyedWorkItems.elementsWorkItem( + "foo", + ImmutableSet.of( + WindowedValue.valueInGlobalWindow(-1), + WindowedValue.valueInGlobalWindow(1), + WindowedValue.valueInGlobalWindow(3))), + keyCoder))); + assertThat( + barBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher( + KeyedWorkItems.elementsWorkItem( + "bar", + ImmutableSet.of( + WindowedValue.valueInGlobalWindow(12), + WindowedValue.valueInGlobalWindow(22))), + keyCoder))); + assertThat( + bazBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher( + KeyedWorkItems.elementsWorkItem( + "baz", + ImmutableSet.of(WindowedValue.valueInGlobalWindow(Integer.MAX_VALUE))), + keyCoder))); + } + + private KV> gwValue(KV kv) { + return KV.of(kv.getKey(), WindowedValue.valueInGlobalWindow(kv.getValue())); + } + + private static class KeyedWorkItemMatcher + extends BaseMatcher>> { + private final KeyedWorkItem myWorkItem; + private final Coder keyCoder; + + public KeyedWorkItemMatcher(KeyedWorkItem myWorkItem, Coder keyCoder) { + this.myWorkItem = myWorkItem; + this.keyCoder = keyCoder; + } + + @Override + public boolean matches(Object item) { + if (item == null || !(item instanceof WindowedValue)) { + return false; + } + WindowedValue> that = (WindowedValue>) item; + Multiset> myValues = HashMultiset.create(); + Multiset> thatValues = HashMultiset.create(); + for (WindowedValue value : myWorkItem.elementsIterable()) { + myValues.add(value); + } + for (WindowedValue value : that.getValue().elementsIterable()) { + thatValues.add(value); + } + try { + return myValues.equals(thatValues) + && keyCoder + .structuralValue(myWorkItem.key()) + .equals(keyCoder.structuralValue(that.getValue().key())); + } catch (Exception e) { + return false; + } + } + + @Override + public void describeTo(Description description) { + description + .appendText("KeyedWorkItem containing key ") + .appendValue(myWorkItem.key()) + .appendText(" and values ") + .appendValueList("[", ", ", "]", myWorkItem.elementsIterable()); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java new file mode 100644 index 000000000000..24251522d824 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java @@ -0,0 +1,1099 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate.TimerUpdateBuilder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TransformWatermarks; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.joda.time.Instant; +import org.joda.time.ReadableInstant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Tests for {@link InMemoryWatermarkManager}. + */ +@RunWith(JUnit4.class) +public class InMemoryWatermarkManagerTest implements Serializable { + private transient MockClock clock; + + private transient PCollection createdInts; + + private transient PCollection filtered; + private transient PCollection filteredTimesTwo; + private transient PCollection> keyed; + + private transient PCollection intsToFlatten; + private transient PCollection flattened; + + private transient InMemoryWatermarkManager manager; + + @Before + public void setup() { + TestPipeline p = TestPipeline.create(); + + createdInts = p.apply("createdInts", Create.of(1, 2, 3)); + + filtered = createdInts.apply("filtered", Filter.greaterThan(1)); + filteredTimesTwo = filtered.apply("timesTwo", ParDo.of(new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + c.output(c.element() * 2); + } + })); + + keyed = createdInts.apply("keyed", WithKeys.of("MyKey")); + + intsToFlatten = p.apply("intsToFlatten", Create.of(-1, 256, 65535)); + PCollectionList preFlatten = PCollectionList.of(createdInts).and(intsToFlatten); + flattened = preFlatten.apply("flattened", Flatten.pCollections()); + + Collection> rootTransforms = + ImmutableList.>of( + createdInts.getProducingTransformInternal(), + intsToFlatten.getProducingTransformInternal()); + + Map>> consumers = new HashMap<>(); + consumers.put( + createdInts, + ImmutableList.>of(filtered.getProducingTransformInternal(), + keyed.getProducingTransformInternal(), flattened.getProducingTransformInternal())); + consumers.put( + filtered, + Collections.>singleton( + filteredTimesTwo.getProducingTransformInternal())); + consumers.put(filteredTimesTwo, Collections.>emptyList()); + consumers.put(keyed, Collections.>emptyList()); + + consumers.put( + intsToFlatten, + Collections.>singleton( + flattened.getProducingTransformInternal())); + consumers.put(flattened, Collections.>emptyList()); + + clock = MockClock.fromInstant(new Instant(1000)); + + manager = InMemoryWatermarkManager.create(clock, rootTransforms, consumers); + } + + /** + * Demonstrates that getWatermark, when called on an {@link AppliedPTransform} that has not + * processed any elements, returns the {@link BoundedWindow#TIMESTAMP_MIN_VALUE}. + */ + @Test + public void getWatermarkForUntouchedTransform() { + TransformWatermarks watermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + + assertThat(watermarks.getInputWatermark(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(watermarks.getOutputWatermark(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + } + + /** + * Demonstrates that getWatermark for a transform that consumes no input uses the Watermark + * Hold value provided to it as the output watermark. + */ + @Test + public void getWatermarkForUpdatedSourceTransform() { + CommittedBundle output = globallyWindowedBundle(createdInts, 1); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(output), new Instant(8000L)); + TransformWatermarks updatedSourceWatermark = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + + assertThat(updatedSourceWatermark.getOutputWatermark(), equalTo(new Instant(8000L))); + } + + /** + * Demonstrates that getWatermark for a transform that takes multiple inputs is held to the + * minimum watermark across all of its inputs. + */ + @Test + public void getWatermarkForMultiInputTransform() { + CommittedBundle secondPcollectionBundle = globallyWindowedBundle(intsToFlatten, -1); + + manager.updateWatermarks(null, intsToFlatten.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(secondPcollectionBundle), + BoundedWindow.TIMESTAMP_MAX_VALUE); + + // We didn't do anything for the first source, so we shouldn't have progressed the watermark + TransformWatermarks firstSourceWatermark = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat( + firstSourceWatermark.getOutputWatermark(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + + // the Second Source output all of the elements so it should be done (with a watermark at the + // end of time). + TransformWatermarks secondSourceWatermark = + manager.getWatermarks(intsToFlatten.getProducingTransformInternal()); + assertThat( + secondSourceWatermark.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + + // We haven't consumed anything yet, so our watermark should be at the beginning of time + TransformWatermarks transformWatermark = + manager.getWatermarks(flattened.getProducingTransformInternal()); + assertThat( + transformWatermark.getInputWatermark(), not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + assertThat( + transformWatermark.getOutputWatermark(), not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + + CommittedBundle flattenedBundleSecondCreate = globallyWindowedBundle(flattened, -1); + // We have finished processing the bundle from the second PCollection, but we haven't consumed + // anything from the first PCollection yet; so our watermark shouldn't advance + manager.updateWatermarks(secondPcollectionBundle, flattened.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(flattenedBundleSecondCreate), + null); + TransformWatermarks transformAfterProcessing = + manager.getWatermarks(flattened.getProducingTransformInternal()); + manager.updateWatermarks(secondPcollectionBundle, flattened.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(flattenedBundleSecondCreate), + null); + assertThat( + transformAfterProcessing.getInputWatermark(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + assertThat( + transformAfterProcessing.getOutputWatermark(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + + Instant firstCollectionTimestamp = new Instant(10000); + CommittedBundle firstPcollectionBundle = + timestampedBundle(createdInts, TimestampedValue.of(5, firstCollectionTimestamp)); + // the source is done, but elements are still buffered. The source output watermark should be + // past the end of the global window + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(firstPcollectionBundle), + new Instant(Long.MAX_VALUE)); + TransformWatermarks firstSourceWatermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat( + firstSourceWatermarks.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + + // We still haven't consumed any of the first source's input, so the watermark should still not + // progress + TransformWatermarks flattenAfterSourcesProduced = + manager.getWatermarks(flattened.getProducingTransformInternal()); + assertThat( + flattenAfterSourcesProduced.getInputWatermark(), not(laterThan(firstCollectionTimestamp))); + assertThat( + flattenAfterSourcesProduced.getOutputWatermark(), not(laterThan(firstCollectionTimestamp))); + + // We have buffered inputs, but since the PCollection has all of the elements (has a WM past the + // end of the global window), we should have a watermark equal to the min among buffered + // elements + TransformWatermarks withBufferedElements = + manager.getWatermarks(flattened.getProducingTransformInternal()); + assertThat(withBufferedElements.getInputWatermark(), equalTo(firstCollectionTimestamp)); + assertThat(withBufferedElements.getOutputWatermark(), equalTo(firstCollectionTimestamp)); + + CommittedBundle completedFlattenBundle = + InProcessBundle.unkeyed(flattened).commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + manager.updateWatermarks(firstPcollectionBundle, flattened.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(completedFlattenBundle), + null); + TransformWatermarks afterConsumingAllInput = + manager.getWatermarks(flattened.getProducingTransformInternal()); + assertThat( + afterConsumingAllInput.getInputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat( + afterConsumingAllInput.getOutputWatermark(), + not(laterThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + } + + /** + * Demonstrates that pending elements are independent among + * {@link AppliedPTransform AppliedPTransforms} that consume the same input {@link PCollection}. + */ + @Test + public void getWatermarkForMultiConsumedCollection() { + CommittedBundle createdBundle = timestampedBundle(createdInts, + TimestampedValue.of(1, new Instant(1_000_000L)), TimestampedValue.of(2, new Instant(1234L)), + TimestampedValue.of(3, new Instant(-1000L))); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createdBundle), new Instant(Long.MAX_VALUE)); + TransformWatermarks createdAfterProducing = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat( + createdAfterProducing.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + + CommittedBundle> keyBundle = + timestampedBundle(keyed, TimestampedValue.of(KV.of("MyKey", 1), new Instant(1_000_000L)), + TimestampedValue.of(KV.of("MyKey", 2), new Instant(1234L)), + TimestampedValue.of(KV.of("MyKey", 3), new Instant(-1000L))); + manager.updateWatermarks(createdBundle, keyed.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(keyBundle), null); + TransformWatermarks keyedWatermarks = + manager.getWatermarks(keyed.getProducingTransformInternal()); + assertThat( + keyedWatermarks.getInputWatermark(), not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat( + keyedWatermarks.getOutputWatermark(), not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + + TransformWatermarks filteredWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat(filteredWatermarks.getInputWatermark(), not(laterThan(new Instant(-1000L)))); + assertThat(filteredWatermarks.getOutputWatermark(), not(laterThan(new Instant(-1000L)))); + + CommittedBundle filteredBundle = + timestampedBundle(filtered, TimestampedValue.of(2, new Instant(1234L))); + manager.updateWatermarks(createdBundle, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(filteredBundle), null); + TransformWatermarks filteredProcessedWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat( + filteredProcessedWatermarks.getInputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat( + filteredProcessedWatermarks.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + } + + /** + * Demonstrates that the watermark of an {@link AppliedPTransform} is held to the provided + * watermark hold. + */ + @Test + public void updateWatermarkWithWatermarkHolds() { + CommittedBundle createdBundle = timestampedBundle(createdInts, + TimestampedValue.of(1, new Instant(1_000_000L)), TimestampedValue.of(2, new Instant(1234L)), + TimestampedValue.of(3, new Instant(-1000L))); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createdBundle), new Instant(Long.MAX_VALUE)); + + CommittedBundle> keyBundle = + timestampedBundle(keyed, TimestampedValue.of(KV.of("MyKey", 1), new Instant(1_000_000L)), + TimestampedValue.of(KV.of("MyKey", 2), new Instant(1234L)), + TimestampedValue.of(KV.of("MyKey", 3), new Instant(-1000L))); + manager.updateWatermarks(createdBundle, keyed.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(keyBundle), + new Instant(500L)); + TransformWatermarks keyedWatermarks = + manager.getWatermarks(keyed.getProducingTransformInternal()); + assertThat( + keyedWatermarks.getInputWatermark(), not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat(keyedWatermarks.getOutputWatermark(), not(laterThan(new Instant(500L)))); + } + + /** + * Demonstrates that the watermark of an {@link AppliedPTransform} is held to the provided + * watermark hold. + */ + @Test + public void updateWatermarkWithKeyedWatermarkHolds() { + CommittedBundle firstKeyBundle = + InProcessBundle.keyed(createdInts, "Odd") + .add(WindowedValue.timestampedValueInGlobalWindow(1, new Instant(1_000_000L))) + .add(WindowedValue.timestampedValueInGlobalWindow(3, new Instant(-1000L))) + .commit(clock.now()); + + CommittedBundle secondKeyBundle = + InProcessBundle.keyed(createdInts, "Even") + .add(WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1234L))) + .commit(clock.now()); + + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + ImmutableList.of(firstKeyBundle, secondKeyBundle), BoundedWindow.TIMESTAMP_MAX_VALUE); + + manager.updateWatermarks(firstKeyBundle, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>emptyList(), new Instant(-1000L)); + manager.updateWatermarks(secondKeyBundle, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>emptyList(), new Instant(1234L)); + + TransformWatermarks filteredWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat( + filteredWatermarks.getInputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat(filteredWatermarks.getOutputWatermark(), not(laterThan(new Instant(-1000L)))); + + CommittedBundle fauxFirstKeyTimerBundle = + InProcessBundle.keyed(createdInts, "Odd").commit(clock.now()); + manager.updateWatermarks(fauxFirstKeyTimerBundle, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>emptyList(), + BoundedWindow.TIMESTAMP_MAX_VALUE); + + assertThat(filteredWatermarks.getOutputWatermark(), equalTo(new Instant(1234L))); + + CommittedBundle fauxSecondKeyTimerBundle = + InProcessBundle.keyed(createdInts, "Even").commit(clock.now()); + manager.updateWatermarks(fauxSecondKeyTimerBundle, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>emptyList(), new Instant(5678L)); + assertThat(filteredWatermarks.getOutputWatermark(), equalTo(new Instant(5678L))); + + manager.updateWatermarks(fauxSecondKeyTimerBundle, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>emptyList(), + BoundedWindow.TIMESTAMP_MAX_VALUE); + assertThat( + filteredWatermarks.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + } + + /** + * Demonstrates that updated output watermarks are monotonic in the presence of late data, when + * called on an {@link AppliedPTransform} that consumes no input. + */ + @Test + public void updateOutputWatermarkShouldBeMonotonic() { + CommittedBundle firstInput = + InProcessBundle.unkeyed(createdInts).commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(firstInput), new Instant(0L)); + TransformWatermarks firstWatermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(firstWatermarks.getOutputWatermark(), equalTo(new Instant(0L))); + + CommittedBundle secondInput = + InProcessBundle.unkeyed(createdInts).commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(secondInput), new Instant(-250L)); + TransformWatermarks secondWatermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(secondWatermarks.getOutputWatermark(), not(earlierThan(new Instant(0L)))); + } + + /** + * Demonstrates that updated output watermarks are monotonic in the presence of watermark holds + * that become earlier than a previous watermark hold. + */ + @Test + public void updateWatermarkWithHoldsShouldBeMonotonic() { + CommittedBundle createdBundle = timestampedBundle(createdInts, + TimestampedValue.of(1, new Instant(1_000_000L)), TimestampedValue.of(2, new Instant(1234L)), + TimestampedValue.of(3, new Instant(-1000L))); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createdBundle), new Instant(Long.MAX_VALUE)); + + CommittedBundle> keyBundle = + timestampedBundle(keyed, TimestampedValue.of(KV.of("MyKey", 1), new Instant(1_000_000L)), + TimestampedValue.of(KV.of("MyKey", 2), new Instant(1234L)), + TimestampedValue.of(KV.of("MyKey", 3), new Instant(-1000L))); + manager.updateWatermarks(createdBundle, keyed.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(keyBundle), + new Instant(500L)); + TransformWatermarks keyedWatermarks = + manager.getWatermarks(keyed.getProducingTransformInternal()); + assertThat( + keyedWatermarks.getInputWatermark(), not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat(keyedWatermarks.getOutputWatermark(), not(laterThan(new Instant(500L)))); + Instant oldOutputWatermark = keyedWatermarks.getOutputWatermark(); + + TransformWatermarks updatedWatermarks = + manager.getWatermarks(keyed.getProducingTransformInternal()); + assertThat( + updatedWatermarks.getInputWatermark(), not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + // We added a hold prior to the old watermark; we shouldn't progress (due to the earlier hold) + // but the watermark is monotonic and should not backslide to the new, earlier hold + assertThat(updatedWatermarks.getOutputWatermark(), equalTo(oldOutputWatermark)); + } + + /** + * Demonstrates that updateWatermarks in the presence of late data is monotonic. + */ + @Test + public void updateWatermarkWithLateData() { + Instant sourceWatermark = new Instant(1_000_000L); + CommittedBundle createdBundle = timestampedBundle(createdInts, + TimestampedValue.of(1, sourceWatermark), TimestampedValue.of(2, new Instant(1234L))); + + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createdBundle), sourceWatermark); + + CommittedBundle> keyBundle = + timestampedBundle(keyed, TimestampedValue.of(KV.of("MyKey", 1), sourceWatermark), + TimestampedValue.of(KV.of("MyKey", 2), new Instant(1234L))); + + // Finish processing the on-time data. The watermarks should progress to be equal to the source + manager.updateWatermarks(createdBundle, keyed.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(keyBundle), null); + TransformWatermarks onTimeWatermarks = + manager.getWatermarks(keyed.getProducingTransformInternal()); + assertThat(onTimeWatermarks.getInputWatermark(), equalTo(sourceWatermark)); + assertThat(onTimeWatermarks.getOutputWatermark(), equalTo(sourceWatermark)); + + CommittedBundle lateDataBundle = + timestampedBundle(createdInts, TimestampedValue.of(3, new Instant(-1000L))); + // the late data arrives in a downstream PCollection after its watermark has advanced past it; + // we don't advance the watermark past the current watermark until we've consumed the late data + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(lateDataBundle), new Instant(2_000_000L)); + TransformWatermarks bufferedLateWm = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(bufferedLateWm.getOutputWatermark(), equalTo(new Instant(2_000_000L))); + + // The input watermark should be held to its previous value (not advanced due to late data; not + // moved backwards in the presence of watermarks due to monotonicity). + TransformWatermarks lateDataBufferedWatermark = + manager.getWatermarks(keyed.getProducingTransformInternal()); + assertThat(lateDataBufferedWatermark.getInputWatermark(), not(earlierThan(sourceWatermark))); + assertThat(lateDataBufferedWatermark.getOutputWatermark(), not(earlierThan(sourceWatermark))); + + CommittedBundle> lateKeyedBundle = + timestampedBundle(keyed, TimestampedValue.of(KV.of("MyKey", 3), new Instant(-1000L))); + manager.updateWatermarks(lateDataBundle, keyed.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(lateKeyedBundle), null); + } + + /** + * Demonstrates that after watermarks of an upstream transform are updated, but no output has been + * produced, the watermarks of a downstream process are advanced. + */ + @Test + public void getWatermarksAfterOnlyEmptyOutput() { + CommittedBundle emptyCreateOutput = globallyWindowedBundle(createdInts); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(emptyCreateOutput), + BoundedWindow.TIMESTAMP_MAX_VALUE); + TransformWatermarks updatedSourceWatermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + + assertThat( + updatedSourceWatermarks.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + + TransformWatermarks finishedFilterWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat( + finishedFilterWatermarks.getInputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat( + finishedFilterWatermarks.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + } + + /** + * Demonstrates that after watermarks of an upstream transform are updated, but no output has been + * produced, and the downstream transform has a watermark hold, the watermark is held to the hold. + */ + @Test + public void getWatermarksAfterHoldAndEmptyOutput() { + CommittedBundle firstCreateOutput = globallyWindowedBundle(createdInts, 1, 2); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(firstCreateOutput), new Instant(12_000L)); + + CommittedBundle firstFilterOutput = globallyWindowedBundle(filtered); + manager.updateWatermarks(firstCreateOutput, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(firstFilterOutput), + new Instant(10_000L)); + TransformWatermarks firstFilterWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat(firstFilterWatermarks.getInputWatermark(), not(earlierThan(new Instant(12_000L)))); + assertThat(firstFilterWatermarks.getOutputWatermark(), not(laterThan(new Instant(10_000L)))); + + CommittedBundle emptyCreateOutput = globallyWindowedBundle(createdInts); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(emptyCreateOutput), + BoundedWindow.TIMESTAMP_MAX_VALUE); + TransformWatermarks updatedSourceWatermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + + assertThat( + updatedSourceWatermarks.getOutputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + + TransformWatermarks finishedFilterWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat( + finishedFilterWatermarks.getInputWatermark(), + not(earlierThan(BoundedWindow.TIMESTAMP_MAX_VALUE))); + assertThat(finishedFilterWatermarks.getOutputWatermark(), not(laterThan(new Instant(10_000L)))); + } + + @Test + public void getSynchronizedProcessingTimeInputWatermarksHeldToPendingBundles() { + TransformWatermarks watermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(watermarks.getSynchronizedProcessingInputTime(), equalTo(clock.now())); + assertThat( + watermarks.getSynchronizedProcessingOutputTime(), + equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + + TransformWatermarks filteredWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + // Non-root processing watermarks don't progress until data has been processed + assertThat( + filteredWatermarks.getSynchronizedProcessingInputTime(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + assertThat( + filteredWatermarks.getSynchronizedProcessingOutputTime(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + + CommittedBundle createOutput = + InProcessBundle.unkeyed(createdInts).commit(new Instant(1250L)); + + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createOutput), BoundedWindow.TIMESTAMP_MAX_VALUE); + TransformWatermarks createAfterUpdate = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(createAfterUpdate.getSynchronizedProcessingInputTime(), equalTo(clock.now())); + assertThat(createAfterUpdate.getSynchronizedProcessingOutputTime(), equalTo(clock.now())); + + TransformWatermarks filterAfterProduced = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat( + filterAfterProduced.getSynchronizedProcessingInputTime(), not(laterThan(clock.now()))); + assertThat( + filterAfterProduced.getSynchronizedProcessingOutputTime(), not(laterThan(clock.now()))); + + clock.set(new Instant(1500L)); + assertThat(createAfterUpdate.getSynchronizedProcessingInputTime(), equalTo(clock.now())); + assertThat(createAfterUpdate.getSynchronizedProcessingOutputTime(), equalTo(clock.now())); + assertThat( + filterAfterProduced.getSynchronizedProcessingInputTime(), + not(laterThan(new Instant(1250L)))); + assertThat( + filterAfterProduced.getSynchronizedProcessingOutputTime(), + not(laterThan(new Instant(1250L)))); + + CommittedBundle filterOutputBundle = + InProcessBundle.unkeyed(intsToFlatten).commit(new Instant(1250L)); + manager.updateWatermarks(createOutput, filtered.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>singleton(filterOutputBundle), + BoundedWindow.TIMESTAMP_MAX_VALUE); + TransformWatermarks filterAfterConsumed = + manager.getWatermarks(filtered.getProducingTransformInternal()); + assertThat( + filterAfterConsumed.getSynchronizedProcessingInputTime(), + not(laterThan(createAfterUpdate.getSynchronizedProcessingOutputTime()))); + assertThat( + filterAfterConsumed.getSynchronizedProcessingOutputTime(), + not(laterThan(filterAfterConsumed.getSynchronizedProcessingInputTime()))); + } + + /** + * Demonstrates that the Synchronized Processing Time output watermark cannot progress past + * pending timers in the same set. This propagates to all downstream SynchronizedProcessingTimes. + * + *

    Also demonstrate that the result is monotonic. + */ + // @Test + public void getSynchronizedProcessingTimeOutputHeldToPendingTimers() { + CommittedBundle createdBundle = globallyWindowedBundle(createdInts, 1, 2, 4, 8); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createdBundle), new Instant(1248L)); + + TransformWatermarks filteredWms = + manager.getWatermarks(filtered.getProducingTransformInternal()); + TransformWatermarks filteredDoubledWms = + manager.getWatermarks(filteredTimesTwo.getProducingTransformInternal()); + Instant initialFilteredWm = filteredWms.getSynchronizedProcessingOutputTime(); + Instant initialFilteredDoubledWm = filteredDoubledWms.getSynchronizedProcessingOutputTime(); + + CommittedBundle filteredBundle = globallyWindowedBundle(filtered, 2, 8); + TimerData pastTimer = + TimerData.of(StateNamespaces.global(), new Instant(250L), TimeDomain.PROCESSING_TIME); + TimerData futureTimer = + TimerData.of(StateNamespaces.global(), new Instant(4096L), TimeDomain.PROCESSING_TIME); + TimerUpdate timers = + TimerUpdate.builder("key").setTimer(pastTimer).setTimer(futureTimer).build(); + manager.updateWatermarks(createdBundle, filtered.getProducingTransformInternal(), timers, + Collections.>singleton(filteredBundle), + BoundedWindow.TIMESTAMP_MAX_VALUE); + Instant startTime = clock.now(); + clock.set(startTime.plus(250L)); + // We're held based on the past timer + assertThat(filteredWms.getSynchronizedProcessingOutputTime(), not(laterThan(startTime))); + assertThat(filteredDoubledWms.getSynchronizedProcessingOutputTime(), not(laterThan(startTime))); + // And we're monotonic + assertThat( + filteredWms.getSynchronizedProcessingOutputTime(), not(earlierThan(initialFilteredWm))); + assertThat( + filteredDoubledWms.getSynchronizedProcessingOutputTime(), + not(earlierThan(initialFilteredDoubledWm))); + + Map, Map> firedTimers = + manager.extractFiredTimers(); + assertThat( + firedTimers.get(filtered.getProducingTransformInternal()) + .get("key") + .getTimers(TimeDomain.PROCESSING_TIME), + contains(pastTimer)); + // Our timer has fired, but has not been completed, so it holds our synchronized processing WM + assertThat(filteredWms.getSynchronizedProcessingOutputTime(), not(laterThan(startTime))); + assertThat(filteredDoubledWms.getSynchronizedProcessingOutputTime(), not(laterThan(startTime))); + + CommittedBundle filteredTimerBundle = + InProcessBundle.keyed(filtered, "key").commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + CommittedBundle filteredTimerResult = + InProcessBundle.keyed(filteredTimesTwo, "key") + .commit(filteredWms.getSynchronizedProcessingOutputTime()); + // Complete the processing time timer + manager.updateWatermarks(filteredTimerBundle, filtered.getProducingTransformInternal(), + TimerUpdate.builder("key") + .withCompletedTimers(Collections.singleton(pastTimer)) + .build(), + Collections.>singleton(filteredTimerResult), + BoundedWindow.TIMESTAMP_MAX_VALUE); + + clock.set(startTime.plus(500L)); + assertThat(filteredWms.getSynchronizedProcessingOutputTime(), not(laterThan(clock.now()))); + // filtered should be held to the time at which the filteredTimerResult fired + assertThat( + filteredDoubledWms.getSynchronizedProcessingOutputTime(), + not(earlierThan(filteredTimerResult.getSynchronizedProcessingOutputWatermark()))); + + manager.updateWatermarks(filteredTimerResult, filteredTimesTwo.getProducingTransformInternal(), + TimerUpdate.empty(), Collections.>emptyList(), + BoundedWindow.TIMESTAMP_MAX_VALUE); + assertThat(filteredDoubledWms.getSynchronizedProcessingOutputTime(), equalTo(clock.now())); + + clock.set(new Instant(Long.MAX_VALUE)); + assertThat(filteredWms.getSynchronizedProcessingOutputTime(), equalTo(new Instant(4096))); + assertThat( + filteredDoubledWms.getSynchronizedProcessingOutputTime(), equalTo(new Instant(4096))); + } + + /** + * Demonstrates that if any earlier processing holds appear in the synchronized processing time + * output hold the result is monotonic. + */ + @Test + public void getSynchronizedProcessingTimeOutputTimeIsMonotonic() { + Instant startTime = clock.now(); + TransformWatermarks watermarks = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(watermarks.getSynchronizedProcessingInputTime(), equalTo(startTime)); + + TransformWatermarks filteredWatermarks = + manager.getWatermarks(filtered.getProducingTransformInternal()); + // Non-root processing watermarks don't progress until data has been processed + assertThat( + filteredWatermarks.getSynchronizedProcessingInputTime(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + assertThat( + filteredWatermarks.getSynchronizedProcessingOutputTime(), + not(laterThan(BoundedWindow.TIMESTAMP_MIN_VALUE))); + + CommittedBundle createOutput = + InProcessBundle.unkeyed(createdInts).commit(new Instant(1250L)); + + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createOutput), BoundedWindow.TIMESTAMP_MAX_VALUE); + TransformWatermarks createAfterUpdate = + manager.getWatermarks(createdInts.getProducingTransformInternal()); + assertThat(createAfterUpdate.getSynchronizedProcessingInputTime(), not(laterThan(clock.now()))); + assertThat( + createAfterUpdate.getSynchronizedProcessingOutputTime(), not(laterThan(clock.now()))); + + CommittedBundle createSecondOutput = + InProcessBundle.unkeyed(createdInts).commit(new Instant(750L)); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(createSecondOutput), + BoundedWindow.TIMESTAMP_MAX_VALUE); + + assertThat(createAfterUpdate.getSynchronizedProcessingOutputTime(), equalTo(clock.now())); + } + + @Test + public void synchronizedProcessingInputTimeIsHeldToUpstreamProcessingTimeTimers() { + CommittedBundle created = globallyWindowedBundle(createdInts, 1, 2, 3); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(created), new Instant(40_900L)); + + CommittedBundle filteredBundle = globallyWindowedBundle(filtered, 2, 4); + Instant upstreamHold = new Instant(2048L); + TimerData upstreamProcessingTimer = + TimerData.of(StateNamespaces.global(), upstreamHold, TimeDomain.PROCESSING_TIME); + manager.updateWatermarks(created, filtered.getProducingTransformInternal(), + TimerUpdate.builder("key").setTimer(upstreamProcessingTimer).build(), + Collections.>singleton(filteredBundle), + BoundedWindow.TIMESTAMP_MAX_VALUE); + + TransformWatermarks downstreamWms = + manager.getWatermarks(filteredTimesTwo.getProducingTransformInternal()); + assertThat(downstreamWms.getSynchronizedProcessingInputTime(), equalTo(clock.now())); + + clock.set(BoundedWindow.TIMESTAMP_MAX_VALUE); + assertThat(downstreamWms.getSynchronizedProcessingInputTime(), equalTo(upstreamHold)); + + manager.extractFiredTimers(); + // Pending processing time timers that have been fired but aren't completed hold the + // synchronized processing time + assertThat(downstreamWms.getSynchronizedProcessingInputTime(), equalTo(upstreamHold)); + + CommittedBundle otherCreated = globallyWindowedBundle(createdInts, 4, 8, 12); + manager.updateWatermarks(otherCreated, filtered.getProducingTransformInternal(), + TimerUpdate.builder("key") + .withCompletedTimers(Collections.singleton(upstreamProcessingTimer)) + .build(), + Collections.>emptyList(), BoundedWindow.TIMESTAMP_MAX_VALUE); + + assertThat(downstreamWms.getSynchronizedProcessingInputTime(), not(earlierThan(clock.now()))); + } + + @Test + public void synchronizedProcessingInputTimeIsHeldToPendingBundleTimes() { + CommittedBundle created = globallyWindowedBundle(createdInts, 1, 2, 3); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(created), new Instant(29_919_235L)); + + Instant upstreamHold = new Instant(2048L); + CommittedBundle filteredBundle = + InProcessBundle.keyed(filtered, "key").commit(upstreamHold); + manager.updateWatermarks(created, filtered.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>singleton(filteredBundle), + BoundedWindow.TIMESTAMP_MAX_VALUE); + + TransformWatermarks downstreamWms = + manager.getWatermarks(filteredTimesTwo.getProducingTransformInternal()); + assertThat(downstreamWms.getSynchronizedProcessingInputTime(), equalTo(clock.now())); + + clock.set(BoundedWindow.TIMESTAMP_MAX_VALUE); + assertThat(downstreamWms.getSynchronizedProcessingInputTime(), equalTo(upstreamHold)); + } + + @Test + public void extractFiredTimersReturnsFiredEventTimeTimers() { + Map, Map> initialTimers = + manager.extractFiredTimers(); + // Watermarks haven't advanced + assertThat(initialTimers.entrySet(), emptyIterable()); + + // Advance WM of keyed past the first timer, but ahead of the second and third + CommittedBundle createdBundle = globallyWindowedBundle(filtered); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.singleton(createdBundle), new Instant(1500L)); + + TimerData earliestTimer = + TimerData.of(StateNamespaces.global(), new Instant(1000), TimeDomain.EVENT_TIME); + TimerData middleTimer = + TimerData.of(StateNamespaces.global(), new Instant(5000L), TimeDomain.EVENT_TIME); + TimerData lastTimer = + TimerData.of(StateNamespaces.global(), new Instant(10000L), TimeDomain.EVENT_TIME); + Object key = new Object(); + TimerUpdate update = + TimerUpdate.builder(key) + .setTimer(earliestTimer) + .setTimer(middleTimer) + .setTimer(lastTimer) + .build(); + + manager.updateWatermarks(createdBundle, filtered.getProducingTransformInternal(), update, + Collections.>singleton(globallyWindowedBundle(intsToFlatten)), + new Instant(1000L)); + + Map, Map> firstTransformFiredTimers = + manager.extractFiredTimers(); + assertThat( + firstTransformFiredTimers.get(filtered.getProducingTransformInternal()), not(nullValue())); + Map firstFilteredTimers = + firstTransformFiredTimers.get(filtered.getProducingTransformInternal()); + assertThat(firstFilteredTimers.get(key), not(nullValue())); + FiredTimers firstFired = firstFilteredTimers.get(key); + assertThat(firstFired.getTimers(TimeDomain.EVENT_TIME), contains(earliestTimer)); + + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>emptyList(), new Instant(50_000L)); + Map, Map> secondTransformFiredTimers = + manager.extractFiredTimers(); + assertThat( + secondTransformFiredTimers.get(filtered.getProducingTransformInternal()), not(nullValue())); + Map secondFilteredTimers = + secondTransformFiredTimers.get(filtered.getProducingTransformInternal()); + assertThat(secondFilteredTimers.get(key), not(nullValue())); + FiredTimers secondFired = secondFilteredTimers.get(key); + // Contains, in order, middleTimer and then lastTimer + assertThat(secondFired.getTimers(TimeDomain.EVENT_TIME), contains(middleTimer, lastTimer)); + } + + @Test + public void extractFiredTimersReturnsFiredProcessingTimeTimers() { + Map, Map> initialTimers = + manager.extractFiredTimers(); + // Watermarks haven't advanced + assertThat(initialTimers.entrySet(), emptyIterable()); + + // Advance WM of keyed past the first timer, but ahead of the second and third + CommittedBundle createdBundle = globallyWindowedBundle(filtered); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.singleton(createdBundle), new Instant(1500L)); + + TimerData earliestTimer = + TimerData.of(StateNamespaces.global(), new Instant(999L), TimeDomain.PROCESSING_TIME); + TimerData middleTimer = + TimerData.of(StateNamespaces.global(), new Instant(5000L), TimeDomain.PROCESSING_TIME); + TimerData lastTimer = + TimerData.of(StateNamespaces.global(), new Instant(10000L), TimeDomain.PROCESSING_TIME); + Object key = new Object(); + TimerUpdate update = + TimerUpdate.builder(key) + .setTimer(lastTimer) + .setTimer(earliestTimer) + .setTimer(middleTimer) + .build(); + + manager.updateWatermarks(createdBundle, filtered.getProducingTransformInternal(), update, + Collections.>singleton(globallyWindowedBundle(intsToFlatten)), + new Instant(1000L)); + + Map, Map> firstTransformFiredTimers = + manager.extractFiredTimers(); + assertThat( + firstTransformFiredTimers.get(filtered.getProducingTransformInternal()), not(nullValue())); + Map firstFilteredTimers = + firstTransformFiredTimers.get(filtered.getProducingTransformInternal()); + assertThat(firstFilteredTimers.get(key), not(nullValue())); + FiredTimers firstFired = firstFilteredTimers.get(key); + assertThat(firstFired.getTimers(TimeDomain.PROCESSING_TIME), contains(earliestTimer)); + + clock.set(new Instant(50_000L)); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>emptyList(), new Instant(50_000L)); + Map, Map> secondTransformFiredTimers = + manager.extractFiredTimers(); + assertThat( + secondTransformFiredTimers.get(filtered.getProducingTransformInternal()), not(nullValue())); + Map secondFilteredTimers = + secondTransformFiredTimers.get(filtered.getProducingTransformInternal()); + assertThat(secondFilteredTimers.get(key), not(nullValue())); + FiredTimers secondFired = secondFilteredTimers.get(key); + // Contains, in order, middleTimer and then lastTimer + assertThat(secondFired.getTimers(TimeDomain.PROCESSING_TIME), contains(middleTimer, lastTimer)); + } + + @Test + public void extractFiredTimersReturnsFiredSynchronizedProcessingTimeTimers() { + Map, Map> initialTimers = + manager.extractFiredTimers(); + // Watermarks haven't advanced + assertThat(initialTimers.entrySet(), emptyIterable()); + + // Advance WM of keyed past the first timer, but ahead of the second and third + CommittedBundle createdBundle = globallyWindowedBundle(filtered); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.singleton(createdBundle), new Instant(1500L)); + + TimerData earliestTimer = TimerData.of( + StateNamespaces.global(), new Instant(999L), TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + TimerData middleTimer = TimerData.of( + StateNamespaces.global(), new Instant(5000L), TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + TimerData lastTimer = TimerData.of( + StateNamespaces.global(), new Instant(10000L), TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + Object key = new Object(); + TimerUpdate update = + TimerUpdate.builder(key) + .setTimer(lastTimer) + .setTimer(earliestTimer) + .setTimer(middleTimer) + .build(); + + manager.updateWatermarks(createdBundle, filtered.getProducingTransformInternal(), update, + Collections.>singleton(globallyWindowedBundle(intsToFlatten)), + new Instant(1000L)); + + Map, Map> firstTransformFiredTimers = + manager.extractFiredTimers(); + assertThat( + firstTransformFiredTimers.get(filtered.getProducingTransformInternal()), not(nullValue())); + Map firstFilteredTimers = + firstTransformFiredTimers.get(filtered.getProducingTransformInternal()); + assertThat(firstFilteredTimers.get(key), not(nullValue())); + FiredTimers firstFired = firstFilteredTimers.get(key); + assertThat( + firstFired.getTimers(TimeDomain.SYNCHRONIZED_PROCESSING_TIME), contains(earliestTimer)); + + clock.set(new Instant(50_000L)); + manager.updateWatermarks(null, createdInts.getProducingTransformInternal(), TimerUpdate.empty(), + Collections.>emptyList(), new Instant(50_000L)); + Map, Map> secondTransformFiredTimers = + manager.extractFiredTimers(); + assertThat( + secondTransformFiredTimers.get(filtered.getProducingTransformInternal()), not(nullValue())); + Map secondFilteredTimers = + secondTransformFiredTimers.get(filtered.getProducingTransformInternal()); + assertThat(secondFilteredTimers.get(key), not(nullValue())); + FiredTimers secondFired = secondFilteredTimers.get(key); + // Contains, in order, middleTimer and then lastTimer + assertThat( + secondFired.getTimers(TimeDomain.SYNCHRONIZED_PROCESSING_TIME), + contains(middleTimer, lastTimer)); + } + + @Test + public void timerUpdateBuilderBuildAddsAllAddedTimers() { + TimerData set = TimerData.of(StateNamespaces.global(), new Instant(10L), TimeDomain.EVENT_TIME); + TimerData deleted = + TimerData.of(StateNamespaces.global(), new Instant(24L), TimeDomain.PROCESSING_TIME); + TimerData completedOne = TimerData.of( + StateNamespaces.global(), new Instant(1024L), TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + TimerData completedTwo = + TimerData.of(StateNamespaces.global(), new Instant(2048L), TimeDomain.EVENT_TIME); + + TimerUpdate update = + TimerUpdate.builder("foo") + .withCompletedTimers(ImmutableList.of(completedOne, completedTwo)) + .setTimer(set) + .deletedTimer(deleted) + .build(); + + assertThat(update.getCompletedTimers(), containsInAnyOrder(completedOne, completedTwo)); + assertThat(update.getSetTimers(), contains(set)); + assertThat(update.getDeletedTimers(), contains(deleted)); + } + + @Test + public void timerUpdateBuilderWithSetThenDeleteHasOnlyDeleted() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.setTimer(timer).deletedTimer(timer).build(); + + assertThat(built.getSetTimers(), emptyIterable()); + assertThat(built.getDeletedTimers(), contains(timer)); + } + + @Test + public void timerUpdateBuilderWithDeleteThenSetHasOnlySet() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.deletedTimer(timer).setTimer(timer).build(); + + assertThat(built.getSetTimers(), contains(timer)); + assertThat(built.getDeletedTimers(), emptyIterable()); + } + + @Test + public void timerUpdateBuilderWithSetAfterBuildNotAddedToBuilt() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.build(); + builder.setTimer(timer); + assertThat(built.getSetTimers(), emptyIterable()); + builder.build(); + assertThat(built.getSetTimers(), emptyIterable()); + } + + @Test + public void timerUpdateBuilderWithDeleteAfterBuildNotAddedToBuilt() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.build(); + builder.deletedTimer(timer); + assertThat(built.getDeletedTimers(), emptyIterable()); + builder.build(); + assertThat(built.getDeletedTimers(), emptyIterable()); + } + + @Test + public void timerUpdateBuilderWithCompletedAfterBuildNotAddedToBuilt() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.build(); + builder.withCompletedTimers(ImmutableList.of(timer)); + assertThat(built.getCompletedTimers(), emptyIterable()); + builder.build(); + assertThat(built.getCompletedTimers(), emptyIterable()); + } + + private static Matcher earlierThan(final Instant laterInstant) { + return new BaseMatcher() { + @Override + public boolean matches(Object item) { + ReadableInstant instant = (ReadableInstant) item; + return instant.isBefore(laterInstant); + } + + @Override + public void describeTo(Description description) { + description.appendText("earlier than ").appendValue(laterInstant); + } + }; + } + + private static Matcher laterThan(final Instant shouldBeEarlier) { + return new BaseMatcher() { + @Override + public boolean matches(Object item) { + ReadableInstant instant = (ReadableInstant) item; + return instant.isAfter(shouldBeEarlier); + } + + @Override + public void describeTo(Description description) { + description.appendText("later than ").appendValue(shouldBeEarlier); + } + }; + } + + @SafeVarargs + private final CommittedBundle timestampedBundle( + PCollection pc, TimestampedValue... values) { + UncommittedBundle bundle = InProcessBundle.unkeyed(pc); + for (TimestampedValue value : values) { + bundle.add( + WindowedValue.timestampedValueInGlobalWindow(value.getValue(), value.getTimestamp())); + } + return bundle.commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + } + + @SafeVarargs + private final CommittedBundle globallyWindowedBundle(PCollection pc, T... values) { + UncommittedBundle bundle = InProcessBundle.unkeyed(pc); + for (T value : values) { + bundle.add(WindowedValue.valueInGlobalWindow(value)); + } + return bundle.commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundleTest.java new file mode 100644 index 000000000000..dcba86bc252c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundleTest.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; + +/** + * Tests for {@link InProcessBundle}. + */ +@RunWith(JUnit4.class) +public class InProcessBundleTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void unkeyedShouldCreateWithNullKey() { + PCollection pcollection = TestPipeline.create().apply(Create.of(1)); + + InProcessBundle inFlightBundle = InProcessBundle.unkeyed(pcollection); + + CommittedBundle bundle = inFlightBundle.commit(Instant.now()); + + assertThat(bundle.isKeyed(), is(false)); + assertThat(bundle.getKey(), nullValue()); + } + + private void keyedCreateBundle(Object key) { + PCollection pcollection = TestPipeline.create().apply(Create.of(1)); + + InProcessBundle inFlightBundle = InProcessBundle.keyed(pcollection, key); + + CommittedBundle bundle = inFlightBundle.commit(Instant.now()); + assertThat(bundle.isKeyed(), is(true)); + assertThat(bundle.getKey(), equalTo(key)); + } + + @Test + public void keyedWithNullKeyShouldCreateKeyedBundle() { + keyedCreateBundle(null); + } + + @Test + public void keyedWithKeyShouldCreateKeyedBundle() { + keyedCreateBundle(new Object()); + } + + private void afterCommitGetElementsShouldHaveAddedElements(Iterable> elems) { + PCollection pcollection = TestPipeline.create().apply(Create.of()); + + InProcessBundle bundle = InProcessBundle.unkeyed(pcollection); + Collection>> expectations = new ArrayList<>(); + for (WindowedValue elem : elems) { + bundle.add(elem); + expectations.add(equalTo(elem)); + } + Matcher>> containsMatcher = + Matchers.>containsInAnyOrder(expectations); + assertThat(bundle.commit(Instant.now()).getElements(), containsMatcher); + } + + @Test + public void getElementsBeforeAddShouldReturnEmptyIterable() { + afterCommitGetElementsShouldHaveAddedElements(Collections.>emptyList()); + } + + @Test + public void getElementsAfterAddShouldReturnAddedElements() { + WindowedValue firstValue = WindowedValue.valueInGlobalWindow(1); + WindowedValue secondValue = + WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1000L)); + + afterCommitGetElementsShouldHaveAddedElements(ImmutableList.of(firstValue, secondValue)); + } + + @Test + public void addAfterCommitShouldThrowException() { + PCollection pcollection = TestPipeline.create().apply(Create.of()); + + InProcessBundle bundle = InProcessBundle.unkeyed(pcollection); + bundle.add(WindowedValue.valueInGlobalWindow(1)); + CommittedBundle firstCommit = bundle.commit(Instant.now()); + assertThat(firstCommit.getElements(), containsInAnyOrder(WindowedValue.valueInGlobalWindow(1))); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("3"); + thrown.expectMessage("committed"); + + bundle.add(WindowedValue.valueInGlobalWindow(3)); + } + + @Test + public void commitAfterCommitShouldThrowException() { + PCollection pcollection = TestPipeline.create().apply(Create.of()); + + InProcessBundle bundle = InProcessBundle.unkeyed(pcollection); + bundle.add(WindowedValue.valueInGlobalWindow(1)); + CommittedBundle firstCommit = bundle.commit(Instant.now()); + assertThat(firstCommit.getElements(), containsInAnyOrder(WindowedValue.valueInGlobalWindow(1))); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("committed"); + + bundle.commit(Instant.now()); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessCreateTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessCreateTest.java new file mode 100644 index 000000000000..4db014e3ee26 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessCreateTest.java @@ -0,0 +1,199 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.NullableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessCreate.InMemorySource; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.SourceTestUtils; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Collections; +import java.util.List; + +/** + * Tests for {@link InProcessCreate}. + */ +@RunWith(JUnit4.class) +public class InProcessCreateTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testConvertsCreate() { + TestPipeline p = TestPipeline.create(); + Create.Values og = Create.of(1, 2, 3); + + InProcessCreate converted = InProcessCreate.from(og); + + DataflowAssert.that(p.apply(converted)).containsInAnyOrder(2, 1, 3); + } + + @Test + public void testConvertsCreateWithNullElements() { + Create.Values og = + Create.of("foo", null, "spam", "ham", null, "eggs") + .withCoder(NullableCoder.of(StringUtf8Coder.of())); + + InProcessCreate converted = InProcessCreate.from(og); + TestPipeline p = TestPipeline.create(); + + DataflowAssert.that(p.apply(converted)) + .containsInAnyOrder(null, "foo", null, "spam", "ham", "eggs"); + } + + static class Record implements Serializable {} + + static class Record2 extends Record {} + + @Test + public void testThrowsIllegalArgumentWhenCannotInferCoder() { + Create.Values og = Create.of(new Record(), new Record2()); + InProcessCreate converted = InProcessCreate.from(og); + + Pipeline p = TestPipeline.create(); + + // Create won't infer a default coder in this case. + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(Matchers.containsString("Unable to infer a coder")); + + PCollection c = p.apply(converted); + p.run(); + + fail("Unexpectedly Inferred Coder " + c.getCoder()); + } + + /** + * An unserializable class to demonstrate encoding of elements. + */ + private static class UnserializableRecord { + private final String myString; + + private UnserializableRecord(String myString) { + this.myString = myString; + } + + @Override + public int hashCode() { + return myString.hashCode(); + } + + @Override + public boolean equals(Object o) { + return myString.equals(((UnserializableRecord) o).myString); + } + + static class UnserializableRecordCoder extends StandardCoder { + private final Coder stringCoder = StringUtf8Coder.of(); + + @Override + public void encode( + UnserializableRecord value, + OutputStream outStream, + com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException { + stringCoder.encode(value.myString, outStream, context.nested()); + } + + @Override + public UnserializableRecord decode( + InputStream inStream, com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new UnserializableRecord(stringCoder.decode(inStream, context.nested())); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + + @Override + public void verifyDeterministic() throws Coder.NonDeterministicException { + stringCoder.verifyDeterministic(); + } + } + } + + @Test + public void testSerializableOnUnserializableElements() throws Exception { + List elements = + ImmutableList.of( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + InMemorySource source = + new InMemorySource<>(elements, new UnserializableRecord.UnserializableRecordCoder()); + SerializableUtils.ensureSerializable(source); + } + + @Test + public void testSplitIntoBundles() throws Exception { + InProcessCreate.InMemorySource source = + new InMemorySource<>(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), BigEndianIntegerCoder.of()); + PipelineOptions options = PipelineOptionsFactory.create(); + List> splitSources = source.splitIntoBundles(12, options); + assertThat(splitSources, hasSize(3)); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splitSources, options); + } + + @Test + public void testDoesNotProduceSortedKeys() throws Exception { + InProcessCreate.InMemorySource source = + new InMemorySource<>(ImmutableList.of("spam", "ham", "eggs"), StringUtf8Coder.of()); + assertThat(source.producesSortedKeys(PipelineOptionsFactory.create()), is(false)); + } + + @Test + public void testGetDefaultOutputCoderReturnsConstructorCoder() throws Exception { + Coder coder = VarIntCoder.of(); + InProcessCreate.InMemorySource source = + new InMemorySource<>(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), coder); + + Coder defaultCoder = source.getDefaultOutputCoder(); + assertThat(defaultCoder, equalTo(coder)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java new file mode 100644 index 000000000000..4cfe78293687 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java @@ -0,0 +1,356 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.doAnswer; + +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Mean; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; +import com.google.cloud.dataflow.sdk.util.PCollectionViews; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * Tests for {@link InProcessSideInputContainer}. + */ +@RunWith(JUnit4.class) +public class InProcessSideInputContainerTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock + private InProcessEvaluationContext context; + + private TestPipeline pipeline; + + private InProcessSideInputContainer container; + + private PCollectionView> mapView; + private PCollectionView singletonView; + + // Not present in container. + private PCollectionView> iterableView; + + private BoundedWindow firstWindow = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(789541L); + } + + @Override + public String toString() { + return "firstWindow"; + } + }; + + private BoundedWindow secondWindow = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(14564786L); + } + + @Override + public String toString() { + return "secondWindow"; + } + }; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + pipeline = TestPipeline.create(); + + PCollection create = + pipeline.apply("forBaseCollection", Create.of(1, 2, 3, 4)); + + mapView = + create.apply("forKeyTypes", WithKeys.of("foo")) + .apply("asMapView", View.asMap()); + + singletonView = + create.apply("forCombinedTypes", Mean.globally()) + .apply("asDoubleView", View.asSingleton()); + + iterableView = create.apply("asIterableView", View.asIterable()); + + container = InProcessSideInputContainer.create( + context, ImmutableList.of(iterableView, mapView, singletonView)); + } + + @Test + public void getAfterWriteReturnsPaneInWindow() throws Exception { + WindowedValue> one = WindowedValue.of( + KV.of("one", 1), new Instant(1L), firstWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING); + WindowedValue> two = WindowedValue.of( + KV.of("two", 2), new Instant(20L), firstWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING); + container.write(mapView, ImmutableList.>of(one, two)); + + Map viewContents = + container.withViews(ImmutableList.>of(mapView)) + .get(mapView, firstWindow); + assertThat(viewContents, hasEntry("one", 1)); + assertThat(viewContents, hasEntry("two", 2)); + assertThat(viewContents.size(), is(2)); + } + + @Test + public void getReturnsLatestPaneInWindow() throws Exception { + WindowedValue> one = WindowedValue.of(KV.of("one", 1), new Instant(1L), + secondWindow, PaneInfo.createPane(true, false, Timing.EARLY)); + WindowedValue> two = WindowedValue.of(KV.of("two", 2), new Instant(20L), + secondWindow, PaneInfo.createPane(true, false, Timing.EARLY)); + container.write(mapView, ImmutableList.>of(one, two)); + + Map viewContents = + container.withViews(ImmutableList.>of(mapView)) + .get(mapView, secondWindow); + assertThat(viewContents, hasEntry("one", 1)); + assertThat(viewContents, hasEntry("two", 2)); + assertThat(viewContents.size(), is(2)); + + WindowedValue> three = WindowedValue.of(KV.of("three", 3), + new Instant(300L), secondWindow, PaneInfo.createPane(false, false, Timing.EARLY, 1, -1)); + container.write(mapView, ImmutableList.>of(three)); + + Map overwrittenViewContents = + container.withViews(ImmutableList.>of(mapView)) + .get(mapView, secondWindow); + assertThat(overwrittenViewContents, hasEntry("three", 3)); + assertThat(overwrittenViewContents.size(), is(1)); + } + + /** + * Demonstrates that calling get() on a window that currently has no data does not return until + * there is data in the pane. + */ + @Test + public void getBlocksUntilPaneAvailable() throws Exception { + BoundedWindow window = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(1024L); + } + }; + Future singletonFuture = + getFutureOfView(container.withViews(ImmutableList.>of(singletonView)), + singletonView, window); + + WindowedValue singletonValue = + WindowedValue.of(4.75, new Instant(475L), window, PaneInfo.ON_TIME_AND_ONLY_FIRING); + + assertThat(singletonFuture.isDone(), is(false)); + container.write(singletonView, ImmutableList.>of(singletonValue)); + assertThat(singletonFuture.get(), equalTo(4.75)); + } + + @Test + public void withPCollectionViewsWithPutInOriginalReturnsContents() throws Exception { + BoundedWindow window = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(1024L); + } + }; + SideInputReader newReader = + container.withViews(ImmutableList.>of(singletonView)); + Future singletonFuture = getFutureOfView(newReader, singletonView, window); + + WindowedValue singletonValue = + WindowedValue.of(24.125, new Instant(475L), window, PaneInfo.ON_TIME_AND_ONLY_FIRING); + + assertThat(singletonFuture.isDone(), is(false)); + container.write(singletonView, ImmutableList.>of(singletonValue)); + assertThat(singletonFuture.get(), equalTo(24.125)); + } + + @Test + public void withPCollectionViewsErrorsForContainsNotInViews() { + PCollectionView>> newView = PCollectionViews.multimapView(pipeline, + WindowingStrategy.globalDefault(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("with unknown views " + ImmutableList.of(newView).toString()); + + container.withViews(ImmutableList.>of(newView)); + } + + @Test + public void withViewsForViewNotInContainerFails() { + PCollectionView>> newView = PCollectionViews.multimapView(pipeline, + WindowingStrategy.globalDefault(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("unknown views"); + thrown.expectMessage(newView.toString()); + + container.withViews(ImmutableList.>of(newView)); + } + + @Test + public void getOnReaderForViewNotInReaderFails() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("unknown view: " + iterableView.toString()); + + container.withViews(ImmutableList.>of(mapView)) + .get(iterableView, GlobalWindow.INSTANCE); + } + + @Test + public void writeForMultipleElementsInDifferentWindowsSucceeds() throws Exception { + WindowedValue firstWindowedValue = WindowedValue.of(2.875, + firstWindow.maxTimestamp().minus(200L), firstWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING); + WindowedValue secondWindowedValue = + WindowedValue.of(4.125, secondWindow.maxTimestamp().minus(2_000_000L), secondWindow, + PaneInfo.ON_TIME_AND_ONLY_FIRING); + container.write(singletonView, ImmutableList.of(firstWindowedValue, secondWindowedValue)); + assertThat( + container.withViews(ImmutableList.>of(singletonView)) + .get(singletonView, firstWindow), + equalTo(2.875)); + assertThat( + container.withViews(ImmutableList.>of(singletonView)) + .get(singletonView, secondWindow), + equalTo(4.125)); + } + + @Test + public void writeForMultipleIdenticalElementsInSameWindowSucceeds() throws Exception { + WindowedValue firstValue = WindowedValue.of( + 44, firstWindow.maxTimestamp().minus(200L), firstWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING); + WindowedValue secondValue = WindowedValue.of( + 44, firstWindow.maxTimestamp().minus(200L), firstWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING); + + container.write(iterableView, ImmutableList.of(firstValue, secondValue)); + + assertThat( + container.withViews(ImmutableList.>of(iterableView)) + .get(iterableView, firstWindow), + contains(44, 44)); + } + + @Test + public void writeForElementInMultipleWindowsSucceeds() throws Exception { + WindowedValue multiWindowedValue = + WindowedValue.of(2.875, firstWindow.maxTimestamp().minus(200L), + ImmutableList.of(firstWindow, secondWindow), PaneInfo.ON_TIME_AND_ONLY_FIRING); + container.write(singletonView, ImmutableList.of(multiWindowedValue)); + assertThat( + container.withViews(ImmutableList.>of(singletonView)) + .get(singletonView, firstWindow), + equalTo(2.875)); + assertThat( + container.withViews(ImmutableList.>of(singletonView)) + .get(singletonView, secondWindow), + equalTo(2.875)); + } + + @Test + public void finishDoesNotOverwriteWrittenElements() throws Exception { + WindowedValue> one = WindowedValue.of(KV.of("one", 1), new Instant(1L), + secondWindow, PaneInfo.createPane(true, false, Timing.EARLY)); + WindowedValue> two = WindowedValue.of(KV.of("two", 2), new Instant(20L), + secondWindow, PaneInfo.createPane(true, false, Timing.EARLY)); + container.write(mapView, ImmutableList.>of(one, two)); + + immediatelyInvokeCallback(mapView, secondWindow); + + Map viewContents = + container.withViews(ImmutableList.>of(mapView)) + .get(mapView, secondWindow); + + assertThat(viewContents, hasEntry("one", 1)); + assertThat(viewContents, hasEntry("two", 2)); + assertThat(viewContents.size(), is(2)); + } + + @Test + public void finishOnPendingViewsSetsEmptyElements() throws Exception { + immediatelyInvokeCallback(mapView, secondWindow); + Future> mapFuture = getFutureOfView( + container.withViews(ImmutableList.>of(mapView)), mapView, secondWindow); + + assertThat(mapFuture.get().isEmpty(), is(true)); + } + + /** + * When a callAfterWindowCloses with the specified view's producing transform, window, and + * windowing strategy is invoked, immediately execute the callback. + */ + private void immediatelyInvokeCallback(PCollectionView view, BoundedWindow window) { + doAnswer( + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + Object callback = invocation.getArguments()[3]; + Runnable callbackRunnable = (Runnable) callback; + callbackRunnable.run(); + return null; + } + }) + .when(context) + .callAfterOutputMustHaveBeenProduced(Mockito.eq(view), Mockito.eq(window), + Mockito.eq(view.getWindowingStrategyInternal()), Mockito.any(Runnable.class)); + } + + private Future getFutureOfView(final SideInputReader myReader, + final PCollectionView view, final BoundedWindow window) { + Callable callable = new Callable() { + @Override + public ValueT call() throws Exception { + return myReader.get(view, window); + } + }; + return Executors.newSingleThreadExecutor().submit(callable); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTimerInternalsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTimerInternalsTest.java new file mode 100644 index 000000000000..435a5ba9e3bd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessTimerInternalsTest.java @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate.TimerUpdateBuilder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TransformWatermarks; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link InProcessTimerInternals}. + */ +@RunWith(JUnit4.class) +public class InProcessTimerInternalsTest { + private MockClock clock; + @Mock private TransformWatermarks watermarks; + + private TimerUpdateBuilder timerUpdateBuilder; + + private InProcessTimerInternals internals; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + clock = MockClock.fromInstant(new Instant(0)); + + timerUpdateBuilder = TimerUpdate.builder(1234); + + internals = InProcessTimerInternals.create(clock, watermarks, timerUpdateBuilder); + } + + @Test + public void setTimerAddsToBuilder() { + TimerData eventTimer = + TimerData.of(StateNamespaces.global(), new Instant(20145L), TimeDomain.EVENT_TIME); + TimerData processingTimer = + TimerData.of(StateNamespaces.global(), new Instant(125555555L), TimeDomain.PROCESSING_TIME); + TimerData synchronizedProcessingTimer = + TimerData.of( + StateNamespaces.global(), + new Instant(98745632189L), + TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + internals.setTimer(eventTimer); + internals.setTimer(processingTimer); + internals.setTimer(synchronizedProcessingTimer); + + assertThat( + internals.getTimerUpdate().getSetTimers(), + containsInAnyOrder(eventTimer, synchronizedProcessingTimer, processingTimer)); + } + + @Test + public void deleteTimerDeletesOnBuilder() { + TimerData eventTimer = + TimerData.of(StateNamespaces.global(), new Instant(20145L), TimeDomain.EVENT_TIME); + TimerData processingTimer = + TimerData.of(StateNamespaces.global(), new Instant(125555555L), TimeDomain.PROCESSING_TIME); + TimerData synchronizedProcessingTimer = + TimerData.of( + StateNamespaces.global(), + new Instant(98745632189L), + TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + internals.deleteTimer(eventTimer); + internals.deleteTimer(processingTimer); + internals.deleteTimer(synchronizedProcessingTimer); + + assertThat( + internals.getTimerUpdate().getDeletedTimers(), + containsInAnyOrder(eventTimer, synchronizedProcessingTimer, processingTimer)); + } + + @Test + public void getProcessingTimeIsClockNow() { + assertThat(internals.currentProcessingTime(), equalTo(clock.now())); + Instant oldProcessingTime = internals.currentProcessingTime(); + + clock.advance(Duration.standardHours(12)); + + assertThat(internals.currentProcessingTime(), equalTo(clock.now())); + assertThat( + internals.currentProcessingTime(), + equalTo(oldProcessingTime.plus(Duration.standardHours(12)))); + } + + @Test + public void getSynchronizedProcessingTimeIsWatermarkSynchronizedInputTime() { + when(watermarks.getSynchronizedProcessingInputTime()).thenReturn(new Instant(12345L)); + assertThat(internals.currentSynchronizedProcessingTime(), equalTo(new Instant(12345L))); + } + + @Test + public void getInputWatermarkTimeUsesWatermarkTime() { + when(watermarks.getInputWatermark()).thenReturn(new Instant(8765L)); + assertThat(internals.currentInputWatermarkTime(), equalTo(new Instant(8765L))); + } + + @Test + public void getOutputWatermarkTimeUsesWatermarkTime() { + when(watermarks.getOutputWatermark()).thenReturn(new Instant(25525L)); + assertThat(internals.currentOutputWatermarkTime(), equalTo(new Instant(25525L))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/MockClock.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/MockClock.java new file mode 100644 index 000000000000..d69660b399c0 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/MockClock.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkArgument; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** + * A clock that returns a constant value for now which can be set with calls to + * {@link #set(Instant)}. + * + *

    For uses of the {@link Clock} interface in unit tests. + */ +public class MockClock implements Clock { + + private Instant now; + + public static MockClock fromInstant(Instant initial) { + return new MockClock(initial); + } + + private MockClock(Instant initialNow) { + this.now = initialNow; + } + + public void set(Instant newNow) { + checkArgument(!newNow.isBefore(now), "Cannot move MockClock backwards in time from %s to %s", + now, newNow); + this.now = newNow; + } + + public void advance(Duration duration) { + checkArgument( + duration.getMillis() > 0, + "Cannot move MockClock backwards in time by duration %s", + duration); + set(now.plus(duration)); + } + + @Override + public Instant now() { + return now; + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java new file mode 100644 index 000000000000..033f9de204d3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.ParDo.BoundMulti; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.BagState; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.util.state.WatermarkHoldState; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for {@link ParDoMultiEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class ParDoMultiEvaluatorFactoryTest implements Serializable { + @Test + public void testParDoMultiInMemoryTransformEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + + TupleTag> mainOutputTag = new TupleTag>() {}; + final TupleTag elementTag = new TupleTag<>(); + final TupleTag lengthTag = new TupleTag<>(); + + BoundMulti> pardo = + ParDo.of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), c.element().length())); + c.sideOutput(elementTag, c.element()); + c.sideOutput(lengthTag, c.element().length()); + } + }).withOutputTags(mainOutputTag, TupleTagList.of(elementTag).and(lengthTag)); + PCollectionTuple outputTuple = input.apply(pardo); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + PCollection> mainOutput = outputTuple.get(mainOutputTag); + PCollection elementOutput = outputTuple.get(elementTag); + PCollection lengthOutput = outputTuple.get(lengthTag); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle> mainOutputBundle = InProcessBundle.unkeyed(mainOutput); + UncommittedBundle elementOutputBundle = InProcessBundle.unkeyed(elementOutput); + UncommittedBundle lengthOutputBundle = InProcessBundle.unkeyed(lengthOutput); + + when(evaluationContext.createBundle(inputBundle, mainOutput)).thenReturn(mainOutputBundle); + when(evaluationContext.createBundle(inputBundle, elementOutput)) + .thenReturn(elementOutputBundle); + when(evaluationContext.createBundle(inputBundle, lengthOutput)).thenReturn(lengthOutputBundle); + + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, null, null, null); + when(evaluationContext.getExecutionContext(mainOutput.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + com.google.cloud.dataflow.sdk.runners.inprocess.TransformEvaluator evaluator = + new ParDoMultiEvaluatorFactory().forApplication( + mainOutput.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getOutputBundles(), + Matchers.>containsInAnyOrder( + lengthOutputBundle, mainOutputBundle, elementOutputBundle)); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + assertThat(result.getCounters(), equalTo(counters)); + + assertThat( + mainOutputBundle.commit(Instant.now()).getElements(), + Matchers.>>containsInAnyOrder( + WindowedValue.valueInGlobalWindow(KV.of("foo", 3)), + WindowedValue.timestampedValueInGlobalWindow(KV.of("bara", 4), new Instant(1000)), + WindowedValue.valueInGlobalWindow( + KV.of("bazam", 5), PaneInfo.ON_TIME_AND_ONLY_FIRING))); + assertThat( + elementOutputBundle.commit(Instant.now()).getElements(), + Matchers.>containsInAnyOrder( + WindowedValue.valueInGlobalWindow("foo"), + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000)), + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING))); + assertThat( + lengthOutputBundle.commit(Instant.now()).getElements(), + Matchers.>containsInAnyOrder( + WindowedValue.valueInGlobalWindow(3), + WindowedValue.timestampedValueInGlobalWindow(4, new Instant(1000)), + WindowedValue.valueInGlobalWindow(5, PaneInfo.ON_TIME_AND_ONLY_FIRING))); + } + + @Test + public void testParDoMultiUndeclaredSideOutput() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + + TupleTag> mainOutputTag = new TupleTag>() {}; + final TupleTag elementTag = new TupleTag<>(); + final TupleTag lengthTag = new TupleTag<>(); + + BoundMulti> pardo = + ParDo.of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), c.element().length())); + c.sideOutput(elementTag, c.element()); + c.sideOutput(lengthTag, c.element().length()); + } + }).withOutputTags(mainOutputTag, TupleTagList.of(elementTag)); + PCollectionTuple outputTuple = input.apply(pardo); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + PCollection> mainOutput = outputTuple.get(mainOutputTag); + PCollection elementOutput = outputTuple.get(elementTag); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle> mainOutputBundle = InProcessBundle.unkeyed(mainOutput); + UncommittedBundle elementOutputBundle = InProcessBundle.unkeyed(elementOutput); + + when(evaluationContext.createBundle(inputBundle, mainOutput)).thenReturn(mainOutputBundle); + when(evaluationContext.createBundle(inputBundle, elementOutput)) + .thenReturn(elementOutputBundle); + + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, null, null, null); + when(evaluationContext.getExecutionContext(mainOutput.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + com.google.cloud.dataflow.sdk.runners.inprocess.TransformEvaluator evaluator = + new ParDoMultiEvaluatorFactory().forApplication( + mainOutput.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getOutputBundles(), + Matchers.>containsInAnyOrder( + mainOutputBundle, elementOutputBundle)); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + assertThat(result.getCounters(), equalTo(counters)); + + assertThat( + mainOutputBundle.commit(Instant.now()).getElements(), + Matchers.>>containsInAnyOrder( + WindowedValue.valueInGlobalWindow(KV.of("foo", 3)), + WindowedValue.timestampedValueInGlobalWindow(KV.of("bara", 4), new Instant(1000)), + WindowedValue.valueInGlobalWindow( + KV.of("bazam", 5), PaneInfo.ON_TIME_AND_ONLY_FIRING))); + assertThat( + elementOutputBundle.commit(Instant.now()).getElements(), + Matchers.>containsInAnyOrder( + WindowedValue.valueInGlobalWindow("foo"), + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000)), + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING))); + } + + @Test + public void finishBundleWithStatePutsStateInResult() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + + TupleTag> mainOutputTag = new TupleTag>() {}; + final TupleTag elementTag = new TupleTag<>(); + + final StateTag> watermarkTag = + StateTags.watermarkStateInternal("myId", OutputTimeFns.outputAtEndOfWindow()); + final StateTag> bagTag = StateTags.bag("myBag", StringUtf8Coder.of()); + final StateNamespace windowNs = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE); + BoundMulti> pardo = + ParDo.of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.windowingInternals() + .stateInternals() + .state(StateNamespaces.global(), watermarkTag) + .add(new Instant(20202L + c.element().length())); + c.windowingInternals() + .stateInternals() + .state( + StateNamespaces.window( + GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE), + bagTag) + .add(c.element()); + } + }) + .withOutputTags(mainOutputTag, TupleTagList.of(elementTag)); + PCollectionTuple outputTuple = input.apply(pardo); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + PCollection> mainOutput = outputTuple.get(mainOutputTag); + PCollection elementOutput = outputTuple.get(elementTag); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle> mainOutputBundle = InProcessBundle.unkeyed(mainOutput); + UncommittedBundle elementOutputBundle = InProcessBundle.unkeyed(elementOutput); + + when(evaluationContext.createBundle(inputBundle, mainOutput)).thenReturn(mainOutputBundle); + when(evaluationContext.createBundle(inputBundle, elementOutput)) + .thenReturn(elementOutputBundle); + + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, "myKey", null, null); + when(evaluationContext.getExecutionContext(mainOutput.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + com.google.cloud.dataflow.sdk.runners.inprocess.TransformEvaluator evaluator = + new ParDoMultiEvaluatorFactory().forApplication( + mainOutput.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getOutputBundles(), + Matchers.>containsInAnyOrder(mainOutputBundle, elementOutputBundle)); + assertThat(result.getWatermarkHold(), equalTo(new Instant(20205L))); + assertThat(result.getState(), not(nullValue())); + assertThat( + result.getState().state(StateNamespaces.global(), watermarkTag).read(), + equalTo(new Instant(20205L))); + assertThat( + result.getState().state(windowNs, bagTag).read(), + containsInAnyOrder("foo", "bara", "bazam")); + } + + @Test + public void finishBundleWithStateAndTimersPutsTimersInResult() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + + TupleTag> mainOutputTag = new TupleTag>() {}; + final TupleTag elementTag = new TupleTag<>(); + + final TimerData addedTimer = + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow( + new Instant(0).plus(Duration.standardMinutes(5)), + new Instant(1) + .plus(Duration.standardMinutes(5)) + .plus(Duration.standardHours(1)))), + new Instant(54541L), + TimeDomain.EVENT_TIME); + final TimerData deletedTimer = + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow(new Instant(0), new Instant(0).plus(Duration.standardHours(1)))), + new Instant(3400000), + TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + + BoundMulti> pardo = + ParDo.of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.windowingInternals().stateInternals(); + c.windowingInternals() + .timerInternals() + .setTimer( + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow( + new Instant(0).plus(Duration.standardMinutes(5)), + new Instant(1) + .plus(Duration.standardMinutes(5)) + .plus(Duration.standardHours(1)))), + new Instant(54541L), + TimeDomain.EVENT_TIME)); + c.windowingInternals() + .timerInternals() + .deleteTimer( + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow( + new Instant(0), + new Instant(0).plus(Duration.standardHours(1)))), + new Instant(3400000), + TimeDomain.SYNCHRONIZED_PROCESSING_TIME)); + } + }) + .withOutputTags(mainOutputTag, TupleTagList.of(elementTag)); + PCollectionTuple outputTuple = input.apply(pardo); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + PCollection> mainOutput = outputTuple.get(mainOutputTag); + PCollection elementOutput = outputTuple.get(elementTag); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle> mainOutputBundle = InProcessBundle.unkeyed(mainOutput); + UncommittedBundle elementOutputBundle = InProcessBundle.unkeyed(elementOutput); + + when(evaluationContext.createBundle(inputBundle, mainOutput)).thenReturn(mainOutputBundle); + when(evaluationContext.createBundle(inputBundle, elementOutput)) + .thenReturn(elementOutputBundle); + + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, "myKey", null, null); + when(evaluationContext.getExecutionContext(mainOutput.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + com.google.cloud.dataflow.sdk.runners.inprocess.TransformEvaluator evaluator = + new ParDoMultiEvaluatorFactory().forApplication( + mainOutput.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getTimerUpdate(), + equalTo( + TimerUpdate.builder("myKey") + .setTimer(addedTimer) + .setTimer(addedTimer) + .setTimer(addedTimer) + .deletedTimer(deletedTimer) + .deletedTimer(deletedTimer) + .deletedTimer(deletedTimer) + .build())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java new file mode 100644 index 000000000000..ae599bab62bc --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java @@ -0,0 +1,311 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.BagState; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.util.state.WatermarkHoldState; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for {@link ParDoSingleEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class ParDoSingleEvaluatorFactoryTest implements Serializable { + @Test + public void testParDoInMemoryTransformEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + PCollection collection = input.apply(ParDo.of(new DoFn() { + @Override public void processElement(ProcessContext c) { + c.output(c.element().length()); + } + })); + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle outputBundle = + InProcessBundle.unkeyed(collection); + when(evaluationContext.createBundle(inputBundle, collection)).thenReturn(outputBundle); + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, null, null, null); + when(evaluationContext.getExecutionContext(collection.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + com.google.cloud.dataflow.sdk.runners.inprocess.TransformEvaluator evaluator = + new ParDoSingleEvaluatorFactory().forApplication( + collection.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getOutputBundles(), Matchers.>contains(outputBundle)); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + assertThat(result.getCounters(), equalTo(counters)); + + assertThat( + outputBundle.commit(Instant.now()).getElements(), + Matchers.>containsInAnyOrder( + WindowedValue.valueInGlobalWindow(3), + WindowedValue.timestampedValueInGlobalWindow(4, new Instant(1000)), + WindowedValue.valueInGlobalWindow(5, PaneInfo.ON_TIME_AND_ONLY_FIRING))); + } + + @Test + public void testSideOutputToUndeclaredSideOutputSucceeds() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + final TupleTag sideOutputTag = new TupleTag() {}; + PCollection collection = input.apply(ParDo.of(new DoFn() { + @Override public void processElement(ProcessContext c) { + c.sideOutput(sideOutputTag, c.element().length()); + } + })); + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle outputBundle = + InProcessBundle.unkeyed(collection); + when(evaluationContext.createBundle(inputBundle, collection)).thenReturn(outputBundle); + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, null, null, null); + when(evaluationContext.getExecutionContext(collection.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + TransformEvaluator evaluator = + new ParDoSingleEvaluatorFactory().forApplication( + collection.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getOutputBundles(), Matchers.>containsInAnyOrder(outputBundle)); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + assertThat(result.getCounters(), equalTo(counters)); + } + + @Test + public void finishBundleWithStatePutsStateInResult() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + + final StateTag> watermarkTag = + StateTags.watermarkStateInternal("myId", OutputTimeFns.outputAtEarliestInputTimestamp()); + final StateTag> bagTag = StateTags.bag("myBag", StringUtf8Coder.of()); + final StateNamespace windowNs = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE); + ParDo.Bound> pardo = + ParDo.of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.windowingInternals() + .stateInternals() + .state(StateNamespaces.global(), watermarkTag) + .add(new Instant(124443L - c.element().length())); + c.windowingInternals() + .stateInternals() + .state( + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE), + bagTag) + .add(c.element()); + } + }); + PCollection> mainOutput = input.apply(pardo); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle> mainOutputBundle = InProcessBundle.unkeyed(mainOutput); + + when(evaluationContext.createBundle(inputBundle, mainOutput)).thenReturn(mainOutputBundle); + + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, "myKey", null, null); + when(evaluationContext.getExecutionContext(mainOutput.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + com.google.cloud.dataflow.sdk.runners.inprocess.TransformEvaluator evaluator = + new ParDoSingleEvaluatorFactory() + .forApplication( + mainOutput.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + evaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow("bara", new Instant(1000))); + evaluator.processElement( + WindowedValue.valueInGlobalWindow("bazam", PaneInfo.ON_TIME_AND_ONLY_FIRING)); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(new Instant(124438L))); + assertThat(result.getState(), not(nullValue())); + assertThat( + result.getState().state(StateNamespaces.global(), watermarkTag).read(), + equalTo(new Instant(124438L))); + assertThat( + result.getState().state(windowNs, bagTag).read(), + containsInAnyOrder("foo", "bara", "bazam")); + } + + @Test + public void finishBundleWithStateAndTimersPutsTimersInResult() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bara", "bazam")); + + final TimerData addedTimer = + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow( + new Instant(0).plus(Duration.standardMinutes(5)), + new Instant(1) + .plus(Duration.standardMinutes(5)) + .plus(Duration.standardHours(1)))), + new Instant(54541L), + TimeDomain.EVENT_TIME); + final TimerData deletedTimer = + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow(new Instant(0), new Instant(0).plus(Duration.standardHours(1)))), + new Instant(3400000), + TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + + ParDo.Bound> pardo = + ParDo.of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.windowingInternals().stateInternals(); + c.windowingInternals() + .timerInternals() + .setTimer( + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow( + new Instant(0).plus(Duration.standardMinutes(5)), + new Instant(1) + .plus(Duration.standardMinutes(5)) + .plus(Duration.standardHours(1)))), + new Instant(54541L), + TimeDomain.EVENT_TIME)); + c.windowingInternals() + .timerInternals() + .deleteTimer( + TimerData.of( + StateNamespaces.window( + IntervalWindow.getCoder(), + new IntervalWindow( + new Instant(0), + new Instant(0).plus(Duration.standardHours(1)))), + new Instant(3400000), + TimeDomain.SYNCHRONIZED_PROCESSING_TIME)); + } + }); + PCollection> mainOutput = input.apply(pardo); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + UncommittedBundle> mainOutputBundle = InProcessBundle.unkeyed(mainOutput); + + when(evaluationContext.createBundle(inputBundle, mainOutput)).thenReturn(mainOutputBundle); + + InProcessExecutionContext executionContext = + new InProcessExecutionContext(null, "myKey", null, null); + when(evaluationContext.getExecutionContext(mainOutput.getProducingTransformInternal(), null)) + .thenReturn(executionContext); + CounterSet counters = new CounterSet(); + when(evaluationContext.createCounterSet()).thenReturn(counters); + + TransformEvaluator evaluator = + new ParDoSingleEvaluatorFactory() + .forApplication( + mainOutput.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInGlobalWindow("foo")); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getTimerUpdate(), + equalTo( + TimerUpdate.builder("myKey") + .setTimer(addedTimer) + .deletedTimer(deletedTimer) + .build())); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java new file mode 100644 index 000000000000..f139c5648e95 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.io.CountingSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Matchers; +import org.joda.time.DateTime; +import org.joda.time.Instant; +import org.joda.time.ReadableInstant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +/** + * Tests for {@link UnboundedReadEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class UnboundedReadEvaluatorFactoryTest { + private PCollection longs; + private TransformEvaluatorFactory factory; + private InProcessEvaluationContext context; + private UncommittedBundle output; + + @Before + public void setup() { + UnboundedSource source = + CountingSource.unboundedWithTimestampFn(new LongToInstantFn()); + TestPipeline p = TestPipeline.create(); + longs = p.apply(Read.from(source)); + + factory = new UnboundedReadEvaluatorFactory(); + context = mock(InProcessEvaluationContext.class); + output = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(longs)).thenReturn(output); + } + + @Test + public void unboundedSourceInMemoryTransformEvaluatorProducesElements() throws Exception { + TransformEvaluator evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getWatermarkHold(), Matchers.lessThan(DateTime.now().toInstant())); + assertThat( + output.commit(Instant.now()).getElements(), + containsInAnyOrder( + tgw(1L), tgw(2L), tgw(4L), tgw(8L), tgw(9L), tgw(7L), tgw(6L), tgw(5L), tgw(3L), + tgw(0L))); + } + + /** + * Demonstrate that multiple sequential creations will produce additional elements if the source + * can provide them. + */ + @Test + public void unboundedSourceInMemoryTransformEvaluatorMultipleSequentialCalls() throws Exception { + TransformEvaluator evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat( + result.getWatermarkHold(), Matchers.lessThan(DateTime.now().toInstant())); + assertThat( + output.commit(Instant.now()).getElements(), + containsInAnyOrder( + tgw(1L), tgw(2L), tgw(4L), tgw(8L), tgw(9L), tgw(7L), tgw(6L), tgw(5L), tgw(3L), + tgw(0L))); + + UncommittedBundle secondOutput = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(longs)).thenReturn(secondOutput); + TransformEvaluator secondEvaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult secondResult = secondEvaluator.finishBundle(); + assertThat( + secondResult.getWatermarkHold(), + Matchers.lessThan(DateTime.now().toInstant())); + assertThat( + secondOutput.commit(Instant.now()).getElements(), + containsInAnyOrder(tgw(11L), tgw(12L), tgw(14L), tgw(18L), tgw(19L), tgw(17L), tgw(16L), + tgw(15L), tgw(13L), tgw(10L))); + } + + // TODO: Once the source is split into multiple sources before evaluating, this test will have to + // be updated. + /** + * Demonstrate that only a single unfinished instance of TransformEvaluator can be created at a + * time, with other calls returning an empty evaluator. + */ + @Test + public void unboundedSourceWithMultipleSimultaneousEvaluatorsIndependent() throws Exception { + UncommittedBundle secondOutput = InProcessBundle.unkeyed(longs); + + TransformEvaluator evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + + TransformEvaluator secondEvaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + + InProcessTransformResult secondResult = secondEvaluator.finishBundle(); + InProcessTransformResult result = evaluator.finishBundle(); + + assertThat( + result.getWatermarkHold(), Matchers.lessThan(DateTime.now().toInstant())); + assertThat( + output.commit(Instant.now()).getElements(), + containsInAnyOrder( + tgw(1L), tgw(2L), tgw(4L), tgw(8L), tgw(9L), tgw(7L), tgw(6L), tgw(5L), tgw(3L), + tgw(0L))); + + assertThat(secondResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(secondOutput.commit(Instant.now()).getElements(), emptyIterable()); + } + + /** + * A terse alias for producing timestamped longs in the {@link GlobalWindow}, where + * the timestamp is the epoch offset by the value of the element. + */ + private static WindowedValue tgw(Long elem) { + return WindowedValue.timestampedValueInGlobalWindow(elem, new Instant(elem)); + } + + private static class LongToInstantFn implements SerializableFunction { + @Override + public Instant apply(Long input) { + return new Instant(input); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java new file mode 100644 index 000000000000..2f5cd0fb888a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.util.PCollectionViews; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ViewEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class ViewEvaluatorFactoryTest { + @Test + public void testInMemoryEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of("foo", "bar")); + CreatePCollectionView> createView = + CreatePCollectionView.of( + PCollectionViews.iterableView(p, input.getWindowingStrategy(), StringUtf8Coder.of())); + PCollection> concat = + input.apply(WithKeys.of((Void) null)) + .setCoder(KvCoder.of(VoidCoder.of(), StringUtf8Coder.of())) + .apply(GroupByKey.create()) + .apply(Values.>create()); + PCollectionView> view = + concat.apply(new ViewEvaluatorFactory.WriteView<>(createView)); + + InProcessEvaluationContext context = mock(InProcessEvaluationContext.class); + TestViewWriter> viewWriter = new TestViewWriter<>(); + when(context.createPCollectionViewWriter(concat, view)).thenReturn(viewWriter); + + CommittedBundle inputBundle = InProcessBundle.unkeyed(input).commit(Instant.now()); + TransformEvaluator> evaluator = + new ViewEvaluatorFactory() + .forApplication(view.getProducingTransformInternal(), inputBundle, context); + + evaluator.processElement( + WindowedValue.>valueInGlobalWindow(ImmutableList.of("foo", "bar"))); + assertThat(viewWriter.latest, nullValue()); + + evaluator.finishBundle(); + assertThat( + viewWriter.latest, + containsInAnyOrder( + WindowedValue.valueInGlobalWindow("foo"), WindowedValue.valueInGlobalWindow("bar"))); + } + + private static class TestViewWriter implements PCollectionViewWriter { + private Iterable> latest; + + @Override + public void add(Iterable> values) { + latest = values; + } + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/CoderPropertiesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/CoderPropertiesTest.java new file mode 100644 index 000000000000..4564d95fc355 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/CoderPropertiesTest.java @@ -0,0 +1,214 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.common.base.Strings; + +import org.hamcrest.CoreMatchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** Unit tests for {@link CoderProperties}. */ +@RunWith(JUnit4.class) +public class CoderPropertiesTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testGoodCoderIsDeterministic() throws Exception { + CoderProperties.coderDeterministic(StringUtf8Coder.of(), "TestData", "TestData"); + } + + /** A coder that says it is not deterministic but actually is. */ + public static class NonDeterministicCoder extends CustomCoder { + @Override + public void encode(String value, OutputStream outStream, Context context) + throws CoderException, IOException { + StringUtf8Coder.of().encode(value, outStream, context); + } + + @Override + public String decode(InputStream inStream, Context context) + throws CoderException, IOException { + return StringUtf8Coder.of().decode(inStream, context); + } + } + + @Test + public void testNonDeterministicCoder() throws Exception { + try { + CoderProperties.coderDeterministic(new NonDeterministicCoder(), "TestData", "TestData"); + fail("Expected AssertionError"); + } catch (AssertionError error) { + assertThat(error.getMessage(), + CoreMatchers.containsString("Expected that the coder is deterministic")); + } + } + + @Test + public void testPassingInNonEqualValuesWithDeterministicCoder() throws Exception { + try { + CoderProperties.coderDeterministic(StringUtf8Coder.of(), "AAA", "BBB"); + fail("Expected AssertionError"); + } catch (AssertionError error) { + assertThat(error.getMessage(), + CoreMatchers.containsString("Expected that the passed in values")); + } + } + + /** A coder that is non-deterministic because it adds a string to the value. */ + private static class BadDeterminsticCoder extends CustomCoder { + public BadDeterminsticCoder() { + } + + @Override + public void encode(String value, OutputStream outStream, Context context) + throws IOException, CoderException { + StringUtf8Coder.of().encode(value + System.nanoTime(), outStream, context); + } + + @Override + public String decode(InputStream inStream, Context context) + throws CoderException, IOException { + return StringUtf8Coder.of().decode(inStream, context); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { } + } + + @Test + public void testBadCoderIsNotDeterministic() throws Exception { + try { + CoderProperties.coderDeterministic(new BadDeterminsticCoder(), "TestData", "TestData"); + fail("Expected AssertionError"); + } catch (AssertionError error) { + assertThat(error.getMessage(), + CoreMatchers.containsString("<84>, <101>, <115>, <116>, <68>")); + } + } + + @Test + public void testGoodCoderEncodesEqualValues() throws Exception { + CoderProperties.coderDecodeEncodeEqual(StringUtf8Coder.of(), "TestData"); + } + + /** This coder changes state during encoding/decoding. */ + private static class StateChangingSerializingCoder extends CustomCoder { + private int changedState; + + public StateChangingSerializingCoder() { + changedState = 10; + } + + @Override + public void encode(String value, OutputStream outStream, Context context) + throws CoderException, IOException { + changedState += 1; + StringUtf8Coder.of().encode(value + Strings.repeat("A", changedState), outStream, context); + } + + @Override + public String decode(InputStream inStream, Context context) + throws CoderException, IOException { + String decodedValue = StringUtf8Coder.of().decode(inStream, context); + return decodedValue.substring(0, decodedValue.length() - changedState); + } + } + + @Test + public void testBadCoderThatDependsOnChangingState() throws Exception { + try { + CoderProperties.coderDecodeEncodeEqual(new StateChangingSerializingCoder(), "TestData"); + fail("Expected AssertionError"); + } catch (AssertionError error) { + assertThat(error.getMessage(), CoreMatchers.containsString("TestData")); + } + } + + /** This coder loses information critical to its operation. */ + private static class ForgetfulSerializingCoder extends CustomCoder { + private transient int lostState; + + public ForgetfulSerializingCoder(int lostState) { + this.lostState = lostState; + } + + @Override + public void encode(String value, OutputStream outStream, Context context) + throws CoderException, IOException { + if (lostState == 0) { + throw new RuntimeException("I forgot something..."); + } + StringUtf8Coder.of().encode(value, outStream, context); + } + + @Override + public String decode(InputStream inStream, Context context) + throws CoderException, IOException { + return StringUtf8Coder.of().decode(inStream, context); + } + } + + @Test + public void testBadCoderThatDependsOnStateThatIsLost() throws Exception { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("I forgot something..."); + CoderProperties.coderDecodeEncodeEqual(new ForgetfulSerializingCoder(1), "TestData"); + } + + /** A coder which closes the underlying stream during encoding and decoding. */ + public static class ClosingCoder extends CustomCoder { + @Override + public void encode(String value, OutputStream outStream, Context context) throws IOException { + outStream.close(); + } + + @Override + public String decode(InputStream inStream, Context context) throws IOException { + inStream.close(); + return null; + } + } + + @Test + public void testClosingCoderFailsWhenDecoding() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderProperties.decode(new ClosingCoder(), Context.NESTED, new byte[0]); + } + + @Test + public void testClosingCoderFailsWhenEncoding() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderProperties.encode(new ClosingCoder(), Context.NESTED, "test-value"); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/DataflowAssertTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/DataflowAssertTest.java new file mode 100644 index 000000000000..2cd3014b110b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/DataflowAssertTest.java @@ -0,0 +1,326 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.anything; +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.not; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * Test case for {@link DataflowAssert}. + */ +@RunWith(JUnit4.class) +public class DataflowAssertTest implements Serializable { + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + private static class NotSerializableObject { + + @Override + public boolean equals(Object other) { + return (other instanceof NotSerializableObject); + } + + @Override + public int hashCode() { + return 73; + } + } + + private static class NotSerializableObjectCoder extends AtomicCoder { + private NotSerializableObjectCoder() { } + private static final NotSerializableObjectCoder INSTANCE = new NotSerializableObjectCoder(); + + @JsonCreator + public static NotSerializableObjectCoder of() { + return INSTANCE; + } + + @Override + public void encode(NotSerializableObject value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public NotSerializableObject decode(InputStream inStream, Context context) + throws CoderException, IOException { + return new NotSerializableObject(); + } + + @Override + public boolean isRegisterByteSizeObserverCheap(NotSerializableObject value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + NotSerializableObject value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(0L); + } + } + + /** + * A {@link DataflowAssert} about the contents of a {@link PCollection} + * must not require the contents of the {@link PCollection} to be + * serializable. + */ + @Test + @Category(RunnableOnService.class) + public void testContainsInAnyOrderNotSerializable() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection pcollection = pipeline + .apply(Create.of( + new NotSerializableObject(), + new NotSerializableObject()) + .withCoder(NotSerializableObjectCoder.of())); + + DataflowAssert.that(pcollection).containsInAnyOrder( + new NotSerializableObject(), + new NotSerializableObject()); + + pipeline.run(); + } + + /** + * A {@link DataflowAssert} about the contents of a {@link PCollection} + * is allows to be verified by an arbitrary {@link SerializableFunction}, + * though. + */ + @Test + @Category(RunnableOnService.class) + public void testSerializablePredicate() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection pcollection = pipeline + .apply(Create.of( + new NotSerializableObject(), + new NotSerializableObject()) + .withCoder(NotSerializableObjectCoder.of())); + + DataflowAssert.that(pcollection).satisfies( + new SerializableFunction, Void>() { + @Override + public Void apply(Iterable contents) { + return null; // no problem! + } + }); + + pipeline.run(); + } + + /** + * Basic test of succeeding {@link DataflowAssert} using a {@link SerializableMatcher}. + */ + @Test + @Category(RunnableOnService.class) + public void testBasicMatcherSuccess() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(42)); + DataflowAssert.that(pcollection).containsInAnyOrder(anything()); + pipeline.run(); + } + + /** + * Basic test of failing {@link DataflowAssert} using a {@link SerializableMatcher}. + */ + @Test + @Category(RunnableOnService.class) + public void testBasicMatcherFailure() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(42)); + DataflowAssert.that(pcollection).containsInAnyOrder(not(anything())); + runExpectingAssertionFailure(pipeline); + } + + /** + * Test that we throw an error at pipeline construction time when the user mistakenly uses + * {@code DataflowAssert.thatSingleton().equals()} instead of the test method {@code .isEqualTo}. + */ + @SuppressWarnings("deprecation") // test of deprecated function + @Test + public void testDataflowAssertEqualsSingletonUnsupported() throws Exception { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("isEqualTo"); + + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(42)); + DataflowAssert.thatSingleton(pcollection).equals(42); + } + + /** + * Test that we throw an error at pipeline construction time when the user mistakenly uses + * {@code DataflowAssert.that().equals()} instead of the test method {@code .containsInAnyOrder}. + */ + @SuppressWarnings("deprecation") // test of deprecated function + @Test + public void testDataflowAssertEqualsIterableUnsupported() throws Exception { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("containsInAnyOrder"); + + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(42)); + DataflowAssert.that(pcollection).equals(42); + } + + /** + * Test that {@code DataflowAssert.thatSingleton().hashCode()} is unsupported. + * See {@link #testDataflowAssertEqualsSingletonUnsupported}. + */ + @SuppressWarnings("deprecation") // test of deprecated function + @Test + public void testDataflowAssertHashCodeSingletonUnsupported() throws Exception { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage(".hashCode() is not supported."); + + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(42)); + DataflowAssert.thatSingleton(pcollection).hashCode(); + } + + /** + * Test that {@code DataflowAssert.thatIterable().hashCode()} is unsupported. + * See {@link #testDataflowAssertEqualsIterableUnsupported}. + */ + @SuppressWarnings("deprecation") // test of deprecated function + @Test + public void testDataflowAssertHashCodeIterableUnsupported() throws Exception { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage(".hashCode() is not supported."); + + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(42)); + DataflowAssert.that(pcollection).hashCode(); + } + + /** + * Basic test for {@code isEqualTo}. + */ + @Test + @Category(RunnableOnService.class) + public void testIsEqualTo() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(43)); + DataflowAssert.thatSingleton(pcollection).isEqualTo(43); + pipeline.run(); + } + + /** + * Basic test for {@code notEqualTo}. + */ + @Test + @Category(RunnableOnService.class) + public void testNotEqualTo() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(43)); + DataflowAssert.thatSingleton(pcollection).notEqualTo(42); + pipeline.run(); + } + + /** + * Tests that {@code containsInAnyOrder} is actually order-independent. + */ + @Test + @Category(RunnableOnService.class) + public void testContainsInAnyOrder() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection pcollection = pipeline.apply(Create.of(1, 2, 3, 4)); + DataflowAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3); + pipeline.run(); + } + + /** + * Tests that {@code containsInAnyOrder} fails when and how it should. + */ + @Test + @Category(RunnableOnService.class) + public void testContainsInAnyOrderFalse() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection pcollection = pipeline + .apply(Create.of(1, 2, 3, 4)); + + DataflowAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3, 7); + + // The service runner does not give an exception we can usefully inspect. + @Nullable + Throwable exc = runExpectingAssertionFailure(pipeline); + Pattern expectedPattern = Pattern.compile( + "Expected: iterable over \\[((<4>|<7>|<3>|<2>|<1>)(, )?){5}\\] in any order"); + if (exc != null) { + // A loose pattern, but should get the job done. + assertTrue("Expected error message from DataflowAssert with substring matching " + + expectedPattern + " but the message was \"" + exc.getMessage() + "\"", + expectedPattern.matcher(exc.getMessage()).find()); + } + } + + private static Throwable runExpectingAssertionFailure(Pipeline pipeline) { + // Even though this test will succeed or fail adequately whether local or on the service, + // it results in a different exception depending on the runner. + if (pipeline.getRunner() instanceof DirectPipelineRunner) { + // We cannot use thrown.expect(AssertionError.class) because the AssertionError + // is first caught by JUnit and causes a test failure. + try { + pipeline.run(); + } catch (AssertionError exc) { + return exc; + } + } else if (pipeline.getRunner() instanceof TestDataflowPipelineRunner) { + // Separately, if this is run on the service, then the TestDataflowPipelineRunner throws + // an IllegalStateException with a basic message. + try { + pipeline.run(); + } catch (IllegalStateException exc) { + assertThat(exc.getMessage(), containsString("The dataflow failed.")); + return null; + } + } + fail("assertion should have failed"); + throw new RuntimeException("unreachable"); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/DataflowJUnitTestRunner.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/DataflowJUnitTestRunner.java new file mode 100644 index 000000000000..c5b5fcacb6b6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/DataflowJUnitTestRunner.java @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.common.base.Predicate; +import com.google.common.collect.Iterables; +import com.google.common.reflect.ClassPath; +import com.google.common.reflect.ClassPath.ClassInfo; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.runner.JUnitCore; +import org.junit.runner.Request; +import org.junit.runner.Result; +import org.junit.runner.notification.Failure; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * A test runner which can invoke a series of method or class test targets configuring + * the {@link TestPipeline} to run against the Dataflow service based upon command-line + * parameters specified. + * + *

    Supported target definitions as command line parameters are: + *

      + *
    • Class: "ClassName" + *
    • Method: "ClassName#methodName" + *
    + * Multiple parameters can be specified in sequence, which will cause the test + * runner to invoke the tests in the specified order. + * + *

    All tests will be executed after which, if any test had failed, the runner + * will exit with a non-zero status code. + */ +public class DataflowJUnitTestRunner { + private static final Logger LOG = LoggerFactory.getLogger(DataflowJUnitTestRunner.class); + + /** + * Options which control a Dataflow JUnit test runner to invoke + * a series of method and/or class targets. + */ + @Description("Options which control a Dataflow JUnit test runner to invoke " + + "a series of method and/or class targets.") + public interface Options extends PipelineOptions { + @Description("A list of supported test targets. Supported targets are 'ClassName' " + + "or 'ClassName#MethodName'") + @Validation.Required + String[] getTest(); + void setTest(String[] values); + } + + public static void main(String ... args) throws Exception { + PipelineOptionsFactory.register(Options.class); + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Set classes = + ClassPath.from(ClassLoader.getSystemClassLoader()).getAllClasses(); + + // Build a list of requested test targets + List requests = new ArrayList<>(); + for (String testTarget : options.getTest()) { + if (testTarget.contains("#")) { + String[] parts = testTarget.split("#", 2); + Class klass = findClass(parts[0], classes); + requests.add(Request.method(klass, parts[1])); + } else { + requests.add(Request.aClass(findClass(testTarget, classes))); + } + } + + // Set system properties required by TestPipeline so that it is able to execute tests + // on the service. + String dataflowPipelineOptions = new ObjectMapper().writeValueAsString(args); + System.setProperty("runIntegrationTestOnService", "true"); + System.setProperty("dataflowOptions", dataflowPipelineOptions); + + // Run the set of tests + boolean success = true; + JUnitCore core = new JUnitCore(); + for (Request request : requests) { + Result result = core.run(request); + if (!result.wasSuccessful()) { + for (Failure failure : result.getFailures()) { + LOG.error(failure.getTestHeader(), failure.getException()); + } + success = false; + } + } + if (!success) { + throw new IllegalStateException("Tests failed, check output logs for details."); + } + } + + private static final Class findClass( + final String simpleName, Set classes) + throws ClassNotFoundException { + Iterable matches = + Iterables.filter(classes, new Predicate() { + @Override + public boolean apply(@Nullable ClassInfo input) { + return input != null && simpleName.equals(input.getSimpleName()); + } + }); + return Class.forName(Iterables.getOnlyElement(matches).getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogs.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogs.java new file mode 100644 index 000000000000..48449b340044 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogs.java @@ -0,0 +1,306 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.fail; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +import java.util.Collection; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.logging.Formatter; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import java.util.logging.SimpleFormatter; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * This {@link TestRule} enables the ability to capture JUL logging events during test execution and + * assert expectations that they contain certain messages (with or without {@link Throwable}) at + * certain log levels. For logs generated via the SLF4J logging frontend, the JUL backend must be + * used. + */ +public class ExpectedLogs extends ExternalResource { + /** + * Returns a {@link TestRule} that captures logs for the given logger name. + * + * @param name The logger name to capture logs for. + * @return A {@link ExpectedLogs} test rule. + */ + public static ExpectedLogs none(String name) { + return new ExpectedLogs(name); + } + + /** + * Returns a {@link TestRule} that captures logs for the given class. + * + * @param klass The class to capture logs for. + * @return A {@link ExpectedLogs} test rule. + */ + public static ExpectedLogs none(Class klass) { + return ExpectedLogs.none(klass.getName()); + } + + /** + * Verify a logging event at the trace level with the given message. + * + * @param substring The message to match against. + */ + public void verifyTrace(String substring) { + verify(Level.FINEST, substring); + } + + /** + * Verify a logging event at the trace level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void verifyTrace(String substring, Throwable t) { + verify(Level.FINEST, substring, t); + } + + /** + * Verify a logging event at the debug level with the given message. + * + * @param substring The message to match against. + */ + public void verifyDebug(String substring) { + verify(Level.FINE, substring); + } + + /** + * Verify a logging event at the debug level with the given message and throwable. + * + * @param message The message to match against. + * @param t The throwable to match against. + */ + public void verifyDebug(String message, Throwable t) { + verify(Level.FINE, message, t); + } + + /** + * Verify a logging event at the info level with the given message. + * + * @param substring The message to match against. + */ + public void verifyInfo(String substring) { + verify(Level.INFO, substring); + } + + /** + * Verify a logging event at the info level with the given message and throwable. + * + * @param message The message to match against. + * @param t The throwable to match against. + */ + public void verifyInfo(String message, Throwable t) { + verify(Level.INFO, message, t); + } + + /** + * Verify a logging event at the warn level with the given message. + * + * @param substring The message to match against. + */ + public void verifyWarn(String substring) { + verify(Level.WARNING, substring); + } + + /** + * Verify a logging event at the warn level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void verifyWarn(String substring, Throwable t) { + verify(Level.WARNING, substring, t); + } + + /** + * Verify a logging event at the error level with the given message. + * + * @param substring The message to match against. + */ + public void verifyError(String substring) { + verify(Level.SEVERE, substring); + } + + /** + * Verify a logging event at the error level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void verifyError(String substring, Throwable t) { + verify(Level.SEVERE, substring, t); + } + + /** + * Verify there are no logging events with messages containing the given substring. + * + * @param substring The message to match against. + */ + public void verifyNotLogged(String substring) { + verifyNotLogged(matcher(substring)); + } + + /** + * Verify there is no logging event at the error level with the given message and throwable. + * + * @param substring The message to match against. + * @param t The throwable to match against. + */ + public void verifyNoError(String substring, Throwable t) { + verifyNo(Level.SEVERE, substring, t); + } + + private void verify(final Level level, final String substring) { + verifyLogged(matcher(level, substring)); + } + + private TypeSafeMatcher matcher(final String substring) { + return new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText(String.format("log message containing message [%s]", substring)); + } + + @Override + protected boolean matchesSafely(LogRecord item) { + return item.getMessage().contains(substring); + } + }; + } + + private TypeSafeMatcher matcher(final Level level, final String substring) { + return new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText(String.format( + "log message of level [%s] containing message [%s]", level, substring)); + } + + @Override + protected boolean matchesSafely(LogRecord item) { + return level.equals(item.getLevel()) + && item.getMessage().contains(substring); + } + }; + } + + private void verify(final Level level, final String substring, final Throwable throwable) { + verifyLogged(matcher(level, substring, throwable)); + } + + private void verifyNo(final Level level, final String substring, final Throwable throwable) { + verifyNotLogged(matcher(level, substring, throwable)); + } + + private TypeSafeMatcher matcher( + final Level level, final String substring, final Throwable throwable) { + return new TypeSafeMatcher() { + @Override + public void describeTo(Description description) { + description.appendText(String.format( + "log message of level [%s] containg message [%s] with exception [%s] " + + "containing message [%s]", + level, substring, throwable.getClass(), throwable.getMessage())); + } + + @Override + protected boolean matchesSafely(LogRecord item) { + return level.equals(item.getLevel()) + && item.getMessage().contains(substring) + && item.getThrown().getClass().equals(throwable.getClass()) + && item.getThrown().getMessage().contains(throwable.getMessage()); + } + }; + } + + private void verifyLogged(Matcher matcher) { + for (LogRecord record : logSaver.getLogs()) { + if (matcher.matches(record)) { + return; + } + } + + fail(String.format("Missing match for [%s]", matcher)); + } + + private void verifyNotLogged(Matcher matcher) { + // Don't use Matchers.everyItem(Matchers.not(matcher)) because it doesn't format the logRecord + for (LogRecord record : logSaver.getLogs()) { + if (matcher.matches(record)) { + fail(String.format("Unexpected match of [%s]: [%s]", matcher, logFormatter.format(record))); + } + } + } + + @Override + protected void before() throws Throwable { + previousLevel = log.getLevel(); + log.setLevel(Level.ALL); + log.addHandler(logSaver); + } + + @Override + protected void after() { + log.removeHandler(logSaver); + log.setLevel(previousLevel); + } + + private final Logger log; + private final LogSaver logSaver; + private final Formatter logFormatter = new SimpleFormatter(); + private Level previousLevel; + + private ExpectedLogs(String name) { + log = Logger.getLogger(name); + logSaver = new LogSaver(); + } + + /** + * A JUL logging {@link Handler} that records all logging events that are passed to it. + */ + @ThreadSafe + private static class LogSaver extends Handler { + Collection logRecords = new ConcurrentLinkedDeque<>(); + + public Collection getLogs() { + return logRecords; + } + + @Override + public void publish(LogRecord record) { + logRecords.add(record); + } + + @Override + public void flush() {} + + @Override + public void close() throws SecurityException {} + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogsTest.java new file mode 100644 index 000000000000..2dce880ffc39 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ExpectedLogsTest.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static com.google.cloud.dataflow.sdk.testing.SystemNanoTimeSleeper.sleepMillis; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +/** Tests for {@link FastNanoClockAndSleeper}. */ +@RunWith(JUnit4.class) +public class ExpectedLogsTest { + private static final Logger LOG = LoggerFactory.getLogger(ExpectedLogsTest.class); + + private Random random = new Random(); + + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(ExpectedLogsTest.class); + + @Test + public void testWhenNoExpectations() throws Throwable { + LOG.error(generateRandomString()); + } + + @Test + public void testVerifyWhenMatchedFully() throws Throwable { + String expected = generateRandomString(); + + LOG.error(expected); + expectedLogs.verifyError(expected); + } + + @Test + public void testVerifyWhenMatchedPartially() throws Throwable { + String expected = generateRandomString(); + LOG.error("Extra stuff around expected " + expected + " blah"); + expectedLogs.verifyError(expected); + } + + @Test + public void testVerifyWhenMatchedWithExceptionBeingLogged() throws Throwable { + String expected = generateRandomString(); + LOG.error(expected, new IOException("Fake Exception")); + expectedLogs.verifyError(expected); + } + + @Test(expected = AssertionError.class) + public void testVerifyWhenNotMatched() throws Throwable { + String expected = generateRandomString(); + + expectedLogs.verifyError(expected); + } + + @Test(expected = AssertionError.class) + public void testVerifyNotLoggedWhenMatchedFully() throws Throwable { + String expected = generateRandomString(); + + LOG.error(expected); + expectedLogs.verifyNotLogged(expected); + } + + @Test(expected = AssertionError.class) + public void testVerifyNotLoggedWhenMatchedPartially() throws Throwable { + String expected = generateRandomString(); + LOG.error("Extra stuff around expected " + expected + " blah"); + expectedLogs.verifyNotLogged(expected); + } + + @Test(expected = AssertionError.class) + public void testVerifyNotLoggedWhenMatchedWithException() throws Throwable { + String expected = generateRandomString(); + LOG.error(expected, new IOException("Fake Exception")); + expectedLogs.verifyNotLogged(expected); + } + + @Test + public void testVerifyNotLoggedWhenNotMatched() throws Throwable { + String expected = generateRandomString(); + expectedLogs.verifyNotLogged(expected); + } + + @Test + public void testLogCaptureOccursAtLowestLogLevel() throws Throwable { + String expected = generateRandomString(); + LOG.trace(expected); + expectedLogs.verifyTrace(expected); + } + + @Test + public void testThreadSafetyOfLogSaver() throws Throwable { + CompletionService completionService = + new ExecutorCompletionService<>(Executors.newCachedThreadPool()); + final long scheduledLogTime = + TimeUnit.MILLISECONDS.convert(System.nanoTime(), TimeUnit.NANOSECONDS) + 500L; + + List expectedStrings = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + final String expected = generateRandomString(); + expectedStrings.add(expected); + completionService.submit(new Callable() { + @Override + public Void call() throws Exception { + // Have all threads started and waiting to log at about the same moment. + sleepMillis(Math.max(1, scheduledLogTime + - TimeUnit.MILLISECONDS.convert(System.nanoTime(), TimeUnit.NANOSECONDS))); + LOG.trace(expected); + return null; + } + }); + } + + // Wait for all the threads to complete. + for (int i = 0; i < 100; i++) { + completionService.take(); + } + + for (String expected : expectedStrings) { + expectedLogs.verifyTrace(expected); + } + } + + // Generates a random fake error message. + private String generateRandomString() { + return "Fake error message: " + random.nextInt(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeper.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeper.java new file mode 100644 index 000000000000..944795673399 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeper.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; + +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +/** + * This object quickly moves time forward based upon how much it has been asked to sleep, + * without actually sleeping, to simulate the backoff. + */ +public class FastNanoClockAndSleeper extends ExternalResource + implements NanoClock, Sleeper, TestRule { + private long fastNanoTime; + + @Override + public long nanoTime() { + return fastNanoTime; + } + + @Override + protected void before() throws Throwable { + fastNanoTime = NanoClock.SYSTEM.nanoTime(); + } + + @Override + public void sleep(long millis) throws InterruptedException { + fastNanoTime += millis * 1000000L; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeperTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeperTest.java new file mode 100644 index 000000000000..76256962ad4d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/FastNanoClockAndSleeperTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.concurrent.TimeUnit; + +/** Tests for {@link FastNanoClockAndSleeper}. */ +@RunWith(JUnit4.class) +public class FastNanoClockAndSleeperTest { + @Rule public FastNanoClockAndSleeper fastNanoClockAndSleeper = new FastNanoClockAndSleeper(); + + @Test + public void testClockAndSleeper() throws Exception { + long sleepTimeMs = TimeUnit.SECONDS.toMillis(30); + long sleepTimeNano = TimeUnit.MILLISECONDS.toNanos(sleepTimeMs); + long fakeTimeNano = fastNanoClockAndSleeper.nanoTime(); + long startTimeNano = System.nanoTime(); + fastNanoClockAndSleeper.sleep(sleepTimeMs); + long maxTimeNano = startTimeNano + TimeUnit.SECONDS.toNanos(1); + // Verify that actual time didn't progress as much as was requested + assertTrue(System.nanoTime() < maxTimeNano); + // Verify that the fake time did go up by the amount requested + assertEquals(fakeTimeNano + sleepTimeNano, fastNanoClockAndSleeper.nanoTime()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/PCollectionViewTesting.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/PCollectionViewTesting.java new file mode 100644 index 000000000000..e6555c056b96 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/PCollectionViewTesting.java @@ -0,0 +1,295 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValueBase; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Function; +import com.google.common.base.MoreObjects; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.List; +import java.util.Objects; + +/** + * Methods for creating and using {@link PCollectionView} instances. + */ +public final class PCollectionViewTesting { + + // Do not instantiate; static methods only + private PCollectionViewTesting() { } + + /** + * The length of the default window, which is an {@link IntervalWindow}, but kept encapsulated + * as it is not for the user to know what sort of window it is. + */ + private static final long DEFAULT_WINDOW_MSECS = 1000 * 60 * 60; + + /** + * A default windowing strategy. Tests that are not concerned with the windowing + * strategy should not specify it, and all views will use this. + */ + public static final WindowingStrategy DEFAULT_WINDOWING_STRATEGY = + WindowingStrategy.of(FixedWindows.of(new Duration(DEFAULT_WINDOW_MSECS))); + + /** + * A default window into which test elements will be placed, if the window is + * not explicitly overridden. + */ + public static final BoundedWindow DEFAULT_NONEMPTY_WINDOW = + new IntervalWindow(new Instant(0), new Instant(DEFAULT_WINDOW_MSECS)); + + /** + * A timestamp in the {@link #DEFAULT_NONEMPTY_WINDOW}. + */ + public static final Instant DEFAULT_TIMESTAMP = DEFAULT_NONEMPTY_WINDOW.maxTimestamp().minus(1); + + /** + * A window into which no element will be placed by methods in this class, unless explicitly + * requested. + */ + public static final BoundedWindow DEFAULT_EMPTY_WINDOW = new IntervalWindow( + DEFAULT_NONEMPTY_WINDOW.maxTimestamp(), + DEFAULT_NONEMPTY_WINDOW.maxTimestamp().plus(DEFAULT_WINDOW_MSECS)); + + /** + * A specialization of {@link SerializableFunction} just for putting together + * {@link PCollectionView} instances. + */ + public static interface ViewFn + extends SerializableFunction>, ViewT> { } + + /** + * A {@link ViewFn} that returns the provided contents as a fully lazy iterable. + */ + public static class IdentityViewFn implements ViewFn> { + @Override + public Iterable apply(Iterable> contents) { + return Iterables.transform(contents, new Function, T>() { + @Override + public T apply(WindowedValue windowedValue) { + return windowedValue.getValue(); + } + }); + } + } + + /** + * A {@link ViewFn} that traverses the whole iterable eagerly and returns the number of elements. + * + *

    Only for use in testing scenarios with small collections. If there are more elements + * provided than {@code Integer.MAX_VALUE} then behavior is unpredictable. + */ + public static class LengthViewFn implements ViewFn { + @Override + public Long apply(Iterable> contents) { + return (long) Iterables.size(contents); + } + } + + /** + * A {@link ViewFn} that always returns the value with which it is instantiated. + */ + public static class ConstantViewFn implements ViewFn { + private ViewT value; + + public ConstantViewFn(ViewT value) { + this.value = value; + } + + @Override + public ViewT apply(Iterable> contents) { + return value; + } + } + + /** + * A {@link PCollectionView} explicitly built from a {@link TupleTag} + * and conversion {@link ViewFn}, and an element coder, using the + * {@link #DEFAULT_WINDOWING_STRATEGY}. + * + *

    This method is only recommended for use by runner implementors to test their + * implementations. It is very easy to construct a {@link PCollectionView} that does + * not respect the invariants required for proper functioning. + * + *

    Note that if the provided {@code WindowingStrategy} does not match that of the windowed + * values provided to the view during execution, results are unpredictable. It is recommended + * that the values be prepared via {@link #contentsInDefaultWindow}. + */ + public static PCollectionView testingView( + TupleTag>> tag, + ViewFn viewFn, + Coder elemCoder) { + return testingView( + tag, + viewFn, + elemCoder, + DEFAULT_WINDOWING_STRATEGY); + } + + /** + * The default {@link Coder} used for windowed values, given an element {@link Coder}. + */ + public static Coder> defaultWindowedValueCoder(Coder elemCoder) { + return WindowedValue.getFullCoder( + elemCoder, DEFAULT_WINDOWING_STRATEGY.getWindowFn().windowCoder()); + } + + /** + * A {@link PCollectionView} explicitly built from its {@link TupleTag}, + * {@link WindowingStrategy}, {@link Coder}, and conversion function. + * + *

    This method is only recommended for use by runner implementors to test their + * implementations. It is very easy to construct a {@link PCollectionView} that does + * not respect the invariants required for proper functioning. + * + *

    Note that if the provided {@code WindowingStrategy} does not match that of the windowed + * values provided to the view during execution, results are unpredictable. + */ + public static PCollectionView testingView( + TupleTag>> tag, + ViewFn viewFn, + Coder elemCoder, + WindowingStrategy windowingStrategy) { + return new PCollectionViewFromParts<>( + tag, + viewFn, + windowingStrategy, + IterableCoder.of( + WindowedValue.getFullCoder(elemCoder, windowingStrategy.getWindowFn().windowCoder()))); + } + + /** + * Places the given {@code value} in the {@link #DEFAULT_NONEMPTY_WINDOW}. + */ + public static WindowedValue valueInDefaultWindow(T value) { + return WindowedValue.of(value, DEFAULT_TIMESTAMP, DEFAULT_NONEMPTY_WINDOW, PaneInfo.NO_FIRING); + } + + /** + * Prepares {@code values} for reading as the contents of a {@link PCollectionView} side input. + */ + @SafeVarargs + public static Iterable> contentsInDefaultWindow(T... values) + throws Exception { + List> windowedValues = Lists.newArrayList(); + for (T value : values) { + windowedValues.add(valueInDefaultWindow(value)); + } + return windowedValues; + } + + /** + * Prepares {@code values} for reading as the contents of a {@link PCollectionView} side input. + */ + public static Iterable> contentsInDefaultWindow(Iterable values) + throws Exception { + List> windowedValues = Lists.newArrayList(); + for (T value : values) { + windowedValues.add(valueInDefaultWindow(value)); + } + return windowedValues; + } + + // Internal details below here + + /** + * A {@link PCollectionView} explicitly built from its {@link TupleTag}, + * {@link WindowingStrategy}, and conversion function. + * + *

    Instantiate via {@link #testingView}. + */ + private static class PCollectionViewFromParts + extends PValueBase + implements PCollectionView { + private TupleTag>> tag; + private ViewFn viewFn; + private WindowingStrategy windowingStrategy; + private Coder>> coder; + + public PCollectionViewFromParts( + TupleTag>> tag, + ViewFn viewFn, + WindowingStrategy windowingStrategy, + Coder>> coder) { + this.tag = tag; + this.viewFn = viewFn; + this.windowingStrategy = windowingStrategy; + this.coder = coder; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public TupleTag>> getTagInternal() { + return (TupleTag) tag; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public ViewT fromIterableInternal(Iterable> contents) { + return (ViewT) viewFn.apply((Iterable) contents); + } + + @Override + public WindowingStrategy getWindowingStrategyInternal() { + return windowingStrategy; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Override + public Coder>> getCoderInternal() { + return (Coder) coder; + } + + @Override + public int hashCode() { + return Objects.hash(tag); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PCollectionView)) { + return false; + } + @SuppressWarnings("unchecked") + PCollectionView otherView = (PCollectionView) other; + return tag.equals(otherView.getTagInternal()); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("tag", tag) + .add("viewFn", viewFn) + .toString(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProvider.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProvider.java new file mode 100644 index 000000000000..34c901b3941d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProvider.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import org.joda.time.DateTimeUtils; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +/** + * This {@link TestRule} resets the date time provider in Joda to the system date + * time provider after tests. + */ +public class ResetDateTimeProvider extends ExternalResource { + public void setDateTimeFixed(String iso8601) { + setDateTimeFixed(ISODateTimeFormat.dateTime().parseMillis(iso8601)); + } + + public void setDateTimeFixed(long millis) { + DateTimeUtils.setCurrentMillisFixed(millis); + } + + @Override + protected void after() { + DateTimeUtils.setCurrentMillisSystem(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProviderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProviderTest.java new file mode 100644 index 000000000000..294bb41c67c2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/ResetDateTimeProviderTest.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import org.joda.time.DateTimeUtils; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ResetDateTimeProvider}. */ +@RunWith(JUnit4.class) +public class ResetDateTimeProviderTest { + private static final String TEST_TIME = "2014-12-08T19:07:06.698Z"; + private static final long TEST_TIME_MS = + ISODateTimeFormat.dateTime().parseMillis(TEST_TIME); + + @Rule public ResetDateTimeProvider resetDateTimeProviderRule = new ResetDateTimeProvider(); + + /* + * Since these tests can run out of order, both test A and B change the provider + * and verify that the provider was reset. + */ + @Test + public void testResetA() { + assertNotEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + resetDateTimeProviderRule.setDateTimeFixed(TEST_TIME); + assertEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + } + + @Test + public void testResetB() { + assertNotEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + resetDateTimeProviderRule.setDateTimeFixed(TEST_TIME); + assertEquals(TEST_TIME_MS, DateTimeUtils.currentTimeMillis()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemProperties.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemProperties.java new file mode 100644 index 000000000000..03bc6a530c99 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemProperties.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.common.base.Throwables; + +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** + * Saves and restores the current system properties for tests. + */ +public class RestoreSystemProperties extends ExternalResource implements TestRule { + private byte[] originalProperties; + + @Override + protected void before() throws Throwable { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + System.getProperties().store(baos, ""); + baos.close(); + originalProperties = baos.toByteArray(); + } + + @Override + protected void after() { + try (ByteArrayInputStream bais = new ByteArrayInputStream(originalProperties)) { + System.getProperties().clear(); + System.getProperties().load(bais); + } catch (IOException e) { + throw Throwables.propagate(e); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemPropertiesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemPropertiesTest.java new file mode 100644 index 000000000000..ab49c75adc17 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/RestoreSystemPropertiesTest.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RestoreSystemProperties}. */ +@RunWith(JUnit4.class) +public class RestoreSystemPropertiesTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + + /* + * Since these tests can run out of order, both test A and B verify that they + * could insert their property and that the other does not exist. + */ + @Test + public void testThatPropertyIsClearedA() { + System.getProperties().put("TestA", "TestA"); + assertNotNull(System.getProperty("TestA")); + assertNull(System.getProperty("TestB")); + } + + @Test + public void testThatPropertyIsClearedB() { + System.getProperties().put("TestB", "TestB"); + assertNotNull(System.getProperty("TestB")); + assertNull(System.getProperty("TestA")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SerializableMatchersTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SerializableMatchersTest.java new file mode 100644 index 000000000000..1ab94c5379a0 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SerializableMatchersTest.java @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.allOf; +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.anything; +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.containsInAnyOrder; +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.kvWithKey; +import static com.google.cloud.dataflow.sdk.testing.SerializableMatchers.not; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.SerializableUtils; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; + +/** + * Test case for {@link SerializableMatchers}. + * + *

    Since the only new matchers are those for {@link KV}, only those are tested here, to avoid + * tediously repeating all of hamcrest's tests. + * + *

    A few wrappers of a hamcrest matchers are tested for serializability. Beyond that, + * the boilerplate that is identical to each is considered thoroughly tested. + */ +@RunWith(JUnit4.class) +public class SerializableMatchersTest implements Serializable { + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + public void testAnythingSerializable() throws Exception { + SerializableUtils.ensureSerializable(anything()); + } + + @Test + public void testAllOfSerializable() throws Exception { + SerializableUtils.ensureSerializable(allOf(anything())); + } + + @Test + public void testContainsInAnyOrderSerializable() throws Exception { + assertThat(ImmutableList.of(2, 1, 3), + SerializableUtils.ensureSerializable(containsInAnyOrder(1, 2, 3))); + } + + @Test + public void testContainsInAnyOrderNotSerializable() throws Exception { + assertThat( + ImmutableList.of(new NotSerializableClass()), + SerializableUtils.ensureSerializable(containsInAnyOrder( + new NotSerializableClassCoder(), + new NotSerializableClass()))); + } + + @Test + public void testKvKeyMatcherSerializable() throws Exception { + assertThat( + KV.of("hello", 42L), + SerializableUtils.ensureSerializable(kvWithKey("hello"))); + } + + @Test + public void testKvMatcherBasicSuccess() throws Exception { + assertThat( + KV.of(1, 2), + SerializableMatchers.kv(anything(), anything())); + } + + @Test + public void testKvMatcherKeyFailure() throws Exception { + try { + assertThat( + KV.of(1, 2), + SerializableMatchers.kv(not(anything()), anything())); + } catch (AssertionError exc) { + assertThat(exc.getMessage(), Matchers.containsString("key did not match")); + return; + } + fail("Should have failed"); + } + + @Test + public void testKvMatcherValueFailure() throws Exception { + try { + assertThat( + KV.of(1, 2), + SerializableMatchers.kv(anything(), not(anything()))); + } catch (AssertionError exc) { + assertThat(exc.getMessage(), Matchers.containsString("value did not match")); + return; + } + fail("Should have failed"); + } + + @Test + public void testKvMatcherGBKLikeSuccess() throws Exception { + assertThat( + KV.of("key", ImmutableList.of(1, 2, 3)), + SerializableMatchers.>kv( + anything(), containsInAnyOrder(3, 2, 1))); + } + + @Test + public void testKvMatcherGBKLikeFailure() throws Exception { + try { + assertThat( + KV.of("key", ImmutableList.of(1, 2, 3)), + SerializableMatchers.>kv( + anything(), containsInAnyOrder(1, 2, 3, 4))); + } catch (AssertionError exc) { + assertThat(exc.getMessage(), Matchers.containsString("value did not match")); + return; + } + fail("Should have failed."); + } + + private static class NotSerializableClass { + @Override public boolean equals(Object other) { + return other instanceof NotSerializableClass; + } + + @Override public int hashCode() { + return 0; + } + } + + private static class NotSerializableClassCoder extends AtomicCoder { + @Override + public void encode(NotSerializableClass value, OutputStream outStream, Coder.Context context) { + } + + @Override + public NotSerializableClass decode(InputStream inStream, Coder.Context context) { + return new NotSerializableClass(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SystemNanoTimeSleeper.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SystemNanoTimeSleeper.java new file mode 100644 index 000000000000..d8507f79b08f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SystemNanoTimeSleeper.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import com.google.api.client.util.Sleeper; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; + +/** + * This class provides an expensive sleeper to deal with issues around Java's + * accuracy of {@link System#currentTimeMillis} and methods such as + * {@link Object#wait} and {@link Thread#sleep} which depend on it. This + * article goes into further detail about this issue. + * + * This {@link Sleeper} uses {@link System#nanoTime} + * as the timing source and {@link LockSupport#parkNanos} as the wait method. + * Note that usage of this sleeper may impact performance because + * of the relatively more expensive methods being invoked when compared to + * {@link Thread#sleep}. + */ +public class SystemNanoTimeSleeper implements Sleeper { + public static final Sleeper INSTANCE = new SystemNanoTimeSleeper(); + + /** Limit visibility to prevent instantiation. */ + private SystemNanoTimeSleeper() { + } + + @Override + public void sleep(long millis) throws InterruptedException { + long currentTime; + long endTime = System.nanoTime() + TimeUnit.NANOSECONDS.convert(millis, TimeUnit.MILLISECONDS); + while ((currentTime = System.nanoTime()) < endTime) { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + LockSupport.parkNanos(endTime - currentTime); + } + if (Thread.interrupted()) { + throw new InterruptedException(); + } + return; + } + + /** + * Causes the currently executing thread to sleep (temporarily cease + * execution) for the specified number of milliseconds. The thread does not + * lose ownership of any monitors. + */ + public static void sleepMillis(long millis) throws InterruptedException { + INSTANCE.sleep(millis); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SystemNanoTimeSleeperTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SystemNanoTimeSleeperTest.java new file mode 100644 index 000000000000..33b6b693a297 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/SystemNanoTimeSleeperTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static com.google.cloud.dataflow.sdk.testing.SystemNanoTimeSleeper.sleepMillis; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link SystemNanoTimeSleeper}. */ +@RunWith(JUnit4.class) +public class SystemNanoTimeSleeperTest { + @Test + public void testSleep() throws Exception { + long startTime = System.nanoTime(); + sleepMillis(100); + long endTime = System.nanoTime(); + assertTrue(endTime - startTime >= 100); + } + + @Test + public void testNegativeSleep() throws Exception { + sleepMillis(-100); + } + + @Test(expected = InterruptedException.class) + public void testInterruptionInLoop() throws Exception { + Thread.currentThread().interrupt(); + sleepMillis(0); + } + + @Test(expected = InterruptedException.class) + public void testInterruptionOutsideOfLoop() throws Exception { + Thread.currentThread().interrupt(); + sleepMillis(-100); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunnerTest.java new file mode 100644 index 000000000000..d39ef2e4ae8b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestDataflowPipelineRunnerTest.java @@ -0,0 +1,317 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.Json; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.JobMetrics; +import com.google.api.services.dataflow.model.MetricStructuredName; +import com.google.api.services.dataflow.model.MetricUpdate; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineJob; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil; +import com.google.cloud.dataflow.sdk.util.MonitoringUtil.JobMessagesHandler; +import com.google.cloud.dataflow.sdk.util.NoopPathValidator; +import com.google.cloud.dataflow.sdk.util.TestCredential; +import com.google.cloud.dataflow.sdk.util.TimeUtil; +import com.google.cloud.dataflow.sdk.util.Transport; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; + +/** Tests for {@link TestDataflowPipelineRunner}. */ +@RunWith(JUnit4.class) +public class TestDataflowPipelineRunnerTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + @Mock private MockHttpTransport transport; + @Mock private MockLowLevelHttpRequest request; + @Mock private GcsUtil mockGcsUtil; + + private TestDataflowPipelineOptions options; + private Dataflow service; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + when(transport.buildRequest(anyString(), anyString())).thenReturn(request); + doCallRealMethod().when(request).getContentAsString(); + service = new Dataflow(transport, Transport.getJsonFactory(), null); + + options = PipelineOptionsFactory.as(TestDataflowPipelineOptions.class); + options.setAppName("TestAppName"); + options.setProject("test-project"); + options.setTempLocation("gs://test/temp/location"); + options.setGcpCredential(new TestCredential()); + options.setDataflowClient(service); + options.setRunner(TestDataflowPipelineRunner.class); + options.setPathValidatorClass(NoopPathValidator.class); + } + + @Test + public void testToString() { + assertEquals("TestDataflowPipelineRunner#TestAppName", + new TestDataflowPipelineRunner(options).toString()); + } + + @Test + public void testRunBatchJobThatSucceeds() throws Exception { + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + DataflowPipelineJob mockJob = Mockito.mock(DataflowPipelineJob.class); + when(mockJob.getDataflowClient()).thenReturn(service); + when(mockJob.getState()).thenReturn(State.DONE); + when(mockJob.getProjectId()).thenReturn("test-project"); + when(mockJob.getJobId()).thenReturn("test-job"); + + DataflowPipelineRunner mockRunner = Mockito.mock(DataflowPipelineRunner.class); + when(mockRunner.run(any(Pipeline.class))).thenReturn(mockJob); + + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + when(request.execute()).thenReturn( + generateMockMetricResponse(true /* success */, true /* tentative */)); + assertEquals(mockJob, runner.run(p, mockRunner)); + } + + @Test + public void testRunBatchJobThatFails() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("The dataflow failed."); + + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + DataflowPipelineJob mockJob = Mockito.mock(DataflowPipelineJob.class); + when(mockJob.getDataflowClient()).thenReturn(service); + when(mockJob.getState()).thenReturn(State.FAILED); + when(mockJob.getProjectId()).thenReturn("test-project"); + when(mockJob.getJobId()).thenReturn("test-job"); + + DataflowPipelineRunner mockRunner = Mockito.mock(DataflowPipelineRunner.class); + when(mockRunner.run(any(Pipeline.class))).thenReturn(mockJob); + + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + runner.run(p, mockRunner); + } + + @Test + public void testRunStreamingJobThatSucceeds() throws Exception { + options.setStreaming(true); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + DataflowPipelineJob mockJob = Mockito.mock(DataflowPipelineJob.class); + when(mockJob.getDataflowClient()).thenReturn(service); + when(mockJob.getState()).thenReturn(State.RUNNING); + when(mockJob.getProjectId()).thenReturn("test-project"); + when(mockJob.getJobId()).thenReturn("test-job"); + + DataflowPipelineRunner mockRunner = Mockito.mock(DataflowPipelineRunner.class); + when(mockRunner.run(any(Pipeline.class))).thenReturn(mockJob); + + when(request.execute()).thenReturn( + generateMockMetricResponse(true /* success */, true /* tentative */)); + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + runner.run(p, mockRunner); + } + + @Test + public void testRunStreamingJobThatFails() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("The dataflow failed."); + + options.setStreaming(true); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + DataflowPipelineJob mockJob = Mockito.mock(DataflowPipelineJob.class); + when(mockJob.getDataflowClient()).thenReturn(service); + when(mockJob.getState()).thenReturn(State.RUNNING); + when(mockJob.getProjectId()).thenReturn("test-project"); + when(mockJob.getJobId()).thenReturn("test-job"); + + DataflowPipelineRunner mockRunner = Mockito.mock(DataflowPipelineRunner.class); + when(mockRunner.run(any(Pipeline.class))).thenReturn(mockJob); + + when(request.execute()).thenReturn( + generateMockMetricResponse(false /* success */, true /* tentative */)); + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + runner.run(p, mockRunner); + } + + @Test + public void testCheckingForSuccessWhenDataflowAssertSucceeds() throws Exception { + DataflowPipelineJob job = + spy(new DataflowPipelineJob("test-project", "test-job", service, null)); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + when(request.execute()).thenReturn( + generateMockMetricResponse(true /* success */, true /* tentative */)); + doReturn(State.DONE).when(job).getState(); + assertEquals(Optional.of(true), runner.checkForSuccess(job)); + } + + @Test + public void testCheckingForSuccessWhenDataflowAssertFails() throws Exception { + DataflowPipelineJob job = + spy(new DataflowPipelineJob("test-project", "test-job", service, null)); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + when(request.execute()).thenReturn( + generateMockMetricResponse(false /* success */, true /* tentative */)); + doReturn(State.DONE).when(job).getState(); + assertEquals(Optional.of(false), runner.checkForSuccess(job)); + } + + @Test + public void testCheckingForSuccessSkipsNonTentativeMetrics() throws Exception { + DataflowPipelineJob job = + spy(new DataflowPipelineJob("test-project", "test-job", service, null)); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + when(request.execute()).thenReturn( + generateMockMetricResponse(true /* success */, false /* tentative */)); + doReturn(State.RUNNING).when(job).getState(); + assertEquals(Optional.absent(), runner.checkForSuccess(job)); + } + + private LowLevelHttpResponse generateMockMetricResponse(boolean success, boolean tentative) + throws Exception { + MetricStructuredName name = new MetricStructuredName(); + name.setName(success ? "DataflowAssertSuccess" : "DataflowAssertFailure"); + name.setContext( + tentative ? ImmutableMap.of("tentative", "") : ImmutableMap.of()); + + MetricUpdate metric = new MetricUpdate(); + metric.setName(name); + metric.setScalar(BigDecimal.ONE); + + MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); + response.setContentType(Json.MEDIA_TYPE); + JobMetrics jobMetrics = new JobMetrics(); + jobMetrics.setMetrics(Lists.newArrayList(metric)); + // N.B. Setting the factory is necessary in order to get valid JSON. + jobMetrics.setFactory(Transport.getJsonFactory()); + response.setContent(jobMetrics.toPrettyString()); + return response; + } + + @Test + public void testStreamingPipelineFailsIfServiceFails() throws Exception { + DataflowPipelineJob job = + spy(new DataflowPipelineJob("test-project", "test-job", service, null)); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + when(request.execute()).thenReturn( + generateMockMetricResponse(true /* success */, false /* tentative */)); + doReturn(State.FAILED).when(job).getState(); + assertEquals(Optional.of(false), runner.checkForSuccess(job)); + } + + @Test + public void testStreamingPipelineFailsIfException() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("The dataflow failed."); + + options.setStreaming(true); + Pipeline p = TestPipeline.create(options); + PCollection pc = p.apply(Create.of(1, 2, 3)); + DataflowAssert.that(pc).containsInAnyOrder(1, 2, 3); + + DataflowPipelineJob mockJob = Mockito.mock(DataflowPipelineJob.class); + when(mockJob.getDataflowClient()).thenReturn(service); + when(mockJob.getState()).thenReturn(State.RUNNING); + when(mockJob.getProjectId()).thenReturn("test-project"); + when(mockJob.getJobId()).thenReturn("test-job"); + when(mockJob.waitToFinish(any(Long.class), any(TimeUnit.class), any(JobMessagesHandler.class))) + .thenAnswer(new Answer() { + @Override + public State answer(InvocationOnMock invocation) { + JobMessage message = new JobMessage(); + message.setMessageText("FooException"); + message.setTime(TimeUtil.toCloudTime(Instant.now())); + message.setMessageImportance("JOB_MESSAGE_ERROR"); + ((MonitoringUtil.JobMessagesHandler) invocation.getArguments()[2]) + .process(Arrays.asList(message)); + return State.CANCELLED; + } + }); + + DataflowPipelineRunner mockRunner = Mockito.mock(DataflowPipelineRunner.class); + when(mockRunner.run(any(Pipeline.class))).thenReturn(mockJob); + + when(request.execute()).thenReturn( + generateMockMetricResponse(false /* success */, true /* tentative */)); + TestDataflowPipelineRunner runner = (TestDataflowPipelineRunner) p.getRunner(); + runner.run(p, mockRunner); + + verify(mockJob).cancel(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestPipelineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestPipelineTest.java new file mode 100644 index 000000000000..397920a1dbb7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/testing/TestPipelineTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.testing; + +import static org.hamcrest.CoreMatchers.startsWith; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.GcpOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TestPipeline}. */ +@RunWith(JUnit4.class) +public class TestPipelineTest { + @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + + @Test + public void testCreationUsingDefaults() { + assertNotNull(TestPipeline.create()); + } + + @Test + public void testCreationOfPipelineOptions() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + String stringOptions = mapper.writeValueAsString(new String[]{ + "--runner=DataflowPipelineRunner", + "--project=testProject", + "--apiRootUrl=testApiRootUrl", + "--dataflowEndpoint=testDataflowEndpoint", + "--tempLocation=testTempLocation", + "--serviceAccountName=testServiceAccountName", + "--serviceAccountKeyfile=testServiceAccountKeyfile", + "--zone=testZone", + "--numWorkers=1", + "--diskSizeGb=2" + }); + System.getProperties().put("dataflowOptions", stringOptions); + DataflowPipelineOptions options = + TestPipeline.testingPipelineOptions().as(DataflowPipelineOptions.class); + assertEquals(DataflowPipelineRunner.class, options.getRunner()); + assertThat(options.getJobName(), startsWith("testpipelinetest0testcreationofpipelineoptions-")); + assertEquals("testProject", options.as(GcpOptions.class).getProject()); + assertEquals("testApiRootUrl", options.getApiRootUrl()); + assertEquals("testDataflowEndpoint", options.getDataflowEndpoint()); + assertEquals("testTempLocation", options.getTempLocation()); + assertEquals("testServiceAccountName", options.getServiceAccountName()); + assertEquals( + "testServiceAccountKeyfile", options.as(GcpOptions.class).getServiceAccountKeyfile()); + assertEquals("testZone", options.getZone()); + assertEquals(2, options.getDiskSizeGb()); + } + + @Test + public void testCreationOfPipelineOptionsFromReallyVerboselyNamedTestCase() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + String stringOptions = mapper.writeValueAsString(new String[]{}); + System.getProperties().put("dataflowOptions", stringOptions); + PipelineOptions options = TestPipeline.testingPipelineOptions(); + assertThat(options.as(ApplicationNameOptions.class).getAppName(), startsWith( + "TestPipelineTest-testCreationOfPipelineOptionsFromReallyVerboselyNamedTestCase")); + } + + @Test + public void testToString() { + assertEquals("TestPipeline#TestPipelineTest-testToString", TestPipeline.create().toString()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantilesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantilesTest.java new file mode 100644 index 000000000000..e366e6948d88 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateQuantilesTest.java @@ -0,0 +1,299 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.ApproximateQuantiles.ApproximateQuantilesCombineFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +/** + * Tests for {@link ApproximateQuantiles}. + */ +@RunWith(JUnit4.class) +public class ApproximateQuantilesTest { + + static final List> TABLE = Arrays.asList( + KV.of("a", 1), + KV.of("a", 2), + KV.of("a", 3), + KV.of("b", 1), + KV.of("b", 10), + KV.of("b", 10), + KV.of("b", 100) + ); + + public PCollection> createInputTable(Pipeline p) { + return p.apply(Create.of(TABLE).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + } + + @Test + public void testQuantilesGlobally() { + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = intRangeCollection(p, 101); + PCollection> quantiles = + input.apply(ApproximateQuantiles.globally(5)); + + p.run(); + + DataflowAssert.that(quantiles) + .containsInAnyOrder(Arrays.asList(0, 25, 50, 75, 100)); + } + + @Test + public void testQuantilesGobally_comparable() { + DirectPipeline p = DirectPipeline.createForTest(); + + PCollection input = intRangeCollection(p, 101); + PCollection> quantiles = + input.apply( + ApproximateQuantiles.globally(5, new DescendingIntComparator())); + + p.run(); + + DataflowAssert.that(quantiles) + .containsInAnyOrder(Arrays.asList(100, 75, 50, 25, 0)); + } + + @Test + public void testQuantilesPerKey() { + Pipeline p = TestPipeline.create(); + + PCollection> input = createInputTable(p); + PCollection>> quantiles = input.apply( + ApproximateQuantiles.perKey(2)); + + DataflowAssert.that(quantiles) + .containsInAnyOrder( + KV.of("a", Arrays.asList(1, 3)), + KV.of("b", Arrays.asList(1, 100))); + p.run(); + + } + + @Test + public void testQuantilesPerKey_reversed() { + Pipeline p = TestPipeline.create(); + + PCollection> input = createInputTable(p); + PCollection>> quantiles = input.apply( + ApproximateQuantiles.perKey( + 2, new DescendingIntComparator())); + + DataflowAssert.that(quantiles) + .containsInAnyOrder( + KV.of("a", Arrays.asList(3, 1)), + KV.of("b", Arrays.asList(100, 1))); + p.run(); + } + + @Test + public void testSingleton() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + Arrays.asList(389), + Arrays.asList(389, 389, 389, 389, 389)); + } + + @Test + public void testSimpleQuantiles() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + intRange(101), + Arrays.asList(0, 25, 50, 75, 100)); + } + + @Test + public void testUnevenQuantiles() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(37), + intRange(5000), + quantileMatcher(5000, 37, 20 /* tolerance */)); + } + + @Test + public void testLargerQuantiles() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(50), + intRange(10001), + quantileMatcher(10001, 50, 20 /* tolerance */)); + } + + @Test + public void testTightEpsilon() { + checkCombineFn( + ApproximateQuantilesCombineFn.create(10).withEpsilon(0.01), + intRange(10001), + quantileMatcher(10001, 10, 5 /* tolerance */)); + } + + @Test + public void testDuplicates() { + int size = 101; + List all = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + all.addAll(intRange(size)); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(0, 25, 50, 75, 100)); + } + + @Test + public void testLotsOfDuplicates() { + List all = new ArrayList<>(); + all.add(1); + for (int i = 1; i < 300; i++) { + all.add(2); + } + for (int i = 300; i < 1000; i++) { + all.add(3); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(1, 2, 3, 3, 3)); + } + + @Test + public void testLogDistribution() { + List all = new ArrayList<>(); + for (int i = 1; i < 1000; i++) { + all.add((int) Math.log(i)); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(0, 5, 6, 6, 6)); + } + + @Test + public void testZipfianDistribution() { + List all = new ArrayList<>(); + for (int i = 1; i < 1000; i++) { + all.add(1000 / i); + } + checkCombineFn( + ApproximateQuantilesCombineFn.create(5), + all, + Arrays.asList(1, 1, 2, 4, 1000)); + } + + @Test + public void testAlternateComparator() { + List inputs = Arrays.asList( + "aa", "aaa", "aaaa", "b", "ccccc", "dddd", "zz"); + checkCombineFn( + ApproximateQuantilesCombineFn.create(3), + inputs, + Arrays.asList("aa", "b", "zz")); + checkCombineFn( + ApproximateQuantilesCombineFn.create(3, new OrderByLength()), + inputs, + Arrays.asList("b", "aaa", "ccccc")); + } + + private Matcher> quantileMatcher( + int size, int numQuantiles, int absoluteError) { + List> quantiles = new ArrayList<>(); + quantiles.add(CoreMatchers.is(0)); + for (int k = 1; k < numQuantiles - 1; k++) { + int expected = (int) (((double) (size - 1)) * k / (numQuantiles - 1)); + quantiles.add(new Between<>( + expected - absoluteError, expected + absoluteError)); + } + quantiles.add(CoreMatchers.is(size - 1)); + return contains(quantiles); + } + + private static class Between> + extends TypeSafeDiagnosingMatcher { + private final T min; + private final T max; + private Between(T min, T max) { + this.min = min; + this.max = max; + } + @Override + public void describeTo(Description description) { + description.appendText("is between " + min + " and " + max); + } + + @Override + protected boolean matchesSafely(T item, Description mismatchDescription) { + return min.compareTo(item) <= 0 && item.compareTo(max) <= 0; + } + } + + private static class DescendingIntComparator implements + SerializableComparator { + @Override + public int compare(Integer o1, Integer o2) { + return o2.compareTo(o1); + } + } + + private static class OrderByLength implements Comparator, Serializable { + @Override + public int compare(String a, String b) { + if (a.length() != b.length()) { + return a.length() - b.length(); + } else { + return a.compareTo(b); + } + } + } + + + private PCollection intRangeCollection(Pipeline p, int size) { + return p.apply("CreateIntsUpTo(" + size + ")", Create.of(intRange(size))); + } + + private List intRange(int size) { + List all = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + all.add(i); + } + return all; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUniqueTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUniqueTest.java new file mode 100644 index 000000000000..39731bb93157 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ApproximateUniqueTest.java @@ -0,0 +1,291 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.TestUtils; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for the ApproximateUnique aggregator transform. + */ +@RunWith(JUnit4.class) +public class ApproximateUniqueTest implements Serializable { + // implements Serializable just to make it easy to use anonymous inner DoFn subclasses + + @Test + public void testEstimationErrorToSampleSize() { + assertEquals(40000, ApproximateUnique.sampleSizeFromEstimationError(0.01)); + assertEquals(10000, ApproximateUnique.sampleSizeFromEstimationError(0.02)); + assertEquals(2500, ApproximateUnique.sampleSizeFromEstimationError(0.04)); + assertEquals(1600, ApproximateUnique.sampleSizeFromEstimationError(0.05)); + assertEquals(400, ApproximateUnique.sampleSizeFromEstimationError(0.1)); + assertEquals(100, ApproximateUnique.sampleSizeFromEstimationError(0.2)); + assertEquals(25, ApproximateUnique.sampleSizeFromEstimationError(0.4)); + assertEquals(16, ApproximateUnique.sampleSizeFromEstimationError(0.5)); + } + + @Test + @Category(RunnableOnService.class) + public void testApproximateUniqueWithSmallInput() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply( + Create.of(Arrays.asList(1, 2, 3, 3))); + + PCollection estimate = input + .apply(ApproximateUnique.globally(1000)); + + DataflowAssert.thatSingleton(estimate).isEqualTo(3L); + + p.run(); + } + + @Test + public void testApproximateUniqueWithDuplicates() { + runApproximateUniqueWithDuplicates(100, 100, 100); + runApproximateUniqueWithDuplicates(1000, 1000, 100); + runApproximateUniqueWithDuplicates(1500, 1000, 100); + runApproximateUniqueWithDuplicates(10000, 1000, 100); + } + + private void runApproximateUniqueWithDuplicates(int elementCount, + int uniqueCount, int sampleSize) { + + assert elementCount >= uniqueCount; + List elements = Lists.newArrayList(); + for (int i = 0; i < elementCount; i++) { + elements.add(1.0 / (i % uniqueCount + 1)); + } + Collections.shuffle(elements); + + Pipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(elements)); + PCollection estimate = + input.apply(ApproximateUnique.globally(sampleSize)); + + DataflowAssert.thatSingleton(estimate).satisfies(new VerifyEstimateFn(uniqueCount, sampleSize)); + + p.run(); + } + + @Test + public void testApproximateUniqueWithSkewedDistributions() { + runApproximateUniqueWithSkewedDistributions(100, 100, 100); + runApproximateUniqueWithSkewedDistributions(10000, 10000, 100); + runApproximateUniqueWithSkewedDistributions(10000, 1000, 100); + runApproximateUniqueWithSkewedDistributions(10000, 200, 100); + } + + @Test + public void testApproximateUniqueWithSkewedDistributionsAndLargeSampleSize() { + runApproximateUniqueWithSkewedDistributions(10000, 2000, 1000); + } + + private void runApproximateUniqueWithSkewedDistributions(int elementCount, + final int uniqueCount, final int sampleSize) { + List elements = Lists.newArrayList(); + // Zipf distribution with approximately elementCount items. + double s = 1 - 1.0 * uniqueCount / elementCount; + double maxCount = Math.pow(uniqueCount, s); + for (int k = 0; k < uniqueCount; k++) { + int count = Math.max(1, (int) Math.round(maxCount * Math.pow(k, -s))); + // Element k occurs count times. + for (int c = 0; c < count; c++) { + elements.add(k); + } + } + + Pipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(elements)); + PCollection estimate = + input.apply(ApproximateUnique.globally(sampleSize)); + + DataflowAssert.thatSingleton(estimate).satisfies(new VerifyEstimateFn(uniqueCount, sampleSize)); + + p.run(); + } + + @Test + public void testApproximateUniquePerKey() { + List> elements = Lists.newArrayList(); + List keys = ImmutableList.of(20L, 50L, 100L); + int elementCount = 1000; + int sampleSize = 100; + // Use the key as the number of unique values. + for (long uniqueCount : keys) { + for (long value = 0; value < elementCount; value++) { + elements.add(KV.of(uniqueCount, value % uniqueCount)); + } + } + + Pipeline p = TestPipeline.create(); + PCollection> input = p.apply(Create.of(elements)); + PCollection> counts = + input.apply(ApproximateUnique.perKey(sampleSize)); + + DataflowAssert.that(counts).satisfies(new VerifyEstimatePerKeyFn(sampleSize)); + + p.run(); + + } + + /** + * Applies {@link ApproximateUnique} for different sample sizes and verifies + * that the estimation error falls within the maximum allowed error of + * {@code 2 / sqrt(sampleSize)}. + */ + @Test + public void testApproximateUniqueWithDifferentSampleSizes() { + runApproximateUniquePipeline(16); + runApproximateUniquePipeline(64); + runApproximateUniquePipeline(128); + runApproximateUniquePipeline(256); + runApproximateUniquePipeline(512); + runApproximateUniquePipeline(1000); + runApproximateUniquePipeline(1024); + try { + runApproximateUniquePipeline(15); + fail("Accepted sampleSize < 16"); + } catch (IllegalArgumentException e) { + assertTrue("Expected an exception due to sampleSize < 16", e.getMessage() + .startsWith("ApproximateUnique needs a sampleSize >= 16")); + } + } + + @Test + public void testApproximateUniqueGetName() { + assertEquals("ApproximateUnique.PerKey", ApproximateUnique.perKey(16).getName()); + assertEquals("ApproximateUnique.Globally", ApproximateUnique.globally(16).getName()); + } + + /** + * Applies {@code ApproximateUnique(sampleSize)} verifying that the estimation + * error falls within the maximum allowed error of {@code 2/sqrt(sampleSize)}. + */ + private static void runApproximateUniquePipeline(int sampleSize) { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(TEST_LINES)); + PCollection approximate = input.apply(ApproximateUnique.globally(sampleSize)); + final PCollectionView exact = + input + .apply(RemoveDuplicates.create()) + .apply(Count.globally()) + .apply(View.asSingleton()); + + PCollection> approximateAndExact = approximate + .apply(ParDo.of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), c.sideInput(exact))); + } + }) + .withSideInputs(exact)); + + DataflowAssert.that(approximateAndExact).satisfies(new VerifyEstimatePerKeyFn(sampleSize)); + + p.run(); + } + + private static final int TEST_PAGES = 100; + private static final List TEST_LINES = + new ArrayList<>(TEST_PAGES * TestUtils.LINES.size()); + + static { + for (int i = 0; i < TEST_PAGES; i++) { + TEST_LINES.addAll(TestUtils.LINES); + } + } + + /** + * Checks that the estimation error, i.e., the difference between + * {@code uniqueCount} and {@code estimate} is less than + * {@code 2 / sqrt(sampleSize}). + */ + private static void verifyEstimate(long uniqueCount, int sampleSize, long estimate) { + if (uniqueCount < sampleSize) { + assertEquals("Number of hashes is less than the sample size. " + + "Estimate should be exact", uniqueCount, estimate); + } + + double error = 100.0 * Math.abs(estimate - uniqueCount) / uniqueCount; + double maxError = 100.0 * 2 / Math.sqrt(sampleSize); + + assertTrue("Estimate= " + estimate + " Actual=" + uniqueCount + " Error=" + + error + "%, MaxError=" + maxError + "%.", error < maxError); + + assertTrue("Estimate= " + estimate + " Actual=" + uniqueCount + " Error=" + + error + "%, MaxError=" + maxError + "%.", error < maxError); + } + + private static class VerifyEstimateFn implements SerializableFunction { + private long uniqueCount; + private int sampleSize; + + public VerifyEstimateFn(long uniqueCount, int sampleSize) { + this.uniqueCount = uniqueCount; + this.sampleSize = sampleSize; + } + + @Override + public Void apply(Long estimate) { + verifyEstimate(uniqueCount, sampleSize, estimate); + return null; + } + } + + private static class VerifyEstimatePerKeyFn + implements SerializableFunction>, Void> { + + private int sampleSize; + + public VerifyEstimatePerKeyFn(int sampleSize) { + this.sampleSize = sampleSize; + } + + @Override + public Void apply(Iterable> estimatePerKey) { + for (KV result : estimatePerKey) { + verifyEstimate(result.getKey(), sampleSize, result.getValue()); + } + return null; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineTest.java new file mode 100644 index 000000000000..37ed20415fbb --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineTest.java @@ -0,0 +1,1137 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterPane; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Repeatedly; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; +import com.google.cloud.dataflow.sdk.util.PerKeyCombineFnRunners; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Random; +import java.util.Set; + +/** + * Tests for Combine transforms. + */ +@RunWith(JUnit4.class) +public class CombineTest implements Serializable { + // This test is Serializable, just so that it's easy to have + // anonymous inner classes inside the non-static test methods. + + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] TABLE = new KV[] { + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13), + }; + + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] EMPTY_TABLE = new KV[] { + }; + + static final Integer[] NUMBERS = new Integer[] { + 1, 1, 2, 3, 5, 8, 13, 21, 34, 55 + }; + + @Mock private DoFn.ProcessContext processContext; + + PCollection> createInput(Pipeline p, + KV[] table) { + return p.apply(Create.of(Arrays.asList(table)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + } + + private void runTestSimpleCombine(KV[] table, + int globalSum, + KV[] perKeyCombines) { + Pipeline p = TestPipeline.create(); + PCollection> input = createInput(p, table); + + PCollection sum = input + .apply(Values.create()) + .apply(Combine.globally(new SumInts())); + + // Java 8 will infer. + PCollection> sumPerKey = input + .apply(Combine.perKey(new TestKeyedCombineFn())); + + DataflowAssert.that(sum).containsInAnyOrder(globalSum); + DataflowAssert.that(sumPerKey).containsInAnyOrder(perKeyCombines); + + p.run(); + } + + private void runTestSimpleCombineWithContext(KV[] table, + int globalSum, + KV[] perKeyCombines, + String[] globallyCombines) { + Pipeline p = TestPipeline.create(); + PCollection> perKeyInput = createInput(p, table); + PCollection globallyInput = perKeyInput.apply(Values.create()); + + PCollection sum = globallyInput.apply("Sum", Combine.globally(new SumInts())); + + PCollectionView globallySumView = sum.apply(View.asSingleton()); + + // Java 8 will infer. + PCollection> combinePerKey = perKeyInput + .apply(Combine.perKey(new TestKeyedCombineFnWithContext(globallySumView)) + .withSideInputs(Arrays.asList(globallySumView))); + + PCollection combineGlobally = globallyInput + .apply(Combine.globally(new TestKeyedCombineFnWithContext(globallySumView) + .forKey("G", StringUtf8Coder.of())) + .withoutDefaults() + .withSideInputs(Arrays.asList(globallySumView))); + + DataflowAssert.that(sum).containsInAnyOrder(globalSum); + DataflowAssert.that(combinePerKey).containsInAnyOrder(perKeyCombines); + DataflowAssert.that(combineGlobally).containsInAnyOrder(globallyCombines); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testSimpleCombine() { + runTestSimpleCombine(TABLE, 20, new KV[] { + KV.of("a", "114a"), KV.of("b", "113b") }); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testSimpleCombineWithContext() { + runTestSimpleCombineWithContext(TABLE, 20, + new KV[] {KV.of("a", "01124a"), KV.of("b", "01123b") }, + new String[] {"01111234G"}); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testSimpleCombineWithContextEmpty() { + runTestSimpleCombineWithContext(EMPTY_TABLE, 0, new KV[] {}, new String[] {}); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testSimpleCombineEmpty() { + runTestSimpleCombine(EMPTY_TABLE, 0, new KV[] { }); + } + + @SuppressWarnings("unchecked") + private void runTestBasicCombine(KV[] table, + Set globalUnique, + KV>[] perKeyUnique) { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(Set.class, SetCoder.class); + PCollection> input = createInput(p, table); + + PCollection> unique = input + .apply(Values.create()) + .apply(Combine.globally(new UniqueInts())); + + // Java 8 will infer. + PCollection>> uniquePerKey = input + .apply(Combine.>perKey(new UniqueInts())); + + DataflowAssert.that(unique).containsInAnyOrder(globalUnique); + DataflowAssert.that(uniquePerKey).containsInAnyOrder(perKeyUnique); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings({"rawtypes", "unchecked"}) + public void testBasicCombine() { + runTestBasicCombine(TABLE, ImmutableSet.of(1, 13, 4), new KV[] { + KV.of("a", (Set) ImmutableSet.of(1, 4)), + KV.of("b", (Set) ImmutableSet.of(1, 13)) }); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings("rawtypes") + public void testBasicCombineEmpty() { + runTestBasicCombine(EMPTY_TABLE, ImmutableSet.of(), new KV[] { }); + } + + private void runTestAccumulatingCombine(KV[] table, + Double globalMean, + KV[] perKeyMeans) { + Pipeline p = TestPipeline.create(); + PCollection> input = createInput(p, table); + + PCollection mean = input + .apply(Values.create()) + .apply(Combine.globally(new MeanInts())); + + // Java 8 will infer. + PCollection> meanPerKey = input.apply( + Combine.perKey(new MeanInts())); + + DataflowAssert.that(mean).containsInAnyOrder(globalMean); + DataflowAssert.that(meanPerKey).containsInAnyOrder(perKeyMeans); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFixedWindowsCombine() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.timestamped(Arrays.asList(TABLE), + Arrays.asList(0L, 1L, 6L, 7L, 8L)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(FixedWindows.of(Duration.millis(2)))); + + PCollection sum = input + .apply(Values.create()) + .apply(Combine.globally(new SumInts()).withoutDefaults()); + + PCollection> sumPerKey = input + .apply(Combine.perKey(new TestKeyedCombineFn())); + + DataflowAssert.that(sum).containsInAnyOrder(2, 5, 13); + DataflowAssert.that(sumPerKey).containsInAnyOrder( + KV.of("a", "11a"), + KV.of("a", "4a"), + KV.of("b", "1b"), + KV.of("b", "13b")); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFixedWindowsCombineWithContext() { + Pipeline p = TestPipeline.create(); + + PCollection> perKeyInput = + p.apply(Create.timestamped(Arrays.asList(TABLE), + Arrays.asList(0L, 1L, 6L, 7L, 8L)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(FixedWindows.of(Duration.millis(2)))); + + PCollection globallyInput = perKeyInput.apply(Values.create()); + + PCollection sum = globallyInput + .apply("Sum", Combine.globally(new SumInts()).withoutDefaults()); + + PCollectionView globallySumView = sum.apply(View.asSingleton()); + + PCollection> combinePerKeyWithContext = perKeyInput + .apply(Combine.perKey(new TestKeyedCombineFnWithContext(globallySumView)) + .withSideInputs(Arrays.asList(globallySumView))); + + PCollection combineGloballyWithContext = globallyInput + .apply(Combine.globally(new TestKeyedCombineFnWithContext(globallySumView) + .forKey("G", StringUtf8Coder.of())) + .withoutDefaults() + .withSideInputs(Arrays.asList(globallySumView))); + + DataflowAssert.that(sum).containsInAnyOrder(2, 5, 13); + DataflowAssert.that(combinePerKeyWithContext).containsInAnyOrder( + KV.of("a", "112a"), + KV.of("a", "45a"), + KV.of("b", "15b"), + KV.of("b", "1133b")); + DataflowAssert.that(combineGloballyWithContext).containsInAnyOrder("112G", "145G", "1133G"); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSlidingWindowsCombineWithContext() { + Pipeline p = TestPipeline.create(); + + PCollection> perKeyInput = + p.apply(Create.timestamped(Arrays.asList(TABLE), + Arrays.asList(2L, 3L, 8L, 9L, 10L)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(SlidingWindows.of(Duration.millis(2)))); + + PCollection globallyInput = perKeyInput.apply(Values.create()); + + PCollection sum = globallyInput + .apply("Sum", Combine.globally(new SumInts()).withoutDefaults()); + + PCollectionView globallySumView = sum.apply(View.asSingleton()); + + PCollection> combinePerKeyWithContext = perKeyInput + .apply(Combine.perKey(new TestKeyedCombineFnWithContext(globallySumView)) + .withSideInputs(Arrays.asList(globallySumView))); + + PCollection combineGloballyWithContext = globallyInput + .apply(Combine.globally(new TestKeyedCombineFnWithContext(globallySumView) + .forKey("G", StringUtf8Coder.of())) + .withoutDefaults() + .withSideInputs(Arrays.asList(globallySumView))); + + DataflowAssert.that(sum).containsInAnyOrder(1, 2, 1, 4, 5, 14, 13); + DataflowAssert.that(combinePerKeyWithContext).containsInAnyOrder( + KV.of("a", "11a"), + KV.of("a", "112a"), + KV.of("a", "11a"), + KV.of("a", "44a"), + KV.of("a", "45a"), + KV.of("b", "15b"), + KV.of("b", "11134b"), + KV.of("b", "1133b")); + DataflowAssert.that(combineGloballyWithContext).containsInAnyOrder( + "11G", "112G", "11G", "44G", "145G", "11134G", "1133G"); + p.run(); + } + + private static class FormatPaneInfo extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + ": " + c.pane().isLast()); + } + } + + @Test + @Category(RunnableOnService.class) + public void testGlobalCombineWithDefaultsAndTriggers() { + Pipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(1, 1)); + + PCollection output = input + .apply(Window.into(new GlobalWindows()) + .triggering(AfterPane.elementCountAtLeast(1)) + .accumulatingFiredPanes() + .withAllowedLateness(new Duration(0))) + .apply(Sum.integersGlobally()) + .apply(ParDo.of(new FormatPaneInfo())); + + DataflowAssert.that(output).containsInAnyOrder("1: false", "2: true"); + } + + @Test + @Category(RunnableOnService.class) + public void testSessionsCombine() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.timestamped(Arrays.asList(TABLE), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(Sessions.withGapDuration(Duration.millis(5)))); + + PCollection sum = input + .apply(Values.create()) + .apply(Combine.globally(new SumInts()).withoutDefaults()); + + PCollection> sumPerKey = input + .apply(Combine.perKey(new TestKeyedCombineFn())); + + DataflowAssert.that(sum).containsInAnyOrder(7, 13); + DataflowAssert.that(sumPerKey).containsInAnyOrder( + KV.of("a", "114a"), + KV.of("b", "1b"), + KV.of("b", "13b")); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSessionsCombineWithContext() { + Pipeline p = TestPipeline.create(); + + PCollection> perKeyInput = + p.apply(Create.timestamped(Arrays.asList(TABLE), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection globallyInput = perKeyInput.apply(Values.create()); + + PCollection fixedWindowsSum = globallyInput + .apply("FixedWindows", + Window.into(FixedWindows.of(Duration.millis(5)))) + .apply("Sum", Combine.globally(new SumInts()).withoutDefaults()); + + PCollectionView globallyFixedWindowsView = + fixedWindowsSum.apply(View.asSingleton().withDefaultValue(0)); + + PCollection> sessionsCombinePerKey = perKeyInput + .apply("PerKey Input Sessions", + Window.>into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(Combine.perKey(new TestKeyedCombineFnWithContext(globallyFixedWindowsView)) + .withSideInputs(Arrays.asList(globallyFixedWindowsView))); + + PCollection sessionsCombineGlobally = globallyInput + .apply("Globally Input Sessions", + Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(Combine.globally(new TestKeyedCombineFnWithContext(globallyFixedWindowsView) + .forKey("G", StringUtf8Coder.of())) + .withoutDefaults() + .withSideInputs(Arrays.asList(globallyFixedWindowsView))); + + DataflowAssert.that(fixedWindowsSum).containsInAnyOrder(2, 4, 1, 13); + DataflowAssert.that(sessionsCombinePerKey).containsInAnyOrder( + KV.of("a", "1114a"), + KV.of("b", "11b"), + KV.of("b", "013b")); + DataflowAssert.that(sessionsCombineGlobally).containsInAnyOrder("11114G", "013G"); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedCombineEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection mean = p + .apply(Create.of().withCoder(BigEndianIntegerCoder.of())) + .apply(Window.into(FixedWindows.of(Duration.millis(1)))) + .apply(Combine.globally(new MeanInts()).withoutDefaults()); + + DataflowAssert.that(mean).empty(); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testAccumulatingCombine() { + runTestAccumulatingCombine(TABLE, 4.0, new KV[] { + KV.of("a", 2.0), KV.of("b", 7.0) }); + } + + @Test + @Category(RunnableOnService.class) + public void testAccumulatingCombineEmpty() { + runTestAccumulatingCombine(EMPTY_TABLE, 0.0, new KV[] { }); + } + + // Checks that Min, Max, Mean, Sum (operations that pass-through to Combine), + // provide their own top-level name. + @Test + public void testCombinerNames() { + Combine.PerKey min = Min.integersPerKey(); + Combine.PerKey max = Max.integersPerKey(); + Combine.PerKey mean = Mean.perKey(); + Combine.PerKey sum = Sum.integersPerKey(); + + assertThat(min.getName(), Matchers.startsWith("Min")); + assertThat(max.getName(), Matchers.startsWith("Max")); + assertThat(mean.getName(), Matchers.startsWith("Mean")); + assertThat(sum.getName(), Matchers.startsWith("Sum")); + } + + @Test + public void testAddInputsRandomly() { + TestCounter counter = new TestCounter(); + Combine.KeyedCombineFn< + String, Integer, TestCounter.Counter, Iterable> fn = + counter.asKeyedFn(); + + List accums = DirectPipelineRunner.TestCombineDoFn.addInputsRandomly( + PerKeyCombineFnRunners.create(fn), "bob", Arrays.asList(NUMBERS), new Random(42), + processContext); + + assertThat(accums, Matchers.contains( + counter.new Counter(3, 2, 0, 0), + counter.new Counter(131, 5, 0, 0), + counter.new Counter(8, 2, 0, 0), + counter.new Counter(1, 1, 0, 0))); + } + + private static final SerializableFunction hotKeyFanout = + new SerializableFunction() { + @Override + public Integer apply(String input) { + return input.equals("a") ? 3 : 0; + } + }; + + private static final SerializableFunction splitHotKeyFanout = + new SerializableFunction() { + @Override + public Integer apply(String input) { + return Math.random() < 0.5 ? 3 : 0; + } + }; + + @Test + @Category(RunnableOnService.class) + public void testHotKeyCombining() { + Pipeline p = TestPipeline.create(); + PCollection> input = copy(createInput(p, TABLE), 10); + + KeyedCombineFn mean = + new MeanInts().asKeyedFn(); + PCollection> coldMean = input.apply("ColdMean", + Combine.perKey(mean).withHotKeyFanout(0)); + PCollection> warmMean = input.apply("WarmMean", + Combine.perKey(mean).withHotKeyFanout(hotKeyFanout)); + PCollection> hotMean = input.apply("HotMean", + Combine.perKey(mean).withHotKeyFanout(5)); + PCollection> splitMean = input.apply("SplitMean", + Combine.perKey(mean).withHotKeyFanout(splitHotKeyFanout)); + + List> expected = Arrays.asList(KV.of("a", 2.0), KV.of("b", 7.0)); + DataflowAssert.that(coldMean).containsInAnyOrder(expected); + DataflowAssert.that(warmMean).containsInAnyOrder(expected); + DataflowAssert.that(hotMean).containsInAnyOrder(expected); + DataflowAssert.that(splitMean).containsInAnyOrder(expected); + + p.run(); + } + + private static class GetLast extends DoFn { + @Override + public void processElement(ProcessContext c) { + if (c.pane().isLast()) { + c.output(c.element()); + } + } + } + + @Test + @Category(RunnableOnService.class) + public void testHotKeyCombiningWithAccumulationMode() { + Pipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(1, 2, 3, 4, 5)); + + PCollection output = input + .apply(Window.into(new GlobalWindows()) + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .accumulatingFiredPanes() + .withAllowedLateness(new Duration(0), ClosingBehavior.FIRE_ALWAYS)) + .apply(Sum.integersGlobally().withoutDefaults().withFanout(2)) + .apply(ParDo.of(new GetLast())); + + DataflowAssert.that(output).containsInAnyOrder(15); + + p.run(); + } + + @Test + public void testBinaryCombineFn() { + Pipeline p = TestPipeline.create(); + PCollection> input = copy(createInput(p, TABLE), 2); + PCollection> intProduct = input + .apply("IntProduct", Combine.perKey(new TestProdInt())); + PCollection> objProduct = input + .apply("ObjProduct", Combine.perKey(new TestProdObj())); + + List> expected = Arrays.asList(KV.of("a", 16), KV.of("b", 169)); + DataflowAssert.that(intProduct).containsInAnyOrder(expected); + DataflowAssert.that(objProduct).containsInAnyOrder(expected); + + p.run(); + } + + @Test + public void testBinaryCombineFnWithNulls() { + checkCombineFn(new NullCombiner(), Arrays.asList(3, 3, 5), 45); + checkCombineFn(new NullCombiner(), Arrays.asList(null, 3, 5), 30); + checkCombineFn(new NullCombiner(), Arrays.asList(3, 3, null), 18); + checkCombineFn(new NullCombiner(), Arrays.asList(null, 3, null), 12); + checkCombineFn(new NullCombiner(), Arrays.asList(null, null, null), 8); + } + + private static final class TestProdInt extends Combine.BinaryCombineIntegerFn { + @Override + public int apply(int left, int right) { + return left * right; + } + + @Override + public int identity() { + return 1; + } + } + + private static final class TestProdObj extends Combine.BinaryCombineFn { + @Override + public Integer apply(Integer left, Integer right) { + return left * right; + } + } + + /** + * Computes the product, considering null values to be 2. + */ + private static final class NullCombiner extends Combine.BinaryCombineFn { + @Override + public Integer apply(Integer left, Integer right) { + return (left == null ? 2 : left) * (right == null ? 2 : right); + } + } + + @Test + @Category(RunnableOnService.class) + public void testCombineGloballyAsSingletonView() { + Pipeline p = TestPipeline.create(); + final PCollectionView view = p + .apply("CreateEmptySideInput", Create.of().withCoder(BigEndianIntegerCoder.of())) + .apply(Sum.integersGlobally().asSingletonView()); + + PCollection output = p + .apply("CreateVoidMainInput", Create.of((Void) null)) + .apply("OutputSideInput", ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + }).withSideInputs(view)); + + DataflowAssert.thatSingleton(output).isEqualTo(0); + p.run(); + } + + @Test + public void testCombineGetName() { + assertEquals("Combine.Globally", Combine.globally(new SumInts()).getName()); + assertEquals( + "MyCombineGlobally", Combine.globally(new SumInts()).named("MyCombineGlobally").getName()); + assertEquals( + "Combine.GloballyAsSingletonView", + Combine.globally(new SumInts()).asSingletonView().getName()); + assertEquals("Combine.PerKey", Combine.perKey(new TestKeyedCombineFn()).getName()); + assertEquals( + "MyCombinePerKey", + Combine.perKey(new TestKeyedCombineFn()).named("MyCombinePerKey").getName()); + assertEquals( + "Combine.PerKeyWithHotKeyFanout", + Combine.perKey(new TestKeyedCombineFn()).withHotKeyFanout(10).getName()); + } + + //////////////////////////////////////////////////////////////////////////// + // Test classes, for different kinds of combining fns. + + /** Example SerializableFunction combiner. */ + public static class SumInts + implements SerializableFunction, Integer> { + @Override + public Integer apply(Iterable input) { + int sum = 0; + for (int item : input) { + sum += item; + } + return sum; + } + } + + /** Example CombineFn. */ + public static class UniqueInts extends + Combine.CombineFn, Set> { + + @Override + public Set createAccumulator() { + return new HashSet<>(); + } + + @Override + public Set addInput(Set accumulator, Integer input) { + accumulator.add(input); + return accumulator; + } + + @Override + public Set mergeAccumulators(Iterable> accumulators) { + Set all = new HashSet<>(); + for (Set part : accumulators) { + all.addAll(part); + } + return all; + } + + @Override + public Set extractOutput(Set accumulator) { + return accumulator; + } + } + + // Note: not a deterministic encoding + private static class SetCoder extends StandardCoder> { + + public static SetCoder of(Coder elementCoder) { + return new SetCoder<>(elementCoder); + } + + @JsonCreator + public static SetCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + Preconditions.checkArgument(components.size() == 1, + "Expecting 1 component, got " + components.size()); + return of((Coder) components.get(0)); + } + + @SuppressWarnings("unused") // required for coder instantiation + public static List getInstanceComponents(Set exampleValue) { + return IterableCoder.getInstanceComponents(exampleValue); + } + + private final Coder> iterableCoder; + + private SetCoder(Coder elementCoder) { + iterableCoder = IterableCoder.of(elementCoder); + } + + @Override + public void encode(Set value, OutputStream outStream, Context context) + throws CoderException, IOException { + iterableCoder.encode(value, outStream, context); + } + + @Override + public Set decode(InputStream inStream, Context context) + throws CoderException, IOException { + // TODO: Eliminate extra copy if used in production. + return Sets.newHashSet(iterableCoder.decode(inStream, context)); + } + + @Override + public List> getCoderArguments() { + return iterableCoder.getCoderArguments(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "CombineTest.SetCoder does not encode in a deterministic order."); + } + + @Override + public boolean isRegisterByteSizeObserverCheap(Set value, Context context) { + return iterableCoder.isRegisterByteSizeObserverCheap(value, context); + } + + @Override + public void registerByteSizeObserver( + Set value, ElementByteSizeObserver observer, Context context) + throws Exception { + iterableCoder.registerByteSizeObserver(value, observer, context); + } + } + + /** Example AccumulatingCombineFn. */ + private static class MeanInts extends + Combine.AccumulatingCombineFn { + private static final Coder LONG_CODER = BigEndianLongCoder.of(); + private static final Coder DOUBLE_CODER = DoubleCoder.of(); + + class CountSum implements + Combine.AccumulatingCombineFn.Accumulator { + long count = 0; + double sum = 0.0; + + CountSum(long count, double sum) { + this.count = count; + this.sum = sum; + } + + @Override + public void addInput(Integer element) { + count++; + sum += element.doubleValue(); + } + + @Override + public void mergeAccumulator(CountSum accumulator) { + count += accumulator.count; + sum += accumulator.sum; + } + + @Override + public Double extractOutput() { + return count == 0 ? 0.0 : sum / count; + } + + @Override + public int hashCode() { + return Objects.hash(count, sum); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof CountSum)) { + return false; + } + CountSum other = (CountSum) obj; + return this.count == other.count + && (Math.abs(this.sum - other.sum) < 0.1); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("count", count) + .add("sum", sum) + .toString(); + } + } + + @Override + public CountSum createAccumulator() { + return new CountSum(0, 0.0); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + return new CountSumCoder(); + } + + /** + * A {@link Coder} for {@link CountSum}. + */ + private class CountSumCoder extends CustomCoder { + @Override + public void encode(CountSum value, OutputStream outStream, + Context context) throws CoderException, IOException { + LONG_CODER.encode(value.count, outStream, context); + DOUBLE_CODER.encode(value.sum, outStream, context); + } + + @Override + public CountSum decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + long count = LONG_CODER.decode(inStream, context); + double sum = DOUBLE_CODER.decode(inStream, context); + return new CountSum(count, sum); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { } + + @Override + public boolean isRegisterByteSizeObserverCheap( + CountSum value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + CountSum value, ElementByteSizeObserver observer, Context context) + throws Exception { + LONG_CODER.registerByteSizeObserver(value.count, observer, context); + DOUBLE_CODER.registerByteSizeObserver(value.sum, observer, context); + } + } + } + + /** + * A KeyedCombineFn that exercises the full generality of [Keyed]CombineFn. + * + *

    The net result of applying this CombineFn is a sorted list of all + * characters occurring in the key and the decimal representations of + * each value. + */ + public static class TestKeyedCombineFn + extends KeyedCombineFn { + + // Not serializable. + static class Accumulator { + String value; + public Accumulator(String value) { + this.value = value; + } + + public static Coder getCoder() { + return new CustomCoder() { + @Override + public void encode(Accumulator accumulator, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + StringUtf8Coder.of().encode(accumulator.value, outStream, context); + } + + @Override + public Accumulator decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + return new Accumulator(StringUtf8Coder.of().decode(inStream, context)); + } + + @Override + public String getEncodingId() { + return "CombineTest.TestKeyedCombineFn.getAccumulatorCoder()"; + } + }; + } + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) { + return Accumulator.getCoder(); + } + + @Override + public Accumulator createAccumulator(String key) { + return new Accumulator(key); + } + + @Override + public Accumulator addInput(String key, Accumulator accumulator, Integer value) { + checkNotNull(key); + try { + assertThat(accumulator.value, Matchers.startsWith(key)); + return new Accumulator(accumulator.value + String.valueOf(value)); + } finally { + accumulator.value = "cleared in addInput"; + } + } + + @Override + public Accumulator mergeAccumulators(String key, Iterable accumulators) { + String all = key; + for (Accumulator accumulator : accumulators) { + assertThat(accumulator.value, Matchers.startsWith(key)); + all += accumulator.value.substring(key.length()); + accumulator.value = "cleared in mergeAccumulators"; + } + return new Accumulator(all); + } + + @Override + public String extractOutput(String key, Accumulator accumulator) { + assertThat(accumulator.value, Matchers.startsWith(key)); + char[] chars = accumulator.value.toCharArray(); + Arrays.sort(chars); + return new String(chars); + } + } + + /** + * A {@link KeyedCombineFnWithContext} that exercises the full generality + * of [Keyed]CombineFnWithContext. + * + *

    The net result of applying this CombineFn is a sorted list of all + * characters occurring in the key and the decimal representations of + * main and side inputs values. + */ + public class TestKeyedCombineFnWithContext + extends KeyedCombineFnWithContext { + private final PCollectionView view; + + public TestKeyedCombineFnWithContext(PCollectionView view) { + this.view = view; + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder inputCoder) { + return TestKeyedCombineFn.Accumulator.getCoder(); + } + + @Override + public TestKeyedCombineFn.Accumulator createAccumulator(String key, Context c) { + return new TestKeyedCombineFn.Accumulator(key + c.sideInput(view).toString()); + } + + @Override + public TestKeyedCombineFn.Accumulator addInput( + String key, TestKeyedCombineFn.Accumulator accumulator, Integer value, Context c) { + try { + assertThat(accumulator.value, Matchers.startsWith(key + c.sideInput(view).toString())); + return new TestKeyedCombineFn.Accumulator(accumulator.value + String.valueOf(value)); + } finally { + accumulator.value = "cleared in addInput"; + } + + } + + @Override + public TestKeyedCombineFn.Accumulator mergeAccumulators( + String key, Iterable accumulators, Context c) { + String keyPrefix = key + c.sideInput(view).toString(); + String all = keyPrefix; + for (TestKeyedCombineFn.Accumulator accumulator : accumulators) { + assertThat(accumulator.value, Matchers.startsWith(keyPrefix)); + all += accumulator.value.substring(keyPrefix.length()); + accumulator.value = "cleared in mergeAccumulators"; + } + return new TestKeyedCombineFn.Accumulator(all); + } + + @Override + public String extractOutput(String key, TestKeyedCombineFn.Accumulator accumulator, Context c) { + assertThat(accumulator.value, Matchers.startsWith(key + c.sideInput(view).toString())); + char[] chars = accumulator.value.toCharArray(); + Arrays.sort(chars); + return new String(chars); + } + } + + /** Another example AccumulatingCombineFn. */ + public static class TestCounter extends + Combine.AccumulatingCombineFn< + Integer, TestCounter.Counter, Iterable> { + + /** An accumulator that observes its merges and outputs. */ + public class Counter implements + Combine.AccumulatingCombineFn.Accumulator>, + Serializable { + + public long sum = 0; + public long inputs = 0; + public long merges = 0; + public long outputs = 0; + + public Counter(long sum, long inputs, long merges, long outputs) { + this.sum = sum; + this.inputs = inputs; + this.merges = merges; + this.outputs = outputs; + } + + @Override + public void addInput(Integer element) { + Preconditions.checkState(merges == 0); + Preconditions.checkState(outputs == 0); + + inputs++; + sum += element; + } + + @Override + public void mergeAccumulator(Counter accumulator) { + Preconditions.checkState(outputs == 0); + Preconditions.checkArgument(accumulator.outputs == 0); + + merges += accumulator.merges + 1; + inputs += accumulator.inputs; + sum += accumulator.sum; + } + + @Override + public Iterable extractOutput() { + Preconditions.checkState(outputs == 0); + + return Arrays.asList(sum, inputs, merges, outputs); + } + + @Override + public int hashCode() { + return (int) (sum * 17 + inputs * 31 + merges * 43 + outputs * 181); + } + + @Override + public boolean equals(Object otherObj) { + if (otherObj instanceof Counter) { + Counter other = (Counter) otherObj; + return (sum == other.sum + && inputs == other.inputs + && merges == other.merges + && outputs == other.outputs); + } + return false; + } + + @Override + public String toString() { + return sum + ":" + inputs + ":" + merges + ":" + outputs; + } + } + + @Override + public Counter createAccumulator() { + return new Counter(0, 0, 0, 0); + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) { + // This is a *very* inefficient encoding to send over the wire, but suffices + // for tests. + return SerializableCoder.of(Counter.class); + } + } + + private static PCollection copy(PCollection pc, final int n) { + return pc.apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + for (int i = 0; i < n; i++) { + c.output(c.element()); + } + } + })); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CountTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CountTest.java new file mode 100644 index 000000000000..0fe554787296 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CountTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for Count. + */ +@RunWith(JUnit4.class) +public class CountTest { + static final String[] WORDS_ARRAY = new String[] { + "hi", "there", "hi", "hi", "sue", "bob", + "hi", "sue", "", "", "ZOW", "bob", "" }; + + static final List WORDS = Arrays.asList(WORDS_ARRAY); + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings("unchecked") + public void testCountPerElementBasic() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(WORDS)); + + PCollection> output = + input.apply(Count.perElement()); + + DataflowAssert.that(output) + .containsInAnyOrder( + KV.of("hi", 4L), + KV.of("there", 1L), + KV.of("sue", 2L), + KV.of("bob", 2L), + KV.of("", 3L), + KV.of("ZOW", 1L)); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + @SuppressWarnings("unchecked") + public void testCountPerElementEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(NO_LINES).withCoder(StringUtf8Coder.of())); + + PCollection> output = + input.apply(Count.perElement()); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCountGloballyBasic() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(WORDS)); + + PCollection output = + input.apply(Count.globally()); + + DataflowAssert.that(output) + .containsInAnyOrder(13L); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCountGloballyEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(NO_LINES).withCoder(StringUtf8Coder.of())); + + PCollection output = + input.apply(Count.globally()); + + DataflowAssert.that(output) + .containsInAnyOrder(0L); + p.run(); + } + + @Test + public void testCountGetName() { + assertEquals("Count.PerElement", Count.perElement().getName()); + assertEquals("Count.Globally", Count.globally().getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CreateTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CreateTest.java new file mode 100644 index 000000000000..1af072112aba --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CreateTest.java @@ -0,0 +1,240 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES_ARRAY; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for Create. + */ +@RunWith(JUnit4.class) +@SuppressWarnings("unchecked") +public class CreateTest { + @Rule public final ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(RunnableOnService.class) + public void testCreate() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(LINES)); + + DataflowAssert.that(output) + .containsInAnyOrder(LINES_ARRAY); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(NO_LINES) + .withCoder(StringUtf8Coder.of())); + + DataflowAssert.that(output) + .containsInAnyOrder(NO_LINES_ARRAY); + p.run(); + } + + @Test + public void testCreateEmptyInfersCoder() { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of()); + + assertEquals(VoidCoder.of(), output.getCoder()); + } + + static class Record implements Serializable { + } + + static class Record2 extends Record { + } + + @Test + public void testPolymorphicType() throws Exception { + thrown.expect(RuntimeException.class); + thrown.expectMessage( + Matchers.containsString("Unable to infer a coder")); + + Pipeline p = TestPipeline.create(); + + // Create won't infer a default coder in this case. + p.apply(Create.of(new Record(), new Record2())); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateWithNullsAndValues() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection output = + p.apply(Create.of(null, "test1", null, "test2", null) + .withCoder(SerializableCoder.of(String.class))); + DataflowAssert.that(output) + .containsInAnyOrder(null, "test1", null, "test2", null); + p.run(); + } + + @Test + public void testCreateParameterizedType() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection> output = + p.apply(Create.of( + TimestampedValue.of("a", new Instant(0)), + TimestampedValue.of("b", new Instant(0)))); + + DataflowAssert.that(output) + .containsInAnyOrder( + TimestampedValue.of("a", new Instant(0)), + TimestampedValue.of("b", new Instant(0))); + } + private static class PrintTimestamps extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + ":" + c.timestamp().getMillis()); + } + } + + @Test + @Category(RunnableOnService.class) + public void testCreateTimestamped() { + Pipeline p = TestPipeline.create(); + + List> data = Arrays.asList( + TimestampedValue.of("a", new Instant(1L)), + TimestampedValue.of("b", new Instant(2L)), + TimestampedValue.of("c", new Instant(3L))); + + PCollection output = + p.apply(Create.timestamped(data)) + .apply(ParDo.of(new PrintTimestamps())); + + DataflowAssert.that(output) + .containsInAnyOrder("a:1", "b:2", "c:3"); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateTimestampedEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.timestamped(new ArrayList>()) + .withCoder(StringUtf8Coder.of())); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Test + public void testCreateTimestampedEmptyInfersCoder() { + Pipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.timestamped()); + + assertEquals(VoidCoder.of(), output.getCoder()); + } + + @Test + public void testCreateTimestampedPolymorphicType() throws Exception { + thrown.expect(RuntimeException.class); + thrown.expectMessage( + Matchers.containsString("Unable to infer a coder")); + + Pipeline p = TestPipeline.create(); + + // Create won't infer a default coder in this case. + PCollection c = p.apply(Create.timestamped( + TimestampedValue.of(new Record(), new Instant(0)), + TimestampedValue.of(new Record2(), new Instant(0)))); + + p.run(); + + throw new RuntimeException("Coder: " + c.getCoder()); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateWithVoidType() throws Exception { + Pipeline p = TestPipeline.create(); + PCollection output = p.apply(Create.of((Void) null, (Void) null)); + DataflowAssert.that(output).containsInAnyOrder((Void) null, (Void) null); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateWithKVVoidType() throws Exception { + Pipeline p = TestPipeline.create(); + + PCollection> output = p.apply(Create.of( + KV.of((Void) null, (Void) null), + KV.of((Void) null, (Void) null))); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of((Void) null, (Void) null), + KV.of((Void) null, (Void) null)); + + p.run(); + } + + @Test + public void testCreateGetName() { + assertEquals("Create.Values", Create.of(1, 2, 3).getName()); + assertEquals("Create.TimestampedValues", Create.timestamped(Collections.EMPTY_LIST).getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnContextTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnContextTest.java new file mode 100644 index 000000000000..c4716f9a392e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnContextTest.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link DoFn.Context}. + */ +@RunWith(JUnit4.class) +public class DoFnContextTest { + + @Mock + private Aggregator agg; + + private DoFn fn; + private DoFn.Context context; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + // Need to be real objects to call the constructor, and to reference the + // outer instance of DoFn + NoOpDoFn noOpFn = new NoOpDoFn<>(); + DoFn.Context noOpContext = noOpFn.context(); + + fn = spy(noOpFn); + context = spy(noOpContext); + } + + @Test + public void testSetupDelegateAggregatorsCreatesAndLinksDelegateAggregators() { + Sum.SumLongFn combiner = new Sum.SumLongFn(); + Aggregator delegateAggregator = + fn.createAggregator("test", combiner); + + when(context.createAggregatorInternal("test", combiner)).thenReturn(agg); + + context.setupDelegateAggregators(); + delegateAggregator.addValue(1L); + + verify(agg).addValue(1L); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnDelegatingAggregatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnDelegatingAggregatorTest.java new file mode 100644 index 000000000000..0b82c51f7cfd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnDelegatingAggregatorTest.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.DelegatingAggregator; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for DoFn.DelegatingAggregator. + */ +@RunWith(JUnit4.class) +public class DoFnDelegatingAggregatorTest { + + @Mock + private Aggregator delegate; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testAddValueWithoutDelegateThrowsException() { + DoFn doFn = new NoOpDoFn<>(); + + String name = "agg"; + CombineFn combiner = mockCombineFn(Double.class); + + DelegatingAggregator aggregator = + (DelegatingAggregator) doFn.createAggregator(name, combiner); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("cannot be called"); + thrown.expectMessage("DoFn"); + + aggregator.addValue(21.2); + } + + @Test + public void testSetDelegateThenAddValueCallsDelegate() { + String name = "agg"; + CombineFn combiner = mockCombineFn(Long.class); + + DoFn doFn = new NoOpDoFn<>(); + + DelegatingAggregator aggregator = + (DelegatingAggregator) doFn.createAggregator(name, combiner); + + aggregator.setDelegate(delegate); + + aggregator.addValue(12L); + + verify(delegate).addValue(12L); + } + + @Test + public void testSetDelegateWithExistingDelegateStartsDelegatingToSecond() { + String name = "agg"; + CombineFn combiner = mockCombineFn(Double.class); + + DoFn doFn = new NoOpDoFn<>(); + + DelegatingAggregator aggregator = + (DelegatingAggregator) doFn.createAggregator(name, combiner); + + @SuppressWarnings("unchecked") + Aggregator secondDelegate = + mock(Aggregator.class, "secondDelegate"); + + aggregator.setDelegate(aggregator); + aggregator.setDelegate(secondDelegate); + + aggregator.addValue(2.25); + + verify(secondDelegate).addValue(2.25); + verify(delegate, never()).addValue(anyLong()); + } + + @Test + public void testGetNameReturnsName() { + String name = "agg"; + CombineFn combiner = mockCombineFn(Double.class); + + DoFn doFn = new NoOpDoFn<>(); + + DelegatingAggregator aggregator = + (DelegatingAggregator) doFn.createAggregator(name, combiner); + + assertEquals(name, aggregator.getName()); + } + + @Test + public void testGetCombineFnReturnsCombineFn() { + String name = "agg"; + CombineFn combiner = mockCombineFn(Double.class); + + DoFn doFn = new NoOpDoFn<>(); + + DelegatingAggregator aggregator = + (DelegatingAggregator) doFn.createAggregator(name, combiner); + + assertEquals(combiner, aggregator.getCombineFn()); + } + + @SuppressWarnings("unchecked") + private static CombineFn mockCombineFn( + @SuppressWarnings("unused") Class clazz) { + return mock(CombineFn.class); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnReflectorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnReflectorTest.java new file mode 100644 index 000000000000..2a770c25aaea --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnReflectorTest.java @@ -0,0 +1,493 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.ExtraContextFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.ProcessContext; +import com.google.cloud.dataflow.sdk.transforms.DoFnWithContext.ProcessElement; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.lang.reflect.Method; + +/** + * Tests for {@link DoFnReflector}. + */ +@RunWith(JUnit4.class) +public class DoFnReflectorTest { + + private boolean wasProcessElementInvoked = false; + private boolean wasStartBundleInvoked = false; + private boolean wasFinishBundleInvoked = false; + + private DoFnWithContext fn; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock + private DoFnWithContext.ProcessContext mockContext; + @Mock + private BoundedWindow mockWindow; + @Mock + private WindowingInternals mockWindowingInternals; + + private ExtraContextFactory extraContextFactory; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + this.extraContextFactory = new ExtraContextFactory() { + @Override + public BoundedWindow window() { + return mockWindow; + } + + @Override + public WindowingInternals windowingInternals() { + return mockWindowingInternals; + } + }; + } + + private DoFnReflector underTest(DoFnWithContext fn) { + this.fn = fn; + return DoFnReflector.of(fn.getClass()); + } + + private void checkInvokeProcessElementWorks(DoFnReflector r) throws Exception { + assertFalse(wasProcessElementInvoked); + r.invokeProcessElement(fn, mockContext, extraContextFactory); + assertTrue(wasProcessElementInvoked); + } + + private void checkInvokeStartBundleWorks(DoFnReflector r) throws Exception { + assertFalse(wasStartBundleInvoked); + r.invokeStartBundle(fn, mockContext, extraContextFactory); + assertTrue(wasStartBundleInvoked); + } + + private void checkInvokeFinishBundleWorks(DoFnReflector r) throws Exception { + assertFalse(wasFinishBundleInvoked); + r.invokeFinishBundle(fn, mockContext, extraContextFactory); + assertTrue(wasFinishBundleInvoked); + } + + @Test + public void testDoFnWithNoExtraContext() throws Exception { + DoFnReflector reflector = underTest(new DoFnWithContext() { + + @ProcessElement + public void processElement(ProcessContext c) + throws Exception { + wasProcessElementInvoked = true; + assertSame(c, mockContext); + } + }); + + assertFalse(reflector.usesSingleWindow()); + + checkInvokeProcessElementWorks(reflector); + } + + interface InterfaceWithProcessElement { + @ProcessElement + void processElement(DoFnWithContext.ProcessContext c); + } + + interface LayersOfInterfaces extends InterfaceWithProcessElement {} + + private class IdentityUsingInterfaceWithProcessElement + extends DoFnWithContext + implements LayersOfInterfaces { + + @Override + public void processElement(DoFnWithContext.ProcessContext c) { + wasProcessElementInvoked = true; + assertSame(c, mockContext); + } + } + + @Test + public void testDoFnWithProcessElementInterface() throws Exception { + DoFnReflector reflector = underTest(new IdentityUsingInterfaceWithProcessElement()); + assertFalse(reflector.usesSingleWindow()); + checkInvokeProcessElementWorks(reflector); + } + + private class IdentityParent extends DoFnWithContext { + @ProcessElement + public void process(ProcessContext c) { + wasProcessElementInvoked = true; + assertSame(c, mockContext); + } + } + + private class IdentityChild extends IdentityParent {} + + @Test + public void testDoFnWithMethodInSuperclass() throws Exception { + DoFnReflector reflector = underTest(new IdentityChild()); + assertFalse(reflector.usesSingleWindow()); + checkInvokeProcessElementWorks(reflector); + } + + @Test + public void testDoFnWithWindow() throws Exception { + DoFnReflector reflector = underTest(new DoFnWithContext() { + + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow w) + throws Exception { + wasProcessElementInvoked = true; + assertSame(c, mockContext); + assertSame(w, mockWindow); + } + }); + + assertTrue(reflector.usesSingleWindow()); + + checkInvokeProcessElementWorks(reflector); + } + + @Test + public void testDoFnWithWindowingInternals() throws Exception { + DoFnReflector reflector = underTest(new DoFnWithContext() { + + @ProcessElement + public void processElement(ProcessContext c, WindowingInternals w) + throws Exception { + wasProcessElementInvoked = true; + assertSame(c, mockContext); + assertSame(w, mockWindowingInternals); + } + }); + + assertFalse(reflector.usesSingleWindow()); + + checkInvokeProcessElementWorks(reflector); + } + + @Test + public void testDoFnWithStartBundle() throws Exception { + DoFnReflector reflector = underTest(new DoFnWithContext() { + @ProcessElement + public void processElement(@SuppressWarnings("unused") ProcessContext c) {} + + @StartBundle + public void startBundle(Context c) { + wasStartBundleInvoked = true; + assertSame(c, mockContext); + } + + @FinishBundle + public void finishBundle(Context c) { + wasFinishBundleInvoked = true; + assertSame(c, mockContext); + } + }); + + checkInvokeStartBundleWorks(reflector); + checkInvokeFinishBundleWorks(reflector); + } + + @Test + public void testNoProcessElement() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("No method annotated with @ProcessElement found"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() {}); + } + + @Test + public void testMultipleProcessElement() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Found multiple methods annotated with @ProcessElement"); + thrown.expectMessage("foo()"); + thrown.expectMessage("bar()"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() { + @ProcessElement + public void foo() {} + + @ProcessElement + public void bar() {} + }); + } + + @Test + public void testMultipleStartBundleElement() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Found multiple methods annotated with @StartBundle"); + thrown.expectMessage("bar()"); + thrown.expectMessage("baz()"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() { + @ProcessElement + public void foo() {} + + @StartBundle + public void bar() {} + + @StartBundle + public void baz() {} + }); + } + + @Test + public void testMultipleFinishBundleElement() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Found multiple methods annotated with @FinishBundle"); + thrown.expectMessage("bar()"); + thrown.expectMessage("baz()"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() { + @ProcessElement + public void foo() {} + + @FinishBundle + public void bar() {} + + @FinishBundle + public void baz() {} + }); + } + + @Test + public void testPrivateProcessElement() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("process() must be public"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() { + @ProcessElement + private void process() {} + }); + } + + @Test + public void testPrivateStartBundle() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("startBundle() must be public"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() { + @ProcessElement + public void processElement() {} + + @StartBundle + void startBundle() {} + }); + } + + @Test + public void testPrivateFinishBundle() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("finishBundle() must be public"); + thrown.expectMessage(getClass().getName() + "$"); + underTest(new DoFnWithContext() { + @ProcessElement + public void processElement() {} + + @FinishBundle + void finishBundle() {} + }); + } + + @SuppressWarnings({"unused", "rawtypes"}) + private void missingProcessContext() {} + + @Test + public void testMissingProcessContext() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage(getClass().getName() + + "#missingProcessContext() must take a ProcessContext as its first argument"); + + DoFnReflector.verifyProcessMethodArguments( + getClass().getDeclaredMethod("missingProcessContext")); + } + + @SuppressWarnings({"unused", "rawtypes"}) + private void badProcessContext(String s) {} + + @Test + public void testBadProcessContextType() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage(getClass().getName() + + "#badProcessContext(String) must take a ProcessContext as its first argument"); + + DoFnReflector.verifyProcessMethodArguments( + getClass().getDeclaredMethod("badProcessContext", String.class)); + } + + @SuppressWarnings({"unused", "rawtypes"}) + private void badExtraContext(DoFnWithContext.Context c, int n) {} + + @Test + public void testBadExtraContext() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "int is not a valid context parameter for method " + + getClass().getName() + "#badExtraContext(Context, int). Should be one of ["); + + DoFnReflector.verifyBundleMethodArguments( + getClass().getDeclaredMethod("badExtraContext", Context.class, int.class)); + } + + @SuppressWarnings({"unused", "rawtypes"}) + private void badExtraProcessContext( + DoFnWithContext.ProcessContext c, Integer n) {} + + @Test + public void testBadExtraProcessContextType() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "Integer is not a valid context parameter for method " + + getClass().getName() + "#badExtraProcessContext(ProcessContext, Integer)" + + ". Should be one of [BoundedWindow, WindowingInternals]"); + + DoFnReflector.verifyProcessMethodArguments( + getClass().getDeclaredMethod("badExtraProcessContext", + ProcessContext.class, Integer.class)); + } + + @SuppressWarnings("unused") + private int badReturnType() { + return 0; + } + + @Test + public void testBadReturnType() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage(getClass().getName() + "#badReturnType() must have a void return type"); + + DoFnReflector.verifyProcessMethodArguments(getClass().getDeclaredMethod("badReturnType")); + } + + @SuppressWarnings("unused") + private void goodGenerics(DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testValidGenerics() throws Exception { + Method method = getClass().getDeclaredMethod("goodGenerics", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + DoFnReflector.verifyProcessMethodArguments(method); + } + + @SuppressWarnings("unused") + private void goodWildcards(DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testGoodWildcards() throws Exception { + Method method = getClass().getDeclaredMethod("goodWildcards", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + DoFnReflector.verifyProcessMethodArguments(method); + } + + @SuppressWarnings("unused") + private void goodBoundedWildcards(DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testGoodBoundedWildcards() throws Exception { + Method method = getClass().getDeclaredMethod("goodBoundedWildcards", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + DoFnReflector.verifyProcessMethodArguments(method); + } + + @SuppressWarnings("unused") + private void goodTypeVariables( + DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testGoodTypeVariables() throws Exception { + Method method = getClass().getDeclaredMethod("goodTypeVariables", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + DoFnReflector.verifyProcessMethodArguments(method); + } + + @SuppressWarnings("unused") + private void badGenericTwoArgs(DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testBadGenericsTwoArgs() throws Exception { + Method method = getClass().getDeclaredMethod("badGenericTwoArgs", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Incompatible generics in context parameter " + + "WindowingInternals " + + "for method " + getClass().getName() + + "#badGenericTwoArgs(ProcessContext, WindowingInternals). Should be " + + "WindowingInternals"); + + DoFnReflector.verifyProcessMethodArguments(method); + } + + @SuppressWarnings("unused") + private void badGenericWildCards(DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testBadGenericWildCards() throws Exception { + Method method = getClass().getDeclaredMethod("badGenericWildCards", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Incompatible generics in context parameter " + + "WindowingInternals for method " + + getClass().getName() + + "#badGenericWildCards(ProcessContext, WindowingInternals). Should be " + + "WindowingInternals"); + + DoFnReflector.verifyProcessMethodArguments(method); + } + + @SuppressWarnings("unused") + private void badTypeVariables(DoFnWithContext.ProcessContext c, + WindowingInternals i1) {} + + @Test + public void testBadTypeVariables() throws Exception { + Method method = getClass().getDeclaredMethod("badTypeVariables", + DoFnWithContext.ProcessContext.class, WindowingInternals.class); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Incompatible generics in context parameter " + + "WindowingInternals for method " + getClass().getName() + + "#badTypeVariables(ProcessContext, WindowingInternals). Should be " + + "WindowingInternals"); + + DoFnReflector.verifyProcessMethodArguments(method); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java new file mode 100644 index 000000000000..f6df14a72b2d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.CoreMatchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for DoFn. + */ +@RunWith(JUnit4.class) +public class DoFnTest implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + public void testCreateAggregatorWithCombinerSucceeds() { + String name = "testAggregator"; + Sum.SumLongFn combiner = new Sum.SumLongFn(); + + DoFn doFn = new NoOpDoFn<>(); + + Aggregator aggregator = doFn.createAggregator(name, combiner); + + assertEquals(name, aggregator.getName()); + assertEquals(combiner, aggregator.getCombineFn()); + } + + @Test + public void testCreateAggregatorWithNullNameThrowsException() { + thrown.expect(NullPointerException.class); + thrown.expectMessage("name cannot be null"); + + DoFn doFn = new NoOpDoFn<>(); + + doFn.createAggregator(null, new Sum.SumLongFn()); + } + + @Test + public void testCreateAggregatorWithNullCombineFnThrowsException() { + CombineFn combiner = null; + + thrown.expect(NullPointerException.class); + thrown.expectMessage("combiner cannot be null"); + + DoFn doFn = new NoOpDoFn<>(); + + doFn.createAggregator("testAggregator", combiner); + } + + @Test + public void testCreateAggregatorWithNullSerializableFnThrowsException() { + SerializableFunction, Object> combiner = null; + + thrown.expect(NullPointerException.class); + thrown.expectMessage("combiner cannot be null"); + + DoFn doFn = new NoOpDoFn<>(); + + doFn.createAggregator("testAggregator", combiner); + } + + @Test + public void testCreateAggregatorWithSameNameThrowsException() { + String name = "testAggregator"; + CombineFn combiner = new Max.MaxDoubleFn(); + + DoFn doFn = new NoOpDoFn<>(); + + doFn.createAggregator(name, combiner); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Cannot create"); + thrown.expectMessage(name); + thrown.expectMessage("already exists"); + + doFn.createAggregator(name, combiner); + } + + @Test + public void testCreateAggregatorsWithDifferentNamesSucceeds() { + String nameOne = "testAggregator"; + String nameTwo = "aggregatorPrime"; + CombineFn combiner = new Max.MaxDoubleFn(); + + DoFn doFn = new NoOpDoFn<>(); + + Aggregator aggregatorOne = + doFn.createAggregator(nameOne, combiner); + Aggregator aggregatorTwo = + doFn.createAggregator(nameTwo, combiner); + + assertNotEquals(aggregatorOne, aggregatorTwo); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateAggregatorInStartBundleThrows() { + TestPipeline p = createTestPipeline(new DoFn() { + @Override + public void startBundle(DoFn.Context c) throws Exception { + createAggregator("anyAggregate", new MaxIntegerFn()); + } + + @Override + public void processElement(DoFn.ProcessContext c) throws Exception {} + }); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateAggregatorInProcessElementThrows() { + TestPipeline p = createTestPipeline(new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + createAggregator("anyAggregate", new MaxIntegerFn()); + } + }); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateAggregatorInFinishBundleThrows() { + TestPipeline p = createTestPipeline(new DoFn() { + @Override + public void finishBundle(DoFn.Context c) throws Exception { + createAggregator("anyAggregate", new MaxIntegerFn()); + } + + @Override + public void processElement(DoFn.ProcessContext c) throws Exception {} + }); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + + p.run(); + } + + /** + * Initialize a test pipeline with the specified @{link DoFn}. + */ + private TestPipeline createTestPipeline(DoFn fn) { + TestPipeline pipeline = TestPipeline.create(); + pipeline.apply(Create.of((InputT) null)) + .apply(ParDo.of(fn)); + + return pipeline; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTesterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTesterTest.java new file mode 100644 index 000000000000..89588104913d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTesterTest.java @@ -0,0 +1,253 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.DoFnTester.OutputElementWithTimestamp; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** + * Tests for {@link DoFnTester}. + */ +@RunWith(JUnit4.class) +public class DoFnTesterTest { + + @Test + public void processElement() { + CounterDoFn counterDoFn = new CounterDoFn(); + DoFnTester tester = DoFnTester.of(counterDoFn); + + tester.processElement(1L); + + List take = tester.takeOutputElements(); + + assertThat(take, hasItems("1")); + + // Following takeOutputElements(), neither takeOutputElements() + // nor peekOutputElements() return anything. + assertTrue(tester.takeOutputElements().isEmpty()); + assertTrue(tester.peekOutputElements().isEmpty()); + + // processElement() caused startBundle() to be called, but finishBundle() was never called. + CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn; + assertTrue(deserializedDoFn.wasStartBundleCalled()); + assertFalse(deserializedDoFn.wasFinishBundleCalled()); + } + + @Test + public void processElementsWithPeeks() { + CounterDoFn counterDoFn = new CounterDoFn(); + DoFnTester tester = DoFnTester.of(counterDoFn); + + // Explicitly call startBundle(). + tester.startBundle(); + + // verify startBundle() was called but not finishBundle(). + CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn; + assertTrue(deserializedDoFn.wasStartBundleCalled()); + assertFalse(deserializedDoFn.wasFinishBundleCalled()); + + // process a couple of elements. + tester.processElement(1L); + tester.processElement(2L); + + // peek the first 2 outputs. + List peek = tester.peekOutputElements(); + assertThat(peek, hasItems("1", "2")); + + // process a couple more. + tester.processElement(3L); + tester.processElement(4L); + + // peek all the outputs so far. + peek = tester.peekOutputElements(); + assertThat(peek, hasItems("1", "2", "3", "4")); + // take the outputs. + List take = tester.takeOutputElements(); + assertThat(take, hasItems("1", "2", "3", "4")); + + // Following takeOutputElements(), neither takeOutputElements() + // nor peekOutputElements() return anything. + assertTrue(tester.peekOutputElements().isEmpty()); + assertTrue(tester.takeOutputElements().isEmpty()); + + // verify finishBundle() hasn't been called yet. + assertTrue(deserializedDoFn.wasStartBundleCalled()); + assertFalse(deserializedDoFn.wasFinishBundleCalled()); + + // process a couple more. + tester.processElement(5L); + tester.processElement(6L); + + // peek and take now have only the 2 last outputs. + peek = tester.peekOutputElements(); + assertThat(peek, hasItems("5", "6")); + take = tester.takeOutputElements(); + assertThat(take, hasItems("5", "6")); + + tester.finishBundle(); + + // verify finishBundle() was called. + assertTrue(deserializedDoFn.wasStartBundleCalled()); + assertTrue(deserializedDoFn.wasFinishBundleCalled()); + } + + @Test + public void processBatch() { + CounterDoFn counterDoFn = new CounterDoFn(); + DoFnTester tester = DoFnTester.of(counterDoFn); + + // processBatch() returns all the output like takeOutputElements(). + List take = tester.processBatch(1L, 2L, 3L, 4L); + + assertThat(take, hasItems("1", "2", "3", "4")); + + // peek now returns nothing. + assertTrue(tester.peekOutputElements().isEmpty()); + + // verify startBundle() and finishBundle() were both called. + CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn; + assertTrue(deserializedDoFn.wasStartBundleCalled()); + assertTrue(deserializedDoFn.wasFinishBundleCalled()); + } + + @Test + public void processElementWithTimestamp() { + CounterDoFn counterDoFn = new CounterDoFn(); + DoFnTester tester = DoFnTester.of(counterDoFn); + + tester.processElement(1L); + tester.processElement(2L); + + List> peek = tester.peekOutputElementsWithTimestamp(); + OutputElementWithTimestamp one = + new OutputElementWithTimestamp<>("1", new Instant(1000L)); + OutputElementWithTimestamp two = + new OutputElementWithTimestamp<>("2", new Instant(2000L)); + assertThat(peek, hasItems(one, two)); + + tester.processElement(3L); + tester.processElement(4L); + + OutputElementWithTimestamp three = + new OutputElementWithTimestamp<>("3", new Instant(3000L)); + OutputElementWithTimestamp four = + new OutputElementWithTimestamp<>("4", new Instant(4000L)); + peek = tester.peekOutputElementsWithTimestamp(); + assertThat(peek, hasItems(one, two, three, four)); + List> take = tester.takeOutputElementsWithTimestamp(); + assertThat(take, hasItems(one, two, three, four)); + + // Following takeOutputElementsWithTimestamp(), neither takeOutputElementsWithTimestamp() + // nor peekOutputElementsWithTimestamp() return anything. + assertTrue(tester.takeOutputElementsWithTimestamp().isEmpty()); + assertTrue(tester.peekOutputElementsWithTimestamp().isEmpty()); + + // peekOutputElements() and takeOutputElements() also return nothing. + assertTrue(tester.peekOutputElements().isEmpty()); + assertTrue(tester.takeOutputElements().isEmpty()); + } + + @Test + public void getAggregatorValuesShouldGetValueOfCounter() { + CounterDoFn counterDoFn = new CounterDoFn(); + DoFnTester tester = DoFnTester.of(counterDoFn); + tester.processBatch(1L, 2L, 4L, 8L); + + Long aggregatorVal = tester.getAggregatorValue(counterDoFn.agg); + + assertThat(aggregatorVal, equalTo(15L)); + } + + @Test + public void getAggregatorValuesWithEmptyCounterShouldSucceed() { + CounterDoFn counterDoFn = new CounterDoFn(); + DoFnTester tester = DoFnTester.of(counterDoFn); + tester.processBatch(); + Long aggregatorVal = tester.getAggregatorValue(counterDoFn.agg); + // empty bundle + assertThat(aggregatorVal, equalTo(0L)); + } + + @Test + public void getAggregatorValuesInStartFinishBundleShouldGetValues() { + CounterDoFn fn = new CounterDoFn(1L, 2L); + DoFnTester tester = DoFnTester.of(fn); + tester.processBatch(0L, 0L); + + Long aggValue = tester.getAggregatorValue(fn.agg); + assertThat(aggValue, equalTo(1L + 2L)); + } + + /** + * A DoFn that adds values to an aggregator and converts input to String in processElement. + */ + private static class CounterDoFn extends DoFn { + Aggregator agg = createAggregator("ctr", new Sum.SumLongFn()); + private final long startBundleVal; + private final long finishBundleVal; + private boolean startBundleCalled; + private boolean finishBundleCalled; + + public CounterDoFn() { + this(0L, 0L); + } + + public CounterDoFn(long start, long finish) { + this.startBundleVal = start; + this.finishBundleVal = finish; + } + + @Override + public void startBundle(Context c) { + agg.addValue(startBundleVal); + startBundleCalled = true; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + agg.addValue(c.element()); + Instant instant = new Instant(1000L * c.element()); + c.outputWithTimestamp(c.element().toString(), instant); + } + + @Override + public void finishBundle(Context c) { + agg.addValue(finishBundleVal); + finishBundleCalled = true; + } + + boolean wasStartBundleCalled() { + return startBundleCalled; + } + + boolean wasFinishBundleCalled() { + return finishBundleCalled; + } + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnWithContextTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnWithContextTest.java new file mode 100644 index 000000000000..228b96142194 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnWithContextTest.java @@ -0,0 +1,225 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.CoreMatchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** Tests for {@link DoFnWithContext}. */ +@RunWith(JUnit4.class) +public class DoFnWithContextTest implements Serializable { + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + private class NoOpDoFnWithContext extends DoFnWithContext { + + /** + * @param c context + */ + @ProcessElement + public void processElement(ProcessContext c) { + } + } + + @Test + public void testCreateAggregatorWithCombinerSucceeds() { + String name = "testAggregator"; + Sum.SumLongFn combiner = new Sum.SumLongFn(); + + DoFnWithContext doFn = new NoOpDoFnWithContext(); + + Aggregator aggregator = doFn.createAggregator(name, combiner); + + assertEquals(name, aggregator.getName()); + assertEquals(combiner, aggregator.getCombineFn()); + } + + @Test + public void testCreateAggregatorWithNullNameThrowsException() { + thrown.expect(NullPointerException.class); + thrown.expectMessage("name cannot be null"); + + DoFnWithContext doFn = new NoOpDoFnWithContext(); + + doFn.createAggregator(null, new Sum.SumLongFn()); + } + + @Test + public void testCreateAggregatorWithNullCombineFnThrowsException() { + CombineFn combiner = null; + + thrown.expect(NullPointerException.class); + thrown.expectMessage("combiner cannot be null"); + + DoFnWithContext doFn = new NoOpDoFnWithContext(); + + doFn.createAggregator("testAggregator", combiner); + } + + @Test + public void testCreateAggregatorWithNullSerializableFnThrowsException() { + SerializableFunction, Object> combiner = null; + + thrown.expect(NullPointerException.class); + thrown.expectMessage("combiner cannot be null"); + + DoFnWithContext doFn = new NoOpDoFnWithContext(); + + doFn.createAggregator("testAggregator", combiner); + } + + @Test + public void testCreateAggregatorWithSameNameThrowsException() { + String name = "testAggregator"; + CombineFn combiner = new Max.MaxDoubleFn(); + + DoFnWithContext doFn = new NoOpDoFnWithContext(); + + doFn.createAggregator(name, combiner); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Cannot create"); + thrown.expectMessage(name); + thrown.expectMessage("already exists"); + + doFn.createAggregator(name, combiner); + } + + @Test + public void testCreateAggregatorsWithDifferentNamesSucceeds() { + String nameOne = "testAggregator"; + String nameTwo = "aggregatorPrime"; + CombineFn combiner = new Max.MaxDoubleFn(); + + DoFnWithContext doFn = new NoOpDoFnWithContext(); + + Aggregator aggregatorOne = + doFn.createAggregator(nameOne, combiner); + Aggregator aggregatorTwo = + doFn.createAggregator(nameTwo, combiner); + + assertNotEquals(aggregatorOne, aggregatorTwo); + } + + @Test + public void testDoFnWithContextUsingAggregators() { + NoOpDoFn noOpFn = new NoOpDoFn<>(); + DoFn.Context context = noOpFn.context(); + + DoFn fn = spy(noOpFn); + context = spy(context); + + @SuppressWarnings("unchecked") + Aggregator agg = mock(Aggregator.class); + + Sum.SumLongFn combiner = new Sum.SumLongFn(); + Aggregator delegateAggregator = + fn.createAggregator("test", combiner); + + when(context.createAggregatorInternal("test", combiner)).thenReturn(agg); + + context.setupDelegateAggregators(); + delegateAggregator.addValue(1L); + + verify(agg).addValue(1L); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateAggregatorInStartBundleThrows() { + TestPipeline p = createTestPipeline(new DoFnWithContext() { + @StartBundle + public void startBundle(Context c) { + createAggregator("anyAggregate", new MaxIntegerFn()); + } + + @ProcessElement + public void processElement(ProcessContext c) {} + }); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateAggregatorInProcessElementThrows() { + TestPipeline p = createTestPipeline(new DoFnWithContext() { + @ProcessElement + public void processElement(ProcessContext c) { + createAggregator("anyAggregate", new MaxIntegerFn()); + } + }); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCreateAggregatorInFinishBundleThrows() { + TestPipeline p = createTestPipeline(new DoFnWithContext() { + @FinishBundle + public void finishBundle(Context c) { + createAggregator("anyAggregate", new MaxIntegerFn()); + } + + @ProcessElement + public void processElement(ProcessContext c) {} + }); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + + p.run(); + } + + /** + * Initialize a test pipeline with the specified @{link DoFn}. + */ + private TestPipeline createTestPipeline(DoFnWithContext fn) { + TestPipeline pipeline = TestPipeline.create(); + pipeline.apply(Create.of((InputT) null)) + .apply(ParDo.of(fn)); + + return pipeline; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FilterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FilterTest.java new file mode 100644 index 000000000000..41a4ff2fbea4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FilterTest.java @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for {@link Filter}. + */ +@RunWith(JUnit4.class) +public class FilterTest implements Serializable { + + static class TrivialFn implements SerializableFunction { + private final Boolean returnVal; + + TrivialFn(Boolean returnVal) { + this.returnVal = returnVal; + } + + @Override + public Boolean apply(Integer elem) { + return this.returnVal; + } + } + + static class EvenFn implements SerializableFunction { + @Override + public Boolean apply(Integer elem) { + return elem % 2 == 0; + } + } + + @Deprecated + @Test + @Category(RunnableOnService.class) + public void testIdentityFilterBy() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(591, 11789, 1257, 24578, 24799, 307)) + .apply(Filter.by(new TrivialFn(true))); + + DataflowAssert.that(output).containsInAnyOrder(591, 11789, 1257, 24578, 24799, 307); + p.run(); + } + + @Deprecated + @Test + public void testNoFilter() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(1, 2, 4, 5)) + .apply(Filter.by(new TrivialFn(false))); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Deprecated + @Test + @Category(RunnableOnService.class) + public void testFilterBy() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.by(new EvenFn())); + + DataflowAssert.that(output).containsInAnyOrder(2, 4, 6); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testIdentityFilterByPredicate() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(591, 11789, 1257, 24578, 24799, 307)) + .apply(Filter.byPredicate(new TrivialFn(true))); + + DataflowAssert.that(output).containsInAnyOrder(591, 11789, 1257, 24578, 24799, 307); + p.run(); + } + + @Test + public void testNoFilterByPredicate() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(1, 2, 4, 5)) + .apply(Filter.byPredicate(new TrivialFn(false))); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFilterByPredicate() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.byPredicate(new EvenFn())); + + DataflowAssert.that(output).containsInAnyOrder(2, 4, 6); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFilterLessThan() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.lessThan(4)); + + DataflowAssert.that(output).containsInAnyOrder(1, 2, 3); + p.run(); + } + + @Test + public void testFilterGreaterThan() { + TestPipeline p = TestPipeline.create(); + + PCollection output = p + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.greaterThan(4)); + + DataflowAssert.that(output).containsInAnyOrder(5, 6, 7); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlatMapElementsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlatMapElementsTest.java new file mode 100644 index 000000000000..8938b0398fad --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlatMapElementsTest.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +/** + * Tests for {@link FlatMapElements}. + */ +@RunWith(JUnit4.class) +public class FlatMapElementsTest implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + /** + * Basic test of {@link FlatMapElements} with a {@link SimpleFunction}. + */ + @Test + public void testFlatMapBasic() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + + // Note that FlatMapElements takes a SimpleFunction> + // so the use of List here (as opposed to Iterable) deliberately exercises + // the use of an upper bound. + .apply(FlatMapElements.via(new SimpleFunction>() { + @Override + public List apply(Integer input) { + return ImmutableList.of(-input, input); + } + })); + + DataflowAssert.that(output).containsInAnyOrder(1, -2, -1, -3, 2, 3); + pipeline.run(); + } + + /** + * Tests that when built with a concrete subclass of {@link SimpleFunction}, the type descriptor + * of the output reflects its static type. + */ + @Test + public void testFlatMapFnOutputTypeDescriptor() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of("hello")) + .apply(FlatMapElements.via(new SimpleFunction>() { + @Override + public Set apply(String input) { + return ImmutableSet.copyOf(input.split("")); + } + })); + + assertThat(output.getTypeDescriptor(), + equalTo((TypeDescriptor) new TypeDescriptor() {})); + assertThat(pipeline.getCoderRegistry().getDefaultCoder(output.getTypeDescriptor()), + equalTo(pipeline.getCoderRegistry().getDefaultCoder(new TypeDescriptor() {}))); + + // Make sure the pipeline runs + pipeline.run(); + } + + @Test + public void testVoidValues() throws Exception { + Pipeline pipeline = TestPipeline.create(); + pipeline + .apply(Create.of("hello")) + .apply(WithKeys.of("k")) + .apply(new VoidValues() {}); + // Make sure the pipeline runs + pipeline.run(); + } + + static class VoidValues + extends PTransform>, PCollection>> { + + @Override + public PCollection> apply(PCollection> input) { + return input.apply(FlatMapElements., KV>via( + new SimpleFunction, Iterable>>() { + @Override + public Iterable> apply(KV input) { + return Collections.singletonList(KV.of(input.getKey(), null)); + } + })); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlattenTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlattenTest.java new file mode 100644 index 000000000000..0c9d3315db68 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/FlattenTest.java @@ -0,0 +1,369 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES2; +import static com.google.cloud.dataflow.sdk.TestUtils.LINES_ARRAY; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES_ARRAY; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CollectionCoder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.SetCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.collect.ImmutableSet; + +import org.joda.time.Duration; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Set; + +/** + * Tests for Flatten. + */ +@RunWith(JUnit4.class) +public class FlattenTest implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + private static class ClassWithoutCoder { } + + + @Test + @Category(RunnableOnService.class) + public void testFlattenPCollectionList() { + Pipeline p = TestPipeline.create(); + + List> inputs = Arrays.asList( + LINES, NO_LINES, LINES2, NO_LINES, LINES, NO_LINES); + + PCollection output = + makePCollectionListOfStrings(p, inputs) + .apply(Flatten.pCollections()); + + DataflowAssert.that(output).containsInAnyOrder(flattenLists(inputs)); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenPCollectionListThenParDo() { + Pipeline p = TestPipeline.create(); + + List> inputs = Arrays.asList( + LINES, NO_LINES, LINES2, NO_LINES, LINES, NO_LINES); + + PCollection output = + makePCollectionListOfStrings(p, inputs) + .apply(Flatten.pCollections()) + .apply(ParDo.of(new IdentityFn(){})); + + DataflowAssert.that(output).containsInAnyOrder(flattenLists(inputs)); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenPCollectionListEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection output = + PCollectionList.empty(p) + .apply(Flatten.pCollections()).setCoder(StringUtf8Coder.of()); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyFlattenAsSideInput() { + Pipeline p = TestPipeline.create(); + + final PCollectionView> view = + PCollectionList.empty(p) + .apply(Flatten.pCollections()).setCoder(StringUtf8Coder.of()) + .apply(View.asIterable()); + + PCollection output = p + .apply(Create.of((Void) null).withCoder(VoidCoder.of())) + .apply(ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (String side : c.sideInput(view)) { + c.output(side); + } + } + })); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenPCollectionListEmptyThenParDo() { + + Pipeline p = TestPipeline.create(); + + PCollection output = + PCollectionList.empty(p) + .apply(Flatten.pCollections()).setCoder(StringUtf8Coder.of()) + .apply(ParDo.of(new IdentityFn(){})); + + DataflowAssert.that(output).empty(); + p.run(); + } + + @Test + public void testFlattenNoListsNoCoder() { + // not RunnableOnService because it should fail at pipeline construction time anyhow. + thrown.expect(IllegalStateException.class); + thrown.expectMessage("cannot provide a Coder for empty"); + + Pipeline p = TestPipeline.create(); + + PCollectionList.empty(p) + .apply(Flatten.pCollections()); + + p.run(); + } + + ///////////////////////////////////////////////////////////////////////////// + + @Test + @Category(RunnableOnService.class) + public void testFlattenIterables() { + Pipeline p = TestPipeline.create(); + + PCollection> input = p + .apply(Create.>of(LINES) + .withCoder(IterableCoder.of(StringUtf8Coder.of()))); + + PCollection output = + input.apply(Flatten.iterables()); + + DataflowAssert.that(output) + .containsInAnyOrder(LINES_ARRAY); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenIterablesLists() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.>of(LINES).withCoder(ListCoder.of(StringUtf8Coder.of()))); + + PCollection output = input.apply(Flatten.iterables()); + + DataflowAssert.that(output).containsInAnyOrder(LINES_ARRAY); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenIterablesSets() { + Pipeline p = TestPipeline.create(); + + Set linesSet = ImmutableSet.copyOf(LINES); + + PCollection> input = + p.apply(Create.>of(linesSet).withCoder(SetCoder.of(StringUtf8Coder.of()))); + + PCollection output = input.apply(Flatten.iterables()); + + DataflowAssert.that(output).containsInAnyOrder(LINES_ARRAY); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenIterablesCollections() { + + Pipeline p = TestPipeline.create(); + + Set linesSet = ImmutableSet.copyOf(LINES); + + PCollection> input = + p.apply(Create.>of(linesSet) + .withCoder(CollectionCoder.of(StringUtf8Coder.of()))); + + PCollection output = input.apply(Flatten.iterables()); + + DataflowAssert.that(output).containsInAnyOrder(LINES_ARRAY); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFlattenIterablesEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = p + .apply(Create.>of(NO_LINES) + .withCoder(IterableCoder.of(StringUtf8Coder.of()))); + + PCollection output = + input.apply(Flatten.iterables()); + + DataflowAssert.that(output) + .containsInAnyOrder(NO_LINES_ARRAY); + + p.run(); + } + + ///////////////////////////////////////////////////////////////////////////// + + @Test + public void testEqualWindowFnPropagation() { + Pipeline p = TestPipeline.create(); + + PCollection input1 = + p.apply("CreateInput1", Create.of("Input1")) + .apply("Window1", Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + PCollection input2 = + p.apply("CreateInput2", Create.of("Input2")) + .apply("Window2", Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + + PCollection output = + PCollectionList.of(input1).and(input2) + .apply(Flatten.pCollections()); + + p.run(); + + Assert.assertTrue(output.getWindowingStrategy().getWindowFn().isCompatible( + FixedWindows.of(Duration.standardMinutes(1)))); + } + + @Test + public void testCompatibleWindowFnPropagation() { + Pipeline p = TestPipeline.create(); + + PCollection input1 = + p.apply("CreateInput1", Create.of("Input1")) + .apply("Window1", + Window.into(Sessions.withGapDuration(Duration.standardMinutes(1)))); + PCollection input2 = + p.apply("CreateInput2", Create.of("Input2")) + .apply("Window2", + Window.into(Sessions.withGapDuration(Duration.standardMinutes(2)))); + + PCollection output = + PCollectionList.of(input1).and(input2) + .apply(Flatten.pCollections()); + + p.run(); + + Assert.assertTrue(output.getWindowingStrategy().getWindowFn().isCompatible( + Sessions.withGapDuration(Duration.standardMinutes(2)))); + } + + @Test + public void testIncompatibleWindowFnPropagationFailure() { + Pipeline p = TestPipeline.create(); + + PCollection input1 = + p.apply("CreateInput1", Create.of("Input1")) + .apply("Window1", Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + PCollection input2 = + p.apply("CreateInput2", Create.of("Input2")) + .apply("Window2", Window.into(FixedWindows.of(Duration.standardMinutes(2)))); + + try { + PCollectionList.of(input1).and(input2) + .apply(Flatten.pCollections()); + Assert.fail("Exception should have been thrown"); + } catch (IllegalStateException e) { + Assert.assertTrue(e.getMessage().startsWith( + "Inputs to Flatten had incompatible window windowFns")); + } + } + + @Test + public void testFlattenGetName() { + Assert.assertEquals("Flatten.FlattenIterables", Flatten.iterables().getName()); + Assert.assertEquals("Flatten.FlattenPCollectionList", Flatten.pCollections().getName()); + } + + ///////////////////////////////////////////////////////////////////////////// + + private static class IdentityFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + c.output(c.element()); + } + } + + private PCollectionList makePCollectionListOfStrings( + Pipeline p, + List> lists) { + return makePCollectionList(p, StringUtf8Coder.of(), lists); + } + + private PCollectionList makePCollectionList( + Pipeline p, + Coder coder, + List> lists) { + List> pcs = new ArrayList<>(); + int index = 0; + for (List list : lists) { + PCollection pc = p.apply("Create" + (index++), Create.of(list).withCoder(coder)); + pcs.add(pc); + } + return PCollectionList.of(pcs); + } + + private List flattenLists(List> lists) { + List flattened = new ArrayList<>(); + for (List list : lists) { + flattened.addAll(list); + } + return flattened; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/GroupByKeyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/GroupByKeyTest.java new file mode 100644 index 000000000000..75eb92fcb436 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/GroupByKeyTest.java @@ -0,0 +1,438 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.KvMatcher.isKv; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.MapCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.InvalidWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.NoopPathValidator; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * Tests for GroupByKey. + */ +@RunWith(JUnit4.class) +@SuppressWarnings({"rawtypes", "unchecked"}) +public class GroupByKeyTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(RunnableOnService.class) + public void testGroupByKey() { + List> ungroupedPairs = Arrays.asList( + KV.of("k1", 3), + KV.of("k5", Integer.MAX_VALUE), + KV.of("k5", Integer.MIN_VALUE), + KV.of("k2", 66), + KV.of("k1", 4), + KV.of("k2", -33), + KV.of("k3", 0)); + + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection>> output = + input.apply(GroupByKey.create()); + + DataflowAssert.that(output) + .satisfies(new AssertThatHasExpectedContentsForTestGroupByKey()); + + p.run(); + } + + static class AssertThatHasExpectedContentsForTestGroupByKey + implements SerializableFunction>>, + Void> { + @Override + public Void apply(Iterable>> actual) { + assertThat(actual, containsInAnyOrder( + isKv(is("k1"), containsInAnyOrder(3, 4)), + isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, + Integer.MIN_VALUE)), + isKv(is("k2"), containsInAnyOrder(66, -33)), + isKv(is("k3"), containsInAnyOrder(0)))); + return null; + } + } + + @Test + @Category(RunnableOnService.class) + public void testGroupByKeyAndWindows() { + List> ungroupedPairs = Arrays.asList( + KV.of("k1", 3), // window [0, 5) + KV.of("k5", Integer.MAX_VALUE), // window [0, 5) + KV.of("k5", Integer.MIN_VALUE), // window [0, 5) + KV.of("k2", 66), // window [0, 5) + KV.of("k1", 4), // window [5, 10) + KV.of("k2", -33), // window [5, 10) + KV.of("k3", 0)); // window [5, 10) + + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.timestamped(ungroupedPairs, Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L)) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + PCollection>> output = + input.apply(Window.>into(FixedWindows.of(new Duration(5)))) + .apply(GroupByKey.create()); + + DataflowAssert.that(output) + .satisfies(new AssertThatHasExpectedContentsForTestGroupByKeyAndWindows()); + + p.run(); + } + + static class AssertThatHasExpectedContentsForTestGroupByKeyAndWindows + implements SerializableFunction>>, + Void> { + @Override + public Void apply(Iterable>> actual) { + assertThat(actual, containsInAnyOrder( + isKv(is("k1"), containsInAnyOrder(3)), + isKv(is("k1"), containsInAnyOrder(4)), + isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, + Integer.MIN_VALUE)), + isKv(is("k2"), containsInAnyOrder(66)), + isKv(is("k2"), containsInAnyOrder(-33)), + isKv(is("k3"), containsInAnyOrder(0)))); + return null; + } + } + + @Test + @Category(RunnableOnService.class) + public void testGroupByKeyEmpty() { + List> ungroupedPairs = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection>> output = + input.apply(GroupByKey.create()); + + DataflowAssert.that(output).empty(); + + p.run(); + } + + @Test + public void testGroupByKeyNonDeterministic() throws Exception { + + List, Integer>> ungroupedPairs = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection, Integer>> input = + p.apply(Create.of(ungroupedPairs) + .withCoder( + KvCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), + BigEndianIntegerCoder.of()))); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("must be deterministic"); + input.apply(GroupByKey., Integer>create()); + } + + @Test + public void testIdentityWindowFnPropagation() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(1)))); + + PCollection>> output = + input.apply(GroupByKey.create()); + + p.run(); + + Assert.assertTrue(output.getWindowingStrategy().getWindowFn().isCompatible( + FixedWindows.of(Duration.standardMinutes(1)))); + } + + @Test + public void testWindowFnInvalidation() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + PCollection>> output = + input.apply(GroupByKey.create()); + + p.run(); + + Assert.assertTrue( + output.getWindowingStrategy().getWindowFn().isCompatible( + new InvalidWindows( + "Invalid", + Sessions.withGapDuration( + Duration.standardMinutes(1))))); + } + + /** + * Create a test pipeline that uses the {@link DataflowPipelineRunner} so that {@link GroupByKey} + * is not expanded. This is used for verifying that even without expansion the proper errors show + * up. + */ + private Pipeline createTestServiceRunner() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("someproject"); + options.setStagingLocation("gs://staging"); + options.setPathValidatorClass(NoopPathValidator.class); + options.setDataflowClient(null); + return Pipeline.create(options); + } + + private Pipeline createTestDirectRunner() { + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setRunner(DirectPipelineRunner.class); + return Pipeline.create(options); + } + + @Test + public void testInvalidWindowsDirect() { + Pipeline p = createTestDirectRunner(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("GroupByKey must have a valid Window merge function"); + input + .apply("GroupByKey", GroupByKey.create()) + .apply("GroupByKeyAgain", GroupByKey.>create()); + } + + @Test + public void testInvalidWindowsService() { + Pipeline p = createTestServiceRunner(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("GroupByKey must have a valid Window merge function"); + input + .apply("GroupByKey", GroupByKey.create()) + .apply("GroupByKeyAgain", GroupByKey.>create()); + } + + @Test + public void testRemerge() { + Pipeline p = TestPipeline.create(); + + List> ungroupedPairs = Arrays.asList(); + + PCollection> input = + p.apply(Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + + PCollection>>> middle = input + .apply("GroupByKey", GroupByKey.create()) + .apply("Remerge", Window.>>remerge()) + .apply("GroupByKeyAgain", GroupByKey.>create()) + .apply("RemergeAgain", Window.>>>remerge()); + + p.run(); + + Assert.assertTrue( + middle.getWindowingStrategy().getWindowFn().isCompatible( + Sessions.withGapDuration(Duration.standardMinutes(1)))); + } + + @Test + public void testGroupByKeyDirectUnbounded() { + Pipeline p = createTestDirectRunner(); + + PCollection> input = + p.apply( + new PTransform>>() { + @Override + public PCollection> apply(PBegin input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + PCollection.IsBounded.UNBOUNDED) + .setTypeDescriptorInternal(new TypeDescriptor>() {}); + } + }); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without " + + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); + + input.apply("GroupByKey", GroupByKey.create()); + } + + @Test + public void testGroupByKeyServiceUnbounded() { + Pipeline p = createTestServiceRunner(); + + PCollection> input = + p.apply( + new PTransform>>() { + @Override + public PCollection> apply(PBegin input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + PCollection.IsBounded.UNBOUNDED) + .setTypeDescriptorInternal(new TypeDescriptor>() {}); + } + }); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without " + + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); + + input.apply("GroupByKey", GroupByKey.create()); + } + + /** + * Tests that when two elements are combined via a GroupByKey their output timestamp agrees + * with the windowing function customized to actually be the same as the default, the earlier of + * the two values. + */ + @Test + @Category(RunnableOnService.class) + public void testOutputTimeFnEarliest() { + Pipeline pipeline = TestPipeline.create(); + + pipeline.apply( + Create.timestamped( + TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), + TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(10))) + .withOutputTimeFn(OutputTimeFns.outputAtEarliestInputTimestamp())) + .apply(GroupByKey.create()) + .apply(ParDo.of(new AssertTimestamp(new Instant(0)))); + + pipeline.run(); + } + + + /** + * Tests that when two elements are combined via a GroupByKey their output timestamp agrees + * with the windowing function customized to use the latest value. + */ + @Test + @Category(RunnableOnService.class) + public void testOutputTimeFnLatest() { + Pipeline pipeline = TestPipeline.create(); + + pipeline.apply( + Create.timestamped( + TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), + TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(10))) + .withOutputTimeFn(OutputTimeFns.outputAtLatestInputTimestamp())) + .apply(GroupByKey.create()) + .apply(ParDo.of(new AssertTimestamp(new Instant(10)))); + + pipeline.run(); + } + + private static class AssertTimestamp extends DoFn, Void> { + private final Instant timestamp; + + public AssertTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + assertThat(c.timestamp(), equalTo(timestamp)); + } + } + + @Test + public void testGroupByKeyGetName() { + Assert.assertEquals("GroupByKey", GroupByKey.create().getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/IntraBundleParallelizationTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/IntraBundleParallelizationTest.java new file mode 100644 index 000000000000..76dd67e097df --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/IntraBundleParallelizationTest.java @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.testing.SystemNanoTimeSleeper.sleepMillis; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Tests for RateLimiter. + */ +@RunWith(JUnit4.class) +public class IntraBundleParallelizationTest { + private static final int PARALLELISM_FACTOR = 16; + private static final AtomicInteger numSuccesses = new AtomicInteger(); + private static final AtomicInteger numProcessed = new AtomicInteger(); + private static final AtomicInteger numFailures = new AtomicInteger(); + private static int concurrentElements = 0; + private static int maxDownstreamConcurrency = 0; + + private static final AtomicInteger maxFnConcurrency = new AtomicInteger(); + private static final AtomicInteger currentFnConcurrency = new AtomicInteger(); + + @Before + public void setUp() { + numSuccesses.set(0); + numProcessed.set(0); + numFailures.set(0); + concurrentElements = 0; + maxDownstreamConcurrency = 0; + + maxFnConcurrency.set(0); + currentFnConcurrency.set(0); + } + + /** + * Introduces a delay in processing, then passes thru elements. + */ + private static class DelayFn extends DoFn { + public static final long DELAY_MS = 25; + + @Override + public void processElement(ProcessContext c) { + startConcurrentCall(); + try { + sleepMillis(DELAY_MS); + } catch (InterruptedException e) { + e.printStackTrace(); + throw new RuntimeException("Interrupted"); + } + c.output(c.element()); + finishConcurrentCall(); + } + } + + /** + * Throws an exception after some number of calls. + */ + private static class ExceptionThrowingFn extends DoFn { + private ExceptionThrowingFn(int numSuccesses) { + IntraBundleParallelizationTest.numSuccesses.set(numSuccesses); + } + + @Override + public void processElement(ProcessContext c) { + startConcurrentCall(); + try { + numProcessed.incrementAndGet(); + if (numSuccesses.decrementAndGet() >= 0) { + c.output(c.element()); + return; + } + + numFailures.incrementAndGet(); + throw new RuntimeException("Expected failure"); + } finally { + finishConcurrentCall(); + } + } + } + + /** + * Measures concurrency of the processElement method. + */ + private static class ConcurrencyMeasuringFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + // Synchronize on the class to provide synchronous access irrespective of + // how this DoFn is called. + synchronized (ConcurrencyMeasuringFn.class) { + concurrentElements++; + if (concurrentElements > maxDownstreamConcurrency) { + maxDownstreamConcurrency = concurrentElements; + } + } + + c.output(c.element()); + + synchronized (ConcurrencyMeasuringFn.class) { + concurrentElements--; + } + } + } + + private static void startConcurrentCall() { + int currentlyExecuting = currentFnConcurrency.incrementAndGet(); + int maxConcurrency; + do { + maxConcurrency = maxFnConcurrency.get(); + } while (maxConcurrency < currentlyExecuting + && !maxFnConcurrency.compareAndSet(maxConcurrency, currentlyExecuting)); + } + + private static void finishConcurrentCall() { + currentFnConcurrency.decrementAndGet(); + } + + /** + * Test that the DoFn is parallelized up the the Max Parallelism factor within a bundle, but not + * greater than that amount. + */ + @Test + public void testParallelization() { + int maxConcurrency = Integer.MIN_VALUE; + // Take the minimum from multiple runs. + for (int i = 0; i < 5; ++i) { + maxConcurrency = Math.max(maxConcurrency, + run(2 * PARALLELISM_FACTOR, PARALLELISM_FACTOR, new DelayFn())); + } + + // We should run at least some elements in parallel on some run + assertThat(maxConcurrency, + greaterThanOrEqualTo(2)); + // No run should execute more elements concurrency than the maximum concurrency allowed. + assertThat(maxConcurrency, + lessThanOrEqualTo(PARALLELISM_FACTOR)); + } + + @Test(timeout = 5000L) + public void testExceptionHandling() { + ExceptionThrowingFn fn = new ExceptionThrowingFn<>(10); + try { + run(100, PARALLELISM_FACTOR, fn); + fail("Expected exception to propagate"); + } catch (RuntimeException e) { + assertThat(e.getMessage(), containsString("Expected failure")); + } + + // Should have processed 10 elements, but stopped before processing all + // of them. + assertThat(numProcessed.get(), + is(both(greaterThanOrEqualTo(10)) + .and(lessThan(100)))); + + // The first failure should prevent the scheduling of any more elements. + assertThat(numFailures.get(), + is(both(greaterThanOrEqualTo(1)) + .and(lessThanOrEqualTo(PARALLELISM_FACTOR)))); + } + + @Test(timeout = 5000L) + public void testExceptionHandlingOnLastElement() { + ExceptionThrowingFn fn = new ExceptionThrowingFn<>(9); + try { + run(10, PARALLELISM_FACTOR, fn); + fail("Expected exception to propagate"); + } catch (RuntimeException e) { + assertThat(e.getMessage(), containsString("Expected failure")); + } + + // Should have processed 10 elements, but stopped before processing all + // of them. + assertEquals(10, numProcessed.get()); + assertEquals(1, numFailures.get()); + } + + @Test + public void testIntraBundleParallelizationGetName() { + assertEquals( + "IntraBundleParallelization", + IntraBundleParallelization.of(new DelayFn()).withMaxParallelism(1).getName()); + } + + /** + * Runs the provided doFn inside of an {@link IntraBundleParallelization} transform. + * + *

    This method assumes that the DoFn passed to it will call {@link #startConcurrentCall()} + * before processing each elements and {@link #finishConcurrentCall()} after each element. + * + * @param numElements the size of the input + * @param maxParallelism how many threads to execute in parallel + * @param doFn the DoFn to execute + * @return the maximum observed parallelism of the DoFn + */ + private int run(int numElements, int maxParallelism, DoFn doFn) { + Pipeline pipeline = TestPipeline.create(); + + ArrayList data = new ArrayList<>(numElements); + for (int i = 0; i < numElements; ++i) { + data.add(i); + } + + ConcurrencyMeasuringFn downstream = new ConcurrencyMeasuringFn<>(); + pipeline + .apply(Create.of(data)) + .apply(IntraBundleParallelization.of(doFn).withMaxParallelism(maxParallelism)) + .apply(ParDo.of(downstream)); + + pipeline.run(); + + // All elements should have completed. + assertEquals(0, currentFnConcurrency.get()); + // Downstream methods should not see parallel threads. + assertEquals(1, maxDownstreamConcurrency); + + return maxFnConcurrency.get(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KeysTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KeysTest.java new file mode 100644 index 000000000000..e9edbb71c8b2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KeysTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for Keys transform. + */ +@RunWith(JUnit4.class) +public class KeysTest { + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] TABLE = new KV[] { + KV.of("one", 1), + KV.of("two", 2), + KV.of("three", 3), + KV.of("dup", 4), + KV.of("dup", 5) + }; + + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] EMPTY_TABLE = new KV[] { + }; + + @Test + @Category(RunnableOnService.class) + public void testKeys() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection output = input.apply(Keys.create()); + DataflowAssert.that(output) + .containsInAnyOrder("one", "two", "three", "dup", "dup"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testKeysEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(EMPTY_TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection output = input.apply(Keys.create()); + DataflowAssert.that(output).empty(); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KvSwapTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KvSwapTest.java new file mode 100644 index 000000000000..06abbadd1b62 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/KvSwapTest.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for KvSwap transform. + */ +@RunWith(JUnit4.class) +@SuppressWarnings({"rawtypes", "unchecked"}) +public class KvSwapTest { + static final KV[] TABLE = new KV[] { + KV.of("one", 1), + KV.of("two", 2), + KV.of("three", 3), + KV.of("four", 4), + KV.of("dup", 4), + KV.of("dup", 5) + }; + + static final KV[] EMPTY_TABLE = new KV[] { + }; + + @Test + @Category(RunnableOnService.class) + public void testKvSwap() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection> output = input.apply( + KvSwap.create()); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of(1, "one"), + KV.of(2, "two"), + KV.of(3, "three"), + KV.of(4, "four"), + KV.of(4, "dup"), + KV.of(5, "dup")); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testKvSwapEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(EMPTY_TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection> output = input.apply( + KvSwap.create()); + + DataflowAssert.that(output).empty(); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MapElementsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MapElementsTest.java new file mode 100644 index 000000000000..be3e720646dc --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MapElementsTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for {@link MapElements}. + */ +@RunWith(JUnit4.class) +public class MapElementsTest implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + /** + * Basic test of {@link MapElements} with a {@link SimpleFunction}. + */ + @Test + public void testMapBasic() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(MapElements.via(new SimpleFunction() { + @Override + public Integer apply(Integer input) { + return -input; + } + })); + + DataflowAssert.that(output).containsInAnyOrder(-2, -1, -3); + pipeline.run(); + } + + /** + * Basic test of {@link MapElements} with a {@link SerializableFunction}. This style is + * generally discouraged in Java 7, in favor of {@link SimpleFunction}. + */ + @Test + public void testMapBasicSerializableFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(MapElements.via(new SerializableFunction() { + @Override + public Integer apply(Integer input) { + return -input; + } + }).withOutputType(new TypeDescriptor() {})); + + DataflowAssert.that(output).containsInAnyOrder(-2, -1, -3); + pipeline.run(); + } + + /** + * Tests that when built with a concrete subclass of {@link SimpleFunction}, the type descriptor + * of the output reflects its static type. + */ + @Test + public void testSimpleFunctionOutputTypeDescriptor() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of("hello")) + .apply(MapElements.via(new SimpleFunction() { + @Override + public String apply(String input) { + return input; + } + })); + assertThat(output.getTypeDescriptor(), + equalTo((TypeDescriptor) new TypeDescriptor() {})); + assertThat(pipeline.getCoderRegistry().getDefaultCoder(output.getTypeDescriptor()), + equalTo(pipeline.getCoderRegistry().getDefaultCoder(new TypeDescriptor() {}))); + + // Make sure the pipelien runs too + pipeline.run(); + } + + @Test + public void testVoidValues() throws Exception { + Pipeline pipeline = TestPipeline.create(); + pipeline + .apply(Create.of("hello")) + .apply(WithKeys.of("k")) + .apply(new VoidValues() {}); + // Make sure the pipeline runs + pipeline.run(); + } + + static class VoidValues + extends PTransform>, PCollection>> { + + @Override + public PCollection> apply(PCollection> input) { + return input.apply(MapElements., KV>via( + new SimpleFunction, KV>() { + @Override + public KV apply(KV input) { + return KV.of(input.getKey(), null); + } + })); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MaxTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MaxTest.java new file mode 100644 index 000000000000..e1ea33bcf28f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MaxTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for Max. + */ +@RunWith(JUnit4.class) +public class MaxTest { + @Test + public void testMeanGetNames() { + assertEquals("Max.Globally", Max.integersGlobally().getName()); + assertEquals("Max.Globally", Max.doublesGlobally().getName()); + assertEquals("Max.Globally", Max.longsGlobally().getName()); + assertEquals("Max.PerKey", Max.integersPerKey().getName()); + assertEquals("Max.PerKey", Max.doublesPerKey().getName()); + assertEquals("Max.PerKey", Max.longsPerKey().getName()); + } + + @Test + public void testMaxIntegerFn() { + checkCombineFn( + new Max.MaxIntegerFn(), + Lists.newArrayList(1, 2, 3, 4), + 4); + } + + @Test + public void testMaxLongFn() { + checkCombineFn( + new Max.MaxLongFn(), + Lists.newArrayList(1L, 2L, 3L, 4L), + 4L); + } + + @Test + public void testMaxDoubleFn() { + checkCombineFn( + new Max.MaxDoubleFn(), + Lists.newArrayList(1.0, 2.0, 3.0, 4.0), + 4.0); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MeanTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MeanTest.java new file mode 100644 index 000000000000..64a2f9de94bd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MeanTest.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.transforms.Mean.CountSum; +import com.google.cloud.dataflow.sdk.transforms.Mean.CountSumCoder; +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for Mean. + */ +@RunWith(JUnit4.class) +public class MeanTest { + @Test + public void testMeanGetNames() { + assertEquals("Mean.Globally", Mean.globally().getName()); + assertEquals("Mean.PerKey", Mean.perKey().getName()); + } + + private static final Coder> TEST_CODER = new CountSumCoder<>(); + + private static final List> TEST_VALUES = Arrays.asList( + new CountSum<>(1, 5.7), + new CountSum<>(42, 42.0), + new CountSum<>(29, 2.2)); + + @Test + public void testCountSumCoderEncodeDecode() throws Exception { + for (CountSum value : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, value); + } + } + + @Test + public void testCountSumCoderSerializable() throws Exception { + CoderProperties.coderSerializable(TEST_CODER); + } + + @Test + public void testMeanFn() throws Exception { + checkCombineFn( + new Mean.MeanFn(), + Lists.newArrayList(1, 2, 3, 4), + 2.5); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MinTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MinTest.java new file mode 100644 index 000000000000..a291537adf17 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/MinTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for Min. + */ +@RunWith(JUnit4.class) +public class MinTest { + @Test + public void testMeanGetNames() { + assertEquals("Min.Globally", Min.integersGlobally().getName()); + assertEquals("Min.Globally", Min.doublesGlobally().getName()); + assertEquals("Min.Globally", Min.longsGlobally().getName()); + assertEquals("Min.PerKey", Min.integersPerKey().getName()); + assertEquals("Min.PerKey", Min.doublesPerKey().getName()); + assertEquals("Min.PerKey", Min.longsPerKey().getName()); + } + + @Test + public void testMinIntegerFn() { + checkCombineFn( + new Min.MinIntegerFn(), + Lists.newArrayList(1, 2, 3, 4), + 1); + } + + @Test + public void testMinLongFn() { + checkCombineFn( + new Min.MinLongFn(), + Lists.newArrayList(1L, 2L, 3L, 4L), + 1L); + } + + @Test + public void testMinDoubleFn() { + checkCombineFn( + new Min.MinDoubleFn(), + Lists.newArrayList(1.0, 2.0, 3.0, 4.0), + 1.0); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/NoOpDoFn.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/NoOpDoFn.java new file mode 100644 index 000000000000..20646cfb8e38 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/NoOpDoFn.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.joda.time.Instant; + +/** + * A {@link DoFn} that does nothing with provided elements. Used for testing + * methods provided by the DoFn abstract class. + * + * @param unused. + * @param unused. + */ +class NoOpDoFn extends DoFn { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + } + + /** + * Returns a new NoOp Context. + */ + public DoFn.Context context() { + return new NoOpDoFnContext(); + } + + /** + * Returns a new NoOp Process Context. + */ + public DoFn.ProcessContext processContext() { + return new NoOpDoFnProcessContext(); + } + + /** + * A {@link DoFn.Context} that does nothing and returns exclusively null. + */ + private class NoOpDoFnContext extends DoFn.Context { + @Override + public PipelineOptions getPipelineOptions() { + return null; + } + @Override + public void output(OutputT output) { + } + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + } + @Override + public void sideOutput(TupleTag tag, T output) { + } + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, + Instant timestamp) { + } + @Override + protected Aggregator + createAggregatorInternal(String name, CombineFn combiner) { + return null; + } + } + + /** + * A {@link DoFn.ProcessContext} that does nothing and returns exclusively + * null. + */ + private class NoOpDoFnProcessContext extends DoFn.ProcessContext { + @Override + public InputT element() { + return null; + } + + @Override + public T sideInput(PCollectionView view) { + return null; + } + + @Override + public Instant timestamp() { + return null; + } + + @Override + public BoundedWindow window() { + return null; + } + + @Override + public PaneInfo pane() { + return null; + } + + @Override + public WindowingInternals windowingInternals() { + return null; + } + + @Override + public PipelineOptions getPipelineOptions() { + return null; + } + + @Override + public void output(OutputT output) {} + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) {} + + @Override + public void sideOutput(TupleTag tag, T output) {} + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, + Instant timestamp) {} + + @Override + protected Aggregator + createAggregatorInternal(String name, CombineFn combiner) { + return null; + } + + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java new file mode 100644 index 000000000000..f3f9bde92d6d --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java @@ -0,0 +1,1518 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; +import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; +import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.hamcrest.core.AnyOf.anyOf; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.IllegalMutationException; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.base.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests for ParDo. + */ +@RunWith(JUnit4.class) +public class ParDoTest implements Serializable { + // This test is Serializable, just so that it's easy to have + // anonymous inner classes inside the non-static test methods. + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + private static class PrintingDoFn extends DoFn implements RequiresWindowAccess { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + ":" + c.timestamp().getMillis() + + ":" + c.window().maxTimestamp().getMillis()); + } + } + + static class TestDoFn extends DoFn { + enum State { UNSTARTED, STARTED, PROCESSING, FINISHED } + State state = State.UNSTARTED; + + final List> sideInputViews = new ArrayList<>(); + final List> sideOutputTupleTags = new ArrayList<>(); + + public TestDoFn() { + } + + public TestDoFn(List> sideInputViews, + List> sideOutputTupleTags) { + this.sideInputViews.addAll(sideInputViews); + this.sideOutputTupleTags.addAll(sideOutputTupleTags); + } + + @Override + public void startBundle(Context c) { + assertEquals(State.UNSTARTED, state); + state = State.STARTED; + outputToAll(c, "started"); + } + + @Override + public void processElement(ProcessContext c) { + assertThat(state, + anyOf(equalTo(State.STARTED), equalTo(State.PROCESSING))); + state = State.PROCESSING; + outputToAllWithSideInputs(c, "processing: " + c.element()); + } + + @Override + public void finishBundle(Context c) { + assertThat(state, + anyOf(equalTo(State.STARTED), equalTo(State.PROCESSING))); + state = State.FINISHED; + outputToAll(c, "finished"); + } + + private void outputToAll(Context c, String value) { + c.output(value); + for (TupleTag sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + + private void outputToAllWithSideInputs(ProcessContext c, String value) { + if (!sideInputViews.isEmpty()) { + List sideInputValues = new ArrayList<>(); + for (PCollectionView sideInputView : sideInputViews) { + sideInputValues.add(c.sideInput(sideInputView)); + } + value += ": " + sideInputValues; + } + c.output(value); + for (TupleTag sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + } + + static class TestNoOutputDoFn extends DoFn { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception {} + } + + static class TestDoFnWithContext extends DoFnWithContext { + enum State { UNSTARTED, STARTED, PROCESSING, FINISHED } + State state = State.UNSTARTED; + + final List> sideInputViews = new ArrayList<>(); + final List> sideOutputTupleTags = new ArrayList<>(); + + public TestDoFnWithContext() { + } + + public TestDoFnWithContext(List> sideInputViews, + List> sideOutputTupleTags) { + this.sideInputViews.addAll(sideInputViews); + this.sideOutputTupleTags.addAll(sideOutputTupleTags); + } + + @StartBundle + public void startBundle(Context c) { + assertEquals(State.UNSTARTED, state); + state = State.STARTED; + outputToAll(c, "started"); + } + + @ProcessElement + public void processElement(ProcessContext c) { + assertThat(state, + anyOf(equalTo(State.STARTED), equalTo(State.PROCESSING))); + state = State.PROCESSING; + outputToAllWithSideInputs(c, "processing: " + c.element()); + } + + @FinishBundle + public void finishBundle(Context c) { + assertThat(state, + anyOf(equalTo(State.STARTED), equalTo(State.PROCESSING))); + state = State.FINISHED; + outputToAll(c, "finished"); + } + + private void outputToAll(Context c, String value) { + c.output(value); + for (TupleTag sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + + private void outputToAllWithSideInputs(ProcessContext c, String value) { + if (!sideInputViews.isEmpty()) { + List sideInputValues = new ArrayList<>(); + for (PCollectionView sideInputView : sideInputViews) { + sideInputValues.add(c.sideInput(sideInputView)); + } + value += ": " + sideInputValues; + } + c.output(value); + for (TupleTag sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + } + + static class TestStartBatchErrorDoFn extends DoFn { + @Override + public void startBundle(Context c) { + throw new RuntimeException("test error in initialize"); + } + + @Override + public void processElement(ProcessContext c) { + // This has to be here. + } + } + + static class TestProcessElementErrorDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + throw new RuntimeException("test error in process"); + } + } + + static class TestFinishBatchErrorDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + // This has to be here. + } + + @Override + public void finishBundle(Context c) { + throw new RuntimeException("test error in finalize"); + } + } + + private static class StrangelyNamedDoer extends DoFn { + @Override + public void processElement(ProcessContext c) { + } + } + + static class TestOutputTimestampDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Integer value = c.element(); + c.outputWithTimestamp(value, new Instant(value.longValue())); + } + } + + static class TestShiftTimestampDoFn extends DoFn { + private Duration allowedTimestampSkew; + private Duration durationToShift; + + public TestShiftTimestampDoFn(Duration allowedTimestampSkew, + Duration durationToShift) { + this.allowedTimestampSkew = allowedTimestampSkew; + this.durationToShift = durationToShift; + } + + @Override + public Duration getAllowedTimestampSkew() { + return allowedTimestampSkew; + } + @Override + public void processElement(ProcessContext c) { + Instant timestamp = c.timestamp(); + Preconditions.checkNotNull(timestamp); + Integer value = c.element(); + c.outputWithTimestamp(value, timestamp.plus(durationToShift)); + } + } + + static class TestFormatTimestampDoFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + Preconditions.checkNotNull(c.timestamp()); + c.output("processing: " + c.element() + ", timestamp: " + c.timestamp().getMillis()); + } + } + + static class MultiFilter + extends PTransform, PCollectionTuple> { + + private static final TupleTag BY2 = new TupleTag("by2"){}; + private static final TupleTag BY3 = new TupleTag("by3"){}; + + @Override + public PCollectionTuple apply(PCollection input) { + PCollection by2 = input.apply("Filter2s", ParDo.of(new FilterFn(2))); + PCollection by3 = input.apply("Filter3s", ParDo.of(new FilterFn(3))); + return PCollectionTuple.of(BY2, by2).and(BY3, by3); + } + + static class FilterFn extends DoFn { + private final int divisor; + + FilterFn(int divisor) { + this.divisor = divisor; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + if (c.element() % divisor == 0) { + c.output(c.element()); + } + } + } + } + + @Test + @Category(RunnableOnService.class) + public void testParDo() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection output = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.of(new TestDoFn())); + + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDo2() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection output = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.of(new TestDoFnWithContext())); + + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoEmpty() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(); + + PCollection output = pipeline + .apply(Create.of(inputs).withCoder(VarIntCoder.of())) + .apply("TestDoFn", ParDo.of(new TestDoFn())); + + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoEmptyOutputs() { + + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(); + + PCollection output = pipeline + .apply(Create.of(inputs).withCoder(VarIntCoder.of())) + .apply("TestDoFn", ParDo.of(new TestNoOutputDoFn())); + + DataflowAssert.that(output).empty(); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoWithSideOutputs() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + TupleTag mainOutputTag = new TupleTag("main"){}; + TupleTag sideOutputTag1 = new TupleTag("side1"){}; + TupleTag sideOutputTag2 = new TupleTag("side2"){}; + TupleTag sideOutputTag3 = new TupleTag("side3"){}; + TupleTag sideOutputTagUnwritten = new TupleTag("sideUnwritten"){}; + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo + .of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideOutputTag1, sideOutputTag2, sideOutputTag3))) + .withOutputTags( + mainOutputTag, + TupleTagList.of(sideOutputTag3) + .and(sideOutputTag1) + .and(sideOutputTagUnwritten) + .and(sideOutputTag2))); + + DataflowAssert.that(outputs.get(mainOutputTag)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + DataflowAssert.that(outputs.get(sideOutputTag1)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideOutputTag1)); + DataflowAssert.that(outputs.get(sideOutputTag2)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideOutputTag2)); + DataflowAssert.that(outputs.get(sideOutputTag3)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideOutputTag3)); + DataflowAssert.that(outputs.get(sideOutputTagUnwritten)).empty(); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoEmptyWithSideOutputs() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(); + + TupleTag mainOutputTag = new TupleTag("main"){}; + TupleTag sideOutputTag1 = new TupleTag("side1"){}; + TupleTag sideOutputTag2 = new TupleTag("side2"){}; + TupleTag sideOutputTag3 = new TupleTag("side3"){}; + TupleTag sideOutputTagUnwritten = new TupleTag("sideUnwritten"){}; + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo + .of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideOutputTag1, sideOutputTag2, sideOutputTag3))) + .withOutputTags( + mainOutputTag, + TupleTagList.of(sideOutputTag3).and(sideOutputTag1) + .and(sideOutputTagUnwritten).and(sideOutputTag2))); + + DataflowAssert.that(outputs.get(mainOutputTag)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + DataflowAssert.that(outputs.get(sideOutputTag1)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideOutputTag1)); + DataflowAssert.that(outputs.get(sideOutputTag2)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideOutputTag2)); + DataflowAssert.that(outputs.get(sideOutputTag3)) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs) + .fromSideOutput(sideOutputTag3)); + DataflowAssert.that(outputs.get(sideOutputTagUnwritten)).empty(); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoWithEmptySideOutputs() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(); + + TupleTag mainOutputTag = new TupleTag("main"){}; + TupleTag sideOutputTag1 = new TupleTag("side1"){}; + TupleTag sideOutputTag2 = new TupleTag("side2"){}; + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo + .of(new TestNoOutputDoFn()) + .withOutputTags( + mainOutputTag, + TupleTagList.of(sideOutputTag1).and(sideOutputTag2))); + + DataflowAssert.that(outputs.get(mainOutputTag)).empty(); + + DataflowAssert.that(outputs.get(sideOutputTag1)).empty(); + DataflowAssert.that(outputs.get(sideOutputTag2)).empty(); + + pipeline.run(); + } + + + @Test + @Category(RunnableOnService.class) + public void testParDoWithOnlySideOutputs() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + final TupleTag mainOutputTag = new TupleTag("main"){}; + final TupleTag sideOutputTag = new TupleTag("side"){}; + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) + .of(new DoFn(){ + @Override + public void processElement(ProcessContext c) { + c.sideOutput(sideOutputTag, c.element()); + }})); + + DataflowAssert.that(outputs.get(mainOutputTag)).empty(); + DataflowAssert.that(outputs.get(sideOutputTag)).containsInAnyOrder(inputs); + + pipeline.run(); + } + + @Test + public void testParDoWritingToUndeclaredSideOutput() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + TupleTag sideTag = new TupleTag("side"){}; + + PCollection output = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideTag)))); + + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + pipeline.run(); + } + + @Test + public void testParDoUndeclaredSideOutputLimit() { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline.apply(Create.of(Arrays.asList(3))); + + // Success for a total of 1000 outputs. + input + .apply("Success1000", ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + TupleTag specialSideTag = new TupleTag(){}; + c.sideOutput(specialSideTag, "side"); + c.sideOutput(specialSideTag, "side"); + c.sideOutput(specialSideTag, "side"); + + for (int i = 0; i < 998; i++) { + c.sideOutput(new TupleTag(){}, "side"); + } + }})); + pipeline.run(); + + // Failure for a total of 1001 outputs. + input + .apply("Failure1001", ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (int i = 0; i < 1000; i++) { + c.sideOutput(new TupleTag(){}, "side"); + } + }})); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("the number of side outputs has exceeded a limit"); + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoWithSideInputs() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollectionView sideInput1 = pipeline + .apply("CreateSideInput1", Create.of(11)) + .apply("ViewSideInput1", View.asSingleton()); + PCollectionView sideInputUnread = pipeline + .apply("CreateSideInputUnread", Create.of(-3333)) + .apply("ViewSideInputUnread", View.asSingleton()); + PCollectionView sideInput2 = pipeline + .apply("CreateSideInput2", Create.of(222)) + .apply("ViewSideInput2", View.asSingleton()); + + PCollection output = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.withSideInputs(sideInput1, sideInputUnread, sideInput2) + .of(new TestDoFn( + Arrays.asList(sideInput1, sideInput2), + Arrays.>asList()))); + + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput + .forInput(inputs) + .andSideInputs(11, 222)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoWithSideInputsIsCumulative() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollectionView sideInput1 = pipeline + .apply("CreateSideInput1", Create.of(11)) + .apply("ViewSideInput1", View.asSingleton()); + PCollectionView sideInputUnread = pipeline + .apply("CreateSideInputUnread", Create.of(-3333)) + .apply("ViewSideInputUnread", View.asSingleton()); + PCollectionView sideInput2 = pipeline + .apply("CreateSideInput2", Create.of(222)) + .apply("ViewSideInput2", View.asSingleton()); + + PCollection output = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.withSideInputs(sideInput1) + .withSideInputs(sideInputUnread) + .withSideInputs(sideInput2) + .of(new TestDoFn( + Arrays.asList(sideInput1, sideInput2), + Arrays.>asList()))); + + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput + .forInput(inputs) + .andSideInputs(11, 222)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMultiOutputParDoWithSideInputs() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + final TupleTag mainOutputTag = new TupleTag("main"){}; + final TupleTag sideOutputTag = new TupleTag("sideOutput"){}; + + PCollectionView sideInput1 = pipeline + .apply("CreateSideInput1", Create.of(11)) + .apply("ViewSideInput1", View.asSingleton()); + PCollectionView sideInputUnread = pipeline + .apply("CreateSideInputUnread", Create.of(-3333)) + .apply("ViewSideInputUnread", View.asSingleton()); + PCollectionView sideInput2 = pipeline + .apply("CreateSideInput2", Create.of(222)) + .apply("ViewSideInput2", View.asSingleton()); + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.withSideInputs(sideInput1) + .withSideInputs(sideInputUnread) + .withSideInputs(sideInput2) + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) + .of(new TestDoFn( + Arrays.asList(sideInput1, sideInput2), + Arrays.>asList()))); + + DataflowAssert.that(outputs.get(mainOutputTag)) + .satisfies(ParDoTest.HasExpectedOutput + .forInput(inputs) + .andSideInputs(11, 222)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMultiOutputParDoWithSideInputsIsCumulative() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + final TupleTag mainOutputTag = new TupleTag("main"){}; + final TupleTag sideOutputTag = new TupleTag("sideOutput"){}; + + PCollectionView sideInput1 = pipeline + .apply("CreateSideInput1", Create.of(11)) + .apply("ViewSideInput1", View.asSingleton()); + PCollectionView sideInputUnread = pipeline + .apply("CreateSideInputUnread", Create.of(-3333)) + .apply("ViewSideInputUnread", View.asSingleton()); + PCollectionView sideInput2 = pipeline + .apply("CreateSideInput2", Create.of(222)) + .apply("ViewSideInput2", View.asSingleton()); + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.withSideInputs(sideInput1) + .withSideInputs(sideInputUnread) + .withSideInputs(sideInput2) + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) + .of(new TestDoFn( + Arrays.asList(sideInput1, sideInput2), + Arrays.>asList()))); + + DataflowAssert.that(outputs.get(mainOutputTag)) + .satisfies(ParDoTest.HasExpectedOutput + .forInput(inputs) + .andSideInputs(11, 222)); + + pipeline.run(); + } + + @Test + public void testParDoReadingFromUnknownSideInput() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollectionView sideView = pipeline + .apply("Create3", Create.of(3)) + .apply(View.asSingleton()); + + pipeline.apply("CreateMain", Create.of(inputs)) + .apply(ParDo.of(new TestDoFn( + Arrays.>asList(sideView), + Arrays.>asList()))); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("calling sideInput() with unknown view"); + pipeline.run(); + } + + @Test + public void testParDoWithErrorInStartBatch() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + pipeline.apply(Create.of(inputs)) + .apply(ParDo.of(new TestStartBatchErrorDoFn())); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("test error in initialize"); + pipeline.run(); + } + + @Test + public void testParDoWithErrorInProcessElement() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + pipeline.apply(Create.of(inputs)) + .apply(ParDo.of(new TestProcessElementErrorDoFn())); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("test error in process"); + pipeline.run(); + } + + @Test + public void testParDoWithErrorInFinishBatch() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + pipeline.apply(Create.of(inputs)) + .apply(ParDo.of(new TestFinishBatchErrorDoFn())); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("test error in finalize"); + pipeline.run(); + } + + @Test + public void testParDoGetName() { + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(3, -42, 666))) + .setName("MyInput"); + + { + PCollection output1 = + input + .apply(ParDo.of(new TestDoFn())); + assertEquals("ParDo(Test).out", output1.getName()); + } + + { + PCollection output2 = + input + .apply(ParDo.named("MyParDo").of(new TestDoFn())); + assertEquals("MyParDo.out", output2.getName()); + } + + { + PCollection output3 = + input + .apply(ParDo.of(new TestDoFn()).named("HerParDo")); + assertEquals("HerParDo.out", output3.getName()); + } + + { + PCollection output4 = + input + .apply(ParDo.of(new TestDoFn()).named("TestDoFn")); + assertEquals("TestDoFn.out", output4.getName()); + } + + { + PCollection output5 = + input + .apply(ParDo.of(new StrangelyNamedDoer())); + assertEquals("ParDo(StrangelyNamedDoer).out", + output5.getName()); + } + + assertEquals("ParDo(Printing)", ParDo.of(new PrintingDoFn()).getName()); + + assertEquals( + "ParMultiDo(SideOutputDummy)", + ParDo.of(new SideOutputDummyFn(null)).withOutputTags(null, null).getName()); + } + + @Test + public void testParDoWithSideOutputsName() { + Pipeline p = TestPipeline.create(); + + TupleTag mainOutputTag = new TupleTag("main"){}; + TupleTag sideOutputTag1 = new TupleTag("side1"){}; + TupleTag sideOutputTag2 = new TupleTag("side2"){}; + TupleTag sideOutputTag3 = new TupleTag("side3"){}; + TupleTag sideOutputTagUnwritten = new TupleTag("sideUnwritten"){}; + + PCollectionTuple outputs = p + .apply(Create.of(Arrays.asList(3, -42, 666))).setName("MyInput") + .apply(ParDo + .named("MyParDo") + .of(new TestDoFn( + Arrays.>asList(), + Arrays.asList(sideOutputTag1, sideOutputTag2, sideOutputTag3))) + .withOutputTags( + mainOutputTag, + TupleTagList.of(sideOutputTag3).and(sideOutputTag1) + .and(sideOutputTagUnwritten).and(sideOutputTag2))); + + assertEquals("MyParDo.main", outputs.get(mainOutputTag).getName()); + assertEquals("MyParDo.side1", outputs.get(sideOutputTag1).getName()); + assertEquals("MyParDo.side2", outputs.get(sideOutputTag2).getName()); + assertEquals("MyParDo.side3", outputs.get(sideOutputTag3).getName()); + assertEquals("MyParDo.sideUnwritten", + outputs.get(sideOutputTagUnwritten).getName()); + } + + @Test + @Category(RunnableOnService.class) + public void testParDoInCustomTransform() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + PCollection output = pipeline + .apply(Create.of(inputs)) + .apply("CustomTransform", new PTransform, PCollection>() { + @Override + public PCollection apply(PCollection input) { + return input.apply(ParDo.of(new TestDoFn())); + } + }); + + // Test that Coder inference of the result works through + // user-defined PTransforms. + DataflowAssert.that(output) + .satisfies(ParDoTest.HasExpectedOutput.forInput(inputs)); + + pipeline.run(); + } + + @Test + public void testMultiOutputChaining() { + Pipeline pipeline = TestPipeline.create(); + + PCollectionTuple filters = pipeline + .apply(Create.of(Arrays.asList(3, 4, 5, 6))) + .apply(new MultiFilter()); + PCollection by2 = filters.get(MultiFilter.BY2); + PCollection by3 = filters.get(MultiFilter.BY3); + + // Apply additional filters to each operation. + PCollection by2then3 = by2 + .apply("Filter3sAgain", ParDo.of(new MultiFilter.FilterFn(3))); + PCollection by3then2 = by3 + .apply("Filter2sAgain", ParDo.of(new MultiFilter.FilterFn(2))); + + DataflowAssert.that(by2then3).containsInAnyOrder(6); + DataflowAssert.that(by3then2).containsInAnyOrder(6); + pipeline.run(); + } + + @Test + public void testJsonEscaping() { + // Declare an arbitrary function and make sure we can serialize it + DoFn doFn = new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + 1); + } + }; + + byte[] serializedBytes = serializeToByteArray(doFn); + String serializedJson = byteArrayToJsonString(serializedBytes); + assertArrayEquals( + serializedBytes, jsonStringToByteArray(serializedJson)); + } + + private static class TestDummy { } + + private static class TestDummyCoder extends AtomicCoder { + private TestDummyCoder() { } + private static final TestDummyCoder INSTANCE = new TestDummyCoder(); + + @JsonCreator + public static TestDummyCoder of() { + return INSTANCE; + } + + @SuppressWarnings("unused") // used to create a CoderFactory + public static List getInstanceComponents(TestDummy exampleValue) { + return Collections.emptyList(); + } + + @Override + public void encode(TestDummy value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public TestDummy decode(InputStream inStream, Context context) + throws CoderException, IOException { + return new TestDummy(); + } + + @Override + public boolean isRegisterByteSizeObserverCheap(TestDummy value, Context context) { + return true; + } + + @Override + public void registerByteSizeObserver( + TestDummy value, ElementByteSizeObserver observer, Context context) + throws Exception { + observer.update(0L); + } + } + + private static class SideOutputDummyFn extends DoFn { + private TupleTag sideTag; + public SideOutputDummyFn(TupleTag sideTag) { + this.sideTag = sideTag; + } + @Override + public void processElement(ProcessContext c) { + c.output(1); + c.sideOutput(sideTag, new TestDummy()); + } + } + + private static class MainOutputDummyFn extends DoFn { + private TupleTag sideTag; + public MainOutputDummyFn(TupleTag sideTag) { + this.sideTag = sideTag; + } + @Override + public void processElement(ProcessContext c) { + c.output(new TestDummy()); + c.sideOutput(sideTag, 1); + } + } + + /** DataflowAssert "matcher" for expected output. */ + static class HasExpectedOutput + implements SerializableFunction, Void>, Serializable { + private final List inputs; + private final List sideInputs; + private final String sideOutput; + private final boolean ordered; + + public static HasExpectedOutput forInput(List inputs) { + return new HasExpectedOutput( + new ArrayList(inputs), + new ArrayList(), + "", + false); + } + + private HasExpectedOutput(List inputs, + List sideInputs, + String sideOutput, + boolean ordered) { + this.inputs = inputs; + this.sideInputs = sideInputs; + this.sideOutput = sideOutput; + this.ordered = ordered; + } + + public HasExpectedOutput andSideInputs(Integer... sideInputValues) { + List sideInputs = new ArrayList<>(); + for (Integer sideInputValue : sideInputValues) { + sideInputs.add(sideInputValue); + } + return new HasExpectedOutput(inputs, sideInputs, sideOutput, ordered); + } + + public HasExpectedOutput fromSideOutput(TupleTag sideOutputTag) { + return fromSideOutput(sideOutputTag.getId()); + } + public HasExpectedOutput fromSideOutput(String sideOutput) { + return new HasExpectedOutput(inputs, sideInputs, sideOutput, ordered); + } + + public HasExpectedOutput inOrder() { + return new HasExpectedOutput(inputs, sideInputs, sideOutput, true); + } + + @Override + public Void apply(Iterable outputs) { + List starteds = new ArrayList<>(); + List processeds = new ArrayList<>(); + List finisheds = new ArrayList<>(); + for (String output : outputs) { + if (output.contains("started")) { + starteds.add(output); + } else if (output.contains("finished")) { + finisheds.add(output); + } else { + processeds.add(output); + } + } + + String sideInputsSuffix; + if (sideInputs.isEmpty()) { + sideInputsSuffix = ""; + } else { + sideInputsSuffix = ": " + sideInputs; + } + + String sideOutputPrefix; + if (sideOutput.isEmpty()) { + sideOutputPrefix = ""; + } else { + sideOutputPrefix = sideOutput + ": "; + } + + List expectedProcesseds = new ArrayList<>(); + for (Integer input : inputs) { + expectedProcesseds.add( + sideOutputPrefix + "processing: " + input + sideInputsSuffix); + } + String[] expectedProcessedsArray = + expectedProcesseds.toArray(new String[expectedProcesseds.size()]); + if (!ordered || expectedProcesseds.isEmpty()) { + assertThat(processeds, containsInAnyOrder(expectedProcessedsArray)); + } else { + assertThat(processeds, contains(expectedProcessedsArray)); + } + + assertEquals(starteds.size(), finisheds.size()); + for (String started : starteds) { + assertEquals(sideOutputPrefix + "started", started); + } + for (String finished : finisheds) { + assertEquals(sideOutputPrefix + "finished", finished); + } + + return null; + } + } + + @Test + public void testSideOutputUnknownCoder() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline + .apply(Create.of(Arrays.asList(1, 2, 3))); + + final TupleTag mainOutputTag = new TupleTag("main"); + final TupleTag sideOutputTag = new TupleTag("unknownSide"); + input.apply(ParDo.of(new SideOutputDummyFn(sideOutputTag)) + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag))); + + thrown.expect(PipelineExecutionException.class); + thrown.expectMessage("Unable to return a default Coder"); + pipeline.run(); + } + + @Test + public void testSideOutputUnregisteredExplicitCoder() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline + .apply(Create.of(Arrays.asList(1, 2, 3))); + + final TupleTag mainOutputTag = new TupleTag("main"); + final TupleTag sideOutputTag = new TupleTag("unregisteredSide"); + PCollectionTuple outputTuple = input.apply(ParDo.of(new SideOutputDummyFn(sideOutputTag)) + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag))); + + outputTuple.get(sideOutputTag).setCoder(new TestDummyCoder()); + + outputTuple.get(sideOutputTag).apply(View.asSingleton()); + + assertEquals(new TestDummyCoder(), outputTuple.get(sideOutputTag).getCoder()); + outputTuple.get(sideOutputTag).finishSpecifyingOutput(); // Check for crashes + assertEquals(new TestDummyCoder(), + outputTuple.get(sideOutputTag).getCoder()); // Check for corruption + pipeline.run(); + } + + @Test + public void testMainOutputUnregisteredExplicitCoder() { + Pipeline pipeline = TestPipeline.create(); + PCollection input = pipeline + .apply(Create.of(Arrays.asList(1, 2, 3))); + + final TupleTag mainOutputTag = new TupleTag("unregisteredMain"); + final TupleTag sideOutputTag = new TupleTag("side") {}; + PCollectionTuple outputTuple = input.apply(ParDo.of(new MainOutputDummyFn(sideOutputTag)) + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag))); + + outputTuple.get(mainOutputTag).setCoder(new TestDummyCoder()); + + pipeline.run(); + } + + @Test + public void testMainOutputApplySideOutputNoCoder() { + // Regression test: applying a transform to the main output + // should not cause a crash based on lack of a coder for the + // side output. + + Pipeline pipeline = TestPipeline.create(); + final TupleTag mainOutputTag = new TupleTag("main"); + final TupleTag sideOutputTag = new TupleTag("side"); + PCollectionTuple tuple = pipeline + .apply(Create.of(new TestDummy()) + .withCoder(TestDummyCoder.of())) + .apply(ParDo + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) + .of( + new DoFn() { + @Override public void processElement(ProcessContext context) { + TestDummy element = context.element(); + context.output(element); + context.sideOutput(sideOutputTag, element); + } + }) + ); + + // Before fix, tuple.get(mainOutputTag).apply(...) would indirectly trigger + // tuple.get(sideOutputTag).finishSpecifyingOutput(), which would crash + // on a missing coder. + tuple.get(mainOutputTag) + .setCoder(TestDummyCoder.of()) + .apply("Output1", ParDo.of(new DoFn() { + @Override public void processElement(ProcessContext context) { + context.output(1); + } + })); + + tuple.get(sideOutputTag).setCoder(TestDummyCoder.of()); + + pipeline.run(); + } + + @Test + public void testParDoOutputWithTimestamp() { + Pipeline pipeline = TestPipeline.create(); + + PCollection input = + pipeline.apply(Create.of(Arrays.asList(3, 42, 6))); + + PCollection output = + input + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.ZERO, Duration.ZERO))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + DataflowAssert.that(output).containsInAnyOrder( + "processing: 3, timestamp: 3", + "processing: 42, timestamp: 42", + "processing: 6, timestamp: 6"); + + pipeline.run(); + } + + @Test + public void testParDoSideOutputWithTimestamp() { + Pipeline pipeline = TestPipeline.create(); + + PCollection input = + pipeline.apply(Create.of(Arrays.asList(3, 42, 6))); + + final TupleTag mainOutputTag = new TupleTag("main"){}; + final TupleTag sideOutputTag = new TupleTag("side"){}; + + PCollection output = + input + .apply(ParDo.withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.sideOutputWithTimestamp( + sideOutputTag, c.element(), new Instant(c.element().longValue())); + } + })).get(sideOutputTag) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.ZERO, Duration.ZERO))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + DataflowAssert.that(output).containsInAnyOrder( + "processing: 3, timestamp: 3", + "processing: 42, timestamp: 42", + "processing: 6, timestamp: 6"); + + pipeline.run(); + } + + @Test + public void testParDoShiftTimestamp() { + Pipeline pipeline = TestPipeline.create(); + + PCollection input = + pipeline.apply(Create.of(Arrays.asList(3, 42, 6))); + + PCollection output = + input + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.millis(1000), + Duration.millis(-1000)))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + DataflowAssert.that(output).containsInAnyOrder( + "processing: 3, timestamp: -997", + "processing: 42, timestamp: -958", + "processing: 6, timestamp: -994"); + + pipeline.run(); + } + + @Test + public void testParDoShiftTimestampInvalid() { + Pipeline pipeline = TestPipeline.create(); + + pipeline.apply(Create.of(Arrays.asList(3, 42, 6))) + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.millis(1000), // allowed skew = 1 second + Duration.millis(-1001)))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Cannot output with timestamp"); + thrown.expectMessage( + "Output timestamps must be no earlier than the timestamp of the current input"); + thrown.expectMessage("minus the allowed skew (1 second)."); + pipeline.run(); + } + + @Test + public void testParDoShiftTimestampInvalidZeroAllowed() { + Pipeline pipeline = TestPipeline.create(); + + pipeline.apply(Create.of(Arrays.asList(3, 42, 6))) + .apply(ParDo.of(new TestOutputTimestampDoFn())) + .apply(ParDo.of(new TestShiftTimestampDoFn(Duration.ZERO, + Duration.millis(-1001)))) + .apply(ParDo.of(new TestFormatTimestampDoFn())); + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Cannot output with timestamp"); + thrown.expectMessage( + "Output timestamps must be no earlier than the timestamp of the current input"); + thrown.expectMessage("minus the allowed skew (0 milliseconds)."); + pipeline.run(); + } + + private static class Checker implements SerializableFunction, Void> { + @Override + public Void apply(Iterable input) { + boolean foundStart = false; + boolean foundElement = false; + boolean foundFinish = false; + for (String str : input) { + if (str.equals("elem:1:1")) { + if (foundElement) { + throw new AssertionError("Received duplicate element"); + } + foundElement = true; + } else if (str.equals("start:2:2")) { + foundStart = true; + } else if (str.equals("finish:3:3")) { + foundFinish = true; + } else { + throw new AssertionError("Got unexpected value: " + str); + } + } + if (!foundStart) { + throw new AssertionError("Missing \"start:2:2\""); + } + if (!foundElement) { + throw new AssertionError("Missing \"elem:1:1\""); + } + if (!foundFinish) { + throw new AssertionError("Missing \"finish:3:3\""); + } + + return null; + } + } + + @Test + @Category(RunnableOnService.class) + public void testWindowingInStartAndFinishBundle() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.timestamped(TimestampedValue.of("elem", new Instant(1)))) + .apply(Window.into(FixedWindows.of(Duration.millis(1)))) + .apply(ParDo.of(new DoFn() { + @Override + public void startBundle(Context c) { + c.outputWithTimestamp("start", new Instant(2)); + System.out.println("Start: 2"); + } + + @Override + public void processElement(ProcessContext c) { + c.output(c.element()); + System.out.println("Process: " + c.element() + ":" + c.timestamp().getMillis()); + } + + @Override + public void finishBundle(Context c) { + c.outputWithTimestamp("finish", new Instant(3)); + System.out.println("Finish: 3"); + } + })) + .apply(ParDo.of(new PrintingDoFn())); + + DataflowAssert.that(output).satisfies(new Checker()); + + pipeline.run(); + } + + @Test + public void testWindowingInStartBundleException() { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.timestamped(TimestampedValue.of("elem", new Instant(1)))) + .apply(Window.into(FixedWindows.of(Duration.millis(1)))) + .apply(ParDo.of(new DoFn() { + @Override + public void startBundle(Context c) { + c.output("start"); + } + + @Override + public void processElement(ProcessContext c) { + c.output(c.element()); + } + })); + + thrown.expectMessage("WindowFn attempted to access input timestamp when none was available"); + pipeline.run(); + } + + /** + * Tests that a {@link DoFn} that mutates an output with a good equals() fails in the + * {@link DirectPipelineRunner}. + */ + @Test + public void testMutatingOutputThenOutputDoFnError() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of(42)) + .apply(ParDo.of(new DoFn>() { + @Override public void processElement(ProcessContext c) { + List outputList = Arrays.asList(1, 2, 3, 4); + c.output(outputList); + outputList.set(0, 37); + c.output(outputList); + } + })); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalMutationException.class)); + thrown.expectMessage("output"); + thrown.expectMessage("must not be mutated"); + pipeline.run(); + } + + /** + * Tests that a {@link DoFn} that mutates an output with a good equals() fails in the + * {@link DirectPipelineRunner}. + */ + @Test + public void testMutatingOutputThenTerminateDoFnError() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of(42)) + .apply(ParDo.of(new DoFn>() { + @Override public void processElement(ProcessContext c) { + List outputList = Arrays.asList(1, 2, 3, 4); + c.output(outputList); + outputList.set(0, 37); + } + })); + + thrown.expect(IllegalMutationException.class); + thrown.expectMessage("output"); + thrown.expectMessage("must not be mutated"); + pipeline.run(); + } + + /** + * Tests that a {@link DoFn} that mutates an output with a bad equals() still fails + * in the {@link DirectPipelineRunner}. + */ + @Test + public void testMutatingOutputCoderDoFnError() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of(42)) + .apply(ParDo.of(new DoFn() { + @Override public void processElement(ProcessContext c) { + byte[] outputArray = new byte[]{0x1, 0x2, 0x3}; + c.output(outputArray); + outputArray[0] = 0xa; + c.output(outputArray); + } + })); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalMutationException.class)); + thrown.expectMessage("output"); + thrown.expectMessage("must not be mutated"); + pipeline.run(); + } + + /** + * Tests that a {@link DoFn} that mutates its input with a good equals() fails in the + * {@link DirectPipelineRunner}. + */ + @Test + public void testMutatingInputDoFnError() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of(Arrays.asList(1, 2, 3), Arrays.asList(4, 5, 6)) + .withCoder(ListCoder.of(VarIntCoder.of()))) + .apply(ParDo.of(new DoFn, Integer>() { + @Override public void processElement(ProcessContext c) { + List inputList = c.element(); + inputList.set(0, 37); + c.output(12); + } + })); + + thrown.expect(IllegalMutationException.class); + thrown.expectMessage("input"); + thrown.expectMessage("must not be mutated"); + pipeline.run(); + } + + /** + * Tests that a {@link DoFn} that mutates an input with a bad equals() still fails + * in the {@link DirectPipelineRunner}. + */ + @Test + public void testMutatingInputCoderDoFnError() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of(new byte[]{0x1, 0x2, 0x3}, new byte[]{0x4, 0x5, 0x6})) + .apply(ParDo.of(new DoFn() { + @Override public void processElement(ProcessContext c) { + byte[] inputArray = c.element(); + inputArray[0] = 0xa; + c.output(13); + } + })); + + thrown.expect(IllegalMutationException.class); + thrown.expectMessage("input"); + thrown.expectMessage("must not be mutated"); + pipeline.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PartitionTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PartitionTest.java new file mode 100644 index 000000000000..5121a0a88172 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PartitionTest.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Partition.PartitionFn; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for {@link Partition}. + */ +@RunWith(JUnit4.class) +public class PartitionTest implements Serializable { + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + static class ModFn implements PartitionFn { + @Override + public int partitionFor(Integer elem, int numPartitions) { + return elem % numPartitions; + } + } + + static class IdentityFn implements PartitionFn { + @Override + public int partitionFor(Integer elem, int numPartitions) { + return elem; + } + } + + @Test + @Category(RunnableOnService.class) + public void testEvenOddPartition() { + Pipeline pipeline = TestPipeline.create(); + + PCollectionList outputs = pipeline + .apply(Create.of(591, 11789, 1257, 24578, 24799, 307)) + .apply(Partition.of(2, new ModFn())); + assertTrue(outputs.size() == 2); + DataflowAssert.that(outputs.get(0)).containsInAnyOrder(24578); + DataflowAssert.that(outputs.get(1)).containsInAnyOrder(591, 11789, 1257, + 24799, 307); + pipeline.run(); + } + + @Test + public void testModPartition() { + Pipeline pipeline = TestPipeline.create(); + + PCollectionList outputs = pipeline + .apply(Create.of(1, 2, 4, 5)) + .apply(Partition.of(3, new ModFn())); + assertTrue(outputs.size() == 3); + DataflowAssert.that(outputs.get(0)).empty(); + DataflowAssert.that(outputs.get(1)).containsInAnyOrder(1, 4); + DataflowAssert.that(outputs.get(2)).containsInAnyOrder(2, 5); + pipeline.run(); + } + + @Test + public void testOutOfBoundsPartitions() { + Pipeline pipeline = TestPipeline.create(); + + pipeline + .apply(Create.of(-1)) + .apply(Partition.of(5, new IdentityFn())); + + thrown.expect(RuntimeException.class); + thrown.expectMessage( + "Partition function returned out of bounds index: -1 not in [0..5)"); + pipeline.run(); + } + + @Test + public void testZeroNumPartitions() { + Pipeline pipeline = TestPipeline.create(); + + PCollection input = pipeline.apply(Create.of(591)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("numPartitions must be > 0"); + input.apply(Partition.of(0, new IdentityFn())); + } + + @Test + public void testDroppedPartition() { + Pipeline pipeline = TestPipeline.create(); + + // Compute the set of integers either 1 or 2 mod 3, the hard way. + PCollectionList outputs = pipeline + .apply(Create.of(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)) + .apply(Partition.of(3, new ModFn())); + + List> outputsList = new ArrayList<>(outputs.getAll()); + outputsList.remove(0); + outputs = PCollectionList.of(outputsList); + assertTrue(outputs.size() == 2); + + PCollection output = outputs.apply(Flatten.pCollections()); + DataflowAssert.that(output).containsInAnyOrder(2, 4, 5, 7, 8, 10, 11); + pipeline.run(); + } + + @Test + public void testPartitionGetName() { + assertEquals("Partition", Partition.of(3, new ModFn()).getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesTest.java new file mode 100644 index 000000000000..a6fe7e82a14f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesTest.java @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Tests for RemovedDuplicates. + */ +@RunWith(JUnit4.class) +public class RemoveDuplicatesTest { + @Test + @Category(RunnableOnService.class) + public void testRemoveDuplicates() { + List strings = Arrays.asList( + "k1", + "k5", + "k5", + "k2", + "k1", + "k2", + "k3"); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(strings) + .withCoder(StringUtf8Coder.of())); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + DataflowAssert.that(output) + .containsInAnyOrder("k1", "k5", "k2", "k3"); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testRemoveDuplicatesEmpty() { + List strings = Arrays.asList(); + + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(strings) + .withCoder(StringUtf8Coder.of())); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + DataflowAssert.that(output).empty(); + p.run(); + } + + private static class Keys implements SerializableFunction, String> { + @Override + public String apply(KV input) { + return input.getKey(); + } + } + + private static class Checker implements SerializableFunction>, Void> { + @Override + public Void apply(Iterable> input) { + Map values = new HashMap<>(); + for (KV kv : input) { + values.put(kv.getKey(), kv.getValue()); + } + assertEquals(2, values.size()); + assertTrue(values.get("k1").equals("v1") || values.get("k1").equals("v2")); + assertEquals("v1", values.get("k2")); + return null; + } + } + + + @Test + @Category(RunnableOnService.class) + public void testRemoveDuplicatesWithRepresentativeValue() { + List> strings = Arrays.asList( + KV.of("k1", "v1"), + KV.of("k1", "v2"), + KV.of("k2", "v1")); + + Pipeline p = TestPipeline.create(); + + PCollection> input = p.apply(Create.of(strings)); + + PCollection> output = + input.apply(RemoveDuplicates.withRepresentativeValueFn(new Keys())); + + + DataflowAssert.that(output).satisfies(new Checker()); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SampleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SampleTest.java new file mode 100644 index 000000000000..d8605134636b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SampleTest.java @@ -0,0 +1,260 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; +import static com.google.cloud.dataflow.sdk.TestUtils.NO_LINES; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +/** + * Tests for Sample transform. + */ +@RunWith(JUnit4.class) +public class SampleTest { + static final Integer[] EMPTY = new Integer[] { }; + static final Integer[] DATA = new Integer[] {1, 2, 3, 4, 5}; + static final Integer[] REPEATED_DATA = new Integer[] {1, 1, 2, 2, 3, 3, 4, 4, 5, 5}; + + /** + * Verifies that the result of a Sample operation contains the expected number of elements, + * and that those elements are a subset of the items in expected. + */ + @SuppressWarnings("rawtypes") + public static class VerifyCorrectSample + implements SerializableFunction, Void> { + private T[] expectedValues; + private int expectedSize; + + /** + * expectedSize is the number of elements that the Sample should contain. expected is the set + * of elements that the sample may contain. + */ + @SafeVarargs + VerifyCorrectSample(int expectedSize, T... expected) { + this.expectedValues = expected; + this.expectedSize = expectedSize; + } + + @Override + @SuppressWarnings("unchecked") + public Void apply(Iterable in) { + List actual = new ArrayList<>(); + for (T elem : in) { + actual.add(elem); + } + + assertEquals(expectedSize, actual.size()); + + Collections.sort(actual); // We assume that @expected is already sorted. + int i = 0; // Index into @expected + for (T s : actual) { + boolean matchFound = false; + for (; i < expectedValues.length; i++) { + if (s.equals(expectedValues[i])) { + matchFound = true; + break; + } + } + assertTrue("Invalid sample: " + Joiner.on(',').join(actual), matchFound); + i++; // Don't match the same element again. + } + return null; + } + } + + @Test + @Category(RunnableOnService.class) + public void testSample() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA) + .withCoder(BigEndianIntegerCoder.of())); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(3)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(3, DATA)); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSampleEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(EMPTY) + .withCoder(BigEndianIntegerCoder.of())); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(3)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(0, EMPTY)); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSampleZero() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA) + .withCoder(BigEndianIntegerCoder.of())); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(0)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(0, DATA)); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSampleInsufficientElements() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA) + .withCoder(BigEndianIntegerCoder.of())); + PCollection> output = input.apply( + Sample.fixedSizeGlobally(10)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(5, DATA)); + p.run(); + } + + @Test(expected = IllegalArgumentException.class) + public void testSampleNegative() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(DATA) + .withCoder(BigEndianIntegerCoder.of())); + input.apply(Sample.fixedSizeGlobally(-1)); + } + + @Test + @Category(RunnableOnService.class) + public void testSampleMultiplicity() { + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(REPEATED_DATA) + .withCoder(BigEndianIntegerCoder.of())); + // At least one value must be selected with multiplicity. + PCollection> output = input.apply( + Sample.fixedSizeGlobally(6)); + + DataflowAssert.thatSingletonIterable(output) + .satisfies(new VerifyCorrectSample<>(6, REPEATED_DATA)); + p.run(); + } + + private static class VerifyAnySample implements SerializableFunction, Void> { + private final List lines; + private final int limit; + private VerifyAnySample(List lines, int limit) { + this.lines = lines; + this.limit = limit; + } + + @Override + public Void apply(Iterable actualIter) { + final int expectedSize = Math.min(limit, lines.size()); + + // Make sure actual is the right length, and is a + // subset of expected. + List actual = new ArrayList<>(); + for (String s : actualIter) { + actual.add(s); + } + assertEquals(expectedSize, actual.size()); + Set actualAsSet = new TreeSet<>(actual); + Set linesAsSet = new TreeSet<>(lines); + assertEquals(actual.size(), actualAsSet.size()); + assertEquals(lines.size(), linesAsSet.size()); + assertTrue(linesAsSet.containsAll(actualAsSet)); + return null; + } + } + + void runPickAnyTest(final List lines, int limit) { + Preconditions.checkArgument(new HashSet(lines).size() == lines.size(), + "Duplicates are unsupported."); + Pipeline p = TestPipeline.create(); + + PCollection input = p.apply(Create.of(lines) + .withCoder(StringUtf8Coder.of())); + + PCollection output = + input.apply(Sample.any(limit)); + + + DataflowAssert.that(output) + .satisfies(new VerifyAnySample(lines, limit)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testPickAny() { + runPickAnyTest(LINES, 0); + runPickAnyTest(LINES, LINES.size() / 2); + runPickAnyTest(LINES, LINES.size() * 2); + } + + @Test + // Extra tests, not worth the time to run on the real service. + public void testPickAnyMore() { + runPickAnyTest(LINES, LINES.size() - 1); + runPickAnyTest(LINES, LINES.size()); + runPickAnyTest(LINES, LINES.size() + 1); + } + + @Test + @Category(RunnableOnService.class) + public void testPickAnyWhenEmpty() { + runPickAnyTest(NO_LINES, 0); + runPickAnyTest(NO_LINES, 1); + } + + @Test + public void testSampleGetName() { + assertEquals("Sample.SampleAny", Sample.any(1).getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SimpleStatsFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SimpleStatsFnsTest.java new file mode 100644 index 000000000000..914f95dc2e73 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SimpleStatsFnsTest.java @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests of Min, Max, Mean, and Sum. + */ +@RunWith(JUnit4.class) +public class SimpleStatsFnsTest { + static final double DOUBLE_COMPARISON_ACCURACY = 1e-7; + + private static class TestCase> { + final List data; + final NumT min; + final NumT max; + final NumT sum; + final Double mean; + + @SafeVarargs + @SuppressWarnings("all") + public TestCase(NumT min, NumT max, NumT sum, NumT... values) { + this.data = Arrays.asList(values); + this.min = min; + this.max = max; + this.sum = sum; + this.mean = + values.length == 0 ? Double.NaN : sum.doubleValue() / values.length; + } + } + + static final List> DOUBLE_CASES = Arrays.asList( + new TestCase<>(-312.31, 6312.31, 11629.13, + -312.31, 29.13, 112.158, 6312.31, -312.158, -312.158, 112.158, + -312.31, 6312.31, 0.0), + new TestCase<>(3.14, 3.14, 3.14, 3.14), + new TestCase<>(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0.0)); + + static final List> LONG_CASES = Arrays.asList( + new TestCase<>(-50000000000000000L, + 70000000000000000L, + 60000033123213121L, + 0L, 1L, 10000000000000000L, -50000000000000000L, + 70000000000000000L, 0L, 10000000000000000L, -1L, + -50000000000000000L, 70000000000000000L, 33123213121L), + new TestCase<>(3L, 3L, 3L, 3L), + new TestCase<>(Long.MAX_VALUE, Long.MIN_VALUE, 0L)); + + static final List> INTEGER_CASES = Arrays.asList( + new TestCase<>(-3, 6, 22, + 1, -3, 2, 6, 3, 4, -3, 5, 6, 1), + new TestCase<>(3, 3, 3, 3), + new TestCase<>(Integer.MAX_VALUE, Integer.MIN_VALUE, 0)); + + @Test + public void testInstantStats() { + assertEquals(new Instant(1000), Min.MinFn.naturalOrder().apply( + Arrays.asList(new Instant(1000), new Instant(2000)))); + assertEquals(null, Min.MinFn.naturalOrder().apply( + Collections.emptyList())); + assertEquals(new Instant(5000), Min.MinFn.naturalOrder(new Instant(5000)).apply( + Collections.emptyList())); + + assertEquals(new Instant(2000), Max.MaxFn.naturalOrder().apply( + Arrays.asList(new Instant(1000), new Instant(2000)))); + assertEquals(null, Max.MaxFn.naturalOrder().apply( + Collections.emptyList())); + assertEquals(new Instant(5000), Max.MaxFn.naturalOrder(new Instant(5000)).apply( + Collections.emptyList())); + } + + @Test + public void testDoubleStats() { + for (TestCase t : DOUBLE_CASES) { + assertEquals(t.sum, new Sum.SumDoubleFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + assertEquals(t.min, new Min.MinDoubleFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + assertEquals(t.max, new Max.MaxDoubleFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + assertEquals(t.mean, new Mean.MeanFn().apply(t.data), + DOUBLE_COMPARISON_ACCURACY); + } + } + + @Test + public void testIntegerStats() { + for (TestCase t : INTEGER_CASES) { + assertEquals(t.sum, new Sum.SumIntegerFn().apply(t.data)); + assertEquals(t.min, new Min.MinIntegerFn().apply(t.data)); + assertEquals(t.max, new Max.MaxIntegerFn().apply(t.data)); + assertEquals(t.mean, new Mean.MeanFn().apply(t.data)); + } + } + + @Test + public void testLongStats() { + for (TestCase t : LONG_CASES) { + assertEquals(t.sum, new Sum.SumLongFn().apply(t.data)); + assertEquals(t.min, new Min.MinLongFn().apply(t.data)); + assertEquals(t.max, new Max.MaxLongFn().apply(t.data)); + assertEquals(t.mean, new Mean.MeanFn().apply(t.data)); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SumTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SumTest.java new file mode 100644 index 000000000000..b5ad51ce132f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/SumTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.cloud.dataflow.sdk.TestUtils.checkCombineFn; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.Lists; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for Sum. + */ +@RunWith(JUnit4.class) +public class SumTest { + + @Test + public void testSumGetNames() { + assertEquals("Sum.Globally", Sum.integersGlobally().getName()); + assertEquals("Sum.Globally", Sum.doublesGlobally().getName()); + assertEquals("Sum.Globally", Sum.longsGlobally().getName()); + assertEquals("Sum.PerKey", Sum.integersPerKey().getName()); + assertEquals("Sum.PerKey", Sum.doublesPerKey().getName()); + assertEquals("Sum.PerKey", Sum.longsPerKey().getName()); + } + + @Test + public void testSumIntegerFn() { + checkCombineFn( + new Sum.SumIntegerFn(), + Lists.newArrayList(1, 2, 3, 4), + 10); + } + + @Test + public void testSumLongFn() { + checkCombineFn( + new Sum.SumLongFn(), + Lists.newArrayList(1L, 2L, 3L, 4L), + 10L); + } + + @Test + public void testSumDoubleFn() { + checkCombineFn( + new Sum.SumDoubleFn(), + Lists.newArrayList(1.0, 2.0, 3.0, 4.0), + 10.0); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/TopTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/TopTest.java new file mode 100644 index 000000000000..ac06bff21392 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/TopTest.java @@ -0,0 +1,259 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.Bound; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +/** Tests for Top. */ +@RunWith(JUnit4.class) +public class TopTest { + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @SuppressWarnings("unchecked") + static final String[] COLLECTION = new String[] { + "a", "bb", "c", "c", "z" + }; + + @SuppressWarnings("unchecked") + static final String[] EMPTY_COLLECTION = new String[] { + }; + + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] TABLE = new KV[] { + KV.of("a", 1), + KV.of("a", 2), + KV.of("a", 3), + KV.of("b", 1), + KV.of("b", 10), + KV.of("b", 10), + KV.of("b", 100), + }; + + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] EMPTY_TABLE = new KV[] { + }; + + public PCollection> createInputTable(Pipeline p) { + return p.apply("CreateInputTable", Create.of(Arrays.asList(TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + } + + public PCollection> createEmptyInputTable(Pipeline p) { + return p.apply("CreateEmptyInputTable", Create.of(Arrays.asList(EMPTY_TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + } + + @Test + @SuppressWarnings("unchecked") + public void testTop() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION)) + .withCoder(StringUtf8Coder.of())); + + PCollection> top1 = input.apply(Top.of(1, new OrderByLength())); + PCollection> top2 = input.apply(Top.largest(2)); + PCollection> top3 = input.apply(Top.smallest(3)); + + PCollection> inputTable = createInputTable(p); + PCollection>> largestPerKey = inputTable + .apply(Top.largestPerKey(2)); + PCollection>> smallestPerKey = inputTable + .apply(Top.smallestPerKey(2)); + + DataflowAssert.thatSingletonIterable(top1).containsInAnyOrder(Arrays.asList("bb")); + DataflowAssert.thatSingletonIterable(top2).containsInAnyOrder("z", "c"); + DataflowAssert.thatSingletonIterable(top3).containsInAnyOrder("a", "bb", "c"); + DataflowAssert.that(largestPerKey).containsInAnyOrder( + KV.of("a", Arrays.asList(3, 2)), + KV.of("b", Arrays.asList(100, 10))); + DataflowAssert.that(smallestPerKey).containsInAnyOrder( + KV.of("a", Arrays.asList(1, 2)), + KV.of("b", Arrays.asList(1, 10))); + + p.run(); + } + + @Test + @SuppressWarnings("unchecked") + public void testTopEmpty() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.of(Arrays.asList(EMPTY_COLLECTION)) + .withCoder(StringUtf8Coder.of())); + + PCollection> top1 = input.apply(Top.of(1, new OrderByLength())); + PCollection> top2 = input.apply(Top.largest(2)); + PCollection> top3 = input.apply(Top.smallest(3)); + + PCollection> inputTable = createEmptyInputTable(p); + PCollection>> largestPerKey = inputTable + .apply(Top.largestPerKey(2)); + PCollection>> smallestPerKey = inputTable + .apply(Top.smallestPerKey(2)); + + DataflowAssert.thatSingletonIterable(top1).empty(); + DataflowAssert.thatSingletonIterable(top2).empty(); + DataflowAssert.thatSingletonIterable(top3).empty(); + DataflowAssert.that(largestPerKey).empty(); + DataflowAssert.that(smallestPerKey).empty(); + + p.run(); + } + + @Test + public void testTopEmptyWithIncompatibleWindows() { + Pipeline p = TestPipeline.create(); + Bound windowingFn = Window.into(FixedWindows.of(Duration.standardDays(10L))); + PCollection input = + p.apply(Create.timestamped(Collections.emptyList(), Collections.emptyList())) + .apply(windowingFn); + + expectedEx.expect(IllegalStateException.class); + expectedEx.expectMessage("Top"); + expectedEx.expectMessage("GlobalWindows"); + expectedEx.expectMessage("withoutDefaults"); + expectedEx.expectMessage("asSingletonView"); + + input.apply(Top.of(1, new OrderByLength())); + } + + @Test + @SuppressWarnings("unchecked") + public void testTopZero() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION)) + .withCoder(StringUtf8Coder.of())); + + PCollection> top1 = input.apply(Top.of(0, new OrderByLength())); + PCollection> top2 = input.apply(Top.largest(0)); + PCollection> top3 = input.apply(Top.smallest(0)); + + PCollection> inputTable = createInputTable(p); + PCollection>> largestPerKey = inputTable + .apply(Top.largestPerKey(0)); + + PCollection>> smallestPerKey = inputTable + .apply(Top.smallestPerKey(0)); + + DataflowAssert.thatSingletonIterable(top1).empty(); + DataflowAssert.thatSingletonIterable(top2).empty(); + DataflowAssert.thatSingletonIterable(top3).empty(); + DataflowAssert.that(largestPerKey).containsInAnyOrder( + KV.of("a", Arrays.asList()), + KV.of("b", Arrays.asList())); + DataflowAssert.that(smallestPerKey).containsInAnyOrder( + KV.of("a", Arrays.asList()), + KV.of("b", Arrays.asList())); + + p.run(); + } + + // This is a purely compile-time test. If the code compiles, then it worked. + @Test + public void testPerKeySerializabilityRequirement() { + Pipeline p = TestPipeline.create(); + p.apply("CreateCollection", Create.of(Arrays.asList(COLLECTION)) + .withCoder(StringUtf8Coder.of())); + + PCollection> inputTable = createInputTable(p); + inputTable + .apply(Top.perKey(1, + new IntegerComparator())); + + inputTable + .apply("PerKey2", Top.perKey(1, + new IntegerComparator2())); + } + + @Test + public void testCountConstraint() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION)) + .withCoder(StringUtf8Coder.of())); + + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString(">= 0")); + + input.apply(Top.of(-1, new OrderByLength())); + } + + @Test + public void testTopGetNames() { + assertEquals("Top.Globally", Top.of(1, new OrderByLength()).getName()); + assertEquals("Smallest.Globally", Top.smallest(1).getName()); + assertEquals("Largest.Globally", Top.largest(2).getName()); + assertEquals("Top.PerKey", Top.perKey(1, new IntegerComparator()).getName()); + assertEquals("Smallest.PerKey", Top.smallestPerKey(1).getName()); + assertEquals("Largest.PerKey", Top.largestPerKey(2).getName()); + } + + private static class OrderByLength implements Comparator, Serializable { + @Override + public int compare(String a, String b) { + if (a.length() != b.length()) { + return a.length() - b.length(); + } else { + return a.compareTo(b); + } + } + } + + private static class IntegerComparator implements Comparator, Serializable { + @Override + public int compare(Integer o1, Integer o2) { + return o1.compareTo(o2); + } + } + + private static class IntegerComparator2 implements Comparator, Serializable { + @Override + public int compare(Integer o1, Integer o2) { + return o1.compareTo(o2); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ValuesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ValuesTest.java new file mode 100644 index 000000000000..9663c453dd19 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ValuesTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for Values transform. + */ +@RunWith(JUnit4.class) +public class ValuesTest { + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] TABLE = new KV[] { + KV.of("one", 1), + KV.of("two", 2), + KV.of("three", 3), + KV.of("four", 4), + KV.of("dup", 4) + }; + + @SuppressWarnings({"rawtypes", "unchecked"}) + static final KV[] EMPTY_TABLE = new KV[] { + }; + + @Test + @Category(RunnableOnService.class) + public void testValues() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection output = input.apply(Values.create()); + + DataflowAssert.that(output) + .containsInAnyOrder(1, 2, 3, 4, 4); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testValuesEmpty() { + Pipeline p = TestPipeline.create(); + + PCollection> input = + p.apply(Create.of(Arrays.asList(EMPTY_TABLE)).withCoder( + KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + + PCollection output = input.apply(Values.create()); + + DataflowAssert.that(output).empty(); + + p.run(); + } + + @Test + public void testValuesGetName() { + assertEquals("Values", Values.create().getName()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ViewTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ViewTest.java new file mode 100644 index 000000000000..145956961f1f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ViewTest.java @@ -0,0 +1,1548 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.NullableCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.InvalidWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.util.NoopPathValidator; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; + +/** + * Tests for {@link View}. See also {@link ParDoTest}, which + * provides additional coverage since views can only be + * observed via {@link ParDo}. + */ +@RunWith(JUnit4.class) +public class ViewTest implements Serializable { + // This test is Serializable, just so that it's easy to have + // anonymous inner classes inside the non-static test methods. + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(RunnableOnService.class) + public void testSingletonSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView view = + pipeline.apply("Create47", Create.of(47)).apply(View.asSingleton()); + + PCollection output = + pipeline.apply("Create123", Create.of(1, 2, 3)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + DataflowAssert.that(output).containsInAnyOrder(47, 47, 47); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedSingletonSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView view = + pipeline.apply("Create47", Create.timestamped( + TimestampedValue.of(47, new Instant(1)), + TimestampedValue.of(48, new Instant(11)))) + .apply("SideWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asSingleton()); + + PCollection output = + pipeline.apply("Create123", Create.timestamped( + TimestampedValue.of(1, new Instant(4)), + TimestampedValue.of(2, new Instant(8)), + TimestampedValue.of(3, new Instant(12)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + DataflowAssert.that(output).containsInAnyOrder(47, 47, 48); + + pipeline.run(); + } + + @Test + public void testEmptySingletonSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView view = + pipeline.apply("CreateEmptyIntegers", Create.of().withCoder(VarIntCoder.of())) + .apply(View.asSingleton()); + + pipeline.apply("Create123", Create.of(1, 2, 3)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(NoSuchElementException.class)); + thrown.expectMessage("Empty"); + thrown.expectMessage("PCollection"); + thrown.expectMessage("singleton"); + + pipeline.run(); + } + + @Test + public void testNonSingletonSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollection oneTwoThree = pipeline.apply(Create.of(1, 2, 3)); + final PCollectionView view = oneTwoThree.apply(View.asSingleton()); + + oneTwoThree.apply( + "OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view)); + } + })); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalArgumentException.class)); + thrown.expectMessage("PCollection"); + thrown.expectMessage("more than one"); + thrown.expectMessage("singleton"); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testListSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(11, 13, 17, 23)).apply(View.asList()); + + PCollection output = + pipeline.apply("CreateMainInput", Create.of(29, 31)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + Preconditions.checkArgument(c.sideInput(view).size() == 4); + Preconditions.checkArgument(c.sideInput(view).get(0) == c.sideInput(view).get(0)); + for (Integer i : c.sideInput(view)) { + c.output(i); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder(11, 13, 17, 23, 11, 13, 17, 23); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedListSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(11, new Instant(1)), + TimestampedValue.of(13, new Instant(1)), + TimestampedValue.of(17, new Instant(1)), + TimestampedValue.of(23, new Instant(1)), + TimestampedValue.of(31, new Instant(11)), + TimestampedValue.of(33, new Instant(11)), + TimestampedValue.of(37, new Instant(11)), + TimestampedValue.of(43, new Instant(11)))) + .apply("SideWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asList()); + + PCollection output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of(29, new Instant(1)), + TimestampedValue.of(35, new Instant(11)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + Preconditions.checkArgument(c.sideInput(view).size() == 4); + Preconditions.checkArgument(c.sideInput(view).get(0) == c.sideInput(view).get(0)); + for (Integer i : c.sideInput(view)) { + c.output(i); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder(11, 13, 17, 23, 31, 33, 37, 43); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyListSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateEmptyView", Create.of().withCoder(VarIntCoder.of())) + .apply(View.asList()); + + PCollection results = + pipeline.apply("Create1", Create.of(1)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + assertTrue(c.sideInput(view).isEmpty()); + assertFalse(c.sideInput(view).iterator().hasNext()); + c.output(1); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(results).containsInAnyOrder(1); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testListSideInputIsImmutable() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(11)).apply(View.asList()); + + PCollection output = + pipeline.apply("CreateMainInput", Create.of(29)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + try { + c.sideInput(view).clear(); + fail("Expected UnsupportedOperationException on clear()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).add(4); + fail("Expected UnsupportedOperationException on add()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).addAll(new ArrayList()); + fail("Expected UnsupportedOperationException on addAll()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).remove(0); + fail("Expected UnsupportedOperationException on remove()"); + } catch (UnsupportedOperationException expected) { + } + for (Integer i : c.sideInput(view)) { + c.output(i); + } + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(output).containsInAnyOrder(11); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testIterableSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(11, 13, 17, 23)) + .apply(View.asIterable()); + + PCollection output = + pipeline.apply("CreateMainInput", Create.of(29, 31)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (Integer i : c.sideInput(view)) { + c.output(i); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder(11, 13, 17, 23, 11, 13, 17, 23); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedIterableSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(11, new Instant(1)), + TimestampedValue.of(13, new Instant(1)), + TimestampedValue.of(17, new Instant(1)), + TimestampedValue.of(23, new Instant(1)), + TimestampedValue.of(31, new Instant(11)), + TimestampedValue.of(33, new Instant(11)), + TimestampedValue.of(37, new Instant(11)), + TimestampedValue.of(43, new Instant(11)))) + .apply("SideWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asIterable()); + + PCollection output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of(29, new Instant(1)), + TimestampedValue.of(35, new Instant(11)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (Integer i : c.sideInput(view)) { + c.output(i); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder(11, 13, 17, 23, 31, 33, 37, 43); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyIterableSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateEmptyView", Create.of().withCoder(VarIntCoder.of())) + .apply(View.asIterable()); + + PCollection results = + pipeline.apply("Create1", Create.of(1)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + assertFalse(c.sideInput(view).iterator().hasNext()); + c.output(1); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(results).containsInAnyOrder(1); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testIterableSideInputIsImmutable() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(11)).apply(View.asIterable()); + + PCollection output = + pipeline.apply("CreateMainInput", Create.of(29)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + Iterator iterator = c.sideInput(view).iterator(); + while (iterator.hasNext()) { + try { + iterator.remove(); + fail("Expected UnsupportedOperationException on remove()"); + } catch (UnsupportedOperationException expected) { + } + c.output(iterator.next()); + } + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(output).containsInAnyOrder(11); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMultimapSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1), KV.of("a", 2), KV.of("b", 3))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple", "banana", "blackberry")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + for (Integer v : c.sideInput(view).get(c.element().substring(0, 1))) { + c.output(KV.of(c.element(), v)); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("apple", 2), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMultimapAsEntrySetSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1), KV.of("a", 2), KV.of("b", 3))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of(2 /* size */)) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + assertEquals((int) c.element(), c.sideInput(view).size()); + assertEquals((int) c.element(), c.sideInput(view).entrySet().size()); + for (Entry> entry : c.sideInput(view).entrySet()) { + for (Integer value : entry.getValue()) { + c.output(KV.of(entry.getKey(), value)); + } + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", 1), KV.of("a", 2), KV.of("b", 3)); + + pipeline.run(); + } + + private static class NonDeterministicStringCoder extends CustomCoder { + @Override + public void encode(String value, OutputStream outStream, Coder.Context context) + throws CoderException, IOException { + StringUtf8Coder.of().encode(value, outStream, context); + } + + @Override + public String decode(InputStream inStream, Coder.Context context) + throws CoderException, IOException { + return StringUtf8Coder.of().decode(inStream, context); + } + + @Override + public void verifyDeterministic() + throws com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException { + throw new NonDeterministicException(this, "Test coder is not deterministic on purpose."); + } + } + + @Test + @Category(RunnableOnService.class) + public void testMultimapSideInputWithNonDeterministicKeyCoder() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", + Create.of(KV.of("a", 1), KV.of("a", 2), KV.of("b", 3)) + .withCoder(KvCoder.of(new NonDeterministicStringCoder(), VarIntCoder.of()))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple", "banana", "blackberry")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + for (Integer v : c.sideInput(view).get(c.element().substring(0, 1))) { + c.output(KV.of(c.element(), v)); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("apple", 2), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedMultimapSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(1)), + TimestampedValue.of(KV.of("a", 2), new Instant(7)), + TimestampedValue.of(KV.of("b", 3), new Instant(14)))) + .apply( + "SideWindowInto", + Window.>into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("apple", new Instant(5)), + TimestampedValue.of("banana", new Instant(13)), + TimestampedValue.of("blackberry", new Instant(16)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + for (Integer v : + c.sideInput(view) + .get(c.element().substring(0, 1))) { + c.output(KV.of(c.element(), v)); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("apple", 2), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedMultimapAsEntrySetSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(1)), + TimestampedValue.of(KV.of("a", 2), new Instant(7)), + TimestampedValue.of(KV.of("b", 3), new Instant(14)))) + .apply( + "SideWindowInto", + Window.>into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of(1 /* size */, new Instant(5)), + TimestampedValue.of(1 /* size */, new Instant(16)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + assertEquals((int) c.element(), + c.sideInput(view).size()); + assertEquals((int) c.element(), + c.sideInput(view).entrySet().size()); + for (Entry> entry + : c.sideInput(view).entrySet()) { + for (Integer value : entry.getValue()) { + c.output(KV.of(entry.getKey(), value)); + } + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", 1), KV.of("a", 2), KV.of("b", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedMultimapSideInputWithNonDeterministicKeyCoder() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(1)), + TimestampedValue.of(KV.of("a", 2), new Instant(7)), + TimestampedValue.of(KV.of("b", 3), new Instant(14))) + .withCoder(KvCoder.of(new NonDeterministicStringCoder(), VarIntCoder.of()))) + .apply("SideWindowInto", + Window.>into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("apple", new Instant(5)), + TimestampedValue.of("banana", new Instant(13)), + TimestampedValue.of("blackberry", new Instant(16)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + for (Integer v : + c.sideInput(view) + .get(c.element().substring(0, 1))) { + c.output(KV.of(c.element(), v)); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("apple", 2), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyMultimapSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateEmptyView", Create.>of().withCoder( + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(View.asMultimap()); + + PCollection results = + pipeline.apply("Create1", Create.of(1)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + assertTrue(c.sideInput(view).isEmpty()); + assertTrue(c.sideInput(view).entrySet().isEmpty()); + assertFalse(c.sideInput(view).entrySet().iterator().hasNext()); + c.output(c.element()); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(results).containsInAnyOrder(1); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyMultimapSideInputWithNonDeterministicKeyCoder() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateEmptyView", + Create.>of().withCoder( + KvCoder.of(new NonDeterministicStringCoder(), VarIntCoder.of()))) + .apply(View.asMultimap()); + + PCollection results = + pipeline.apply("Create1", Create.of(1)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + assertTrue(c.sideInput(view).isEmpty()); + assertTrue(c.sideInput(view).entrySet().isEmpty()); + assertFalse(c.sideInput(view).entrySet().iterator().hasNext()); + c.output(c.element()); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(results).containsInAnyOrder(1); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMultimapSideInputIsImmutable() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView>> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1))) + .apply(View.asMultimap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + try { + c.sideInput(view).clear(); + fail("Expected UnsupportedOperationException on clear()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).put("c", ImmutableList.of(3)); + fail("Expected UnsupportedOperationException on put()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).remove("c"); + fail("Expected UnsupportedOperationException on remove()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).putAll(new HashMap>()); + fail("Expected UnsupportedOperationException on putAll()"); + } catch (UnsupportedOperationException expected) { + } + for (Integer v : c.sideInput(view).get(c.element().substring(0, 1))) { + c.output(KV.of(c.element(), v)); + } + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(output).containsInAnyOrder(KV.of("apple", 1)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMapSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1), KV.of("b", 3))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple", "banana", "blackberry")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output( + KV.of(c.element(), c.sideInput(view).get(c.element().substring(0, 1)))); + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMapAsEntrySetSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1), KV.of("b", 3))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of(2 /* size */)) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + assertEquals((int) c.element(), c.sideInput(view).size()); + assertEquals((int) c.element(), c.sideInput(view).entrySet().size()); + for (Entry entry : c.sideInput(view).entrySet()) { + c.output(KV.of(entry.getKey(), entry.getValue())); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", 1), KV.of("b", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMapSideInputWithNonDeterministicKeyCoder() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", + Create.of(KV.of("a", 1), KV.of("b", 3)) + .withCoder(KvCoder.of(new NonDeterministicStringCoder(), VarIntCoder.of()))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple", "banana", "blackberry")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output( + KV.of(c.element(), c.sideInput(view).get(c.element().substring(0, 1)))); + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedMapSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(1)), + TimestampedValue.of(KV.of("b", 2), new Instant(4)), + TimestampedValue.of(KV.of("b", 3), new Instant(18)))) + .apply( + "SideWindowInto", + Window.>into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("apple", new Instant(5)), + TimestampedValue.of("banana", new Instant(4)), + TimestampedValue.of("blackberry", new Instant(16)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of( + c.element(), + c.sideInput(view).get( + c.element().substring(0, 1)))); + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("banana", 2), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedMapAsEntrySetSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(1)), + TimestampedValue.of(KV.of("b", 2), new Instant(4)), + TimestampedValue.of(KV.of("b", 3), new Instant(18)))) + .apply( + "SideWindowInto", + Window.>into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of(2 /* size */, new Instant(5)), + TimestampedValue.of(1 /* size */, new Instant(16)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + assertEquals((int) c.element(), + c.sideInput(view).size()); + assertEquals((int) c.element(), + c.sideInput(view).entrySet().size()); + for (Entry entry + : c.sideInput(view).entrySet()) { + c.output(KV.of(entry.getKey(), entry.getValue())); + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", 1), KV.of("b", 2), KV.of("b", 3)); + + pipeline.run(); + } + + + @Test + @Category(RunnableOnService.class) + public void testWindowedMapSideInputWithNonDeterministicKeyCoder() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(1)), + TimestampedValue.of(KV.of("b", 2), new Instant(4)), + TimestampedValue.of(KV.of("b", 3), new Instant(18))) + .withCoder(KvCoder.of(new NonDeterministicStringCoder(), VarIntCoder.of()))) + .apply( + "SideWindowInto", + Window.>into(FixedWindows.of(Duration.millis(10)))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("apple", new Instant(5)), + TimestampedValue.of("banana", new Instant(4)), + TimestampedValue.of("blackberry", new Instant(16)))) + .apply("MainWindowInto", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of( + new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of( + c.element(), + c.sideInput(view).get( + c.element().substring(0, 1)))); + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("banana", 2), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyMapSideInput() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateEmptyView", Create.>of().withCoder( + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(View.asMap()); + + PCollection results = + pipeline.apply("Create1", Create.of(1)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + assertTrue(c.sideInput(view).isEmpty()); + assertTrue(c.sideInput(view).entrySet().isEmpty()); + assertFalse(c.sideInput(view).entrySet().iterator().hasNext()); + c.output(c.element()); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(results).containsInAnyOrder(1); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testEmptyMapSideInputWithNonDeterministicKeyCoder() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateEmptyView", Create.>of().withCoder( + KvCoder.of(new NonDeterministicStringCoder(), VarIntCoder.of()))) + .apply(View.asMap()); + + PCollection results = + pipeline.apply("Create1", Create.of(1)) + .apply("OutputSideInputs", ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + assertTrue(c.sideInput(view).isEmpty()); + assertTrue(c.sideInput(view).entrySet().isEmpty()); + assertFalse(c.sideInput(view).entrySet().iterator().hasNext()); + c.output(c.element()); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(results).containsInAnyOrder(1); + + pipeline.run(); + } + + @Test + public void testMapSideInputWithNullValuesCatchesDuplicates() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline + .apply( + "CreateSideInput", + Create.of(KV.of("a", (Integer) null), KV.of("a", (Integer) null)) + .withCoder( + KvCoder.of(StringUtf8Coder.of(), NullableCoder.of(VarIntCoder.of())))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple", "banana", "blackberry")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output( + KV.of(c.element(), c.sideInput(view).get(c.element().substring(0, 1)))); + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 1), KV.of("banana", 3), KV.of("blackberry", 3)); + + // PipelineExecutionException is thrown with cause having a message stating that a + // duplicate is not allowed. + thrown.expectCause( + ThrowableMessageMatcher.hasMessage(Matchers.containsString("Duplicate values for a"))); + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMapSideInputIsImmutable() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1))) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple")) + .apply( + "OutputSideInputs", + ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + try { + c.sideInput(view).clear(); + fail("Expected UnsupportedOperationException on clear()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).put("c", 3); + fail("Expected UnsupportedOperationException on put()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).remove("c"); + fail("Expected UnsupportedOperationException on remove()"); + } catch (UnsupportedOperationException expected) { + } + try { + c.sideInput(view).putAll(new HashMap()); + fail("Expected UnsupportedOperationException on putAll()"); + } catch (UnsupportedOperationException expected) { + } + c.output( + KV.of(c.element(), c.sideInput(view).get(c.element().substring(0, 1)))); + } + })); + + // Pass at least one value through to guarantee that DoFn executes. + DataflowAssert.that(output).containsInAnyOrder(KV.of("apple", 1)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testCombinedMapSideInput() { + Pipeline pipeline = TestPipeline.create(); + + final PCollectionView> view = + pipeline.apply("CreateSideInput", Create.of(KV.of("a", 1), KV.of("a", 20), KV.of("b", 3))) + .apply("SumIntegers", Combine.perKey(new Sum.SumIntegerFn().asKeyedFn())) + .apply(View.asMap()); + + PCollection> output = + pipeline.apply("CreateMainInput", Create.of("apple", "banana", "blackberry")) + .apply("Output", ParDo.withSideInputs(view).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(KV.of(c.element(), c.sideInput(view).get(c.element().substring(0, 1)))); + } + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("apple", 21), KV.of("banana", 3), KV.of("blackberry", 3)); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedSideInputFixedToFixed() { + Pipeline p = TestPipeline.create(); + + final PCollectionView view = + p.apply( + "CreateSideInput", + Create.timestamped(TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(2, new Instant(11)), TimestampedValue.of(3, new Instant(13)))) + .apply("WindowSideInput", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.integersGlobally().withoutDefaults()) + .apply(View.asSingleton()); + + PCollection output = + p.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("A", new Instant(4)), + TimestampedValue.of("B", new Instant(15)), + TimestampedValue.of("C", new Instant(7)))) + .apply("WindowMainInput", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputMainAndSideInputs", ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + c.sideInput(view)); + } + })); + + DataflowAssert.that(output).containsInAnyOrder("A1", "B5", "C1"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedSideInputFixedToGlobal() { + Pipeline p = TestPipeline.create(); + + final PCollectionView view = + p.apply( + "CreateSideInput", + Create.timestamped(TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(2, new Instant(11)), TimestampedValue.of(3, new Instant(13)))) + .apply("WindowSideInput", Window.into(new GlobalWindows())) + .apply(Sum.integersGlobally()) + .apply(View.asSingleton()); + + PCollection output = + p.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("A", new Instant(4)), + TimestampedValue.of("B", new Instant(15)), + TimestampedValue.of("C", new Instant(7)))) + .apply("WindowMainInput", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputMainAndSideInputs", ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + c.sideInput(view)); + } + })); + + DataflowAssert.that(output).containsInAnyOrder("A6", "B6", "C6"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowedSideInputFixedToFixedWithDefault() { + Pipeline p = TestPipeline.create(); + + final PCollectionView view = + p.apply("CreateSideInput", Create.timestamped( + TimestampedValue.of(2, new Instant(11)), + TimestampedValue.of(3, new Instant(13)))) + .apply("WindowSideInput", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.integersGlobally().asSingletonView()); + + PCollection output = + p.apply("CreateMainInput", Create.timestamped( + TimestampedValue.of("A", new Instant(4)), + TimestampedValue.of("B", new Instant(15)), + TimestampedValue.of("C", new Instant(7)))) + .apply("WindowMainInput", Window.into(FixedWindows.of(Duration.millis(10)))) + .apply("OutputMainAndSideInputs", ParDo.withSideInputs(view).of( + new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + c.sideInput(view)); + } + })); + + DataflowAssert.that(output).containsInAnyOrder("A0", "B5", "C0"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSideInputWithNullDefault() { + Pipeline p = TestPipeline.create(); + + final PCollectionView view = + p.apply("CreateSideInput", Create.of((Void) null).withCoder(VoidCoder.of())) + .apply(Combine.globally(new SerializableFunction, Void>() { + @Override + public Void apply(Iterable input) { + return null; + } + }).asSingletonView()); + + PCollection output = + p.apply("CreateMainInput", Create.of("")) + .apply( + "OutputMainAndSideInputs", + ParDo.withSideInputs(view).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element() + c.sideInput(view)); + } + })); + + DataflowAssert.that(output).containsInAnyOrder("null"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testSideInputWithNestedIterables() { + Pipeline pipeline = TestPipeline.create(); + final PCollectionView> view1 = + pipeline.apply("CreateVoid1", Create.of((Void) null).withCoder(VoidCoder.of())) + .apply("OutputOneInteger", ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.output(17); + } + })) + .apply("View1", View.asIterable()); + + final PCollectionView>> view2 = + pipeline.apply("CreateVoid2", Create.of((Void) null).withCoder(VoidCoder.of())) + .apply( + "OutputSideInput", + ParDo.withSideInputs(view1).of(new DoFn>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.sideInput(view1)); + } + })) + .apply("View2", View.>asIterable()); + + PCollection output = + pipeline.apply("CreateVoid3", Create.of((Void) null).withCoder(VoidCoder.of())) + .apply( + "ReadIterableSideInput", ParDo.withSideInputs(view2).of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + for (Iterable input : c.sideInput(view2)) { + for (Integer i : input) { + c.output(i); + } + } + } + })); + + DataflowAssert.that(output).containsInAnyOrder(17); + + pipeline.run(); + } + + @Test + public void testViewGetName() { + assertEquals("View.AsSingleton", View.asSingleton().getName()); + assertEquals("View.AsIterable", View.asIterable().getName()); + assertEquals("View.AsMap", View.asMap().getName()); + assertEquals("View.AsMultimap", View.asMultimap().getName()); + } + + private Pipeline createTestBatchRunner() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setProject("someproject"); + options.setStagingLocation("gs://staging"); + options.setPathValidatorClass(NoopPathValidator.class); + options.setDataflowClient(null); + return Pipeline.create(options); + } + + private Pipeline createTestStreamingRunner() { + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setRunner(DataflowPipelineRunner.class); + options.setStreaming(true); + options.setProject("someproject"); + options.setStagingLocation("gs://staging"); + options.setPathValidatorClass(NoopPathValidator.class); + options.setDataflowClient(null); + return Pipeline.create(options); + } + + private Pipeline createTestDirectRunner() { + DirectPipelineOptions options = PipelineOptionsFactory.as(DirectPipelineOptions.class); + options.setRunner(DirectPipelineRunner.class); + return Pipeline.create(options); + } + + private void testViewUnbounded( + Pipeline pipeline, + PTransform>, ? extends PCollectionView> view) { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Unable to create a side-input view from input"); + thrown.expectCause( + ThrowableMessageMatcher.hasMessage(Matchers.containsString("non-bounded PCollection"))); + pipeline + .apply( + new PTransform>>() { + @Override + public PCollection> apply(PBegin input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + PCollection.IsBounded.UNBOUNDED) + .setTypeDescriptorInternal(new TypeDescriptor>() {}); + } + }) + .apply(view); + } + + private void testViewNonmerging( + Pipeline pipeline, + PTransform>, ? extends PCollectionView> view) { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Unable to create a side-input view from input"); + thrown.expectCause( + ThrowableMessageMatcher.hasMessage(Matchers.containsString("Consumed by GroupByKey"))); + pipeline.apply(Create.>of(KV.of("hello", 5))) + .apply(Window.>into(new InvalidWindows<>( + "Consumed by GroupByKey", FixedWindows.of(Duration.standardHours(1))))) + .apply(view); + } + + @Test + public void testViewUnboundedAsSingletonBatch() { + testViewUnbounded(createTestBatchRunner(), View.>asSingleton()); + } + + @Test + public void testViewUnboundedAsSingletonStreaming() { + testViewUnbounded(createTestStreamingRunner(), View.>asSingleton()); + } + + @Test + public void testViewUnboundedAsSingletonDirect() { + testViewUnbounded(createTestDirectRunner(), View.>asSingleton()); + } + + @Test + public void testViewUnboundedAsIterableBatch() { + testViewUnbounded(createTestBatchRunner(), View.>asIterable()); + } + + @Test + public void testViewUnboundedAsIterableStreaming() { + testViewUnbounded(createTestStreamingRunner(), View.>asIterable()); + } + + @Test + public void testViewUnboundedAsIterableDirect() { + testViewUnbounded(createTestDirectRunner(), View.>asIterable()); + } + + @Test + public void testViewUnboundedAsListBatch() { + testViewUnbounded(createTestBatchRunner(), View.>asList()); + } + + @Test + public void testViewUnboundedAsListStreaming() { + testViewUnbounded(createTestStreamingRunner(), View.>asList()); + } + + @Test + public void testViewUnboundedAsListDirect() { + testViewUnbounded(createTestDirectRunner(), View.>asList()); + } + + @Test + public void testViewUnboundedAsMapBatch() { + testViewUnbounded(createTestBatchRunner(), View.asMap()); + } + + @Test + public void testViewUnboundedAsMapStreaming() { + testViewUnbounded(createTestStreamingRunner(), View.asMap()); + } + + @Test + public void testViewUnboundedAsMapDirect() { + testViewUnbounded(createTestDirectRunner(), View.asMap()); + } + + + @Test + public void testViewUnboundedAsMultimapBatch() { + testViewUnbounded(createTestBatchRunner(), View.asMultimap()); + } + + @Test + public void testViewUnboundedAsMultimapStreaming() { + testViewUnbounded(createTestStreamingRunner(), View.asMultimap()); + } + + @Test + public void testViewUnboundedAsMultimapDirect() { + testViewUnbounded(createTestDirectRunner(), View.asMultimap()); + } + + @Test + public void testViewNonmergingAsSingletonBatch() { + testViewNonmerging(createTestBatchRunner(), View.>asSingleton()); + } + + @Test + public void testViewNonmergingAsSingletonStreaming() { + testViewNonmerging(createTestStreamingRunner(), View.>asSingleton()); + } + + @Test + public void testViewNonmergingAsSingletonDirect() { + testViewNonmerging(createTestDirectRunner(), View.>asSingleton()); + } + + @Test + public void testViewNonmergingAsIterableBatch() { + testViewNonmerging(createTestBatchRunner(), View.>asIterable()); + } + + @Test + public void testViewNonmergingAsIterableStreaming() { + testViewNonmerging(createTestStreamingRunner(), View.>asIterable()); + } + + @Test + public void testViewNonmergingAsIterableDirect() { + testViewNonmerging(createTestDirectRunner(), View.>asIterable()); + } + + @Test + public void testViewNonmergingAsListBatch() { + testViewNonmerging(createTestBatchRunner(), View.>asList()); + } + + @Test + public void testViewNonmergingAsListStreaming() { + testViewNonmerging(createTestStreamingRunner(), View.>asList()); + } + + @Test + public void testViewNonmergingAsListDirect() { + testViewNonmerging(createTestDirectRunner(), View.>asList()); + } + + @Test + public void testViewNonmergingAsMapBatch() { + testViewNonmerging(createTestBatchRunner(), View.asMap()); + } + + @Test + public void testViewNonmergingAsMapStreaming() { + testViewNonmerging(createTestStreamingRunner(), View.asMap()); + } + + @Test + public void testViewNonmergingAsMapDirect() { + testViewNonmerging(createTestDirectRunner(), View.asMap()); + } + + + @Test + public void testViewNonmergingAsMultimapBatch() { + testViewNonmerging(createTestBatchRunner(), View.asMultimap()); + } + + @Test + public void testViewNonmergingAsMultimapStreaming() { + testViewNonmerging(createTestStreamingRunner(), View.asMultimap()); + } + + @Test + public void testViewNonmergingAsMultimapDirect() { + testViewNonmerging(createTestDirectRunner(), View.asMultimap()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithKeysTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithKeysTest.java new file mode 100644 index 000000000000..0f9abd487f93 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithKeysTest.java @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for ExtractKeys transform. + */ +@RunWith(JUnit4.class) +public class WithKeysTest { + static final String[] COLLECTION = new String[] { + "a", + "aa", + "b", + "bb", + "bbb" + }; + + static final List> WITH_KEYS = Arrays.asList( + KV.of(1, "a"), + KV.of(2, "aa"), + KV.of(1, "b"), + KV.of(2, "bb"), + KV.of(3, "bbb") + ); + + static final List> WITH_CONST_KEYS = Arrays.asList( + KV.of(100, "a"), + KV.of(100, "aa"), + KV.of(100, "b"), + KV.of(100, "bb"), + KV.of(100, "bbb") + ); + + @Test + public void testExtractKeys() { + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION)).withCoder( + StringUtf8Coder.of())); + + PCollection> output = input.apply(WithKeys.of( + new LengthAsKey())); + DataflowAssert.that(output) + .containsInAnyOrder(WITH_KEYS); + + p.run(); + } + + @Test + public void testConstantKeys() { + Pipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION)).withCoder( + StringUtf8Coder.of())); + + PCollection> output = + input.apply(WithKeys.of(100)); + DataflowAssert.that(output) + .containsInAnyOrder(WITH_CONST_KEYS); + + p.run(); + } + + @Test + public void testWithKeysGetName() { + assertEquals("WithKeys", WithKeys.of(100).getName()); + } + + @Test + public void testWithKeysWithUnneededWithKeyTypeSucceeds() { + TestPipeline p = TestPipeline.create(); + + PCollection input = + p.apply(Create.of(Arrays.asList(COLLECTION)).withCoder( + StringUtf8Coder.of())); + + PCollection> output = + input.apply(WithKeys.of(new LengthAsKey()).withKeyType(TypeDescriptor.of(Integer.class))); + DataflowAssert.that(output).containsInAnyOrder(WITH_KEYS); + + p.run(); + } + + /** + * Key a value by its length. + */ + public static class LengthAsKey + implements SerializableFunction { + @Override + public Integer apply(String value) { + return value.length(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithTimestampsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithTimestampsTest.java new file mode 100644 index 000000000000..60ec2654f9e4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/WithTimestampsTest.java @@ -0,0 +1,210 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.isA; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for {@link WithTimestamps}. + */ +@RunWith(JUnit4.class) +public class WithTimestampsTest implements Serializable { + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(RunnableOnService.class) + public void withTimestampsShouldApplyTimestamps() { + TestPipeline p = TestPipeline.create(); + + SerializableFunction timestampFn = + new SerializableFunction() { + @Override + public Instant apply(String input) { + return new Instant(Long.valueOf(input)); + } + }; + + String yearTwoThousand = "946684800000"; + PCollection timestamped = + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) + .apply(WithTimestamps.of(timestampFn)); + + PCollection> timestampedVals = + timestamped.apply(ParDo.of(new DoFn>() { + @Override + public void processElement(DoFn>.ProcessContext c) + throws Exception { + c.output(KV.of(c.element(), c.timestamp())); + } + })); + + DataflowAssert.that(timestamped) + .containsInAnyOrder(yearTwoThousand, "0", "1234", Integer.toString(Integer.MAX_VALUE)); + DataflowAssert.that(timestampedVals) + .containsInAnyOrder( + KV.of("0", new Instant(0)), + KV.of("1234", new Instant(1234L)), + KV.of(Integer.toString(Integer.MAX_VALUE), new Instant(Integer.MAX_VALUE)), + KV.of(yearTwoThousand, new Instant(Long.valueOf(yearTwoThousand)))); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void withTimestampsBackwardsInTimeShouldThrow() { + TestPipeline p = TestPipeline.create(); + + SerializableFunction timestampFn = + new SerializableFunction() { + @Override + public Instant apply(String input) { + return new Instant(Long.valueOf(input)); + } + }; + SerializableFunction backInTimeFn = + new SerializableFunction() { + @Override + public Instant apply(String input) { + return new Instant(Long.valueOf(input)).minus(Duration.millis(1000L)); + } + }; + + + String yearTwoThousand = "946684800000"; + + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) + .apply("WithTimestamps", WithTimestamps.of(timestampFn)) + .apply("AddSkew", WithTimestamps.of(backInTimeFn)); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(IllegalArgumentException.class)); + thrown.expectMessage("no earlier than the timestamp of the current input"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void withTimestampsBackwardsInTimeAndWithAllowedTimestampSkewShouldSucceed() { + TestPipeline p = TestPipeline.create(); + + SerializableFunction timestampFn = + new SerializableFunction() { + @Override + public Instant apply(String input) { + return new Instant(Long.valueOf(input)); + } + }; + + final Duration skew = Duration.millis(1000L); + SerializableFunction backInTimeFn = + new SerializableFunction() { + @Override + public Instant apply(String input) { + return new Instant(Long.valueOf(input)).minus(skew); + } + }; + + String yearTwoThousand = "946684800000"; + PCollection timestampedWithSkew = + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) + .apply("FirstTimestamp", WithTimestamps.of(timestampFn)) + .apply( + "WithSkew", + WithTimestamps.of(backInTimeFn).withAllowedTimestampSkew(skew.plus(100L))); + + PCollection> timestampedVals = + timestampedWithSkew.apply(ParDo.of(new DoFn>() { + @Override + public void processElement(DoFn>.ProcessContext c) + throws Exception { + c.output(KV.of(c.element(), c.timestamp())); + } + })); + + DataflowAssert.that(timestampedWithSkew) + .containsInAnyOrder(yearTwoThousand, "0", "1234", Integer.toString(Integer.MAX_VALUE)); + DataflowAssert.that(timestampedVals) + .containsInAnyOrder( + KV.of("0", new Instant(0L).minus(skew)), + KV.of("1234", new Instant(1234L).minus(skew)), + KV.of( + Integer.toString(Integer.MAX_VALUE), + new Instant(Long.valueOf(Integer.MAX_VALUE)).minus(skew)), + KV.of(yearTwoThousand, new Instant(Long.valueOf(yearTwoThousand)).minus(skew))); + + p.run(); + } + + @Test + public void withTimestampsWithNullTimestampShouldThrow() { + SerializableFunction timestampFn = + new SerializableFunction() { + @Override + public Instant apply(String input) { + return null; + } + }; + + TestPipeline p = TestPipeline.create(); + String yearTwoThousand = "946684800000"; + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) + .apply(WithTimestamps.of(timestampFn)); + + thrown.expect(PipelineExecutionException.class); + thrown.expectCause(isA(NullPointerException.class)); + thrown.expectMessage("WithTimestamps"); + thrown.expectMessage("cannot be null"); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void withTimestampsWithNullFnShouldThrowOnConstruction() { + TestPipeline p = TestPipeline.create(); + + SerializableFunction timestampFn = null; + + thrown.expect(NullPointerException.class); + thrown.expectMessage("WithTimestamps fn cannot be null"); + + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE))) + .apply(WithTimestamps.of(timestampFn)); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultCoderTest.java new file mode 100644 index 000000000000..7ac278bc8d6e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultCoderTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.junit.Assert.assertFalse; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult.CoGbkResultCoder; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.collect.ImmutableList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests the CoGbkResult.CoGbkResultCoder. + */ +@RunWith(JUnit4.class) +public class CoGbkResultCoderTest { + + private static final CoGbkResultSchema TEST_SCHEMA = + new CoGbkResultSchema(TupleTagList.of(new TupleTag()).and( + new TupleTag())); + + private static final UnionCoder TEST_UNION_CODER = + UnionCoder.of(ImmutableList.>of( + StringUtf8Coder.of(), + VarIntCoder.of())); + + private static final UnionCoder COMPATIBLE_UNION_CODER = + UnionCoder.of(ImmutableList.>of( + StringUtf8Coder.of(), + BigEndianIntegerCoder.of())); + + private static final CoGbkResultSchema INCOMPATIBLE_SCHEMA = + new CoGbkResultSchema(TupleTagList.of(new TupleTag()).and( + new TupleTag())); + + private static final UnionCoder INCOMPATIBLE_UNION_CODER = + UnionCoder.of(ImmutableList.>of( + StringUtf8Coder.of(), + DoubleCoder.of())); + + private static final CoGbkResultCoder TEST_CODER = + CoGbkResultCoder.of(TEST_SCHEMA, TEST_UNION_CODER); + + private static final CoGbkResultCoder COMPATIBLE_TEST_CODER = + CoGbkResultCoder.of(TEST_SCHEMA, COMPATIBLE_UNION_CODER); + + private static final CoGbkResultCoder INCOMPATIBLE_TEST_CODER = + CoGbkResultCoder.of(INCOMPATIBLE_SCHEMA, INCOMPATIBLE_UNION_CODER); + + @Test + public void testEquals() { + assertFalse(TEST_CODER.equals(new Object())); + assertFalse(TEST_CODER.equals(COMPATIBLE_TEST_CODER)); + assertFalse(TEST_CODER.equals(INCOMPATIBLE_TEST_CODER)); + } + + @Test + public void testSerializationDeserialization() { + CoderProperties.coderSerializable(TEST_CODER); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultTest.java new file mode 100644 index 000000000000..da14d8a412e5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGbkResultTest.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.util.common.Reiterable; +import com.google.cloud.dataflow.sdk.util.common.Reiterator; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.List; + +/** + * Tests the CoGbkResult. + */ +@RunWith(JUnit4.class) +public class CoGbkResultTest { + + @Test + public void testLazyResults() { + runLazyResult(0); + runLazyResult(1); + runLazyResult(3); + runLazyResult(10); + } + + public void runLazyResult(int cacheSize) { + int valueLen = 7; + TestUnionValues values = new TestUnionValues(0, 1, 0, 3, 0, 3, 3); + CoGbkResult result = new CoGbkResult(createSchema(5), values, cacheSize); + assertThat(values.maxPos(), equalTo(Math.min(cacheSize, valueLen))); + assertThat(result.getAll(new TupleTag("tag0")), contains(0, 2, 4)); + assertThat(values.maxPos(), equalTo(valueLen)); + assertThat(result.getAll(new TupleTag("tag3")), contains(3, 5, 6)); + assertThat(result.getAll(new TupleTag("tag2")), emptyIterable()); + assertThat(result.getOnly(new TupleTag("tag1")), equalTo(1)); + assertThat(result.getAll(new TupleTag("tag0")), contains(0, 2, 4)); + } + + private CoGbkResultSchema createSchema(int size) { + List> tags = new ArrayList<>(); + for (int i = 0; i < size; i++) { + tags.add(new TupleTag("tag" + i)); + } + return new CoGbkResultSchema(TupleTagList.of(tags)); + } + + private static class TestUnionValues implements Reiterable { + + final int[] tags; + int maxPos = 0; + + /** + * This will create a list of RawUnionValues whose tags are as given and + * values are increasing starting at 0 (i.e. the index in the constructor). + */ + public TestUnionValues(int... tags) { + this.tags = tags; + } + + /** + * Returns the highest position iterated to so far, useful for ensuring + * laziness. + */ + public int maxPos() { + return maxPos; + } + + @Override + public Reiterator iterator() { + return iterator(0); + } + + public Reiterator iterator(final int start) { + return new Reiterator() { + int pos = start; + + @Override + public boolean hasNext() { + return pos < tags.length; + } + + @Override + public RawUnionValue next() { + maxPos = Math.max(pos + 1, maxPos); + return new RawUnionValue(tags[pos], pos++); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public Reiterator copy() { + return iterator(pos); + } + }; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKeyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKeyTest.java new file mode 100644 index 000000000000..1345289739c4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/CoGroupByKeyTest.java @@ -0,0 +1,507 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Iterables; + +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * Tests for CoGroupByKeyTest. Implements Serializable for anonymous DoFns. + */ +@RunWith(JUnit4.class) +public class CoGroupByKeyTest implements Serializable { + + /** + * Converts the given list into a PCollection belonging to the provided + * Pipeline in such a way that coder inference needs to be performed. + */ + private PCollection> createInput(String name, + Pipeline p, List> list) { + return createInput(name, p, list, new ArrayList()); + } + + /** + * Converts the given list with timestamps into a PCollection. + */ + private PCollection> createInput(String name, + Pipeline p, List> list, List timestamps) { + PCollection> input; + if (timestamps.isEmpty()) { + input = p.apply("Create" + name, Create.of(list) + .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()))); + } else { + input = p.apply("Create" + name, Create.timestamped(list, timestamps) + .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()))); + } + return input + .apply("Identity" + name, ParDo.of(new DoFn, + KV>() { + @Override + public void processElement(ProcessContext c) { + c.output(c.element()); + } + })); + } + + /** + * Returns a {@code PCollection>} containing the result + * of a {@link CoGroupByKey} over 2 {@code PCollection>}, + * where each {@link PCollection} has no duplicate keys and the key sets of + * each {@link PCollection} are intersecting but neither is a subset of the other. + */ + private PCollection> buildGetOnlyGbk( + Pipeline p, + TupleTag tag1, + TupleTag tag2) { + List> list1 = + Arrays.asList( + KV.of(1, "collection1-1"), + KV.of(2, "collection1-2")); + List> list2 = + Arrays.asList( + KV.of(2, "collection2-2"), + KV.of(3, "collection2-3")); + PCollection> collection1 = createInput("CreateList1", p, list1); + PCollection> collection2 = createInput("CreateList2", p, list2); + PCollection> coGbkResults = + KeyedPCollectionTuple.of(tag1, collection1) + .and(tag2, collection2) + .apply(CoGroupByKey.create()); + return coGbkResults; + } + + @Test + public void testCoGroupByKeyGetOnly() { + final TupleTag tag1 = new TupleTag<>(); + final TupleTag tag2 = new TupleTag<>(); + + Pipeline p = TestPipeline.create(); + + PCollection> coGbkResults = + buildGetOnlyGbk(p, tag1, tag2); + + DataflowAssert.thatMap(coGbkResults).satisfies( + new SerializableFunction, Void>() { + @Override + public Void apply(Map results) { + assertEquals("collection1-1", results.get(1).getOnly(tag1)); + assertEquals("collection1-2", results.get(2).getOnly(tag1)); + assertEquals("collection2-2", results.get(2).getOnly(tag2)); + assertEquals("collection2-3", results.get(3).getOnly(tag2)); + return null; + } + }); + + p.run(); + } + + /** + * Returns a {@code PCollection>} containing the + * results of the {@code CoGroupByKey} over three + * {@code PCollection>}, each of which correlates + * a customer id to purchases, addresses, or names, respectively. + */ + private PCollection> buildPurchasesCoGbk( + Pipeline p, + TupleTag purchasesTag, + TupleTag addressesTag, + TupleTag namesTag) { + List> idToPurchases = + Arrays.asList( + KV.of(2, "Boat"), + KV.of(1, "Shoes"), + KV.of(3, "Car"), + KV.of(1, "Book"), + KV.of(10, "Pens"), + KV.of(8, "House"), + KV.of(4, "Suit"), + KV.of(11, "House"), + KV.of(14, "Shoes"), + KV.of(2, "Suit"), + KV.of(8, "Suit Case"), + KV.of(3, "House")); + + List> idToAddress = + Arrays.asList( + KV.of(2, "53 S. 3rd"), + KV.of(10, "383 Jackson Street"), + KV.of(20, "3 W. Arizona"), + KV.of(3, "29 School Rd"), + KV.of(8, "6 Watling Rd")); + + List> idToName = + Arrays.asList( + KV.of(1, "John Smith"), + KV.of(2, "Sally James"), + KV.of(8, "Jeffery Spalding"), + KV.of(20, "Joan Lichtfield")); + + PCollection> purchasesTable = + createInput("CreateIdToPurchases", p, idToPurchases); + + PCollection> addressTable = + createInput("CreateIdToAddress", p, idToAddress); + + PCollection> nameTable = + createInput("CreateIdToName", p, idToName); + + PCollection> coGbkResults = + KeyedPCollectionTuple.of(namesTag, nameTable) + .and(addressesTag, addressTable) + .and(purchasesTag, purchasesTable) + .apply(CoGroupByKey.create()); + return coGbkResults; + } + + /** + * Returns a {@code PCollection>} containing the + * results of the {@code CoGroupByKey} over 2 {@code PCollection>}, + * each of which correlates a customer id to clicks, purchases, respectively. + */ + private PCollection> buildPurchasesCoGbkWithWindowing( + Pipeline p, + TupleTag clicksTag, + TupleTag purchasesTag) { + List> idToClick = + Arrays.asList( + KV.of(1, "Click t0"), + KV.of(2, "Click t2"), + KV.of(1, "Click t4"), + KV.of(1, "Click t6"), + KV.of(2, "Click t8")); + + List> idToPurchases = + Arrays.asList( + KV.of(1, "Boat t1"), + KV.of(1, "Shoesi t2"), + KV.of(1, "Pens t3"), + KV.of(2, "House t4"), + KV.of(2, "Suit t5"), + KV.of(1, "Car t6"), + KV.of(1, "Book t7"), + KV.of(2, "House t8"), + KV.of(2, "Shoes t9"), + KV.of(2, "House t10")); + + PCollection> clicksTable = + createInput("CreateClicks", + p, + idToClick, + Arrays.asList(0L, 2L, 4L, 6L, 8L)) + .apply("WindowClicks", Window.>into( + FixedWindows.of(new Duration(4)))); + + PCollection> purchasesTable = + createInput("CreatePurchases", + p, + idToPurchases, + Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L)) + .apply("WindowPurchases", Window.>into( + FixedWindows.of(new Duration(4)))); + + PCollection> coGbkResults = + KeyedPCollectionTuple.of(clicksTag, clicksTable) + .and(purchasesTag, purchasesTable) + .apply(CoGroupByKey.create()); + return coGbkResults; + } + + @Test + public void testCoGroupByKey() { + final TupleTag namesTag = new TupleTag<>(); + final TupleTag addressesTag = new TupleTag<>(); + final TupleTag purchasesTag = new TupleTag<>(); + + Pipeline p = TestPipeline.create(); + + PCollection> coGbkResults = + buildPurchasesCoGbk(p, purchasesTag, addressesTag, namesTag); + + DataflowAssert.thatMap(coGbkResults).satisfies( + new SerializableFunction, Void>() { + @Override + public Void apply(Map results) { + CoGbkResult result1 = results.get(1); + assertEquals("John Smith", result1.getOnly(namesTag)); + assertThat(result1.getAll(purchasesTag), containsInAnyOrder("Shoes", "Book")); + + CoGbkResult result2 = results.get(2); + assertEquals("Sally James", result2.getOnly(namesTag)); + assertEquals("53 S. 3rd", result2.getOnly(addressesTag)); + assertThat(result2.getAll(purchasesTag), containsInAnyOrder("Suit", "Boat")); + + CoGbkResult result3 = results.get(3); + assertEquals("29 School Rd", result3.getOnly(addressesTag), "29 School Rd"); + assertThat(result3.getAll(purchasesTag), containsInAnyOrder("Car", "House")); + + CoGbkResult result8 = results.get(8); + assertEquals("Jeffery Spalding", result8.getOnly(namesTag)); + assertEquals("6 Watling Rd", result8.getOnly(addressesTag)); + assertThat(result8.getAll(purchasesTag), containsInAnyOrder("House", "Suit Case")); + + CoGbkResult result20 = results.get(20); + assertEquals("Joan Lichtfield", result20.getOnly(namesTag)); + assertEquals("3 W. Arizona", result20.getOnly(addressesTag)); + + assertEquals("383 Jackson Street", results.get(10).getOnly(addressesTag)); + + assertThat(results.get(4).getAll(purchasesTag), containsInAnyOrder("Suit")); + assertThat(results.get(10).getAll(purchasesTag), containsInAnyOrder("Pens")); + assertThat(results.get(11).getAll(purchasesTag), containsInAnyOrder("House")); + assertThat(results.get(14).getAll(purchasesTag), containsInAnyOrder("Shoes")); + + return null; + } + }); + + p.run(); + } + + /** + * A DoFn used in testCoGroupByKeyWithWindowing(), to test processing the + * results of a CoGroupByKey. + */ + private static class ClickOfPurchaseFn extends + DoFn, KV> implements RequiresWindowAccess { + private final TupleTag clicksTag; + + private final TupleTag purchasesTag; + + private ClickOfPurchaseFn( + TupleTag clicksTag, + TupleTag purchasesTag) { + this.clicksTag = clicksTag; + this.purchasesTag = purchasesTag; + } + + @Override + public void processElement(ProcessContext c) { + BoundedWindow w = c.window(); + KV e = c.element(); + CoGbkResult row = e.getValue(); + Iterable clicks = row.getAll(clicksTag); + Iterable purchases = row.getAll(purchasesTag); + for (String click : clicks) { + for (String purchase : purchases) { + c.output(KV.of(click + ":" + purchase, + c.timestamp().getMillis() + ":" + w.maxTimestamp().getMillis())); + } + } + } + } + + + /** + * A DoFn used in testCoGroupByKeyHandleResults(), to test processing the + * results of a CoGroupByKey. + */ + private static class CorrelatePurchaseCountForAddressesWithoutNamesFn extends + DoFn, KV> { + private final TupleTag purchasesTag; + + private final TupleTag addressesTag; + + private final TupleTag namesTag; + + private CorrelatePurchaseCountForAddressesWithoutNamesFn( + TupleTag purchasesTag, + TupleTag addressesTag, + TupleTag namesTag) { + this.purchasesTag = purchasesTag; + this.addressesTag = addressesTag; + this.namesTag = namesTag; + } + + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + CoGbkResult row = e.getValue(); + // Don't actually care about the id. + Iterable names = row.getAll(namesTag); + if (names.iterator().hasNext()) { + // Nothing to do. There was a name. + return; + } + Iterable addresses = row.getAll(addressesTag); + if (!addresses.iterator().hasNext()) { + // Nothing to do, there was no address. + return; + } + // Buffer the addresses so we can accredit all of them with + // corresponding purchases. All addresses are for the same id, so + // if there are multiple, we apply the same purchase count to all. + ArrayList addressList = new ArrayList(); + for (String address : addresses) { + addressList.add(address); + } + + Iterable purchases = row.getAll(purchasesTag); + + int purchaseCount = Iterables.size(purchases); + + for (String address : addressList) { + c.output(KV.of(address, purchaseCount)); + } + } + } + + /** + * Tests that the consuming DoFn + * (CorrelatePurchaseCountForAddressesWithoutNamesFn) performs as expected. + */ + @SuppressWarnings("unchecked") + @Test + public void testConsumingDoFn() { + TupleTag purchasesTag = new TupleTag<>(); + TupleTag addressesTag = new TupleTag<>(); + TupleTag namesTag = new TupleTag<>(); + + // result1 should get filtered out because it has a name. + CoGbkResult result1 = CoGbkResult + .of(purchasesTag, Arrays.asList("3a", "3b")) + .and(addressesTag, Arrays.asList("2a", "2b")) + .and(namesTag, Arrays.asList("1a")); + // result 2 should be counted because it has an address and purchases. + CoGbkResult result2 = CoGbkResult + .of(purchasesTag, Arrays.asList("5a", "5b")) + .and(addressesTag, Arrays.asList("4a")) + .and(namesTag, new ArrayList()); + // result 3 should not be counted because it has no addresses. + CoGbkResult result3 = CoGbkResult + .of(purchasesTag, Arrays.asList("7a", "7b")) + .and(addressesTag, new ArrayList()) + .and(namesTag, new ArrayList()); + // result 4 should be counted as 0, because it has no purchases. + CoGbkResult result4 = CoGbkResult + .of(purchasesTag, new ArrayList()) + .and(addressesTag, Arrays.asList("8a")) + .and(namesTag, new ArrayList()); + + List> results = + DoFnTester.of( + new CorrelatePurchaseCountForAddressesWithoutNamesFn( + purchasesTag, + addressesTag, + namesTag)) + .processBatch( + KV.of(1, result1), + KV.of(2, result2), + KV.of(3, result3), + KV.of(4, result4)); + + assertThat(results, containsInAnyOrder(KV.of("4a", 2), KV.of("8a", 0))); + } + + /** + * Tests the pipeline end-to-end. Builds the purchases CoGroupByKey, and + * applies CorrelatePurchaseCountForAddressesWithoutNamesFn to the results. + */ + @SuppressWarnings("unchecked") + @Test + @Category(RunnableOnService.class) + public void testCoGroupByKeyHandleResults() { + TupleTag namesTag = new TupleTag<>(); + TupleTag addressesTag = new TupleTag<>(); + TupleTag purchasesTag = new TupleTag<>(); + + Pipeline p = TestPipeline.create(); + + PCollection> coGbkResults = + buildPurchasesCoGbk(p, purchasesTag, addressesTag, namesTag); + + // Do some simple processing on the result of the CoGroupByKey. Count the + // purchases for each address on record that has no associated name. + PCollection> + purchaseCountByKnownAddressesWithoutKnownNames = + coGbkResults.apply(ParDo.of( + new CorrelatePurchaseCountForAddressesWithoutNamesFn( + purchasesTag, addressesTag, namesTag))); + + DataflowAssert.that(purchaseCountByKnownAddressesWithoutKnownNames) + .containsInAnyOrder( + KV.of("29 School Rd", 2), + KV.of("383 Jackson Street", 1)); + p.run(); + } + + /** + * Tests the pipeline end-to-end with FixedWindows. + */ + @SuppressWarnings("unchecked") + @Test + @Category(RunnableOnService.class) + public void testCoGroupByKeyWithWindowing() { + TupleTag clicksTag = new TupleTag<>(); + TupleTag purchasesTag = new TupleTag<>(); + + Pipeline p = TestPipeline.create(); + + PCollection> coGbkResults = + buildPurchasesCoGbkWithWindowing(p, clicksTag, purchasesTag); + + PCollection> + clickOfPurchase = coGbkResults.apply(ParDo.of( + new ClickOfPurchaseFn(clicksTag, purchasesTag))); + DataflowAssert.that(clickOfPurchase) + .containsInAnyOrder( + KV.of("Click t0:Boat t1", "0:3"), + KV.of("Click t0:Shoesi t2", "0:3"), + KV.of("Click t0:Pens t3", "0:3"), + KV.of("Click t4:Car t6", "4:7"), + KV.of("Click t4:Book t7", "4:7"), + KV.of("Click t6:Car t6", "4:7"), + KV.of("Click t6:Book t7", "4:7"), + KV.of("Click t8:House t8", "8:11"), + KV.of("Click t8:Shoes t9", "8:11"), + KV.of("Click t8:House t10", "8:11")); + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoderTest.java new file mode 100644 index 000000000000..8fbe0d4badb9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/join/UnionCoderTest.java @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.join; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.DoubleCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.util.CloudObject; +import com.google.cloud.dataflow.sdk.util.Serializer; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests the UnionCoder. + */ +@RunWith(JUnit4.class) +public class UnionCoderTest { + + @Test + public void testSerializationDeserialization() { + UnionCoder newCoder = + UnionCoder.of(Arrays.>asList(StringUtf8Coder.of(), + DoubleCoder.of())); + CloudObject encoding = newCoder.asCloudObject(); + Coder decodedCoder = Serializer.deserialize(encoding, Coder.class); + assertEquals(newCoder, decodedCoder); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterAllTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterAllTest.java new file mode 100644 index 000000000000..06f0c3f759b2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterAllTest.java @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link AfterAll}. + */ +@RunWith(JUnit4.class) +public class AfterAllTest { + + private SimpleTriggerTester tester; + + @Test + public void testT1FiresFirst() throws Exception { + tester = TriggerTester.forTrigger( + AfterAll.of( + AfterPane.elementCountAtLeast(1), + AfterPane.elementCountAtLeast(2)), + FixedWindows.of(Duration.millis(100))); + + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + tester.injectElements(1); + assertFalse(tester.shouldFire(window)); + + tester.injectElements(2); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + @Test + public void testT2FiresFirst() throws Exception { + tester = TriggerTester.forTrigger( + AfterAll.of( + AfterPane.elementCountAtLeast(2), + AfterPane.elementCountAtLeast(1)), + FixedWindows.of(Duration.millis(100))); + + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + tester.injectElements(1); + assertFalse(tester.shouldFire(window)); + + tester.injectElements(2); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + /** + * Tests that the AfterAll properly unsets finished bits when a merge causing it to become + * unfinished. + */ + @Test + public void testOnMergeRewinds() throws Exception { + tester = TriggerTester.forTrigger( + AfterEach.inOrder( + AfterAll.of( + AfterWatermark.pastEndOfWindow(), + AfterPane.elementCountAtLeast(1)), + Repeatedly.forever(AfterPane.elementCountAtLeast(1))), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements(1); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + + tester.injectElements(5); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + + // Finish the AfterAll in the first window + tester.advanceInputWatermark(new Instant(11)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + + // Merge them; the AfterAll should not be finished + tester.mergeWindows(); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + assertFalse(tester.isMarkedFinished(mergedWindow)); + + // Confirm that we are back on the first trigger by probing that it is not ready to fire + // after an element (with merging) + tester.injectElements(3); + tester.mergeWindows(); + assertFalse(tester.shouldFire(mergedWindow)); + + // Fire the AfterAll in the merged window + tester.advanceInputWatermark(new Instant(15)); + assertTrue(tester.shouldFire(mergedWindow)); + tester.fireIfShouldFire(mergedWindow); + + // Confirm that we are on the second trigger by probing + tester.injectElements(2); + tester.mergeWindows(); + assertTrue(tester.shouldFire(mergedWindow)); + tester.fireIfShouldFire(mergedWindow); + tester.injectElements(2); + tester.mergeWindows(); + assertTrue(tester.shouldFire(mergedWindow)); + tester.fireIfShouldFire(mergedWindow); + } + + @Test + public void testFireDeadline() throws Exception { + BoundedWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + AfterAll.of(AfterWatermark.pastEndOfWindow(), AfterPane.elementCountAtLeast(1)) + .getWatermarkThatGuaranteesFiring(window)); + } + + @Test + public void testContinuation() throws Exception { + OnceTrigger trigger1 = AfterProcessingTime.pastFirstElementInPane(); + OnceTrigger trigger2 = AfterWatermark.pastEndOfWindow(); + Trigger afterAll = AfterAll.of(trigger1, trigger2); + assertEquals( + AfterAll.of(trigger1.getContinuationTrigger(), trigger2.getContinuationTrigger()), + afterAll.getContinuationTrigger()); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterEachTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterEachTest.java new file mode 100644 index 000000000000..5d6632d1c4ee --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterEachTest.java @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link AfterEach}. + */ +@RunWith(JUnit4.class) +public class AfterEachTest { + + private SimpleTriggerTester tester; + + @Before + public void initMocks() { + MockitoAnnotations.initMocks(this); + } + + /** + * Tests that the {@link AfterEach} trigger fires and finishes the first trigger then the second. + */ + @Test + public void testAfterEachInSequence() throws Exception { + tester = TriggerTester.forTrigger( + AfterEach.inOrder( + Repeatedly.forever(AfterPane.elementCountAtLeast(2)) + .orFinally(AfterPane.elementCountAtLeast(3)), + Repeatedly.forever(AfterPane.elementCountAtLeast(5)) + .orFinally(AfterWatermark.pastEndOfWindow())), + FixedWindows.of(Duration.millis(10))); + + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + // AfterCount(2) not ready + tester.injectElements(1); + assertFalse(tester.shouldFire(window)); + + // AfterCount(2) ready, not finished + tester.injectElements(2); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + // orFinally(AfterCount(3)) ready and will finish the first + tester.injectElements(1, 2, 3); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + // Now running as the second trigger + assertFalse(tester.shouldFire(window)); + // This quantity of elements would fire and finish if it were erroneously still the first + tester.injectElements(1, 2, 3, 4); + assertFalse(tester.shouldFire(window)); + + // Now fire once + tester.injectElements(5); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + // This time advance the watermark to finish the whole mess. + tester.advanceInputWatermark(new Instant(10)); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + @Test + public void testFireDeadline() throws Exception { + BoundedWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + assertEquals(new Instant(9), + AfterEach.inOrder(AfterWatermark.pastEndOfWindow(), + AfterPane.elementCountAtLeast(4)) + .getWatermarkThatGuaranteesFiring(window)); + + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + AfterEach.inOrder(AfterPane.elementCountAtLeast(2), AfterWatermark.pastEndOfWindow()) + .getWatermarkThatGuaranteesFiring(window)); + } + + @Test + public void testContinuation() throws Exception { + OnceTrigger trigger1 = AfterProcessingTime.pastFirstElementInPane(); + OnceTrigger trigger2 = AfterWatermark.pastEndOfWindow(); + Trigger afterEach = AfterEach.inOrder(trigger1, trigger2); + assertEquals( + Repeatedly.forever(AfterFirst.of( + trigger1.getContinuationTrigger(), trigger2.getContinuationTrigger())), + afterEach.getContinuationTrigger()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterFirstTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterFirstTest.java new file mode 100644 index 000000000000..135638c363e8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterFirstTest.java @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link AfterFirst}. + */ +@RunWith(JUnit4.class) +public class AfterFirstTest { + + @Mock private OnceTrigger mockTrigger1; + @Mock private OnceTrigger mockTrigger2; + private SimpleTriggerTester tester; + private static Trigger.TriggerContext anyTriggerContext() { + return Mockito..TriggerContext>any(); + } + + @Before + public void initMocks() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testNeitherShouldFireFixedWindows() throws Exception { + tester = TriggerTester.forTrigger( + AfterFirst.of(mockTrigger1, mockTrigger2), FixedWindows.of(Duration.millis(10))); + + tester.injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + when(mockTrigger1.shouldFire(anyTriggerContext())).thenReturn(false); + when(mockTrigger2.shouldFire(anyTriggerContext())).thenReturn(false); + + assertFalse(tester.shouldFire(window)); // should not fire + assertFalse(tester.isMarkedFinished(window)); // not finished + } + + @Test + public void testOnlyT1ShouldFireFixedWindows() throws Exception { + tester = TriggerTester.forTrigger( + AfterFirst.of(mockTrigger1, mockTrigger2), FixedWindows.of(Duration.millis(10))); + tester.injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(1), new Instant(11)); + + when(mockTrigger1.shouldFire(anyTriggerContext())).thenReturn(true); + when(mockTrigger2.shouldFire(anyTriggerContext())).thenReturn(false); + + assertTrue(tester.shouldFire(window)); // should fire + + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + @Test + public void testOnlyT2ShouldFireFixedWindows() throws Exception { + tester = TriggerTester.forTrigger( + AfterFirst.of(mockTrigger1, mockTrigger2), FixedWindows.of(Duration.millis(10))); + tester.injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(1), new Instant(11)); + + when(mockTrigger1.shouldFire(anyTriggerContext())).thenReturn(false); + when(mockTrigger2.shouldFire(anyTriggerContext())).thenReturn(true); + assertTrue(tester.shouldFire(window)); // should fire + + tester.fireIfShouldFire(window); // now finished + assertTrue(tester.isMarkedFinished(window)); + } + + @Test + public void testBothShouldFireFixedWindows() throws Exception { + tester = TriggerTester.forTrigger( + AfterFirst.of(mockTrigger1, mockTrigger2), FixedWindows.of(Duration.millis(10))); + tester.injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(1), new Instant(11)); + + when(mockTrigger1.shouldFire(anyTriggerContext())).thenReturn(true); + when(mockTrigger2.shouldFire(anyTriggerContext())).thenReturn(true); + assertTrue(tester.shouldFire(window)); // should fire + + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + /** + * Tests that if the first trigger rewinds to be non-finished in the merged window, + * then it becomes the currently active trigger again, with real triggers. + */ + @Test + public void testShouldFireAfterMerge() throws Exception { + tester = TriggerTester.forTrigger( + AfterEach.inOrder( + AfterFirst.of(AfterPane.elementCountAtLeast(5), + AfterWatermark.pastEndOfWindow()), + Repeatedly.forever(AfterPane.elementCountAtLeast(1))), + Sessions.withGapDuration(Duration.millis(10))); + + // Finished the AfterFirst in the first window + tester.injectElements(1); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + assertFalse(tester.shouldFire(firstWindow)); + tester.advanceInputWatermark(new Instant(11)); + assertTrue(tester.shouldFire(firstWindow)); + tester.fireIfShouldFire(firstWindow); + + // Set up second window where it is not done + tester.injectElements(5); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + assertFalse(tester.shouldFire(secondWindow)); + + // Merge them, if the merged window were on the second trigger, it would be ready + tester.mergeWindows(); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + assertFalse(tester.shouldFire(mergedWindow)); + + // Now adding 3 more makes the AfterFirst ready to fire + tester.injectElements(1, 2, 3, 4, 5); + tester.mergeWindows(); + assertTrue(tester.shouldFire(mergedWindow)); + } + + @Test + public void testFireDeadline() throws Exception { + BoundedWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + assertEquals(new Instant(9), + AfterFirst.of(AfterWatermark.pastEndOfWindow(), AfterPane.elementCountAtLeast(4)) + .getWatermarkThatGuaranteesFiring(window)); + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + AfterFirst.of(AfterPane.elementCountAtLeast(2), AfterPane.elementCountAtLeast(1)) + .getWatermarkThatGuaranteesFiring(window)); + } + + @Test + public void testContinuation() throws Exception { + OnceTrigger trigger1 = AfterProcessingTime.pastFirstElementInPane(); + OnceTrigger trigger2 = AfterWatermark.pastEndOfWindow(); + Trigger afterFirst = AfterFirst.of(trigger1, trigger2); + assertEquals( + AfterFirst.of(trigger1.getContinuationTrigger(), trigger2.getContinuationTrigger()), + afterFirst.getContinuationTrigger()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterPaneTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterPaneTest.java new file mode 100644 index 000000000000..9139bc588495 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterPaneTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link AfterPane}. + */ +@RunWith(JUnit4.class) +public class AfterPaneTest { + + SimpleTriggerTester tester; + /** + * Tests that the trigger does fire when enough elements are in a window, and that it only + * fires that window (no leakage). + */ + @Test + public void testAfterPaneElementCountFixedWindows() throws Exception { + tester = TriggerTester.forTrigger( + AfterPane.elementCountAtLeast(2), + FixedWindows.of(Duration.millis(10))); + + tester.injectElements(1); // [0, 10) + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + assertFalse(tester.shouldFire(window)); + + tester.injectElements(2); // [0, 10) + tester.injectElements(11); // [10, 20) + + assertTrue(tester.shouldFire(window)); // ready to fire + tester.fireIfShouldFire(window); // and finished + assertTrue(tester.isMarkedFinished(window)); + + // But don't finish the other window + assertFalse(tester.isMarkedFinished(new IntervalWindow(new Instant(10), new Instant(20)))); + } + + @Test + public void testClear() throws Exception { + SimpleTriggerTester tester = TriggerTester.forTrigger( + AfterPane.elementCountAtLeast(2), + FixedWindows.of(Duration.millis(10))); + + tester.injectElements(1, 2, 3); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + tester.clearState(window); + tester.assertCleared(window); + } + + @Test + public void testAfterPaneElementCountSessions() throws Exception { + tester = TriggerTester.forTrigger( + AfterPane.elementCountAtLeast(2), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements( + 1, // in [1, 11) + 2); // in [2, 12) + + assertFalse(tester.shouldFire(new IntervalWindow(new Instant(1), new Instant(11)))); + assertFalse(tester.shouldFire(new IntervalWindow(new Instant(2), new Instant(12)))); + + tester.mergeWindows(); + + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(12)); + assertTrue(tester.shouldFire(mergedWindow)); + tester.fireIfShouldFire(mergedWindow); + assertTrue(tester.isMarkedFinished(mergedWindow)); + + // Because we closed the previous window, we don't have it around to merge with. So there + // will be a new FIRE_AND_FINISH result. + tester.injectElements( + 7, // in [7, 17) + 9); // in [9, 19) + + tester.mergeWindows(); + + IntervalWindow newMergedWindow = new IntervalWindow(new Instant(7), new Instant(19)); + assertTrue(tester.shouldFire(newMergedWindow)); + tester.fireIfShouldFire(newMergedWindow); + assertTrue(tester.isMarkedFinished(newMergedWindow)); + } + + @Test + public void testFireDeadline() throws Exception { + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + AfterPane.elementCountAtLeast(1).getWatermarkThatGuaranteesFiring( + new IntervalWindow(new Instant(0), new Instant(10)))); + } + + @Test + public void testContinuation() throws Exception { + assertEquals( + AfterPane.elementCountAtLeast(1), + AfterPane.elementCountAtLeast(100).getContinuationTrigger()); + assertEquals( + AfterPane.elementCountAtLeast(1), + AfterPane.elementCountAtLeast(100).getContinuationTrigger().getContinuationTrigger()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterProcessingTimeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterProcessingTimeTest.java new file mode 100644 index 000000000000..a3bb3c372eab --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterProcessingTimeTest.java @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests the {@link AfterProcessingTime}. + */ +@RunWith(JUnit4.class) +public class AfterProcessingTimeTest { + + /** + * Tests the basic property that the trigger does wait for processing time to be + * far enough advanced. + */ + @Test + public void testAfterProcessingTimeFixedWindows() throws Exception { + Duration windowDuration = Duration.millis(10); + SimpleTriggerTester tester = TriggerTester.forTrigger( + AfterProcessingTime + .pastFirstElementInPane() + .plusDelayOf(Duration.millis(5)), + FixedWindows.of(windowDuration)); + + tester.advanceProcessingTime(new Instant(10)); + + // Timer at 15 + tester.injectElements(1); + IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(10)); + tester.advanceProcessingTime(new Instant(12)); + assertFalse(tester.shouldFire(firstWindow)); + + // Load up elements in the next window, timer at 17 for them + tester.injectElements(11, 12, 13); + IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new Instant(20)); + assertFalse(tester.shouldFire(secondWindow)); + + // Not quite time to fire + tester.advanceProcessingTime(new Instant(14)); + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + + // Timer at 19 for these in the first window; it should be ignored since the 15 will fire first + tester.injectElements(2, 3); + + // Advance past the first timer and fire, finishing the first window + tester.advanceProcessingTime(new Instant(16)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + assertTrue(tester.isMarkedFinished(firstWindow)); + + // The next window fires and finishes now + tester.advanceProcessingTime(new Instant(18)); + assertTrue(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(secondWindow); + assertTrue(tester.isMarkedFinished(secondWindow)); + } + + /** + * Tests that when windows merge, if the trigger is waiting for "N millis after the first + * element" that it is relative to the earlier of the two merged windows. + */ + @Test + public void testClear() throws Exception { + SimpleTriggerTester tester = TriggerTester.forTrigger( + AfterProcessingTime + .pastFirstElementInPane() + .plusDelayOf(Duration.millis(5)), + FixedWindows.of(Duration.millis(10))); + + tester.injectElements(1, 2, 3); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + tester.clearState(window); + tester.assertCleared(window); + } + + @Test + public void testAfterProcessingTimeWithMergingWindow() throws Exception { + SimpleTriggerTester tester = TriggerTester.forTrigger( + AfterProcessingTime + .pastFirstElementInPane() + .plusDelayOf(Duration.millis(5)), + Sessions.withGapDuration(Duration.millis(10))); + + tester.advanceProcessingTime(new Instant(10)); + tester.injectElements(1); // in [1, 11), timer for 15 + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + assertFalse(tester.shouldFire(firstWindow)); + + tester.advanceProcessingTime(new Instant(12)); + tester.injectElements(3); // in [3, 13), timer for 17 + IntervalWindow secondWindow = new IntervalWindow(new Instant(3), new Instant(13)); + assertFalse(tester.shouldFire(secondWindow)); + + tester.mergeWindows(); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(13)); + + tester.advanceProcessingTime(new Instant(16)); + assertTrue(tester.shouldFire(mergedWindow)); + } + + @Test + public void testFireDeadline() throws Exception { + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + AfterProcessingTime.pastFirstElementInPane().getWatermarkThatGuaranteesFiring( + new IntervalWindow(new Instant(0), new Instant(10)))); + } + + @Test + public void testContinuation() throws Exception { + OnceTrigger firstElementPlus1 = + AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.standardHours(1)); + assertEquals( + new AfterSynchronizedProcessingTime<>(), + firstElementPlus1.getContinuationTrigger()); + } + + /** + * Basic test of compatibility check between identical triggers. + */ + @Test + public void testCompatibilityIdentical() throws Exception { + Trigger t1 = AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(Duration.standardMinutes(1L)); + Trigger t2 = AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(Duration.standardMinutes(1L)); + assertTrue(t1.isCompatible(t2)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterSynchronizedProcessingTimeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterSynchronizedProcessingTimeTest.java new file mode 100644 index 000000000000..b37bba42e9ed --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterSynchronizedProcessingTimeTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests the {@link AfterSynchronizedProcessingTime}. + */ +@RunWith(JUnit4.class) +public class AfterSynchronizedProcessingTimeTest { + + private Trigger underTest = + new AfterSynchronizedProcessingTime(); + + @Test + public void testAfterProcessingTimeWithFixedWindows() throws Exception { + Duration windowDuration = Duration.millis(10); + SimpleTriggerTester tester = TriggerTester.forTrigger( + AfterProcessingTime + .pastFirstElementInPane() + .plusDelayOf(Duration.millis(5)), + FixedWindows.of(windowDuration)); + + tester.advanceProcessingTime(new Instant(10)); + + // Timer at 15 + tester.injectElements(1); + IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(10)); + tester.advanceProcessingTime(new Instant(12)); + assertFalse(tester.shouldFire(firstWindow)); + + // Load up elements in the next window, timer at 17 for them + tester.injectElements(11, 12, 13); + IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new Instant(20)); + assertFalse(tester.shouldFire(secondWindow)); + + // Not quite time to fire + tester.advanceProcessingTime(new Instant(14)); + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + + // Timer at 19 for these in the first window; it should be ignored since the 15 will fire first + tester.injectElements(2, 3); + + // Advance past the first timer and fire, finishing the first window + tester.advanceProcessingTime(new Instant(16)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + assertTrue(tester.isMarkedFinished(firstWindow)); + + // The next window fires and finishes now + tester.advanceProcessingTime(new Instant(18)); + assertTrue(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(secondWindow); + assertTrue(tester.isMarkedFinished(secondWindow)); + } + + @Test + public void testAfterProcessingTimeWithMergingWindow() throws Exception { + Duration windowDuration = Duration.millis(10); + SimpleTriggerTester tester = TriggerTester.forTrigger( + AfterProcessingTime + .pastFirstElementInPane() + .plusDelayOf(Duration.millis(5)), + Sessions.withGapDuration(windowDuration)); + + tester.advanceProcessingTime(new Instant(10)); + tester.injectElements(1); // in [1, 11), timer for 15 + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + assertFalse(tester.shouldFire(firstWindow)); + + tester.advanceProcessingTime(new Instant(12)); + tester.injectElements(3); // in [3, 13), timer for 17 + IntervalWindow secondWindow = new IntervalWindow(new Instant(3), new Instant(13)); + assertFalse(tester.shouldFire(secondWindow)); + + tester.mergeWindows(); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(13)); + + tester.advanceProcessingTime(new Instant(16)); + assertTrue(tester.shouldFire(mergedWindow)); + } + + @Test + public void testFireDeadline() throws Exception { + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + underTest.getWatermarkThatGuaranteesFiring( + new IntervalWindow(new Instant(0), new Instant(10)))); + } + + @Test + public void testContinuation() throws Exception { + assertEquals(underTest, underTest.getContinuationTrigger()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermarkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermarkTest.java new file mode 100644 index 000000000000..fb610aadd16b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermarkTest.java @@ -0,0 +1,338 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +/** + * Tests the {@link AfterWatermark} triggers. + */ +@RunWith(JUnit4.class) +public class AfterWatermarkTest { + + @Mock private OnceTrigger mockEarly; + @Mock private OnceTrigger mockLate; + + private SimpleTriggerTester tester; + private static Trigger.TriggerContext anyTriggerContext() { + return Mockito..TriggerContext>any(); + } + private static Trigger.OnElementContext anyElementContext() { + return Mockito..OnElementContext>any(); + } + + private void injectElements(int... elements) throws Exception { + for (int element : elements) { + doNothing().when(mockEarly).onElement(anyElementContext()); + doNothing().when(mockLate).onElement(anyElementContext()); + tester.injectElements(element); + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + public void testRunningAsTrigger(OnceTrigger mockTrigger, IntervalWindow window) + throws Exception { + + // Don't fire due to mock saying no + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(false); + assertFalse(tester.shouldFire(window)); // not ready + + // Fire due to mock trigger; early trigger is required to be a OnceTrigger + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + assertTrue(tester.shouldFire(window)); // ready + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + } + + @Test + public void testEarlyAndAtWatermark() throws Exception { + tester = TriggerTester.forTrigger( + AfterWatermark.pastEndOfWindow() + .withEarlyFirings(mockEarly), + FixedWindows.of(Duration.millis(100))); + + injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + testRunningAsTrigger(mockEarly, window); + + // Fire due to watermark + when(mockEarly.shouldFire(anyTriggerContext())).thenReturn(false); + tester.advanceInputWatermark(new Instant(100)); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + @Test + public void testAtWatermarkAndLate() throws Exception { + tester = TriggerTester.forTrigger( + AfterWatermark.pastEndOfWindow() + .withLateFirings(mockLate), + FixedWindows.of(Duration.millis(100))); + + injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + // No early firing, just double checking + when(mockEarly.shouldFire(anyTriggerContext())).thenReturn(true); + assertFalse(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + // Fire due to watermark + when(mockEarly.shouldFire(anyTriggerContext())).thenReturn(false); + tester.advanceInputWatermark(new Instant(100)); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + testRunningAsTrigger(mockLate, window); + } + + @Test + public void testEarlyAndAtWatermarkAndLate() throws Exception { + tester = TriggerTester.forTrigger( + AfterWatermark.pastEndOfWindow() + .withEarlyFirings(mockEarly) + .withLateFirings(mockLate), + FixedWindows.of(Duration.millis(100))); + + injectElements(1); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + testRunningAsTrigger(mockEarly, window); + + // Fire due to watermark + when(mockEarly.shouldFire(anyTriggerContext())).thenReturn(false); + tester.advanceInputWatermark(new Instant(100)); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + testRunningAsTrigger(mockLate, window); + } + + /** + * Tests that if the EOW is finished in both as well as the merged window, then + * it is finished in the merged result. + * + *

    Because windows are discarded when a trigger finishes, we need to embed this + * in a sequence in order to check that it is re-activated. So this test is potentially + * sensitive to other triggers' correctness. + */ + @Test + public void testOnMergeAlreadyFinished() throws Exception { + tester = TriggerTester.forTrigger( + AfterEach.inOrder( + AfterWatermark.pastEndOfWindow(), + Repeatedly.forever(AfterPane.elementCountAtLeast(1))), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements(1); + tester.injectElements(5); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + + // Finish the AfterWatermark.pastEndOfWindow() trigger in both windows + tester.advanceInputWatermark(new Instant(15)); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + tester.fireIfShouldFire(secondWindow); + + // Confirm that we are on the second trigger by probing + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.injectElements(1); + tester.injectElements(5); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + tester.fireIfShouldFire(secondWindow); + + // Merging should leave it finished + tester.mergeWindows(); + + // Confirm that we are on the second trigger by probing + assertFalse(tester.shouldFire(mergedWindow)); + tester.injectElements(1); + assertTrue(tester.shouldFire(mergedWindow)); + } + + /** + * Tests that the trigger rewinds to be non-finished in the merged window. + * + *

    Because windows are discarded when a trigger finishes, we need to embed this + * in a sequence in order to check that it is re-activated. So this test is potentially + * sensitive to other triggers' correctness. + */ + @Test + public void testOnMergeRewinds() throws Exception { + tester = TriggerTester.forTrigger( + AfterEach.inOrder( + AfterWatermark.pastEndOfWindow(), + Repeatedly.forever(AfterPane.elementCountAtLeast(1))), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements(1); + tester.injectElements(5); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + + // Finish the AfterWatermark.pastEndOfWindow() trigger in only the first window + tester.advanceInputWatermark(new Instant(11)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + + // Confirm that we are on the second trigger by probing + assertFalse(tester.shouldFire(firstWindow)); + tester.injectElements(1); + assertTrue(tester.shouldFire(firstWindow)); + tester.fireIfShouldFire(firstWindow); + + // Merging should re-activate the watermark trigger in the merged window + tester.mergeWindows(); + + // Confirm that we are not on the second trigger by probing + assertFalse(tester.shouldFire(mergedWindow)); + tester.injectElements(1); + assertFalse(tester.shouldFire(mergedWindow)); + + // And confirm that advancing the watermark fires again + tester.advanceInputWatermark(new Instant(15)); + assertTrue(tester.shouldFire(mergedWindow)); + } + + /** + * Tests that if the EOW is finished in both as well as the merged window, then + * it is finished in the merged result. + * + *

    Because windows are discarded when a trigger finishes, we need to embed this + * in a sequence in order to check that it is re-activated. So this test is potentially + * sensitive to other triggers' correctness. + */ + @Test + public void testEarlyAndLateOnMergeAlreadyFinished() throws Exception { + tester = TriggerTester.forTrigger( + AfterWatermark.pastEndOfWindow() + .withEarlyFirings(AfterPane.elementCountAtLeast(100)) + .withLateFirings(AfterPane.elementCountAtLeast(1)), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements(1); + tester.injectElements(5); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + + // Finish the AfterWatermark.pastEndOfWindow() bit of the trigger in both windows + tester.advanceInputWatermark(new Instant(15)); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + tester.fireIfShouldFire(secondWindow); + + // Confirm that we are on the late trigger by probing + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.injectElements(1); + tester.injectElements(5); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + tester.fireIfShouldFire(secondWindow); + + // Merging should leave it on the late trigger + tester.mergeWindows(); + + // Confirm that we are on the late trigger by probing + assertFalse(tester.shouldFire(mergedWindow)); + tester.injectElements(1); + assertTrue(tester.shouldFire(mergedWindow)); + } + + /** + * Tests that the trigger rewinds to be non-finished in the merged window. + * + *

    Because windows are discarded when a trigger finishes, we need to embed this + * in a sequence in order to check that it is re-activated. So this test is potentially + * sensitive to other triggers' correctness. + */ + @Test + public void testEarlyAndLateOnMergeRewinds() throws Exception { + tester = TriggerTester.forTrigger( + AfterWatermark.pastEndOfWindow() + .withEarlyFirings(AfterPane.elementCountAtLeast(100)) + .withLateFirings(AfterPane.elementCountAtLeast(1)), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements(1); + tester.injectElements(5); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + + // Finish the AfterWatermark.pastEndOfWindow() bit of the trigger in only the first window + tester.advanceInputWatermark(new Instant(11)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + tester.fireIfShouldFire(firstWindow); + + // Confirm that we are on the late trigger by probing + assertFalse(tester.shouldFire(firstWindow)); + tester.injectElements(1); + assertTrue(tester.shouldFire(firstWindow)); + tester.fireIfShouldFire(firstWindow); + + // Merging should re-activate the early trigger in the merged window + tester.mergeWindows(); + + // Confirm that we are not on the second trigger by probing + assertFalse(tester.shouldFire(mergedWindow)); + tester.injectElements(1); + assertFalse(tester.shouldFire(mergedWindow)); + + // And confirm that advancing the watermark fires again + tester.advanceInputWatermark(new Instant(15)); + assertTrue(tester.shouldFire(mergedWindow)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindowsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindowsTest.java new file mode 100644 index 000000000000..512fcbc292d2 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/CalendarWindowsTest.java @@ -0,0 +1,260 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.runWindowFn; +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.set; +import static org.junit.Assert.assertEquals; + +import org.joda.time.DateTime; +import org.joda.time.DateTimeConstants; +import org.joda.time.DateTimeZone; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Tests for CalendarWindows WindowFn. + */ +@RunWith(JUnit4.class) +public class CalendarWindowsTest { + + private static Instant makeTimestamp(int year, int month, int day, int hours, int minutes) { + return new DateTime(year, month, day, hours, minutes, DateTimeZone.UTC).toInstant(); + } + + @Test + public void testDays() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 1, 1, 0, 0).getMillis(), + makeTimestamp(2014, 1, 1, 23, 59).getMillis(), + + makeTimestamp(2014, 1, 2, 0, 0).getMillis(), + makeTimestamp(2014, 1, 2, 5, 5).getMillis(), + + makeTimestamp(2015, 1, 1, 0, 0).getMillis(), + makeTimestamp(2015, 1, 1, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 1, 0, 0), + makeTimestamp(2014, 1, 2, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 2, 0, 0), + makeTimestamp(2014, 1, 3, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2015, 1, 1, 0, 0), + makeTimestamp(2015, 1, 2, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, runWindowFn(CalendarWindows.days(1), timestamps)); + } + + @Test + public void testWeeks() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 1, 1, 0, 0).getMillis(), + makeTimestamp(2014, 1, 5, 5, 5).getMillis(), + + makeTimestamp(2014, 1, 8, 0, 0).getMillis(), + makeTimestamp(2014, 1, 12, 5, 5).getMillis(), + + makeTimestamp(2015, 1, 1, 0, 0).getMillis(), + makeTimestamp(2015, 1, 6, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 1, 0, 0), + makeTimestamp(2014, 1, 8, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 8, 0, 0), + makeTimestamp(2014, 1, 15, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 12, 31, 0, 0), + makeTimestamp(2015, 1, 7, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, + runWindowFn(CalendarWindows.weeks(1, DateTimeConstants.WEDNESDAY), timestamps)); + } + + @Test + public void testMonths() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 1, 1, 0, 0).getMillis(), + makeTimestamp(2014, 1, 31, 5, 5).getMillis(), + + makeTimestamp(2014, 2, 1, 0, 0).getMillis(), + makeTimestamp(2014, 2, 15, 5, 5).getMillis(), + + makeTimestamp(2015, 1, 1, 0, 0).getMillis(), + makeTimestamp(2015, 1, 31, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 1, 1, 0, 0), + makeTimestamp(2014, 2, 1, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 2, 1, 0, 0), + makeTimestamp(2014, 3, 1, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2015, 1, 1, 0, 0), + makeTimestamp(2015, 2, 1, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, + runWindowFn(CalendarWindows.months(1), timestamps)); + } + + @Test + public void testMultiMonths() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2014, 3, 5, 0, 0).getMillis(), + makeTimestamp(2014, 10, 4, 23, 59).getMillis(), + + makeTimestamp(2014, 10, 5, 0, 0).getMillis(), + makeTimestamp(2015, 3, 1, 0, 0).getMillis(), + + makeTimestamp(2016, 1, 5, 0, 0).getMillis(), + makeTimestamp(2016, 1, 31, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 3, 5, 0, 0), + makeTimestamp(2014, 10, 5, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2014, 10, 5, 0, 0), + makeTimestamp(2015, 5, 5, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2015, 12, 5, 0, 0), + makeTimestamp(2016, 7, 5, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, runWindowFn( + CalendarWindows.months(7).withStartingMonth(2014, 3).beginningOnDay(5), timestamps)); + } + + @Test + public void testYears() throws Exception { + Map> expected = new HashMap<>(); + + final List timestamps = Arrays.asList( + makeTimestamp(2000, 5, 5, 0, 0).getMillis(), + makeTimestamp(2010, 5, 4, 23, 59).getMillis(), + + makeTimestamp(2010, 5, 5, 0, 0).getMillis(), + makeTimestamp(2015, 3, 1, 0, 0).getMillis(), + + makeTimestamp(2052, 1, 5, 0, 0).getMillis(), + makeTimestamp(2060, 5, 4, 5, 5).getMillis()); + + + expected.put( + new IntervalWindow( + makeTimestamp(2000, 5, 5, 0, 0), + makeTimestamp(2010, 5, 5, 0, 0)), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + makeTimestamp(2010, 5, 5, 0, 0), + makeTimestamp(2020, 5, 5, 0, 0)), + set(timestamps.get(2), timestamps.get(3))); + + expected.put( + new IntervalWindow( + makeTimestamp(2050, 5, 5, 0, 0), + makeTimestamp(2060, 5, 5, 0, 0)), + set(timestamps.get(4), timestamps.get(5))); + + assertEquals(expected, runWindowFn( + CalendarWindows.years(10).withStartingYear(2000).beginningOnDay(5, 5), timestamps)); + } + + @Test + public void testTimeZone() throws Exception { + Map> expected = new HashMap<>(); + + DateTimeZone timeZone = DateTimeZone.forID("America/Los_Angeles"); + + final List timestamps = Arrays.asList( + new DateTime(2014, 1, 1, 0, 0, timeZone).getMillis(), + new DateTime(2014, 1, 1, 23, 59, timeZone).getMillis(), + + new DateTime(2014, 1, 2, 8, 0, DateTimeZone.UTC).getMillis(), + new DateTime(2014, 1, 3, 7, 59, DateTimeZone.UTC).getMillis()); + + expected.put( + new IntervalWindow( + new DateTime(2014, 1, 1, 0, 0, timeZone).toInstant(), + new DateTime(2014, 1, 2, 0, 0, timeZone).toInstant()), + set(timestamps.get(0), timestamps.get(1))); + + expected.put( + new IntervalWindow( + new DateTime(2014, 1, 2, 0, 0, timeZone).toInstant(), + new DateTime(2014, 1, 3, 0, 0, timeZone).toInstant()), + set(timestamps.get(2), timestamps.get(3))); + + assertEquals(expected, runWindowFn( + CalendarWindows.days(1).withTimeZone(timeZone), + timestamps)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/DefaultTriggerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/DefaultTriggerTest.java new file mode 100644 index 000000000000..bdc31e27fd83 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/DefaultTriggerTest.java @@ -0,0 +1,176 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests the {@link DefaultTrigger}, which should be equivalent to + * {@code Repeatedly.forever(AfterWatermark.pastEndOfWindow())}. + */ +@RunWith(JUnit4.class) +public class DefaultTriggerTest { + + SimpleTriggerTester tester; + + @Test + public void testDefaultTriggerFixedWindows() throws Exception { + tester = TriggerTester.forTrigger( + DefaultTrigger.of(), + FixedWindows.of(Duration.millis(100))); + + tester.injectElements( + 1, // [0, 100) + 101); // [100, 200) + + IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(100)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(100), new Instant(200)); + + // Advance the watermark almost to the end of the first window. + tester.advanceInputWatermark(new Instant(99)); + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + + // Advance watermark past end of the first window, which is then ready + tester.advanceInputWatermark(new Instant(100)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + + // Fire, but the first window is still allowed to fire + tester.fireIfShouldFire(firstWindow); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + + // Advance watermark to 200, then both are ready + tester.advanceInputWatermark(new Instant(200)); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + + assertFalse(tester.isMarkedFinished(firstWindow)); + assertFalse(tester.isMarkedFinished(secondWindow)); + } + + @Test + public void testDefaultTriggerSlidingWindows() throws Exception { + tester = TriggerTester.forTrigger( + DefaultTrigger.of(), + SlidingWindows.of(Duration.millis(100)).every(Duration.millis(50))); + + tester.injectElements( + 1, // [-50, 50), [0, 100) + 50); // [0, 100), [50, 150) + + IntervalWindow firstWindow = new IntervalWindow(new Instant(-50), new Instant(50)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(0), new Instant(100)); + IntervalWindow thirdWindow = new IntervalWindow(new Instant(50), new Instant(150)); + + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(thirdWindow)); + + // At 50, the first becomes ready; it stays ready after firing + tester.advanceInputWatermark(new Instant(50)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(thirdWindow)); + tester.fireIfShouldFire(firstWindow); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(thirdWindow)); + + // At 99, the first is still the only one ready + tester.advanceInputWatermark(new Instant(99)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(thirdWindow)); + + // At 100, the first and second are ready + tester.advanceInputWatermark(new Instant(100)); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(thirdWindow)); + tester.fireIfShouldFire(firstWindow); + + assertFalse(tester.isMarkedFinished(firstWindow)); + assertFalse(tester.isMarkedFinished(secondWindow)); + assertFalse(tester.isMarkedFinished(thirdWindow)); + } + + @Test + public void testDefaultTriggerSessions() throws Exception { + tester = TriggerTester.forTrigger( + DefaultTrigger.of(), + Sessions.withGapDuration(Duration.millis(100))); + + tester.injectElements( + 1, // [1, 101) + 50); // [50, 150) + tester.mergeWindows(); + + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(101)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(50), new Instant(150)); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(150)); + + // Not ready in any window yet + tester.advanceInputWatermark(new Instant(100)); + assertFalse(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(mergedWindow)); + + // The first window is "ready": the caller owns knowledge of which windows are merged away + tester.advanceInputWatermark(new Instant(149)); + assertTrue(tester.shouldFire(firstWindow)); + assertFalse(tester.shouldFire(secondWindow)); + assertFalse(tester.shouldFire(mergedWindow)); + + // Now ready on all windows + tester.advanceInputWatermark(new Instant(150)); + assertTrue(tester.shouldFire(firstWindow)); + assertTrue(tester.shouldFire(secondWindow)); + assertTrue(tester.shouldFire(mergedWindow)); + + // Ensure it repeats + tester.fireIfShouldFire(mergedWindow); + assertTrue(tester.shouldFire(mergedWindow)); + + assertFalse(tester.isMarkedFinished(mergedWindow)); + } + + @Test + public void testFireDeadline() throws Exception { + assertEquals(new Instant(9), DefaultTrigger.of().getWatermarkThatGuaranteesFiring( + new IntervalWindow(new Instant(0), new Instant(10)))); + assertEquals(GlobalWindow.INSTANCE.maxTimestamp(), + DefaultTrigger.of().getWatermarkThatGuaranteesFiring(GlobalWindow.INSTANCE)); + } + + @Test + public void testContinuation() throws Exception { + assertEquals(DefaultTrigger.of(), DefaultTrigger.of().getContinuationTrigger()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindowsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindowsTest.java new file mode 100644 index 000000000000..935f22e01614 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/FixedWindowsTest.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.runWindowFn; +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.set; +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Tests for FixedWindows WindowFn. + */ +@RunWith(JUnit4.class) +public class FixedWindowsTest { + + @Test + public void testSimpleFixedWindow() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(1, 2, 5, 9)); + expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10, 11)); + expected.put(new IntervalWindow(new Instant(100), new Instant(110)), set(100)); + assertEquals( + expected, + runWindowFn( + FixedWindows.of(new Duration(10)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L, 100L))); + } + + @Test + public void testFixedOffsetWindow() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5), new Instant(5)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(5), new Instant(15)), set(5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(95), new Instant(105)), set(100)); + assertEquals( + expected, + runWindowFn( + FixedWindows.of(new Duration(10)).withOffset(new Duration(5)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L, 100L))); + } + + @Test + public void testTimeUnit() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5000), new Instant(5000)), set(1, 2, 1000)); + expected.put(new IntervalWindow(new Instant(5000), new Instant(15000)), set(5000, 5001, 10000)); + assertEquals( + expected, + runWindowFn( + FixedWindows.of(Duration.standardSeconds(10)).withOffset(Duration.standardSeconds(5)), + Arrays.asList(1L, 2L, 1000L, 5000L, 5001L, 10000L))); + } + + void checkConstructionFailure(int size, int offset) { + try { + FixedWindows.of(Duration.standardSeconds(size)).withOffset(Duration.standardSeconds(offset)); + fail("should have failed"); + } catch (IllegalArgumentException e) { + assertThat(e.toString(), + containsString("FixedWindows WindowingStrategies must have 0 <= offset < size")); + } + } + + @Test + public void testInvalidInput() throws Exception { + checkConstructionFailure(-1, 0); + checkConstructionFailure(1, 2); + checkConstructionFailure(1, -1); + } + + @Test + public void testEquality() { + assertTrue(FixedWindows.of(new Duration(10)).isCompatible(FixedWindows.of(new Duration(10)))); + assertTrue( + FixedWindows.of(new Duration(10)).isCompatible( + FixedWindows.of(new Duration(10)))); + assertTrue( + FixedWindows.of(new Duration(10)).isCompatible( + FixedWindows.of(new Duration(10)))); + + assertFalse(FixedWindows.of(new Duration(10)).isCompatible(FixedWindows.of(new Duration(20)))); + assertFalse(FixedWindows.of(new Duration(10)).isCompatible( + FixedWindows.of(new Duration(20)))); + } + + @Test + public void testValidOutputTimes() throws Exception { + for (long timestamp : Arrays.asList(200, 800, 700)) { + WindowFnTestUtils.validateGetOutputTimestamp( + FixedWindows.of(new Duration(500)), timestamp); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindowTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindowTest.java new file mode 100644 index 000000000000..968063e45fb5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/IntervalWindowTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.InstantCoder; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.collect.Lists; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** + * Tests for {@link Window}. + */ +@RunWith(JUnit4.class) +public class IntervalWindowTest { + + private static final Coder TEST_CODER = IntervalWindow.getCoder(); + + private static final List TEST_VALUES = Lists.newArrayList( + new IntervalWindow(new Instant(0), new Instant(0)), + new IntervalWindow(new Instant(0), new Instant(1000)), + new IntervalWindow(new Instant(-1000), new Instant(735)), + new IntervalWindow(new Instant(350), new Instant(60 * 60 * 1000)), + new IntervalWindow(new Instant(0), new Instant(24 * 60 * 60 * 1000)), + new IntervalWindow( + Instant.parse("2015-04-01T00:00:00Z"), Instant.parse("2015-04-01T11:45:13Z"))); + + @Test + public void testBasicEncoding() throws Exception { + for (IntervalWindow window : TEST_VALUES) { + CoderProperties.coderDecodeEncodeEqual(TEST_CODER, window); + } + } + + /** + * This is a change detector test for the sizes of encoded windows. Since these are present + * for every element of every windowed PCollection, the size matters. + * + *

    This test documents the expectation that encoding as a (endpoint, duration) pair + * using big endian for the endpoint and variable length long for the duration should be about 25% + * smaller than encoding two big endian Long values. + */ + @Test + public void testLengthsOfEncodingChoices() throws Exception { + Instant start = Instant.parse("2015-04-01T00:00:00Z"); + Instant minuteEnd = Instant.parse("2015-04-01T00:01:00Z"); + Instant hourEnd = Instant.parse("2015-04-01T01:00:00Z"); + Instant dayEnd = Instant.parse("2015-04-02T00:00:00Z"); + + Coder instantCoder = InstantCoder.of(); + byte[] encodedStart = CoderUtils.encodeToByteArray(instantCoder, start); + byte[] encodedMinuteEnd = CoderUtils.encodeToByteArray(instantCoder, minuteEnd); + byte[] encodedHourEnd = CoderUtils.encodeToByteArray(instantCoder, hourEnd); + byte[] encodedDayEnd = CoderUtils.encodeToByteArray(instantCoder, dayEnd); + + byte[] encodedMinuteWindow = CoderUtils.encodeToByteArray( + TEST_CODER, new IntervalWindow(start, minuteEnd)); + byte[] encodedHourWindow = CoderUtils.encodeToByteArray( + TEST_CODER, new IntervalWindow(start, hourEnd)); + byte[] encodedDayWindow = CoderUtils.encodeToByteArray( + TEST_CODER, new IntervalWindow(start, dayEnd)); + + assertThat(encodedMinuteWindow.length, + equalTo(encodedStart.length + encodedMinuteEnd.length - 5)); + assertThat(encodedHourWindow.length, + equalTo(encodedStart.length + encodedHourEnd.length - 4)); + assertThat(encodedDayWindow.length, + equalTo(encodedStart.length + encodedDayEnd.length - 4)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/OrFinallyTriggerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/OrFinallyTriggerTest.java new file mode 100644 index 000000000000..a416e6042506 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/OrFinallyTriggerTest.java @@ -0,0 +1,209 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger.OnceTrigger; +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link OrFinallyTrigger}. + */ +@RunWith(JUnit4.class) +public class OrFinallyTriggerTest { + + private SimpleTriggerTester tester; + + /** + * Tests that for {@code OrFinally(actual, ...)} when {@code actual} + * fires and finishes, the {@code OrFinally} also fires and finishes. + */ + @Test + public void testActualFiresAndFinishes() throws Exception { + tester = TriggerTester.forTrigger( + new OrFinallyTrigger<>( + AfterPane.elementCountAtLeast(2), + AfterPane.elementCountAtLeast(100)), + FixedWindows.of(Duration.millis(100))); + + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + // Not yet firing + tester.injectElements(1); + assertFalse(tester.shouldFire(window)); + assertFalse(tester.isMarkedFinished(window)); + + // The actual fires and finishes + tester.injectElements(2); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + /** + * Tests that for {@code OrFinally(actual, ...)} when {@code actual} + * fires but does not finish, the {@code OrFinally} also fires and also does not + * finish. + */ + @Test + public void testActualFiresOnly() throws Exception { + tester = TriggerTester.forTrigger( + new OrFinallyTrigger<>( + Repeatedly.forever(AfterPane.elementCountAtLeast(2)), + AfterPane.elementCountAtLeast(100)), + FixedWindows.of(Duration.millis(100))); + + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(100)); + + // Not yet firing + tester.injectElements(1); + assertFalse(tester.shouldFire(window)); + assertFalse(tester.isMarkedFinished(window)); + + // The actual fires but does not finish + tester.injectElements(2); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + // And again + tester.injectElements(3, 4); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + } + + /** + * Tests that if the first trigger rewinds to be non-finished in the merged window, + * then it becomes the currently active trigger again, with real triggers. + */ + @Test + public void testShouldFireAfterMerge() throws Exception { + tester = TriggerTester.forTrigger( + AfterEach.inOrder( + AfterPane.elementCountAtLeast(5) + .orFinally(AfterWatermark.pastEndOfWindow()), + Repeatedly.forever(AfterPane.elementCountAtLeast(1))), + Sessions.withGapDuration(Duration.millis(10))); + + // Finished the orFinally in the first window + tester.injectElements(1); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + assertFalse(tester.shouldFire(firstWindow)); + tester.advanceInputWatermark(new Instant(11)); + assertTrue(tester.shouldFire(firstWindow)); + tester.fireIfShouldFire(firstWindow); + + // Set up second window where it is not done + tester.injectElements(5); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + assertFalse(tester.shouldFire(secondWindow)); + + // Merge them, if the merged window were on the second trigger, it would be ready + tester.mergeWindows(); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + assertFalse(tester.shouldFire(mergedWindow)); + + // Now adding 3 more makes the main trigger ready to fire + tester.injectElements(1, 2, 3, 4, 5); + tester.mergeWindows(); + assertTrue(tester.shouldFire(mergedWindow)); + } + + /** + * Tests that for {@code OrFinally(actual, until)} when {@code actual} + * fires but does not finish, then {@code until} fires and finishes, the + * whole thing fires and finished. + */ + @Test + public void testActualFiresButUntilFinishes() throws Exception { + tester = TriggerTester.forTrigger( + new OrFinallyTrigger( + Repeatedly.forever(AfterPane.elementCountAtLeast(2)), + AfterPane.elementCountAtLeast(3)), + FixedWindows.of(Duration.millis(10))); + + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + // Before any firing + tester.injectElements(1); + assertFalse(tester.shouldFire(window)); + assertFalse(tester.isMarkedFinished(window)); + + // The actual fires but doesn't finish + tester.injectElements(2); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertFalse(tester.isMarkedFinished(window)); + + // The until fires and finishes; the trigger is finished + tester.injectElements(3); + assertTrue(tester.shouldFire(window)); + tester.fireIfShouldFire(window); + assertTrue(tester.isMarkedFinished(window)); + } + + @Test + public void testFireDeadline() throws Exception { + BoundedWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + + assertEquals(new Instant(9), + Repeatedly.forever(AfterWatermark.pastEndOfWindow()) + .getWatermarkThatGuaranteesFiring(window)); + assertEquals(new Instant(9), Repeatedly.forever(AfterWatermark.pastEndOfWindow()) + .orFinally(AfterPane.elementCountAtLeast(1)) + .getWatermarkThatGuaranteesFiring(window)); + assertEquals(new Instant(9), Repeatedly.forever(AfterPane.elementCountAtLeast(1)) + .orFinally(AfterWatermark.pastEndOfWindow()) + .getWatermarkThatGuaranteesFiring(window)); + assertEquals(new Instant(9), + AfterPane.elementCountAtLeast(100) + .orFinally(AfterWatermark.pastEndOfWindow()) + .getWatermarkThatGuaranteesFiring(window)); + + assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, + Repeatedly.forever(AfterPane.elementCountAtLeast(1)) + .orFinally(AfterPane.elementCountAtLeast(10)) + .getWatermarkThatGuaranteesFiring(window)); + } + + @Test + public void testContinuation() throws Exception { + OnceTrigger triggerA = AfterProcessingTime.pastFirstElementInPane(); + OnceTrigger triggerB = AfterWatermark.pastEndOfWindow(); + Trigger aOrFinallyB = triggerA.orFinally(triggerB); + Trigger bOrFinallyA = triggerB.orFinally(triggerA); + assertEquals( + Repeatedly.forever( + triggerA.getContinuationTrigger().orFinally(triggerB.getContinuationTrigger())), + aOrFinallyB.getContinuationTrigger()); + assertEquals( + Repeatedly.forever( + triggerB.getContinuationTrigger().orFinally(triggerA.getContinuationTrigger())), + bOrFinallyA.getContinuationTrigger()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/PaneInfoTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/PaneInfoTest.java new file mode 100644 index 000000000000..62ac2a1fad81 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/PaneInfoTest.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link PaneInfo}. + */ +@RunWith(JUnit4.class) +public class PaneInfoTest { + + @Test + public void testInterned() throws Exception { + assertSame( + PaneInfo.createPane(true, true, Timing.EARLY), + PaneInfo.createPane(true, true, Timing.EARLY)); + } + + @Test + public void testEncodingRoundTrip() throws Exception { + Coder coder = PaneInfo.PaneInfoCoder.INSTANCE; + for (Timing timing : Timing.values()) { + long onTimeIndex = timing == Timing.EARLY ? -1 : 37; + CoderProperties.coderDecodeEncodeEqual( + coder, PaneInfo.createPane(false, false, timing, 389, onTimeIndex)); + CoderProperties.coderDecodeEncodeEqual( + coder, PaneInfo.createPane(false, true, timing, 5077, onTimeIndex)); + CoderProperties.coderDecodeEncodeEqual( + coder, PaneInfo.createPane(true, false, timing, 0, 0)); + CoderProperties.coderDecodeEncodeEqual( + coder, PaneInfo.createPane(true, true, timing, 0, 0)); + } + } + + @Test + public void testEncodings() { + assertEquals("PaneInfo encoding assumes that there are only 4 Timing values.", + 4, Timing.values().length); + assertEquals("PaneInfo encoding should remain the same.", + 0x0, PaneInfo.createPane(false, false, Timing.EARLY, 1, -1).getEncodedByte()); + assertEquals("PaneInfo encoding should remain the same.", + 0x1, PaneInfo.createPane(true, false, Timing.EARLY).getEncodedByte()); + assertEquals("PaneInfo encoding should remain the same.", + 0x3, PaneInfo.createPane(true, true, Timing.EARLY).getEncodedByte()); + assertEquals("PaneInfo encoding should remain the same.", + 0x7, PaneInfo.createPane(true, true, Timing.ON_TIME).getEncodedByte()); + assertEquals("PaneInfo encoding should remain the same.", + 0xB, PaneInfo.createPane(true, true, Timing.LATE).getEncodedByte()); + assertEquals("PaneInfo encoding should remain the same.", + 0xF, PaneInfo.createPane(true, true, Timing.UNKNOWN).getEncodedByte()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/RepeatedlyTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/RepeatedlyTest.java new file mode 100644 index 000000000000..f445b52565ed --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/RepeatedlyTest.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.util.TriggerTester; +import com.google.cloud.dataflow.sdk.util.TriggerTester.SimpleTriggerTester; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link Repeatedly}. + */ +@RunWith(JUnit4.class) +public class RepeatedlyTest { + + @Mock private Trigger mockTrigger; + private SimpleTriggerTester tester; + private static Trigger.TriggerContext anyTriggerContext() { + return Mockito..TriggerContext>any(); + } + + public void setUp(WindowFn windowFn) throws Exception { + MockitoAnnotations.initMocks(this); + tester = TriggerTester.forTrigger(Repeatedly.forever(mockTrigger), windowFn); + } + + /** + * Tests that onElement correctly passes the data on to the subtrigger. + */ + @Test + public void testOnElement() throws Exception { + setUp(FixedWindows.of(Duration.millis(10))); + tester.injectElements(37); + verify(mockTrigger).onElement(Mockito..OnElementContext>any()); + } + + /** + * Tests that the repeatedly is ready to fire whenever the subtrigger is ready. + */ + @Test + public void testShouldFire() throws Exception { + setUp(FixedWindows.of(Duration.millis(10))); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + assertTrue(tester.shouldFire(new IntervalWindow(new Instant(0), new Instant(10)))); + + when(mockTrigger.shouldFire(Mockito..TriggerContext>any())) + .thenReturn(false); + assertFalse(tester.shouldFire(new IntervalWindow(new Instant(0), new Instant(10)))); + } + + /** + * Tests that the watermark that guarantees firing is that of the subtrigger. + */ + @Test + public void testFireDeadline() throws Exception { + setUp(FixedWindows.of(Duration.millis(10))); + IntervalWindow window = new IntervalWindow(new Instant(0), new Instant(10)); + Instant arbitraryInstant = new Instant(34957849); + + when(mockTrigger.getWatermarkThatGuaranteesFiring(Mockito.any())) + .thenReturn(arbitraryInstant); + + assertThat( + Repeatedly.forever(mockTrigger).getWatermarkThatGuaranteesFiring(window), + equalTo(arbitraryInstant)); + } + + @Test + public void testContinuation() throws Exception { + Trigger trigger = AfterProcessingTime.pastFirstElementInPane(); + Trigger repeatedly = Repeatedly.forever(trigger); + assertEquals( + Repeatedly.forever(trigger.getContinuationTrigger()), repeatedly.getContinuationTrigger()); + assertEquals( + Repeatedly.forever(trigger.getContinuationTrigger().getContinuationTrigger()), + repeatedly.getContinuationTrigger().getContinuationTrigger()); + } + + @Test + public void testShouldFireAfterMerge() throws Exception { + tester = TriggerTester.forTrigger( + Repeatedly.forever(AfterPane.elementCountAtLeast(2)), + Sessions.withGapDuration(Duration.millis(10))); + + tester.injectElements(1); + IntervalWindow firstWindow = new IntervalWindow(new Instant(1), new Instant(11)); + assertFalse(tester.shouldFire(firstWindow)); + + tester.injectElements(5); + IntervalWindow secondWindow = new IntervalWindow(new Instant(5), new Instant(15)); + assertFalse(tester.shouldFire(secondWindow)); + + // Merge them, if the merged window were on the second trigger, it would be ready + tester.mergeWindows(); + IntervalWindow mergedWindow = new IntervalWindow(new Instant(1), new Instant(15)); + assertTrue(tester.shouldFire(mergedWindow)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SessionsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SessionsTest.java new file mode 100644 index 000000000000..91049cd9d9d8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SessionsTest.java @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.runWindowFn; +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.set; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Tests for Sessions WindowFn. + */ +@RunWith(JUnit4.class) +public class SessionsTest { + + @Test + public void testSimple() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(0)); + expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10)); + expected.put(new IntervalWindow(new Instant(101), new Instant(111)), set(101)); + assertEquals( + expected, + runWindowFn( + Sessions.withGapDuration(new Duration(10)), + Arrays.asList(0L, 10L, 101L))); + } + + @Test + public void testConsecutive() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(1), new Instant(19)), set(1, 2, 5, 9)); + expected.put(new IntervalWindow(new Instant(100), new Instant(111)), set(100, 101)); + assertEquals( + expected, + runWindowFn( + Sessions.withGapDuration(new Duration(10)), + Arrays.asList(1L, 2L, 5L, 9L, 100L, 101L))); + } + + @Test + public void testMerging() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(1), new Instant(40)), set(1, 10, 15, 22, 30)); + expected.put(new IntervalWindow(new Instant(95), new Instant(111)), set(95, 100, 101)); + assertEquals( + expected, + runWindowFn( + Sessions.withGapDuration(new Duration(10)), + Arrays.asList(1L, 15L, 30L, 100L, 101L, 95L, 22L, 10L))); + } + + @Test + public void testTimeUnit() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(1), new Instant(2000)), set(1, 2, 1000)); + expected.put(new IntervalWindow(new Instant(5000), new Instant(6001)), set(5000, 5001)); + expected.put(new IntervalWindow(new Instant(10000), new Instant(11000)), set(10000)); + assertEquals( + expected, + runWindowFn( + Sessions.withGapDuration(Duration.standardSeconds(1)), + Arrays.asList(1L, 2L, 1000L, 5000L, 5001L, 10000L))); + } + + @Test + public void testEquality() { + assertTrue( + Sessions.withGapDuration(new Duration(10)).isCompatible( + Sessions.withGapDuration(new Duration(10)))); + assertTrue( + Sessions.withGapDuration(new Duration(10)).isCompatible( + Sessions.withGapDuration(new Duration(20)))); + } + + /** + * Validates that the output timestamp for aggregate data falls within the acceptable range. + */ + @Test + public void testValidOutputTimes() throws Exception { + for (long timestamp : Arrays.asList(200, 800, 700)) { + WindowFnTestUtils.validateGetOutputTimestamp( + Sessions.withGapDuration(Duration.millis(500)), timestamp); + } + } + + /** + * Test to confirm that {@link Sessions} with the default {@link OutputTimeFn} holds up the + * watermark potentially indefinitely. + */ + @Test + public void testInvalidOutputAtEarliest() throws Exception { + try { + WindowFnTestUtils.validateGetOutputTimestamps( + Sessions.withGapDuration(Duration.millis(10)), + OutputTimeFns.outputAtEarliestInputTimestamp(), + ImmutableList.of( + (List) ImmutableList.of(1L, 3L), + (List) ImmutableList.of(0L, 5L, 10L, 15L, 20L))); + } catch (AssertionError exc) { + assertThat( + exc.getMessage(), + // These are the non-volatile pieces of the error message that a timestamp + // was not greater than what it should be. + allOf(containsString("a value greater than"), containsString("was less than"))); + } + } + + /** + * When a user explicitly requests per-key aggregate values have their derived timestamp to be + * the end of the window (instead of the earliest possible), the session here should not hold + * each other up, even though they overlap. + */ + @Test + public void testValidOutputAtEndTimes() throws Exception { + WindowFnTestUtils.validateGetOutputTimestamps( + Sessions.withGapDuration(Duration.millis(10)), + OutputTimeFns.outputAtEndOfWindow(), + ImmutableList.of( + (List) ImmutableList.of(1L, 3L), + (List) ImmutableList.of(0L, 5L, 10L, 15L, 20L))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindowsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindowsTest.java new file mode 100644 index 000000000000..33c4b8b81672 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/SlidingWindowsTest.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.runWindowFn; +import static com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils.set; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.testing.WindowFnTestUtils; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Tests for the SlidingWindows WindowFn. + */ +@RunWith(JUnit4.class) +public class SlidingWindowsTest { + + @Test + public void testSimple() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5), new Instant(5)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(1, 2, 5, 9)); + expected.put(new IntervalWindow(new Instant(5), new Instant(15)), set(5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10, 11)); + assertEquals( + expected, + runWindowFn( + SlidingWindows.of(new Duration(10)).every(new Duration(5)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + } + + @Test + public void testSlightlyOverlapping() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5), new Instant(2)), set(1)); + expected.put(new IntervalWindow(new Instant(0), new Instant(7)), set(1, 2, 5)); + expected.put(new IntervalWindow(new Instant(5), new Instant(12)), set(5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(10), new Instant(17)), set(10, 11)); + assertEquals( + expected, + runWindowFn( + SlidingWindows.of(new Duration(7)).every(new Duration(5)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + } + + @Test + public void testElidings() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(3)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(10), new Instant(13)), set(10, 11)); + expected.put(new IntervalWindow(new Instant(100), new Instant(103)), set(100)); + assertEquals( + expected, + runWindowFn( + // Only look at the first 3 millisecs of every 10-millisec interval. + SlidingWindows.of(new Duration(3)).every(new Duration(10)), + Arrays.asList(1L, 2L, 3L, 5L, 9L, 10L, 11L, 100L))); + } + + @Test + public void testOffset() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-8), new Instant(2)), set(1)); + expected.put(new IntervalWindow(new Instant(-3), new Instant(7)), set(1, 2, 5)); + expected.put(new IntervalWindow(new Instant(2), new Instant(12)), set(2, 5, 9, 10, 11)); + expected.put(new IntervalWindow(new Instant(7), new Instant(17)), set(9, 10, 11)); + assertEquals( + expected, + runWindowFn( + SlidingWindows.of(new Duration(10)).every(new Duration(5)).withOffset(new Duration(2)), + Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + } + + @Test + public void testTimeUnit() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(-5000), new Instant(5000)), set(1, 2, 1000)); + expected.put(new IntervalWindow(new Instant(0), new Instant(10000)), + set(1, 2, 1000, 5000, 5001)); + expected.put(new IntervalWindow(new Instant(5000), new Instant(15000)), set(5000, 5001, 10000)); + expected.put(new IntervalWindow(new Instant(10000), new Instant(20000)), set(10000)); + assertEquals( + expected, + runWindowFn( + SlidingWindows.of(Duration.standardSeconds(10)).every(Duration.standardSeconds(5)), + Arrays.asList(1L, 2L, 1000L, 5000L, 5001L, 10000L))); + } + + @Test + public void testDefaultPeriods() throws Exception { + assertEquals(Duration.standardHours(1), + SlidingWindows.getDefaultPeriod(Duration.standardDays(1))); + assertEquals(Duration.standardHours(1), + SlidingWindows.getDefaultPeriod(Duration.standardHours(2))); + assertEquals(Duration.standardMinutes(1), + SlidingWindows.getDefaultPeriod(Duration.standardHours(1))); + assertEquals(Duration.standardMinutes(1), + SlidingWindows.getDefaultPeriod(Duration.standardMinutes(10))); + assertEquals(Duration.standardSeconds(1), + SlidingWindows.getDefaultPeriod(Duration.standardMinutes(1))); + assertEquals(Duration.standardSeconds(1), + SlidingWindows.getDefaultPeriod(Duration.standardSeconds(10))); + assertEquals(Duration.millis(1), + SlidingWindows.getDefaultPeriod(Duration.standardSeconds(1))); + assertEquals(Duration.millis(1), + SlidingWindows.getDefaultPeriod(Duration.millis(10))); + assertEquals(Duration.millis(1), + SlidingWindows.getDefaultPeriod(Duration.millis(1))); + } + + @Test + public void testEquality() { + assertTrue( + SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(10)))); + assertTrue( + SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(10)))); + + assertFalse(SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(20)))); + assertFalse(SlidingWindows.of(new Duration(10)).isCompatible( + SlidingWindows.of(new Duration(20)))); + } + + @Test + public void testGetSideInputWindow() { + // [40, 1040), [340, 1340), [640, 1640) ... + SlidingWindows slidingWindows = SlidingWindows.of(new Duration(1000)) + .every(new Duration(300)).withOffset(new Duration(40)); + // Prior + assertEquals( + new IntervalWindow(new Instant(340), new Instant(1340)), + slidingWindows.getSideInputWindow( + new IntervalWindow(new Instant(0), new Instant(1041)))); + assertEquals( + new IntervalWindow(new Instant(340), new Instant(1340)), + slidingWindows.getSideInputWindow( + new IntervalWindow(new Instant(0), new Instant(1339)))); + // Align + assertEquals( + new IntervalWindow(new Instant(340), new Instant(1340)), + slidingWindows.getSideInputWindow( + new IntervalWindow(new Instant(0), new Instant(1340)))); + // After + assertEquals( + new IntervalWindow(new Instant(640), new Instant(1640)), + slidingWindows.getSideInputWindow( + new IntervalWindow(new Instant(0), new Instant(1341)))); + } + + @Test + public void testValidOutputTimes() throws Exception { + for (long timestamp : Arrays.asList(200, 800, 499, 500, 501, 700, 1000)) { + WindowFnTestUtils.validateGetOutputTimestamp( + SlidingWindows.of(new Duration(1000)).every(new Duration(500)), timestamp); + } + } + + @Test + public void testOutputTimesNonInterference() throws Exception { + for (long timestamp : Arrays.asList(200, 800, 700)) { + WindowFnTestUtils.validateNonInterferingOutputTimes( + SlidingWindows.of(new Duration(1000)).every(new Duration(500)), timestamp); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/TriggerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/TriggerTest.java new file mode 100644 index 000000000000..ddff33fbda54 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/TriggerTest.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for {@link Trigger}. + */ +@RunWith(JUnit4.class) +public class TriggerTest { + + @Test + public void testTriggerToString() throws Exception { + assertEquals("AfterWatermark.pastEndOfWindow()", AfterWatermark.pastEndOfWindow().toString()); + assertEquals("Repeatedly(AfterWatermark.pastEndOfWindow())", + Repeatedly.forever(AfterWatermark.pastEndOfWindow()).toString()); + } + + @Test + public void testIsCompatible() throws Exception { + assertTrue(new Trigger1(null).isCompatible(new Trigger1(null))); + assertTrue(new Trigger1(Arrays.>asList(new Trigger2(null))) + .isCompatible(new Trigger1(Arrays.>asList(new Trigger2(null))))); + + assertFalse(new Trigger1(null).isCompatible(new Trigger2(null))); + assertFalse(new Trigger1(Arrays.>asList(new Trigger1(null))) + .isCompatible(new Trigger1(Arrays.>asList(new Trigger2(null))))); + } + + private static class Trigger1 extends Trigger { + + private Trigger1(List> subTriggers) { + super(subTriggers); + } + + @Override + public void onElement(Trigger.OnElementContext c) { } + + @Override + public void onMerge(Trigger.OnMergeContext c) { } + + @Override + protected Trigger getContinuationTrigger( + List> continuationTriggers) { + return null; + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(IntervalWindow window) { + return null; + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return false; + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { } + } + + private static class Trigger2 extends Trigger { + + private Trigger2(List> subTriggers) { + super(subTriggers); + } + + @Override + public void onElement(Trigger.OnElementContext c) { } + + @Override + public void onMerge(Trigger.OnMergeContext c) { } + + @Override + protected Trigger getContinuationTrigger( + List> continuationTriggers) { + return null; + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(IntervalWindow window) { + return null; + } + + @Override + public boolean shouldFire(Trigger.TriggerContext context) throws Exception { + return false; + } + + @Override + public void onFire(Trigger.TriggerContext context) throws Exception { } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowTest.java new file mode 100644 index 000000000000..72f2b4c12d1c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowTest.java @@ -0,0 +1,226 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.NonDeterministicException; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +import java.io.Serializable; + +/** + * Tests for {@link Window}. + */ +@RunWith(JUnit4.class) +public class WindowTest implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + public void testWindowIntoSetWindowfn() { + WindowingStrategy strategy = TestPipeline.create() + .apply(Create.of("hello", "world").withCoder(StringUtf8Coder.of())) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(10)))) + .getWindowingStrategy(); + assertTrue(strategy.getWindowFn() instanceof FixedWindows); + assertTrue(strategy.getTrigger().getSpec() instanceof DefaultTrigger); + assertEquals(AccumulationMode.DISCARDING_FIRED_PANES, strategy.getMode()); + } + + @Test + public void testWindowIntoTriggersAndAccumulating() { + FixedWindows fixed10 = FixedWindows.of(Duration.standardMinutes(10)); + Repeatedly trigger = Repeatedly.forever(AfterPane.elementCountAtLeast(5)); + WindowingStrategy strategy = TestPipeline.create() + .apply(Create.of("hello", "world").withCoder(StringUtf8Coder.of())) + .apply(Window.into(fixed10) + .triggering(trigger) + .accumulatingFiredPanes() + .withAllowedLateness(Duration.ZERO)) + .getWindowingStrategy(); + + assertEquals(fixed10, strategy.getWindowFn()); + assertEquals(trigger, strategy.getTrigger().getSpec()); + assertEquals(AccumulationMode.ACCUMULATING_FIRED_PANES, strategy.getMode()); + } + + @Test + public void testWindowPropagatesEachPart() { + FixedWindows fixed10 = FixedWindows.of(Duration.standardMinutes(10)); + Repeatedly trigger = Repeatedly.forever(AfterPane.elementCountAtLeast(5)); + WindowingStrategy strategy = TestPipeline.create() + .apply(Create.of("hello", "world").withCoder(StringUtf8Coder.of())) + .apply("Mode", Window.accumulatingFiredPanes()) + .apply("Lateness", Window.withAllowedLateness(Duration.standardDays(1))) + .apply("Trigger", Window.triggering(trigger)) + .apply("Window", Window.into(fixed10)) + .getWindowingStrategy(); + + assertEquals(fixed10, strategy.getWindowFn()); + assertEquals(trigger, strategy.getTrigger().getSpec()); + assertEquals(AccumulationMode.ACCUMULATING_FIRED_PANES, strategy.getMode()); + assertEquals(Duration.standardDays(1), strategy.getAllowedLateness()); + } + + @Test + public void testWindowIntoPropagatesLateness() { + FixedWindows fixed10 = FixedWindows.of(Duration.standardMinutes(10)); + FixedWindows fixed25 = FixedWindows.of(Duration.standardMinutes(25)); + WindowingStrategy strategy = TestPipeline.create() + .apply(Create.of("hello", "world").withCoder(StringUtf8Coder.of())) + .apply(Window.named("WindowInto10").into(fixed10) + .withAllowedLateness(Duration.standardDays(1)) + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(5))) + .accumulatingFiredPanes()) + .apply(Window.named("WindowInto25").into(fixed25)) + .getWindowingStrategy(); + + assertEquals(Duration.standardDays(1), strategy.getAllowedLateness()); + assertEquals(fixed25, strategy.getWindowFn()); + } + + @Test + public void testWindowGetName() { + assertEquals("Window.Into()", + Window.into(FixedWindows.of(Duration.standardMinutes(10))).getName()); + } + + @Test + public void testNonDeterministicWindowCoder() throws NonDeterministicException { + FixedWindows mockWindowFn = Mockito.mock(FixedWindows.class); + @SuppressWarnings({"unchecked", "rawtypes"}) + Class> coderClazz = (Class) Coder.class; + Coder mockCoder = Mockito.mock(coderClazz); + when(mockWindowFn.windowCoder()).thenReturn(mockCoder); + NonDeterministicException toBeThrown = + new NonDeterministicException(mockCoder, "Its just not deterministic."); + Mockito.doThrow(toBeThrown).when(mockCoder).verifyDeterministic(); + + thrown.expect(IllegalArgumentException.class); + thrown.expectCause(Matchers.sameInstance(toBeThrown)); + thrown.expectMessage("Window coders must be deterministic"); + Window.into(mockWindowFn); + } + + @Test + public void testMissingMode() { + FixedWindows fixed10 = FixedWindows.of(Duration.standardMinutes(10)); + Repeatedly trigger = Repeatedly.forever(AfterPane.elementCountAtLeast(5)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("requires that the accumulation mode"); + TestPipeline.create() + .apply(Create.of("hello", "world").withCoder(StringUtf8Coder.of())) + .apply("Window", Window.into(fixed10)) + .apply("Lateness", Window.withAllowedLateness(Duration.standardDays(1))) + .apply("Trigger", Window.triggering(trigger)); + } + + @Test + public void testMissingLateness() { + FixedWindows fixed10 = FixedWindows.of(Duration.standardMinutes(10)); + Repeatedly trigger = Repeatedly.forever(AfterPane.elementCountAtLeast(5)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("requires that the allowed lateness"); + TestPipeline.create() + .apply(Create.of("hello", "world").withCoder(StringUtf8Coder.of())) + .apply("Mode", Window.accumulatingFiredPanes()) + .apply("Window", Window.into(fixed10)) + .apply("Trigger", Window.triggering(trigger)); + } + + /** + * Tests that when two elements are combined via a GroupByKey their output timestamp agrees + * with the windowing function default, the earlier of the two values. + */ + @Test + @Category(RunnableOnService.class) + public void testOutputTimeFnDefault() { + Pipeline pipeline = TestPipeline.create(); + + pipeline.apply( + Create.timestamped( + TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), + TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(10)))) + .apply(GroupByKey.create()) + .apply(ParDo.of(new DoFn>, Void>() { + @Override + public void processElement(ProcessContext c) throws Exception { + assertThat(c.timestamp(), equalTo(new Instant(0))); + } + })); + + pipeline.run(); + } + + /** + * Tests that when two elements are combined via a GroupByKey their output timestamp agrees + * with the windowing function customized to use the end of the window. + */ + @Test + @Category(RunnableOnService.class) + public void testOutputTimeFnEndOfWindow() { + Pipeline pipeline = TestPipeline.create(); + + pipeline.apply( + Create.timestamped( + TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), + TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) + .apply(Window.>into(FixedWindows.of(Duration.standardMinutes(10))) + .withOutputTimeFn(OutputTimeFns.outputAtEndOfWindow())) + .apply(GroupByKey.create()) + .apply(ParDo.of(new DoFn>, Void>() { + @Override + public void processElement(ProcessContext c) throws Exception { + assertThat(c.timestamp(), equalTo(new Instant(10 * 60 * 1000 - 1))); + } + })); + + pipeline.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingTest.java new file mode 100644 index 000000000000..1c1248bd0cf9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/windowing/WindowingTest.java @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.windowing; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.PrintStream; +import java.io.Serializable; + +/** Unit tests for bucketing. */ +@RunWith(JUnit4.class) +@SuppressWarnings("unchecked") +public class WindowingTest implements Serializable { + @Rule + public transient TemporaryFolder tmpFolder = new TemporaryFolder(); + + private static class WindowedCount extends PTransform, PCollection> { + + private final class FormatCountsDoFn + extends DoFn, String> implements RequiresWindowAccess { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey() + ":" + c.element().getValue() + + ":" + c.timestamp().getMillis() + ":" + c.window()); + } + } + private WindowFn windowFn; + public WindowedCount(WindowFn windowFn) { + this.windowFn = windowFn; + } + @Override + public PCollection apply(PCollection in) { + return in + .apply(Window.named("Window").into(windowFn)) + .apply(Count.perElement()) + .apply(ParDo + .named("FormatCounts").of(new FormatCountsDoFn())) + .setCoder(StringUtf8Coder.of()); + } + } + + private String output(String value, int count, int timestamp, int windowStart, int windowEnd) { + return value + ":" + count + ":" + timestamp + + ":[" + new Instant(windowStart) + ".." + new Instant(windowEnd) + ")"; + } + + @Test + @Category(RunnableOnService.class) + public void testPartitioningWindowing() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("b", new Instant(2)), + TimestampedValue.of("b", new Instant(3)), + TimestampedValue.of("c", new Instant(11)), + TimestampedValue.of("d", new Instant(11)))); + + PCollection output = + input + .apply(new WindowedCount(FixedWindows.of(new Duration(10)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 1, 1, 0, 10), + output("b", 2, 2, 0, 10), + output("c", 1, 11, 10, 20), + output("d", 1, 11, 10, 20)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testNonPartitioningWindowing() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(7)), + TimestampedValue.of("b", new Instant(8)))); + + PCollection output = + input + .apply(new WindowedCount( + SlidingWindows.of(new Duration(10)).every(new Duration(5)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 1, 1, -5, 5), + output("a", 2, 5, 0, 10), + output("a", 1, 10, 5, 15), + output("b", 1, 8, 0, 10), + output("b", 1, 10, 5, 15)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testMergingWindowing() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(5)), + TimestampedValue.of("a", new Instant(20)))); + + PCollection output = + input + .apply(new WindowedCount(Sessions.withGapDuration(new Duration(10)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 2, 1, 1, 15), + output("a", 1, 20, 20, 30)); + + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testWindowPreservation() { + Pipeline p = TestPipeline.create(); + PCollection input1 = p.apply("Create12", + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("b", new Instant(2)))); + + PCollection input2 = p.apply("Create34", + Create.timestamped( + TimestampedValue.of("a", new Instant(3)), + TimestampedValue.of("b", new Instant(4)))); + + PCollectionList input = PCollectionList.of(input1).and(input2); + + PCollection output = + input + .apply(Flatten.pCollections()) + .apply(new WindowedCount(FixedWindows.of(new Duration(5)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 2, 1, 0, 5), + output("b", 2, 2, 0, 5)); + + p.run(); + } + + @Test + public void testEmptyInput() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.timestamped() + .withCoder(StringUtf8Coder.of())); + + PCollection output = + input + .apply(new WindowedCount(FixedWindows.of(new Duration(10)))); + + DataflowAssert.that(output).empty(); + + p.run(); + } + + @Test + public void testTextIoInput() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + try (PrintStream writer = new PrintStream(new FileOutputStream(tmpFile))) { + writer.println("a 1"); + writer.println("b 2"); + writer.println("b 3"); + writer.println("c 11"); + writer.println("d 11"); + } + + Pipeline p = TestPipeline.create(); + PCollection output = p.begin() + .apply(TextIO.Read.named("ReadLines").from(filename)) + .apply(ParDo.of(new ExtractWordsWithTimestampsFn())) + .apply(new WindowedCount(FixedWindows.of(Duration.millis(10)))); + + DataflowAssert.that(output).containsInAnyOrder( + output("a", 1, 1, 0, 10), + output("b", 2, 2, 0, 10), + output("c", 1, 11, 10, 20), + output("d", 1, 11, 10, 20)); + + p.run(); + } + + /** A DoFn that tokenizes lines of text into individual words. */ + static class ExtractWordsWithTimestampsFn extends DoFn { + @Override + public void processElement(ProcessContext c) { + String[] words = c.element().split("[^a-zA-Z0-9']+"); + if (words.length == 2) { + c.outputWithTimestamp(words[0], new Instant(Long.parseLong(words[1]))); + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java new file mode 100644 index 000000000000..e995b821de69 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java @@ -0,0 +1,186 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Tests for ApiSurface. These both test the functionality and also that our + * public API is conformant to a hard-coded policy. + */ +@RunWith(JUnit4.class) +public class ApiSurfaceTest { + + @Test + public void testOurApiSurface() throws Exception { + ApiSurface checkedApiSurface = ApiSurface.getSdkApiSurface() + .pruningClassName("com.google.cloud.dataflow.sdk.runners.worker.StateFetcher") + .pruningClassName("com.google.cloud.dataflow.sdk.util.common.ReflectHelpers") + .pruningClassName("com.google.cloud.dataflow.sdk.DataflowMatchers") + .pruningClassName("com.google.cloud.dataflow.sdk.TestUtils") + .pruningClassName("com.google.cloud.dataflow.sdk.WindowMatchers"); + + checkedApiSurface.getExposedClasses(); + + Map, List>> disallowedClasses = Maps.newHashMap(); + for (Class clazz : checkedApiSurface.getExposedClasses()) { + if (!classIsAllowed(clazz)) { + disallowedClasses.put(clazz, checkedApiSurface.getAnyExposurePath(clazz)); + } + } + + List disallowedMessages = Lists.newArrayList(); + for (Map.Entry, List>> entry : disallowedClasses.entrySet()) { + disallowedMessages.add(entry.getKey() + " exposed via:\n\t\t" + + Joiner.on("\n\t\t").join(entry.getValue())); + } + Collections.sort(disallowedMessages); + + if (!disallowedMessages.isEmpty()) { + fail("The following disallowed classes appear in the public API surface of the SDK:\n\t" + + Joiner.on("\n\t").join(disallowedMessages)); + } + } + + private boolean classIsAllowed(Class clazz) { + return clazz.getName().startsWith("com.google.cloud.dataflow"); + } + + ////////////////////////////////////////////////////////////////////////////////// + + @SuppressWarnings({"rawtypes", "unchecked"}) + private void assertExposed(Class classToExamine, Class... exposedClasses) { + ApiSurface apiSurface = ApiSurface + .ofClass(classToExamine) + .pruningPrefix("java"); + + Set expectedExposed = Sets.newHashSet(classToExamine); + for (Class clazz : exposedClasses) { + expectedExposed.add(clazz); + } + assertThat(apiSurface.getExposedClasses(), containsInAnyOrder(expectedExposed.toArray())); + } + + private static interface Exposed { } + + private static interface ExposedReturnType { + Exposed zero(); + } + + @Test + public void testExposedReturnType() throws Exception { + assertExposed(ExposedReturnType.class, Exposed.class); + } + + private static interface ExposedParameterTypeVarBound { + void getList(T whatever); + } + + @Test + public void testExposedParameterTypeVarBound() throws Exception { + assertExposed(ExposedParameterTypeVarBound.class, Exposed.class); + } + + private static interface ExposedWildcardBound { + void acceptList(List arg); + } + + @Test + public void testExposedWildcardBound() throws Exception { + assertExposed(ExposedWildcardBound.class, Exposed.class); + } + + private static interface ExposedActualTypeArgument extends List { } + + @Test + public void testExposedActualTypeArgument() throws Exception { + assertExposed(ExposedActualTypeArgument.class, Exposed.class); + } + + @Test + public void testIgnoreAll() throws Exception { + ApiSurface apiSurface = ApiSurface.ofClass(ExposedWildcardBound.class) + .includingClass(Object.class) + .includingClass(ApiSurface.class) + .pruningPattern(".*"); + assertThat(apiSurface.getExposedClasses(), emptyIterable()); + } + + private static interface PrunedPattern { } + private static interface NotPruned extends PrunedPattern { } + + @Test + public void testprunedPattern() throws Exception { + ApiSurface apiSurface = ApiSurface.ofClass(NotPruned.class) + .pruningClass(PrunedPattern.class); + assertThat(apiSurface.getExposedClasses(), containsInAnyOrder((Class) NotPruned.class)); + } + + private static interface ExposedTwice { + Exposed zero(); + Exposed one(); + } + + @Test + public void testExposedTwice() throws Exception { + assertExposed(ExposedTwice.class, Exposed.class); + } + + private static interface ExposedCycle { + ExposedCycle zero(Exposed foo); + } + + @Test + public void testExposedCycle() throws Exception { + assertExposed(ExposedCycle.class, Exposed.class); + } + + private static interface ExposedGenericCycle { + Exposed zero(List foo); + } + + @Test + public void testExposedGenericCycle() throws Exception { + assertExposed(ExposedGenericCycle.class, Exposed.class); + } + + private static interface ExposedArrayCycle { + Exposed zero(ExposedArrayCycle[] foo); + } + + @Test + public void testExposedArrayCycle() throws Exception { + assertExposed(ExposedArrayCycle.class, Exposed.class); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptAndTimeBoundedExponentialBackOffTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptAndTimeBoundedExponentialBackOffTest.java new file mode 100644 index 000000000000..1d1f27f61b32 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptAndTimeBoundedExponentialBackOffTest.java @@ -0,0 +1,212 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.api.client.util.BackOff; +import com.google.cloud.dataflow.sdk.testing.FastNanoClockAndSleeper; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link AttemptAndTimeBoundedExponentialBackOff}. */ +@RunWith(JUnit4.class) +public class AttemptAndTimeBoundedExponentialBackOffTest { + @Rule public ExpectedException exception = ExpectedException.none(); + @Rule public FastNanoClockAndSleeper fastClock = new FastNanoClockAndSleeper(); + + @Test + public void testUsingInvalidInitialInterval() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Initial interval must be greater than zero."); + new AttemptAndTimeBoundedExponentialBackOff(10, 0L, 1000L); + } + + @Test + public void testUsingInvalidTimeInterval() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Maximum total wait time must be greater than zero."); + new AttemptAndTimeBoundedExponentialBackOff(10, 2L, 0L); + } + + @Test + public void testUsingInvalidMaximumNumberOfRetries() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Maximum number of attempts must be greater than zero."); + new AttemptAndTimeBoundedExponentialBackOff(-1, 10L, 1000L); + } + + @Test + public void testThatFixedNumberOfAttemptsExits() throws Exception { + BackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, + 500L, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ALL, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testThatResettingAllowsReuse() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, + 500, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ALL, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + + backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 30, + 500, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ALL, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + fastClock.sleep(2000L); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + } + + @Test + public void testThatResettingAttemptsAllowsReuse() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, + 500, + 1000, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ATTEMPTS, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testThatResettingAttemptsDoesNotAllowsReuse() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 30, + 500, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ATTEMPTS, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + fastClock.sleep(2000L); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testThatResettingTimerAllowsReuse() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 30, + 500, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.TIMER, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + fastClock.sleep(2000L); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(561L), lessThan(1688L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(843L), lessThan(2531L))); + } + + @Test + public void testThatResettingTimerDoesNotAllowReuse() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, + 500, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.TIMER, + fastClock); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testTimeBound() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, 500L, 5L, AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ALL, fastClock); + assertEquals(backOff.nextBackOffMillis(), 5L); + } + + @Test + public void testAtMaxAttempts() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, + 500L, + 1000L, + AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ALL, + fastClock); + assertFalse(backOff.atMaxAttempts()); + backOff.nextBackOffMillis(); + assertFalse(backOff.atMaxAttempts()); + backOff.nextBackOffMillis(); + assertTrue(backOff.atMaxAttempts()); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testAtMaxTime() throws Exception { + AttemptBoundedExponentialBackOff backOff = + new AttemptAndTimeBoundedExponentialBackOff( + 3, 500L, 1L, AttemptAndTimeBoundedExponentialBackOff.ResetPolicy.ALL, fastClock); + fastClock.sleep(2); + assertTrue(backOff.atMaxAttempts()); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOffTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOffTest.java new file mode 100644 index 000000000000..6f86e61b8ab4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AttemptBoundedExponentialBackOffTest.java @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + + +import com.google.api.client.util.BackOff; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link AttemptBoundedExponentialBackOff}. */ +@RunWith(JUnit4.class) +public class AttemptBoundedExponentialBackOffTest { + @Rule public ExpectedException exception = ExpectedException.none(); + + @Test + public void testUsingInvalidInitialInterval() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Initial interval must be greater than zero."); + new AttemptBoundedExponentialBackOff(10, 0L); + } + + @Test + public void testUsingInvalidMaximumNumberOfRetries() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Maximum number of attempts must be greater than zero."); + new AttemptBoundedExponentialBackOff(-1, 10L); + } + + @Test + public void testThatFixedNumberOfAttemptsExits() throws Exception { + BackOff backOff = new AttemptBoundedExponentialBackOff(3, 500); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testThatResettingAllowsReuse() throws Exception { + BackOff backOff = new AttemptBoundedExponentialBackOff(3, 500); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThan(374L), lessThan(1126L))); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } + + @Test + public void testAtMaxAttempts() throws Exception { + AttemptBoundedExponentialBackOff backOff = new AttemptBoundedExponentialBackOff(3, 500); + assertFalse(backOff.atMaxAttempts()); + backOff.nextBackOffMillis(); + assertFalse(backOff.atMaxAttempts()); + backOff.nextBackOffMillis(); + assertTrue(backOff.atMaxAttempts()); + assertEquals(BackOff.STOP, backOff.nextBackOffMillis()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AvroUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AvroUtilsTest.java new file mode 100644 index 000000000000..d03ac89eab7f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/AvroUtilsTest.java @@ -0,0 +1,225 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.util.AvroUtils.AvroMetadata; +import com.google.common.collect.Lists; + +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.reflect.Nullable; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Tests for AvroUtils. + */ +@RunWith(JUnit4.class) +public class AvroUtilsTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private static final int DEFAULT_RECORD_COUNT = 10000; + + /** + * Generates an input Avro file containing the given records in the temporary directory and + * returns the full path of the file. + */ + @SuppressWarnings("deprecation") // test of internal functionality + private String generateTestFile(String filename, List elems, AvroCoder coder, + String codec) throws IOException { + File tmpFile = tmpFolder.newFile(filename); + String path = tmpFile.toString(); + + FileOutputStream os = new FileOutputStream(tmpFile); + DatumWriter datumWriter = coder.createDatumWriter(); + try (DataFileWriter writer = new DataFileWriter<>(datumWriter)) { + writer.setCodec(CodecFactory.fromString(codec)); + writer.create(coder.getSchema(), os); + for (T elem : elems) { + writer.append(elem); + } + } + return path; + } + + @Test + public void testReadMetadataWithCodecs() throws Exception { + // Test reading files generated using all codecs. + String codecs[] = {DataFileConstants.NULL_CODEC, DataFileConstants.BZIP2_CODEC, + DataFileConstants.DEFLATE_CODEC, DataFileConstants.SNAPPY_CODEC, + DataFileConstants.XZ_CODEC}; + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + + for (String codec : codecs) { + String filename = generateTestFile( + codec, expected, AvroCoder.of(Bird.class), codec); + AvroMetadata metadata = AvroUtils.readMetadataFromFile(filename); + assertEquals(codec, metadata.getCodec()); + } + } + + @Test + public void testReadSchemaString() throws Exception { + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + String codec = DataFileConstants.NULL_CODEC; + String filename = generateTestFile( + codec, expected, AvroCoder.of(Bird.class), codec); + AvroMetadata metadata = AvroUtils.readMetadataFromFile(filename); + // By default, parse validates the schema, which is what we want. + Schema schema = new Schema.Parser().parse(metadata.getSchemaString()); + assertEquals(8, schema.getFields().size()); + } + + @Test + public void testConvertGenericRecordToTableRow() throws Exception { + TableSchema tableSchema = new TableSchema(); + List subFields = Lists.newArrayList( + new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); + /* + * Note that the quality and quantity fields do not have their mode set, so they should default + * to NULLABLE. This is an important test of BigQuery semantics. + * + * All the other fields we set in this function are required on the Schema response. + * + * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema + */ + List fields = + Lists.newArrayList( + new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), + new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, + new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, + new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), + new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), + new TableFieldSchema().setName("scion").setType("RECORD").setMode("NULLABLE") + .setFields(subFields), + new TableFieldSchema().setName("associates").setType("RECORD").setMode("REPEATED") + .setFields(subFields)); + tableSchema.setFields(fields); + Schema avroSchema = AvroCoder.of(Bird.class).getSchema(); + + { + // Test nullable fields. + GenericRecord record = new GenericData.Record(avroSchema); + record.put("number", 5L); + TableRow convertedRow = AvroUtils.convertGenericRecordToTableRow(record, tableSchema); + TableRow row = new TableRow() + .set("number", "5") + .set("associates", new ArrayList()); + assertEquals(row, convertedRow); + } + { + // Test type conversion for TIMESTAMP, INTEGER, BOOLEAN, and FLOAT. + GenericRecord record = new GenericData.Record(avroSchema); + record.put("number", 5L); + record.put("quality", 5.0); + record.put("birthday", 5L); + record.put("flighted", Boolean.TRUE); + TableRow convertedRow = AvroUtils.convertGenericRecordToTableRow(record, tableSchema); + TableRow row = new TableRow() + .set("number", "5") + .set("birthday", "1970-01-01 00:00:00.000005 UTC") + .set("quality", 5.0) + .set("associates", new ArrayList()) + .set("flighted", Boolean.TRUE); + assertEquals(row, convertedRow); + } + { + // Test repeated fields. + Schema subBirdSchema = AvroCoder.of(Bird.SubBird.class).getSchema(); + GenericRecord nestedRecord = new GenericData.Record(subBirdSchema); + nestedRecord.put("species", "other"); + GenericRecord record = new GenericData.Record(avroSchema); + record.put("number", 5L); + record.put("associates", Lists.newArrayList(nestedRecord)); + TableRow convertedRow = AvroUtils.convertGenericRecordToTableRow(record, tableSchema); + TableRow row = new TableRow() + .set("associates", Lists.newArrayList( + new TableRow().set("species", "other"))) + .set("number", "5"); + assertEquals(row, convertedRow); + } + } + + /** + * Pojo class used as the record type in tests. + */ + @DefaultCoder(AvroCoder.class) + static class Bird { + long number; + @Nullable String species; + @Nullable Double quality; + @Nullable Long quantity; + @Nullable Long birthday; // Exercises TIMESTAMP. + @Nullable Boolean flighted; + @Nullable SubBird scion; + SubBird[] associates; + + static class SubBird { + @Nullable String species; + + public SubBird() {} + } + + public Bird() { + associates = new SubBird[1]; + associates[0] = new SubBird(); + } + } + + /** + * Create a list of n random records. + */ + private static List createRandomRecords(long n) { + String[] species = {"pigeons", "owls", "gulls", "hawks", "robins", "jays"}; + Random random = new Random(0); + + List records = new ArrayList<>(); + for (long i = 0; i < n; i++) { + Bird bird = new Bird(); + bird.quality = random.nextDouble(); + bird.species = species[random.nextInt(species.length)]; + bird.number = i; + bird.quantity = random.nextLong(); + records.add(bird); + } + return records; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BatchTimerInternalsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BatchTimerInternalsTest.java new file mode 100644 index 000000000000..25d07d62d792 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BatchTimerInternalsTest.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaceForTest; + +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link BatchTimerInternals}. + */ +@RunWith(JUnit4.class) +public class BatchTimerInternalsTest { + + private static final StateNamespace NS1 = new StateNamespaceForTest("NS1"); + + @Mock + private ReduceFnRunner mockRunner; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testFiringTimers() throws Exception { + BatchTimerInternals underTest = new BatchTimerInternals(new Instant(0)); + TimerData processingTime1 = TimerData.of(NS1, new Instant(19), TimeDomain.PROCESSING_TIME); + TimerData processingTime2 = TimerData.of(NS1, new Instant(29), TimeDomain.PROCESSING_TIME); + + underTest.setTimer(processingTime1); + underTest.setTimer(processingTime2); + + underTest.advanceProcessingTime(mockRunner, new Instant(20)); + Mockito.verify(mockRunner).onTimer(processingTime1); + Mockito.verifyNoMoreInteractions(mockRunner); + + // Advancing just a little shouldn't refire + underTest.advanceProcessingTime(mockRunner, new Instant(21)); + Mockito.verifyNoMoreInteractions(mockRunner); + + // Adding the timer and advancing a little should refire + underTest.setTimer(processingTime1); + Mockito.verify(mockRunner).onTimer(processingTime1); + underTest.advanceProcessingTime(mockRunner, new Instant(21)); + Mockito.verifyNoMoreInteractions(mockRunner); + + // And advancing the rest of the way should still have the other timer + underTest.advanceProcessingTime(mockRunner, new Instant(30)); + Mockito.verify(mockRunner).onTimer(processingTime2); + Mockito.verifyNoMoreInteractions(mockRunner); + } + + @Test + public void testTimerOrdering() throws Exception { + BatchTimerInternals underTest = new BatchTimerInternals(new Instant(0)); + TimerData watermarkTime1 = TimerData.of(NS1, new Instant(19), TimeDomain.EVENT_TIME); + TimerData processingTime1 = TimerData.of(NS1, new Instant(19), TimeDomain.PROCESSING_TIME); + TimerData watermarkTime2 = TimerData.of(NS1, new Instant(29), TimeDomain.EVENT_TIME); + TimerData processingTime2 = TimerData.of(NS1, new Instant(29), TimeDomain.PROCESSING_TIME); + + underTest.setTimer(processingTime1); + underTest.setTimer(watermarkTime1); + underTest.setTimer(processingTime2); + underTest.setTimer(watermarkTime2); + + underTest.advanceInputWatermark(mockRunner, new Instant(30)); + Mockito.verify(mockRunner).onTimer(watermarkTime1); + Mockito.verify(mockRunner).onTimer(watermarkTime2); + Mockito.verifyNoMoreInteractions(mockRunner); + + underTest.advanceProcessingTime(mockRunner, new Instant(30)); + Mockito.verify(mockRunner).onTimer(processingTime1); + Mockito.verify(mockRunner).onTimer(processingTime2); + Mockito.verifyNoMoreInteractions(mockRunner); + } + + @Test + public void testDeduplicate() throws Exception { + BatchTimerInternals underTest = new BatchTimerInternals(new Instant(0)); + TimerData watermarkTime = TimerData.of(NS1, new Instant(19), TimeDomain.EVENT_TIME); + TimerData processingTime = TimerData.of(NS1, new Instant(19), TimeDomain.PROCESSING_TIME); + underTest.setTimer(watermarkTime); + underTest.setTimer(watermarkTime); + underTest.setTimer(processingTime); + underTest.setTimer(processingTime); + underTest.advanceProcessingTime(mockRunner, new Instant(20)); + underTest.advanceInputWatermark(mockRunner, new Instant(20)); + + Mockito.verify(mockRunner).onTimer(processingTime); + Mockito.verify(mockRunner).onTimer(watermarkTime); + Mockito.verifyNoMoreInteractions(mockRunner); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserterTest.java new file mode 100644 index 000000000000..d53315ba6688 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryTableInserterTest.java @@ -0,0 +1,239 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Verify.verifyNotNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.client.googleapis.json.GoogleJsonError; +import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo; +import com.google.api.client.googleapis.json.GoogleJsonErrorContainer; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.GenericJson; +import com.google.api.client.json.Json; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.util.BackOff; +import com.google.api.client.util.Sleeper; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableReference; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.hadoop.util.RetryBoundedBackOff; +import com.google.common.collect.ImmutableList; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * Tests of {@link BigQueryTableInserter}. + */ +@RunWith(JUnit4.class) +public class BigQueryTableInserterTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(BigQueryTableInserter.class); + @Mock private LowLevelHttpResponse response; + private Bigquery bigquery; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + // A mock transport that lets us mock the API responses. + MockHttpTransport transport = + new MockHttpTransport.Builder() + .setLowLevelHttpRequest( + new MockLowLevelHttpRequest() { + @Override + public LowLevelHttpResponse execute() throws IOException { + return response; + } + }) + .build(); + + // A sample BigQuery API client that uses default JsonFactory and RetryHttpInitializer. + bigquery = + new Bigquery.Builder( + transport, Transport.getJsonFactory(), new RetryHttpRequestInitializer()) + .build(); + } + + @After + public void tearDown() throws IOException { + // These three interactions happen for every request in the normal response parsing. + verify(response, atLeastOnce()).getContentEncoding(); + verify(response, atLeastOnce()).getHeaderCount(); + verify(response, atLeastOnce()).getReasonPhrase(); + verifyNoMoreInteractions(response); + } + + /** A helper to wrap a {@link GenericJson} object in a content stream. */ + private static InputStream toStream(GenericJson content) throws IOException { + return new ByteArrayInputStream(JacksonFactory.getDefaultInstance().toByteArray(content)); + } + + /** A helper that generates the error JSON payload that Google APIs produce. */ + private static GoogleJsonErrorContainer errorWithReasonAndStatus(String reason, int status) { + ErrorInfo info = new ErrorInfo(); + info.setReason(reason); + info.setDomain("global"); + // GoogleJsonError contains one or more ErrorInfo objects; our utiities read the first one. + GoogleJsonError error = new GoogleJsonError(); + error.setErrors(ImmutableList.of(info)); + error.setCode(status); + // The actual JSON response is an error container. + GoogleJsonErrorContainer container = new GoogleJsonErrorContainer(); + container.setError(error); + return container; + } + + /** + * Tests that {@link BigQueryTableInserter} succeeds on the first try. + */ + @Test + public void testCreateTableSucceeds() throws IOException { + Table testTable = new Table().setDescription("a table"); + + when(response.getContentType()).thenReturn(Json.MEDIA_TYPE); + when(response.getStatusCode()).thenReturn(200); + when(response.getContent()).thenReturn(toStream(testTable)); + + BigQueryTableInserter inserter = new BigQueryTableInserter(bigquery); + Table ret = + inserter.tryCreateTable( + new Table(), + "project", + "dataset", + new RetryBoundedBackOff(0, BackOff.ZERO_BACKOFF), + Sleeper.DEFAULT); + assertEquals(testTable, ret); + verify(response, times(1)).getStatusCode(); + verify(response, times(1)).getContent(); + verify(response, times(1)).getContentType(); + } + + /** + * Tests that {@link BigQueryTableInserter} succeeds when the table already exists. + */ + @Test + public void testCreateTableSucceedsAlreadyExists() throws IOException { + when(response.getStatusCode()).thenReturn(409); // 409 means already exists + + BigQueryTableInserter inserter = new BigQueryTableInserter(bigquery); + Table ret = + inserter.tryCreateTable( + new Table(), + "project", + "dataset", + new RetryBoundedBackOff(0, BackOff.ZERO_BACKOFF), + Sleeper.DEFAULT); + + assertNull(ret); + verify(response, times(1)).getStatusCode(); + verify(response, times(1)).getContent(); + verify(response, times(1)).getContentType(); + } + + /** + * Tests that {@link BigQueryTableInserter} retries quota rate limited attempts. + */ + @Test + public void testCreateTableRetry() throws IOException { + TableReference ref = + new TableReference().setProjectId("project").setDatasetId("dataset").setTableId("table"); + Table testTable = new Table().setTableReference(ref); + + // First response is 403 rate limited, second response has valid payload. + when(response.getContentType()).thenReturn(Json.MEDIA_TYPE); + when(response.getStatusCode()).thenReturn(403).thenReturn(200); + when(response.getContent()) + .thenReturn(toStream(errorWithReasonAndStatus("rateLimitExceeded", 403))) + .thenReturn(toStream(testTable)); + + BigQueryTableInserter inserter = new BigQueryTableInserter(bigquery); + Table ret = + inserter.tryCreateTable( + testTable, + "project", + "dataset", + new RetryBoundedBackOff(3, BackOff.ZERO_BACKOFF), + Sleeper.DEFAULT); + assertEquals(testTable, ret); + verify(response, times(2)).getStatusCode(); + verify(response, times(2)).getContent(); + verify(response, times(2)).getContentType(); + verifyNotNull(ret.getTableReference()); + expectedLogs.verifyInfo( + "Quota limit reached when creating table project:dataset.table, " + + "retrying up to 5.0 minutes"); + } + + /** + * Tests that {@link BigQueryTableInserter} does not retry non-rate-limited attempts. + */ + @Test + public void testCreateTableDoesNotRetry() throws IOException { + Table testTable = new Table().setDescription("a table"); + + // First response is 403 not-rate-limited, second response has valid payload but should not + // be invoked. + when(response.getContentType()).thenReturn(Json.MEDIA_TYPE); + when(response.getStatusCode()).thenReturn(403).thenReturn(200); + when(response.getContent()) + .thenReturn(toStream(errorWithReasonAndStatus("actually forbidden", 403))) + .thenReturn(toStream(testTable)); + + thrown.expect(GoogleJsonResponseException.class); + thrown.expectMessage("actually forbidden"); + + BigQueryTableInserter inserter = new BigQueryTableInserter(bigquery); + try { + inserter.tryCreateTable( + new Table(), + "project", + "dataset", + new RetryBoundedBackOff(3, BackOff.ZERO_BACKOFF), + Sleeper.DEFAULT); + fail(); + } catch (IOException e) { + verify(response, times(1)).getStatusCode(); + verify(response, times(1)).getContent(); + verify(response, times(1)).getContentType(); + throw e; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIteratorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIteratorTest.java new file mode 100644 index 000000000000..b82e62595cd8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryTableRowIteratorTest.java @@ -0,0 +1,255 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Dataset; +import com.google.api.services.bigquery.model.ErrorProto; +import com.google.api.services.bigquery.model.Job; +import com.google.api.services.bigquery.model.JobConfiguration; +import com.google.api.services.bigquery.model.JobConfigurationQuery; +import com.google.api.services.bigquery.model.JobReference; +import com.google.api.services.bigquery.model.JobStatus; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableCell; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for {@link BigQueryTableRowIterator}. + */ +@RunWith(JUnit4.class) +public class BigQueryTableRowIteratorTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Mock private Bigquery mockClient; + @Mock private Bigquery.Datasets mockDatasets; + @Mock private Bigquery.Datasets.Delete mockDatasetsDelete; + @Mock private Bigquery.Datasets.Insert mockDatasetsInsert; + @Mock private Bigquery.Jobs mockJobs; + @Mock private Bigquery.Jobs.Get mockJobsGet; + @Mock private Bigquery.Jobs.Insert mockJobsInsert; + @Mock private Bigquery.Tables mockTables; + @Mock private Bigquery.Tables.Get mockTablesGet; + @Mock private Bigquery.Tables.Delete mockTablesDelete; + @Mock private Bigquery.Tabledata mockTabledata; + @Mock private Bigquery.Tabledata.List mockTabledataList; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + when(mockClient.tabledata()).thenReturn(mockTabledata); + when(mockTabledata.list(anyString(), anyString(), anyString())).thenReturn(mockTabledataList); + + when(mockClient.tables()).thenReturn(mockTables); + when(mockTables.delete(anyString(), anyString(), anyString())).thenReturn(mockTablesDelete); + when(mockTables.get(anyString(), anyString(), anyString())).thenReturn(mockTablesGet); + + when(mockClient.datasets()).thenReturn(mockDatasets); + when(mockDatasets.delete(anyString(), anyString())).thenReturn(mockDatasetsDelete); + when(mockDatasets.insert(anyString(), any(Dataset.class))).thenReturn(mockDatasetsInsert); + + when(mockClient.jobs()).thenReturn(mockJobs); + when(mockJobs.insert(anyString(), any(Job.class))).thenReturn(mockJobsInsert); + when(mockJobs.get(anyString(), anyString())).thenReturn(mockJobsGet); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockClient); + verifyNoMoreInteractions(mockDatasets); + verifyNoMoreInteractions(mockDatasetsDelete); + verifyNoMoreInteractions(mockDatasetsInsert); + verifyNoMoreInteractions(mockJobs); + verifyNoMoreInteractions(mockJobsGet); + verifyNoMoreInteractions(mockJobsInsert); + verifyNoMoreInteractions(mockTables); + verifyNoMoreInteractions(mockTablesDelete); + verifyNoMoreInteractions(mockTablesGet); + verifyNoMoreInteractions(mockTabledata); + verifyNoMoreInteractions(mockTabledataList); + } + + private static Table tableWithBasicSchema() { + return new Table() + .setSchema( + new TableSchema() + .setFields( + Arrays.asList( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("answer").setType("INTEGER")))); + } + + private TableRow rawRow(Object... args) { + List cells = new LinkedList<>(); + for (Object a : args) { + cells.add(new TableCell().setV(a)); + } + return new TableRow().setF(cells); + } + + private TableDataList rawDataList(TableRow... rows) { + return new TableDataList().setRows(Arrays.asList(rows)); + } + + /** + * Verifies that when the query runs, the correct data is returned and the temporary dataset and + * table are both cleaned up. + */ + @Test + public void testReadFromQuery() throws IOException, InterruptedException { + // Mock job inserting. + Job insertedJob = new Job().setJobReference(new JobReference()); + when(mockJobsInsert.execute()).thenReturn(insertedJob); + + // Mock job polling. + JobStatus status = new JobStatus().setState("DONE"); + TableReference tableRef = + new TableReference().setProjectId("project").setDatasetId("dataset").setTableId("table"); + JobConfigurationQuery queryConfig = new JobConfigurationQuery().setDestinationTable(tableRef); + Job getJob = + new Job() + .setJobReference(new JobReference()) + .setStatus(status) + .setConfiguration(new JobConfiguration().setQuery(queryConfig)); + when(mockJobsGet.execute()).thenReturn(getJob); + + // Mock table schema fetch. + when(mockTablesGet.execute()).thenReturn(tableWithBasicSchema()); + + // Mock table data fetch. + when(mockTabledataList.execute()).thenReturn(rawDataList(rawRow("Arthur", 42))); + + // Run query and verify + String query = "SELECT name, count from table"; + try (BigQueryTableRowIterator iterator = + BigQueryTableRowIterator.fromQuery(query, "project", mockClient, null)) { + iterator.open(); + assertTrue(iterator.advance()); + TableRow row = iterator.getCurrent(); + + assertTrue(row.containsKey("name")); + assertTrue(row.containsKey("answer")); + assertEquals("Arthur", row.get("name")); + assertEquals(42, row.get("answer")); + + assertFalse(iterator.advance()); + } + + // Temp dataset created and later deleted. + verify(mockClient, times(2)).datasets(); + verify(mockDatasets).insert(anyString(), any(Dataset.class)); + verify(mockDatasetsInsert).execute(); + verify(mockDatasets).delete(anyString(), anyString()); + verify(mockDatasetsDelete).execute(); + // Job inserted to run the query, polled once. + verify(mockClient, times(2)).jobs(); + verify(mockJobs).insert(anyString(), any(Job.class)); + verify(mockJobsInsert).execute(); + verify(mockJobs).get(anyString(), anyString()); + verify(mockJobsGet).execute(); + // Temp table get after query finish, deleted after reading. + verify(mockClient, times(2)).tables(); + verify(mockTables).get("project", "dataset", "table"); + verify(mockTablesGet).execute(); + verify(mockTables).delete(anyString(), anyString(), anyString()); + verify(mockTablesDelete).execute(); + // Table data read. + verify(mockClient).tabledata(); + verify(mockTabledata).list("project", "dataset", "table"); + verify(mockTabledataList).execute(); + } + + /** + * Verifies that when the query fails, the user gets a useful exception and the temporary dataset + * is cleaned up. Also verifies that the temporary table (which is never created) is not + * erroneously attempted to be deleted. + */ + @Test + public void testQueryFailed() throws IOException { + // Job can be created. + JobReference ref = new JobReference(); + Job insertedJob = new Job().setJobReference(ref); + when(mockJobsInsert.execute()).thenReturn(insertedJob); + + // Job state polled with an error. + String errorReason = "bad query"; + JobStatus status = + new JobStatus().setState("DONE").setErrorResult(new ErrorProto().setMessage(errorReason)); + Job getJob = new Job().setJobReference(ref).setStatus(status); + when(mockJobsGet.execute()).thenReturn(getJob); + + String query = "NOT A QUERY"; + try (BigQueryTableRowIterator iterator = + BigQueryTableRowIterator.fromQuery(query, "project", mockClient, null)) { + try { + iterator.open(); + fail(); + } catch (Exception expected) { + // Verify message explains cause and reports the query. + assertThat(expected.getMessage(), containsString("failed")); + assertThat(expected.getMessage(), containsString(errorReason)); + assertThat(expected.getMessage(), containsString(query)); + } + } + + // Temp dataset created and then later deleted. + verify(mockClient, times(2)).datasets(); + verify(mockDatasets).insert(anyString(), any(Dataset.class)); + verify(mockDatasetsInsert).execute(); + verify(mockDatasets).delete(anyString(), anyString()); + verify(mockDatasetsDelete).execute(); + // Job inserted to run the query, then polled once. + verify(mockClient, times(2)).jobs(); + verify(mockJobs).insert(anyString(), any(Job.class)); + verify(mockJobsInsert).execute(); + verify(mockJobs).get(anyString(), anyString()); + verify(mockJobsGet).execute(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryUtilTest.java new file mode 100644 index 000000000000..fab4aecb4708 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BigQueryUtilTest.java @@ -0,0 +1,479 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableCell; +import com.google.api.services.bigquery.model.TableDataInsertAllRequest; +import com.google.api.services.bigquery.model.TableDataInsertAllResponse; +import com.google.api.services.bigquery.model.TableDataList; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for util classes related to BigQuery. + */ +@RunWith(JUnit4.class) +public class BigQueryUtilTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock private Bigquery mockClient; + @Mock private Bigquery.Tables mockTables; + @Mock private Bigquery.Tables.Get mockTablesGet; + @Mock private Bigquery.Tabledata mockTabledata; + @Mock private Bigquery.Tabledata.List mockTabledataList; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockClient); + verifyNoMoreInteractions(mockTables); + verifyNoMoreInteractions(mockTablesGet); + verifyNoMoreInteractions(mockTabledata); + verifyNoMoreInteractions(mockTabledataList); + } + + private void onInsertAll(List> errorIndicesSequence) throws Exception { + when(mockClient.tabledata()) + .thenReturn(mockTabledata); + + final List responses = new ArrayList<>(); + for (List errorIndices : errorIndicesSequence) { + List errors = new ArrayList<>(); + for (long i : errorIndices) { + TableDataInsertAllResponse.InsertErrors error = + new TableDataInsertAllResponse.InsertErrors(); + error.setIndex(i); + } + TableDataInsertAllResponse response = new TableDataInsertAllResponse(); + response.setInsertErrors(errors); + responses.add(response); + } + + doAnswer( + new Answer() { + @Override + public Bigquery.Tabledata.InsertAll answer(InvocationOnMock invocation) throws Throwable { + Bigquery.Tabledata.InsertAll mockInsertAll = mock(Bigquery.Tabledata.InsertAll.class); + when(mockInsertAll.execute()) + .thenReturn(responses.get(0), + responses.subList(1, responses.size()).toArray( + new TableDataInsertAllResponse[responses.size() - 1])); + return mockInsertAll; + } + }) + .when(mockTabledata) + .insertAll(anyString(), anyString(), anyString(), any(TableDataInsertAllRequest.class)); + } + + private void verifyInsertAll(int expectedRetries) throws IOException { + verify(mockClient, times(expectedRetries)).tabledata(); + verify(mockTabledata, times(expectedRetries)) + .insertAll(anyString(), anyString(), anyString(), any(TableDataInsertAllRequest.class)); + } + + private void onTableGet(Table table) throws IOException { + when(mockClient.tables()) + .thenReturn(mockTables); + when(mockTables.get(anyString(), anyString(), anyString())) + .thenReturn(mockTablesGet); + when(mockTablesGet.execute()) + .thenReturn(table); + } + + private void verifyTableGet() throws IOException { + verify(mockClient).tables(); + verify(mockTables).get("project", "dataset", "table"); + verify(mockTablesGet, atLeastOnce()).execute(); + } + + private void onTableList(TableDataList result) throws IOException { + when(mockClient.tabledata()) + .thenReturn(mockTabledata); + when(mockTabledata.list(anyString(), anyString(), anyString())) + .thenReturn(mockTabledataList); + when(mockTabledataList.execute()) + .thenReturn(result); + } + + private void verifyTabledataList() throws IOException { + verify(mockClient, atLeastOnce()).tabledata(); + verify(mockTabledata, atLeastOnce()).list("project", "dataset", "table"); + verify(mockTabledataList, atLeastOnce()).execute(); + // Max results may be set when testing for an empty table. + verify(mockTabledataList, atLeast(0)).setMaxResults(anyLong()); + } + + private Table basicTableSchema() { + return new Table() + .setSchema(new TableSchema() + .setFields(Arrays.asList( + new TableFieldSchema() + .setName("name") + .setType("STRING"), + new TableFieldSchema() + .setName("answer") + .setType("INTEGER") + ))); + } + + private Table basicTableSchemaWithTime() { + return new Table() + .setSchema(new TableSchema() + .setFields(Arrays.asList( + new TableFieldSchema() + .setName("time") + .setType("TIMESTAMP") + ))); + } + + @Test + public void testReadWithTime() throws IOException, InterruptedException { + // The BigQuery JSON API returns timestamps in the following format: floating-point seconds + // since epoch (UTC) with microsecond precision. Test that we faithfully preserve a set of + // known values. + TableDataList input = rawDataList( + rawRow("1.430397296789E9"), + rawRow("1.45206228E9"), + rawRow("1.452062291E9"), + rawRow("1.4520622911E9"), + rawRow("1.45206229112E9"), + rawRow("1.452062291123E9"), + rawRow("1.4520622911234E9"), + rawRow("1.45206229112345E9"), + rawRow("1.452062291123456E9")); + onTableGet(basicTableSchemaWithTime()); + onTableList(input); + + // Known results verified from BigQuery's export to JSON on GCS API. + List expected = ImmutableList.of( + "2015-04-30 12:34:56.789 UTC", + "2016-01-06 06:38:00 UTC", + "2016-01-06 06:38:11 UTC", + "2016-01-06 06:38:11.1 UTC", + "2016-01-06 06:38:11.12 UTC", + "2016-01-06 06:38:11.123 UTC", + "2016-01-06 06:38:11.1234 UTC", + "2016-01-06 06:38:11.12345 UTC", + "2016-01-06 06:38:11.123456 UTC"); + + // Download the rows, verify the interactions. + List rows = new ArrayList<>(); + try (BigQueryTableRowIterator iterator = + BigQueryTableRowIterator.fromTable( + BigQueryIO.parseTableSpec("project:dataset.table"), mockClient)) { + iterator.open(); + while (iterator.advance()) { + rows.add(iterator.getCurrent()); + } + } + verifyTableGet(); + verifyTabledataList(); + + // Verify the timestamp converted as desired. + assertEquals("Expected input and output rows to have the same size", + expected.size(), rows.size()); + for (int i = 0; i < expected.size(); ++i) { + assertEquals("i=" + i, expected.get(i), rows.get(i).get("time")); + } + + } + + private TableRow rawRow(Object...args) { + List cells = new LinkedList<>(); + for (Object a : args) { + cells.add(new TableCell().setV(a)); + } + return new TableRow().setF(cells); + } + + private TableDataList rawDataList(TableRow...rows) { + return new TableDataList() + .setRows(Arrays.asList(rows)); + } + + @Test + public void testRead() throws IOException, InterruptedException { + onTableGet(basicTableSchema()); + + TableDataList dataList = rawDataList(rawRow("Arthur", 42)); + onTableList(dataList); + + try (BigQueryTableRowIterator iterator = BigQueryTableRowIterator.fromTable( + BigQueryIO.parseTableSpec("project:dataset.table"), + mockClient)) { + iterator.open(); + Assert.assertTrue(iterator.advance()); + TableRow row = iterator.getCurrent(); + + Assert.assertTrue(row.containsKey("name")); + Assert.assertTrue(row.containsKey("answer")); + Assert.assertEquals("Arthur", row.get("name")); + Assert.assertEquals(42, row.get("answer")); + + Assert.assertFalse(iterator.advance()); + + verifyTableGet(); + verifyTabledataList(); + } + } + + @Test + public void testReadEmpty() throws IOException, InterruptedException { + onTableGet(basicTableSchema()); + + // BigQuery may respond with a page token for an empty table, ensure we + // handle it. + TableDataList dataList = new TableDataList() + .setPageToken("FEED==") + .setTotalRows(0L); + onTableList(dataList); + + try (BigQueryTableRowIterator iterator = BigQueryTableRowIterator.fromTable( + BigQueryIO.parseTableSpec("project:dataset.table"), + mockClient)) { + iterator.open(); + + Assert.assertFalse(iterator.advance()); + + verifyTableGet(); + verifyTabledataList(); + } + } + + @Test + public void testReadMultiPage() throws IOException, InterruptedException { + onTableGet(basicTableSchema()); + + TableDataList page1 = rawDataList(rawRow("Row1", 1)) + .setPageToken("page2"); + TableDataList page2 = rawDataList(rawRow("Row2", 2)) + .setTotalRows(2L); + + when(mockClient.tabledata()) + .thenReturn(mockTabledata); + when(mockTabledata.list(anyString(), anyString(), anyString())) + .thenReturn(mockTabledataList); + when(mockTabledataList.execute()) + .thenReturn(page1) + .thenReturn(page2); + + try (BigQueryTableRowIterator iterator = BigQueryTableRowIterator.fromTable( + BigQueryIO.parseTableSpec("project:dataset.table"), + mockClient)) { + iterator.open(); + + List names = new LinkedList<>(); + while (iterator.advance()) { + names.add((String) iterator.getCurrent().get("name")); + } + + Assert.assertThat(names, Matchers.hasItems("Row1", "Row2")); + + verifyTableGet(); + verifyTabledataList(); + // The second call should have used a page token. + verify(mockTabledataList).setPageToken("page2"); + } + } + + @Test + public void testReadOpenFailure() throws IOException, InterruptedException { + thrown.expect(IOException.class); + + when(mockClient.tables()) + .thenReturn(mockTables); + when(mockTables.get(anyString(), anyString(), anyString())) + .thenReturn(mockTablesGet); + when(mockTablesGet.execute()) + .thenThrow(new IOException("No such table")); + + try (BigQueryTableRowIterator iterator = BigQueryTableRowIterator.fromTable( + BigQueryIO.parseTableSpec("project:dataset.table"), + mockClient)) { + try { + iterator.open(); // throws. + } finally { + verifyTableGet(); + } + } + } + + @Test + public void testWriteAppend() throws IOException { + onTableGet(basicTableSchema()); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + + BigQueryTableInserter inserter = new BigQueryTableInserter(mockClient); + + inserter.getOrCreateTable(ref, BigQueryIO.Write.WriteDisposition.WRITE_APPEND, + BigQueryIO.Write.CreateDisposition.CREATE_NEVER, null); + + verifyTableGet(); + } + + @Test + public void testWriteEmpty() throws IOException { + onTableGet(basicTableSchema()); + + TableDataList dataList = new TableDataList().setTotalRows(0L); + onTableList(dataList); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + + BigQueryTableInserter inserter = new BigQueryTableInserter(mockClient); + + inserter.getOrCreateTable(ref, BigQueryIO.Write.WriteDisposition.WRITE_EMPTY, + BigQueryIO.Write.CreateDisposition.CREATE_NEVER, null); + + verifyTableGet(); + verifyTabledataList(); + } + + @Test + public void testWriteEmptyFail() throws IOException { + thrown.expect(IOException.class); + + onTableGet(basicTableSchema()); + + TableDataList dataList = rawDataList(rawRow("Arthur", 42)); + onTableList(dataList); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + + BigQueryTableInserter inserter = new BigQueryTableInserter(mockClient); + + try { + inserter.getOrCreateTable(ref, BigQueryIO.Write.WriteDisposition.WRITE_EMPTY, + BigQueryIO.Write.CreateDisposition.CREATE_NEVER, null); + } finally { + verifyTableGet(); + verifyTabledataList(); + } + } + + @Test + public void testInsertAll() throws Exception, IOException { + // Build up a list of indices to fail on each invocation. This should result in + // 5 calls to insertAll. + List> errorsIndices = new ArrayList<>(); + errorsIndices.add(Arrays.asList(0L, 5L, 10L, 15L, 20L)); + errorsIndices.add(Arrays.asList(0L, 2L, 4L)); + errorsIndices.add(Arrays.asList(0L, 2L)); + errorsIndices.add(new ArrayList()); + onInsertAll(errorsIndices); + + TableReference ref = BigQueryIO + .parseTableSpec("project:dataset.table"); + BigQueryTableInserter inserter = new BigQueryTableInserter(mockClient, 5); + + List rows = new ArrayList<>(); + List ids = new ArrayList<>(); + for (int i = 0; i < 25; ++i) { + rows.add(rawRow("foo", 1234)); + ids.add(new String()); + } + + InMemoryLongSumAggregator byteCountAggregator = new InMemoryLongSumAggregator("ByteCount"); + try { + inserter.insertAll(ref, rows, ids, byteCountAggregator); + } finally { + verifyInsertAll(5); + // Each of the 25 rows is 23 bytes: "{f=[{v=foo}, {v=1234}]}" + assertEquals("Incorrect byte count", 25L * 23L, byteCountAggregator.getSum()); + } + } + + private static class InMemoryLongSumAggregator implements Aggregator { + private final String name; + private long sum = 0; + + public InMemoryLongSumAggregator(String name) { + this.name = name; + } + + @Override + public void addValue(Long value) { + sum += value; + } + + @Override + public String getName() { + return name; + } + + @Override + public CombineFn getCombineFn() { + return new Sum.SumLongFn(); + } + + public long getSum() { + return sum; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BufferedElementCountingOutputStreamTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BufferedElementCountingOutputStreamTest.java new file mode 100644 index 000000000000..af2f4425507f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BufferedElementCountingOutputStreamTest.java @@ -0,0 +1,205 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.collection.IsIterableContainingInOrder; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +/** + * Tests for {@link BufferedElementCountingOutputStream}. + */ +@RunWith(JUnit4.class) +public class BufferedElementCountingOutputStreamTest { + @Rule public final ExpectedException expectedException = ExpectedException.none(); + private static final int BUFFER_SIZE = 8; + + @Test + public void testEmptyValues() throws Exception { + testValues(Collections.emptyList()); + } + + @Test + public void testSingleValue() throws Exception { + testValues(toBytes("abc")); + } + + @Test + public void testSingleValueGreaterThanBuffer() throws Exception { + testValues(toBytes("abcdefghijklmnopqrstuvwxyz")); + } + + @Test + public void testMultipleValuesLessThanBuffer() throws Exception { + testValues(toBytes("a", "b", "c")); + } + + @Test + public void testMultipleValuesThatBecomeGreaterThanBuffer() throws Exception { + testValues(toBytes("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", + "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z")); + } + + @Test + public void testMultipleRandomSizedValues() throws Exception { + Random r = new Random(234589023580234890L); + byte[] randomData = new byte[r.nextInt(18)]; + for (int i = 0; i < 1000; ++i) { + List bytes = new ArrayList<>(); + for (int j = 0; j < 100; ++j) { + r.nextBytes(randomData); + bytes.add(randomData); + } + testValues(bytes); + } + } + + @Test + public void testFlushInMiddleOfElement() throws Exception { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + BufferedElementCountingOutputStream os = new BufferedElementCountingOutputStream(bos); + os.markElementStart(); + os.write(1); + os.flush(); + os.write(2); + os.close(); + assertArrayEquals(new byte[]{ 1, 1, 2, 0 }, bos.toByteArray()); + } + + @Test + public void testFlushInMiddleOfElementUsingByteArrays() throws Exception { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + BufferedElementCountingOutputStream os = new BufferedElementCountingOutputStream(bos); + os.markElementStart(); + os.write(new byte[]{ 1 }); + os.flush(); + os.write(new byte[]{ 2 }); + os.close(); + assertArrayEquals(new byte[]{ 1, 1, 2, 0 }, bos.toByteArray()); + } + + @Test + public void testFlushingWhenFinishedIsNoOp() throws Exception { + BufferedElementCountingOutputStream os = testValues(toBytes("a")); + os.flush(); + os.flush(); + os.flush(); + } + + @Test + public void testFinishingWhenFinishedIsNoOp() throws Exception { + BufferedElementCountingOutputStream os = testValues(toBytes("a")); + os.finish(); + os.finish(); + os.finish(); + } + + @Test + public void testClosingFinishesTheStream() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + BufferedElementCountingOutputStream os = createAndWriteValues(toBytes("abcdefghij"), baos); + os.close(); + verifyValues(toBytes("abcdefghij"), new ByteArrayInputStream(baos.toByteArray())); + } + + @Test + public void testAddingElementWhenFinishedThrows() throws Exception { + expectedException.expect(IOException.class); + expectedException.expectMessage("Stream has been finished."); + testValues(toBytes("a")).markElementStart(); + } + + @Test + public void testWritingByteWhenFinishedThrows() throws Exception { + expectedException.expect(IOException.class); + expectedException.expectMessage("Stream has been finished."); + testValues(toBytes("a")).write(1); + } + + @Test + public void testWritingBytesWhenFinishedThrows() throws Exception { + expectedException.expect(IOException.class); + expectedException.expectMessage("Stream has been finished."); + testValues(toBytes("a")).write("b".getBytes()); + } + + private List toBytes(String ... values) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (String value : values) { + builder.add(value.getBytes()); + } + return builder.build(); + } + + private BufferedElementCountingOutputStream + testValues(List expectedValues) throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + BufferedElementCountingOutputStream os = createAndWriteValues(expectedValues, baos); + os.finish(); + verifyValues(expectedValues, new ByteArrayInputStream(baos.toByteArray())); + return os; + } + + private void verifyValues(List expectedValues, InputStream is) throws Exception { + List values = new ArrayList<>(); + long count; + do { + count = VarInt.decodeLong(is); + for (int i = 0; i < count; ++i) { + values.add(ByteArrayCoder.of().decode(is, Context.NESTED)); + } + } while(count > 0); + + if (expectedValues.isEmpty()) { + assertTrue(values.isEmpty()); + } else { + assertThat(values, IsIterableContainingInOrder.contains(expectedValues.toArray())); + } + } + + private BufferedElementCountingOutputStream + createAndWriteValues(List values, OutputStream output) throws Exception { + BufferedElementCountingOutputStream os = + new BufferedElementCountingOutputStream(output, BUFFER_SIZE); + + for (byte[] value : values) { + os.markElementStart(); + ByteArrayCoder.of().encode(value, os, Context.NESTED); + } + return os; + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CoderUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CoderUtilsTest.java new file mode 100644 index 000000000000..e192f456bca6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CoderUtilsTest.java @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.CoderUtils.makeCloudEncoding; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.testing.CoderPropertiesTest.ClosingCoder; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * Tests for CoderUtils. + */ +@RunWith(JUnit4.class) +public class CoderUtilsTest { + + @Rule + public transient ExpectedException expectedException = ExpectedException.none(); + + static class TestCoder extends AtomicCoder { + public static TestCoder of() { + return new TestCoder(); + } + + @Override + public void encode(Integer value, OutputStream outStream, Context context) { + throw new RuntimeException("not expecting to be called"); + } + + @Override + public Integer decode(InputStream inStream, Context context) { + throw new RuntimeException("not expecting to be called"); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + throw new NonDeterministicException(this, + "TestCoder does not actually encode or decode."); + } + } + + @Test + public void testCoderExceptionPropagation() throws Exception { + @SuppressWarnings("unchecked") + Coder crashingCoder = mock(Coder.class); + doThrow(new CoderException("testing exception")) + .when(crashingCoder) + .encode(anyString(), any(OutputStream.class), any(Coder.Context.class)); + + expectedException.expect(CoderException.class); + expectedException.expectMessage("testing exception"); + + CoderUtils.encodeToByteArray(crashingCoder, "hello"); + } + + @Test + public void testCreateAtomicCoders() throws Exception { + Assert.assertEquals( + BigEndianIntegerCoder.of(), + Serializer.deserialize(makeCloudEncoding("BigEndianIntegerCoder"), Coder.class)); + Assert.assertEquals( + StringUtf8Coder.of(), + Serializer.deserialize( + makeCloudEncoding(StringUtf8Coder.class.getName()), Coder.class)); + Assert.assertEquals( + VoidCoder.of(), + Serializer.deserialize(makeCloudEncoding("VoidCoder"), Coder.class)); + Assert.assertEquals( + TestCoder.of(), + Serializer.deserialize(makeCloudEncoding(TestCoder.class.getName()), Coder.class)); + } + + @Test + public void testCreateCompositeCoders() throws Exception { + Assert.assertEquals( + IterableCoder.of(StringUtf8Coder.of()), + Serializer.deserialize( + makeCloudEncoding("IterableCoder", + makeCloudEncoding("StringUtf8Coder")), Coder.class)); + Assert.assertEquals( + KvCoder.of(BigEndianIntegerCoder.of(), VoidCoder.of()), + Serializer.deserialize( + makeCloudEncoding( + "KvCoder", + makeCloudEncoding(BigEndianIntegerCoder.class.getName()), + makeCloudEncoding("VoidCoder")), Coder.class)); + Assert.assertEquals( + IterableCoder.of( + KvCoder.of(IterableCoder.of(BigEndianIntegerCoder.of()), + KvCoder.of(VoidCoder.of(), + TestCoder.of()))), + Serializer.deserialize( + makeCloudEncoding( + IterableCoder.class.getName(), + makeCloudEncoding( + KvCoder.class.getName(), + makeCloudEncoding( + "IterableCoder", + makeCloudEncoding("BigEndianIntegerCoder")), + makeCloudEncoding( + "KvCoder", + makeCloudEncoding("VoidCoder"), + makeCloudEncoding(TestCoder.class.getName())))), Coder.class)); + } + + @Test + public void testCreateUntypedCoders() throws Exception { + Assert.assertEquals( + IterableCoder.of(StringUtf8Coder.of()), + Serializer.deserialize( + makeCloudEncoding( + "kind:stream", + makeCloudEncoding("StringUtf8Coder")), Coder.class)); + Assert.assertEquals( + KvCoder.of(BigEndianIntegerCoder.of(), VoidCoder.of()), + Serializer.deserialize( + makeCloudEncoding( + "kind:pair", + makeCloudEncoding(BigEndianIntegerCoder.class.getName()), + makeCloudEncoding("VoidCoder")), Coder.class)); + Assert.assertEquals( + IterableCoder.of( + KvCoder.of(IterableCoder.of(BigEndianIntegerCoder.of()), + KvCoder.of(VoidCoder.of(), + TestCoder.of()))), + Serializer.deserialize( + makeCloudEncoding( + "kind:stream", + makeCloudEncoding( + "kind:pair", + makeCloudEncoding( + "kind:stream", + makeCloudEncoding("BigEndianIntegerCoder")), + makeCloudEncoding( + "kind:pair", + makeCloudEncoding("VoidCoder"), + makeCloudEncoding(TestCoder.class.getName())))), Coder.class)); + } + + @Test + public void testCreateUnknownCoder() throws Exception { + try { + Serializer.deserialize(makeCloudEncoding("UnknownCoder"), Coder.class); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + CoreMatchers.containsString( + "Unable to convert coder ID UnknownCoder to class")); + } + } + + @Test + public void testClosingCoderFailsWhenDecodingBase64() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderUtils.decodeFromBase64(new ClosingCoder(), "test-value"); + } + + @Test + public void testClosingCoderFailsWhenDecodingByteArray() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderUtils.decodeFromByteArray(new ClosingCoder(), new byte[0]); + } + + @Test + public void testClosingCoderFailsWhenDecodingByteArrayInContext() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderUtils.decodeFromByteArray(new ClosingCoder(), new byte[0], Context.NESTED); + } + + @Test + public void testClosingCoderFailsWhenEncodingToBase64() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderUtils.encodeToBase64(new ClosingCoder(), "test-value"); + } + + @Test + public void testClosingCoderFailsWhenEncodingToByteArray() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderUtils.encodeToByteArray(new ClosingCoder(), "test-value"); + } + + @Test + public void testClosingCoderFailsWhenEncodingToByteArrayInContext() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + CoderUtils.encodeToByteArray(new ClosingCoder(), "test-value", Context.NESTED); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java new file mode 100644 index 000000000000..978ee50f440b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.withSettings; + +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.state.StateContexts; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.NotSerializableException; +import java.io.ObjectOutputStream; + +/** + * Unit tests for {@link CombineFnUtil}. + */ +@RunWith(JUnit4.class) +public class CombineFnUtilTest { + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + KeyedCombineFnWithContext mockCombineFn; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + mockCombineFn = mock(KeyedCombineFnWithContext.class, withSettings().serializable()); + } + + @Test + public void testNonSerializable() throws Exception { + expectedException.expect(NotSerializableException.class); + expectedException.expectMessage( + "Cannot serialize the CombineFn resulting from CombineFnUtil.bindContext."); + + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(buffer); + oos.writeObject(CombineFnUtil.bindContext(mockCombineFn, StateContexts.nullContext())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CounterAggregatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CounterAggregatorTest.java new file mode 100644 index 000000000000..4cc2f6a30ef4 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CounterAggregatorTest.java @@ -0,0 +1,253 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MIN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.IterableCombineFn; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterProvider; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.common.CounterSet.AddCounterMutator; +import com.google.common.collect.Iterables; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for the {@link Aggregator} API. + */ +@RunWith(JUnit4.class) +public class CounterAggregatorTest { + @Rule + public final ExpectedException expectedEx = ExpectedException.none(); + + private static final String AGGREGATOR_NAME = "aggregator_name"; + + @SuppressWarnings("rawtypes") + private void testAggregator(List items, + Combine.CombineFn combiner, + Counter expectedCounter) { + CounterSet counters = new CounterSet(); + Aggregator aggregator = new CounterAggregator<>( + AGGREGATOR_NAME, combiner, counters.getAddCounterMutator()); + for (V item : items) { + aggregator.addValue(item); + } + + assertEquals(Iterables.getOnlyElement(counters), expectedCounter); + } + + @Test + public void testGetName() { + String name = "testAgg"; + CounterAggregator aggregator = new CounterAggregator<>( + name, new Sum.SumLongFn(), + new CounterSet().getAddCounterMutator()); + + assertEquals(name, aggregator.getName()); + } + + @Test + public void testGetCombineFn() { + CombineFn combineFn = new Min.MinLongFn(); + + CounterAggregator aggregator = new CounterAggregator<>("foo", + combineFn, new CounterSet().getAddCounterMutator()); + + assertEquals(combineFn, aggregator.getCombineFn()); + } + + @Test + + public void testSumInteger() throws Exception { + testAggregator(Arrays.asList(2, 4, 1, 3), new Sum.SumIntegerFn(), + Counter.ints(AGGREGATOR_NAME, SUM).resetToValue(10)); + } + + @Test + public void testSumLong() throws Exception { + testAggregator(Arrays.asList(2L, 4L, 1L, 3L), new Sum.SumLongFn(), + Counter.longs(AGGREGATOR_NAME, SUM).resetToValue(10L)); + } + + @Test + public void testSumDouble() throws Exception { + testAggregator(Arrays.asList(2.0, 4.1, 1.0, 3.1), new Sum.SumDoubleFn(), + Counter.doubles(AGGREGATOR_NAME, SUM).resetToValue(10.2)); + } + + @Test + public void testMinInteger() throws Exception { + testAggregator(Arrays.asList(2, 4, 1, 3), new Min.MinIntegerFn(), + Counter.ints(AGGREGATOR_NAME, MIN).resetToValue(1)); + } + + @Test + public void testMinLong() throws Exception { + testAggregator(Arrays.asList(2L, 4L, 1L, 3L), new Min.MinLongFn(), + Counter.longs(AGGREGATOR_NAME, MIN).resetToValue(1L)); + } + + @Test + public void testMinDouble() throws Exception { + testAggregator(Arrays.asList(2.0, 4.1, 1.0, 3.1), new Min.MinDoubleFn(), + Counter.doubles(AGGREGATOR_NAME, MIN).resetToValue(1.0)); + } + + @Test + public void testMaxInteger() throws Exception { + testAggregator(Arrays.asList(2, 4, 1, 3), new Max.MaxIntegerFn(), + Counter.ints(AGGREGATOR_NAME, MAX).resetToValue(4)); + } + + @Test + public void testMaxLong() throws Exception { + testAggregator(Arrays.asList(2L, 4L, 1L, 3L), new Max.MaxLongFn(), + Counter.longs(AGGREGATOR_NAME, MAX).resetToValue(4L)); + } + + @Test + public void testMaxDouble() throws Exception { + testAggregator(Arrays.asList(2.0, 4.1, 1.0, 3.1), new Max.MaxDoubleFn(), + Counter.doubles(AGGREGATOR_NAME, MAX).resetToValue(4.1)); + } + + @Test + public void testCounterProviderCallsProvidedCounterAddValue() { + @SuppressWarnings("unchecked") + CombineFn combiner = mock(CombineFn.class, + withSettings().extraInterfaces(CounterProvider.class)); + @SuppressWarnings("unchecked") + CounterProvider provider = (CounterProvider) combiner; + + @SuppressWarnings("unchecked") + Counter mockCounter = mock(Counter.class); + String name = "foo"; + when(provider.getCounter(name)).thenReturn(mockCounter); + + AddCounterMutator addCounterMutator = mock(AddCounterMutator.class); + when(addCounterMutator.addCounter(mockCounter)).thenReturn(mockCounter); + + Aggregator aggregator = + new CounterAggregator<>(name, combiner, addCounterMutator); + + aggregator.addValue("bar_baz"); + + verify(mockCounter).addValue("bar_baz"); + verify(addCounterMutator).addCounter(mockCounter); + } + + + @Test + public void testCompatibleDuplicateNames() throws Exception { + CounterSet counters = new CounterSet(); + Aggregator aggregator1 = new CounterAggregator<>( + AGGREGATOR_NAME, new Sum.SumIntegerFn(), + counters.getAddCounterMutator()); + + Aggregator aggregator2 = new CounterAggregator<>( + AGGREGATOR_NAME, new Sum.SumIntegerFn(), + counters.getAddCounterMutator()); + + // The duplicate aggregators should update the same counter. + aggregator1.addValue(3); + aggregator2.addValue(4); + Assert.assertEquals( + new CounterSet(Counter.ints(AGGREGATOR_NAME, SUM).resetToValue(7)), + counters); + } + + @Test + public void testIncompatibleDuplicateNames() throws Exception { + CounterSet counters = new CounterSet(); + new CounterAggregator<>( + AGGREGATOR_NAME, new Sum.SumIntegerFn(), + counters.getAddCounterMutator()); + + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString( + "aggregator's name collides with an existing aggregator or " + + "system-provided counter of an incompatible type")); + new CounterAggregator<>( + AGGREGATOR_NAME, new Sum.SumLongFn(), + counters.getAddCounterMutator()); + } + + @Test + public void testUnsupportedCombineFn() throws Exception { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString("unsupported combiner")); + new CounterAggregator<>( + AGGREGATOR_NAME, + new Combine.CombineFn, Integer>() { + @Override + public List createAccumulator() { + return null; + } + @Override + public List addInput(List accumulator, Integer input) { + return null; + } + @Override + public List mergeAccumulators(Iterable> accumulators) { + return null; + } + @Override + public Integer extractOutput(List accumulator) { + return null; + } + }, (new CounterSet()).getAddCounterMutator()); + } + + @Test + public void testUnsupportedSerializableFunction() throws Exception { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage(Matchers.containsString("unsupported combiner")); + CombineFn, Integer> combiner = IterableCombineFn + .of(new SerializableFunction, Integer>() { + @Override + public Integer apply(Iterable input) { + return null; + } + }); + new CounterAggregator<>(AGGREGATOR_NAME, combiner, + (new CounterSet()).getAddCounterMutator()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/DataflowPathValidatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/DataflowPathValidatorTest.java new file mode 100644 index 000000000000..19d5adc7cffa --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/DataflowPathValidatorTest.java @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link DataflowPathValidator}. */ +@RunWith(JUnit4.class) +public class DataflowPathValidatorTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Mock private GcsUtil mockGcsUtil; + private DataflowPathValidator validator; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + when(mockGcsUtil.bucketExists(any(GcsPath.class))).thenReturn(true); + when(mockGcsUtil.isGcsPatternSupported(anyString())).thenCallRealMethod(); + DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class); + options.setGcpCredential(new TestCredential()); + options.setRunner(DataflowPipelineRunner.class); + options.setGcsUtil(mockGcsUtil); + validator = new DataflowPathValidator(options); + } + + @Test + public void testValidFilePattern() { + validator.validateInputFilePatternSupported("gs://bucket/path"); + } + + @Test + public void testInvalidFilePattern() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "DataflowPipelineRunner expected a valid 'gs://' path but was given '/local/path'"); + validator.validateInputFilePatternSupported("/local/path"); + } + + @Test + public void testWhenBucketDoesNotExist() throws Exception { + when(mockGcsUtil.bucketExists(any(GcsPath.class))).thenReturn(false); + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "Could not find file gs://non-existent-bucket/location"); + validator.validateInputFilePatternSupported("gs://non-existent-bucket/location"); + } + + @Test + public void testValidOutputPrefix() { + validator.validateOutputFilePrefixSupported("gs://bucket/path"); + } + + @Test + public void testInvalidOutputPrefix() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage( + "DataflowPipelineRunner expected a valid 'gs://' path but was given '/local/path'"); + validator.validateOutputFilePrefixSupported("/local/path"); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExecutableTriggerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExecutableTriggerTest.java new file mode 100644 index 000000000000..7b8466a5faed --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExecutableTriggerTest.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for {@link ExecutableTrigger}. + */ +@RunWith(JUnit4.class) +public class ExecutableTriggerTest { + + @Test + public void testIndexAssignmentLeaf() throws Exception { + StubTrigger t1 = new StubTrigger(); + ExecutableTrigger executable = ExecutableTrigger.create(t1); + assertEquals(0, executable.getTriggerIndex()); + } + + @Test + public void testIndexAssignmentOneLevel() throws Exception { + StubTrigger t1 = new StubTrigger(); + StubTrigger t2 = new StubTrigger(); + StubTrigger t = new StubTrigger(t1, t2); + + ExecutableTrigger executable = ExecutableTrigger.create(t); + + assertEquals(0, executable.getTriggerIndex()); + assertEquals(1, executable.subTriggers().get(0).getTriggerIndex()); + assertSame(t1, executable.subTriggers().get(0).getSpec()); + assertEquals(2, executable.subTriggers().get(1).getTriggerIndex()); + assertSame(t2, executable.subTriggers().get(1).getSpec()); + } + + @Test + public void testIndexAssignmentTwoLevel() throws Exception { + StubTrigger t11 = new StubTrigger(); + StubTrigger t12 = new StubTrigger(); + StubTrigger t13 = new StubTrigger(); + StubTrigger t14 = new StubTrigger(); + StubTrigger t21 = new StubTrigger(); + StubTrigger t22 = new StubTrigger(); + StubTrigger t1 = new StubTrigger(t11, t12, t13, t14); + StubTrigger t2 = new StubTrigger(t21, t22); + StubTrigger t = new StubTrigger(t1, t2); + + ExecutableTrigger executable = ExecutableTrigger.create(t); + + assertEquals(0, executable.getTriggerIndex()); + assertEquals(1, executable.subTriggers().get(0).getTriggerIndex()); + assertEquals(6, executable.subTriggers().get(0).getFirstIndexAfterSubtree()); + assertEquals(6, executable.subTriggers().get(1).getTriggerIndex()); + + assertSame(t1, executable.getSubTriggerContaining(1).getSpec()); + assertSame(t2, executable.getSubTriggerContaining(6).getSpec()); + assertSame(t1, executable.getSubTriggerContaining(2).getSpec()); + assertSame(t1, executable.getSubTriggerContaining(3).getSpec()); + assertSame(t1, executable.getSubTriggerContaining(5).getSpec()); + assertSame(t2, executable.getSubTriggerContaining(7).getSpec()); + } + + private static class StubTrigger extends Trigger { + + @SafeVarargs + protected StubTrigger(Trigger... subTriggers) { + super(Arrays.asList(subTriggers)); + } + + @Override + public void onElement(OnElementContext c) throws Exception { } + + @Override + public void onMerge(OnMergeContext c) throws Exception { } + + @Override + public void clear(TriggerContext c) throws Exception { + } + + @Override + public Instant getWatermarkThatGuaranteesFiring(IntervalWindow window) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } + + @Override + public boolean isCompatible(Trigger other) { + return false; + } + + @Override + public Trigger getContinuationTrigger( + List> continuationTriggers) { + return this; + } + + @Override + public boolean shouldFire(TriggerContext c) { + return false; + } + + @Override + public void onFire(TriggerContext c) { } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayInputStreamTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayInputStreamTest.java new file mode 100644 index 000000000000..00830e96cdb1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayInputStreamTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertSame; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.IOException; + +/** Unit tests for {@link ExposedByteArrayInputStream}. */ +@RunWith(JUnit4.class) +public class ExposedByteArrayInputStreamTest { + + private static final byte[] TEST_DATA = "Hello World!".getBytes(); + + private ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA); + + private ExposedByteArrayInputStream exposedStream = new ExposedByteArrayInputStream(TEST_DATA); + + @Test + public void testConstructWithEmptyArray() throws IOException { + try (ExposedByteArrayInputStream s = new ExposedByteArrayInputStream(new byte[0])) { + assertEquals(0, s.available()); + byte[] data = s.readAll(); + assertEquals(0, data.length); + } + } + + @Test + public void testReadAll() throws IOException { + assertEquals(TEST_DATA.length, exposedStream.available()); + byte[] data = exposedStream.readAll(); + assertArrayEquals(TEST_DATA, data); + assertSame(TEST_DATA, data); + assertEquals(0, exposedStream.available()); + } + + @Test + public void testReadPartial() throws IOException { + assertEquals(TEST_DATA.length, exposedStream.available()); + assertEquals(TEST_DATA.length, stream.available()); + byte[] data1 = new byte[4]; + byte[] data2 = new byte[4]; + int ret1 = exposedStream.read(data1); + int ret2 = stream.read(data2); + assertEquals(ret2, ret1); + assertArrayEquals(data2, data1); + assertEquals(stream.available(), exposedStream.available()); + } + + @Test + public void testReadAllAfterReadPartial() throws IOException { + assertNotEquals(-1, exposedStream.read()); + byte[] ret = exposedStream.readAll(); + assertArrayEquals("ello World!".getBytes(), ret); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayOutputStreamTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayOutputStreamTest.java new file mode 100644 index 000000000000..f40f0508dbb5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ExposedByteArrayOutputStreamTest.java @@ -0,0 +1,245 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** Unit tests for {@link ExposedByteArrayOutputStream}. */ +@RunWith(JUnit4.class) +public class ExposedByteArrayOutputStreamTest { + + private static final byte[] TEST_DATA = "Hello World!".getBytes(); + + private ExposedByteArrayOutputStream exposedStream = new ExposedByteArrayOutputStream(); + private ByteArrayOutputStream stream = new ByteArrayOutputStream(); + + @Test + public void testNoWrite() { + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteZeroLengthArray() throws IOException { + writeToBoth(new byte[0]); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteZeroLengthArrayWithOffset(){ + writeToBoth(new byte[0], 0, 0); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleByte() { + writeToBoth(32); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleByteTwice() { + writeToBoth(32); + writeToBoth(32); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleArray() throws IOException { + writeToBoth(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + assertNotSame(TEST_DATA, exposedStream.toByteArray()); + } + + @Test + public void testWriteSingleArrayFast() throws IOException { + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + assertSame(TEST_DATA, exposedStream.toByteArray()); + } + + + @Test + public void testWriteSingleArrayTwice() throws IOException { + writeToBoth(TEST_DATA); + writeToBoth(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleArrayTwiceFast() throws IOException { + writeToBothFast(TEST_DATA); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleArrayTwiceFast1() throws IOException { + writeToBothFast(TEST_DATA); + writeToBoth(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleArrayTwiceFast2() throws IOException { + writeToBoth(TEST_DATA); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteSingleArrayWithLength() { + writeToBoth(TEST_DATA, 0, TEST_DATA.length); + assertStreamContentsEquals(stream, exposedStream); + assertNotSame(TEST_DATA, exposedStream.toByteArray()); + } + + @Test + public void testWritePartial() { + writeToBoth(TEST_DATA, 0, TEST_DATA.length - 1); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWritePartialWithNonZeroBegin() { + writeToBoth(TEST_DATA, 1, TEST_DATA.length - 1); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteByteAfterWriteArrayFast() throws IOException { + writeToBothFast(TEST_DATA); + writeToBoth(32); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteArrayFastAfterByte() throws IOException { + writeToBoth(32); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testResetAfterWriteFast() throws IOException { + writeToBothFast(TEST_DATA); + resetBoth(); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteArrayFastAfterReset() throws IOException { + writeToBothFast(TEST_DATA); + resetBoth(); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + assertSame(TEST_DATA, exposedStream.toByteArray()); + } + + @Test + public void testWriteArrayFastAfterReset1() throws IOException { + writeToBothFast(TEST_DATA); + writeToBothFast(TEST_DATA); + resetBoth(); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + assertSame(TEST_DATA, exposedStream.toByteArray()); + } + + @Test + public void testWriteArrayFastAfterReset2() throws IOException { + resetBoth(); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + assertSame(TEST_DATA, exposedStream.toByteArray()); + } + + @Test + public void testWriteArrayFastTwiceAfterReset() throws IOException { + writeToBothFast(TEST_DATA); + resetBoth(); + writeToBothFast(TEST_DATA); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteArrayFastTwiceAfterReset1() throws IOException { + writeToBothFast(TEST_DATA); + writeToBothFast(TEST_DATA); + resetBoth(); + writeToBothFast(TEST_DATA); + writeToBothFast(TEST_DATA); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteByteAfterReset() { + writeToBoth(32); + resetBoth(); + writeToBoth(32); + assertStreamContentsEquals(stream, exposedStream); + } + + @Test + public void testWriteByteAfterReset1() { + resetBoth(); + writeToBoth(32); + assertStreamContentsEquals(stream, exposedStream); + } + + private void assertStreamContentsEquals( + ByteArrayOutputStream stream1, ByteArrayOutputStream stream2) { + assertArrayEquals(stream1.toByteArray(), stream2.toByteArray()); + assertEquals(stream1.toString(), stream2.toString()); + assertEquals(stream1.size(), stream2.size()); + } + + private void writeToBoth(int b) { + exposedStream.write(b); + stream.write(b); + } + + private void writeToBoth(byte[] b) throws IOException { + exposedStream.write(b); + stream.write(b); + } + + private void writeToBothFast(byte[] b) throws IOException { + exposedStream.writeAndOwn(b); + stream.write(b); + } + + private void writeToBoth(byte[] b, int off, int length) { + exposedStream.write(b, off, length); + stream.write(b, off, length); + } + + private void resetBoth() { + exposedStream.reset(); + stream.reset(); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactoryTest.java new file mode 100644 index 000000000000..80ec3a130883 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactoryTest.java @@ -0,0 +1,226 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import com.google.common.io.Files; +import com.google.common.io.LineReader; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.Reader; +import java.io.Writer; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.List; + +/** Tests for {@link FileIOChannelFactory}. */ +@RunWith(JUnit4.class) +public class FileIOChannelFactoryTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + private FileIOChannelFactory factory = new FileIOChannelFactory(); + + private void testCreate(Path path) throws Exception { + String expected = "my test string"; + try (Writer writer = Channels.newWriter( + factory.create(path.toString(), MimeTypes.TEXT), StandardCharsets.UTF_8.name())) { + writer.write(expected); + } + assertThat( + Files.readLines(path.toFile(), StandardCharsets.UTF_8), + containsInAnyOrder(expected)); + } + + @Test + public void testCreateWithExistingFile() throws Exception { + File existingFile = temporaryFolder.newFile(); + testCreate(existingFile.toPath()); + } + + @Test + public void testCreateWithinExistingDirectory() throws Exception { + testCreate(temporaryFolder.getRoot().toPath().resolve("file.txt")); + } + + @Test + public void testCreateWithNonExistentSubDirectory() throws Exception { + testCreate(temporaryFolder.getRoot().toPath().resolve("non-existent-dir").resolve("file.txt")); + } + + @Test + public void testReadWithExistingFile() throws Exception { + String expected = "my test string"; + File existingFile = temporaryFolder.newFile(); + Files.write(expected, existingFile, StandardCharsets.UTF_8); + String data; + try (Reader reader = + Channels.newReader(factory.open(existingFile.getPath()), StandardCharsets.UTF_8.name())) { + data = new LineReader(reader).readLine(); + } + assertEquals(expected, data); + } + + @Test + public void testReadNonExistentFile() throws Exception { + thrown.expect(FileNotFoundException.class); + factory + .open( + temporaryFolder + .getRoot() + .toPath() + .resolve("non-existent-file.txt") + .toString()) + .close(); + } + + @Test + public void testIsReadSeekEfficient() throws Exception { + assertTrue(factory.isReadSeekEfficient("somePath")); + } + + @Test + public void testMatchExact() throws Exception { + List expected = ImmutableList.of(temporaryFolder.newFile("a").toString()); + temporaryFolder.newFile("aa"); + temporaryFolder.newFile("ab"); + + assertThat(factory.match(temporaryFolder.getRoot().toPath().resolve("a").toString()), + containsInAnyOrder(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testMatchNone() throws Exception { + List expected = ImmutableList.of(); + temporaryFolder.newFile("a"); + temporaryFolder.newFile("aa"); + temporaryFolder.newFile("ab"); + + // Windows doesn't like resolving paths with * in them, so the * is appended after resolve. + assertThat(factory.match(factory.resolve(temporaryFolder.getRoot().getPath(), "b") + "*"), + containsInAnyOrder(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testMatchUsingExplicitPath() throws Exception { + List expected = ImmutableList.of(temporaryFolder.newFile("a").toString()); + temporaryFolder.newFile("aa"); + + assertThat(factory.match(factory.resolve(temporaryFolder.getRoot().getPath(), "a")), + containsInAnyOrder(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testMatchUsingExplicitPathForNonExistentFile() throws Exception { + List expected = ImmutableList.of(); + temporaryFolder.newFile("aa"); + + assertThat(factory.match(factory.resolve(temporaryFolder.getRoot().getPath(), "a")), + containsInAnyOrder(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testMatchMultipleWithoutSubdirectoryExpansion() throws Exception { + File unmatchedSubDir = temporaryFolder.newFolder("aaa"); + File unmatchedSubDirFile = File.createTempFile("sub-dir-file", "", unmatchedSubDir); + unmatchedSubDirFile.deleteOnExit(); + List expected = ImmutableList.of(temporaryFolder.newFile("a").toString(), + temporaryFolder.newFile("aa").toString(), temporaryFolder.newFile("ab").toString()); + temporaryFolder.newFile("ba"); + temporaryFolder.newFile("bb"); + + // Windows doesn't like resolving paths with * in them, so the * is appended after resolve. + assertThat(factory.match(factory.resolve(temporaryFolder.getRoot().getPath(), "a") + "*"), + containsInAnyOrder(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testMatchMultipleWithSubdirectoryExpansion() throws Exception { + File matchedSubDir = temporaryFolder.newFolder("a"); + File matchedSubDirFile = File.createTempFile("sub-dir-file", "", matchedSubDir); + matchedSubDirFile.deleteOnExit(); + File unmatchedSubDir = temporaryFolder.newFolder("b"); + File unmatchedSubDirFile = File.createTempFile("sub-dir-file", "", unmatchedSubDir); + unmatchedSubDirFile.deleteOnExit(); + + List expected = ImmutableList.of(matchedSubDirFile.toString(), + temporaryFolder.newFile("aa").toString(), temporaryFolder.newFile("ab").toString()); + temporaryFolder.newFile("ba"); + temporaryFolder.newFile("bb"); + + // Windows doesn't like resolving paths with * in them, so the ** is appended after resolve. + assertThat(factory.match(factory.resolve(temporaryFolder.getRoot().getPath(), "a") + "**"), + Matchers.hasItems(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testMatchWithDirectoryFiltersOutDirectory() throws Exception { + List expected = ImmutableList.of(temporaryFolder.newFile("a").toString()); + temporaryFolder.newFolder("a_dir_that_should_not_be_matched"); + + // Windows doesn't like resolving paths with * in them, so the * is appended after resolve. + assertThat(factory.match(factory.resolve(temporaryFolder.getRoot().getPath(), "a") + "*"), + containsInAnyOrder(expected.toArray(new String[expected.size()]))); + } + + @Test + public void testResolve() throws Exception { + String expected = temporaryFolder.getRoot().toPath().resolve("aa").toString(); + assertEquals(expected, factory.resolve(temporaryFolder.getRoot().toString(), "aa")); + } + + @Test + public void testResolveOtherIsFullPath() throws Exception { + String expected = temporaryFolder.getRoot().getPath().toString(); + assertEquals(expected, factory.resolve(expected, expected)); + } + + @Test + public void testResolveOtherIsEmptyPath() throws Exception { + String expected = temporaryFolder.getRoot().getPath().toString(); + assertEquals(expected, factory.resolve(expected, "")); + } + + @Test + public void testGetSizeBytes() throws Exception { + String data = "TestData!!!"; + File file = temporaryFolder.newFile(); + Files.write(data, file, StandardCharsets.UTF_8); + assertEquals(data.length(), factory.getSizeBytes(file.getPath())); + } + + @Test + public void testGetSizeBytesForNonExistentFile() throws Exception { + thrown.expect(FileNotFoundException.class); + factory.getSizeBytes( + factory.resolve(temporaryFolder.getRoot().getPath(), "non-existent-file")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersBitSetTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersBitSetTest.java new file mode 100644 index 000000000000..7e66683e2602 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersBitSetTest.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.theInstance; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FinishedTriggersBitSet}. + */ +@RunWith(JUnit4.class) +public class FinishedTriggersBitSetTest { + /** + * Tests that after a trigger is set to finished, it reads back as finished. + */ + @Test + public void testSetGet() { + FinishedTriggersProperties.verifyGetAfterSet(FinishedTriggersBitSet.emptyWithCapacity(1)); + } + + /** + * Tests that clearing a trigger recursively clears all of that triggers subTriggers, but no + * others. + */ + @Test + public void testClearRecursively() { + FinishedTriggersProperties.verifyClearRecursively(FinishedTriggersBitSet.emptyWithCapacity(1)); + } + + @Test + public void testCopy() throws Exception { + FinishedTriggersBitSet finishedSet = FinishedTriggersBitSet.emptyWithCapacity(10); + assertThat(finishedSet.copy().getBitSet(), not(theInstance(finishedSet.getBitSet()))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersProperties.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersProperties.java new file mode 100644 index 000000000000..7b3ac689af1e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersProperties.java @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterAll; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterFirst; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterPane; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterProcessingTime; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; + +/** + * Generalized tests for {@link FinishedTriggers} implementations. + */ +public class FinishedTriggersProperties { + /** + * Tests that for the provided trigger and {@link FinishedTriggers}, when the trigger is set + * finished, it is correctly reported as finished. + */ + public static void verifyGetAfterSet(FinishedTriggers finishedSet, ExecutableTrigger trigger) { + assertFalse(finishedSet.isFinished(trigger)); + finishedSet.setFinished(trigger, true); + assertTrue(finishedSet.isFinished(trigger)); + } + + /** + * For a few arbitrary triggers, tests that when the trigger is set finished it is correctly + * reported as finished. + */ + public static void verifyGetAfterSet(FinishedTriggers finishedSet) { + ExecutableTrigger trigger = ExecutableTrigger.create(AfterAll.of( + AfterFirst.of(AfterPane.elementCountAtLeast(3), AfterWatermark.pastEndOfWindow()), + AfterAll.of( + AfterPane.elementCountAtLeast(10), AfterProcessingTime.pastFirstElementInPane()))); + + verifyGetAfterSet(finishedSet, trigger); + verifyGetAfterSet(finishedSet, trigger.subTriggers().get(0).subTriggers().get(1)); + verifyGetAfterSet(finishedSet, trigger.subTriggers().get(0)); + verifyGetAfterSet(finishedSet, trigger.subTriggers().get(1)); + verifyGetAfterSet(finishedSet, trigger.subTriggers().get(1).subTriggers().get(1)); + verifyGetAfterSet(finishedSet, trigger.subTriggers().get(1).subTriggers().get(0)); + } + + /** + * Tests that clearing a trigger recursively clears all of that triggers subTriggers, but no + * others. + */ + public static void verifyClearRecursively(FinishedTriggers finishedSet) { + ExecutableTrigger trigger = ExecutableTrigger.create(AfterAll.of( + AfterFirst.of(AfterPane.elementCountAtLeast(3), AfterWatermark.pastEndOfWindow()), + AfterAll.of( + AfterPane.elementCountAtLeast(10), AfterProcessingTime.pastFirstElementInPane()))); + + // Set them all finished. This method is not on a trigger as it makes no sense outside tests. + setFinishedRecursively(finishedSet, trigger); + assertTrue(finishedSet.isFinished(trigger)); + assertTrue(finishedSet.isFinished(trigger.subTriggers().get(0))); + assertTrue(finishedSet.isFinished(trigger.subTriggers().get(0).subTriggers().get(0))); + assertTrue(finishedSet.isFinished(trigger.subTriggers().get(0).subTriggers().get(1))); + + // Clear just the second AfterAll + finishedSet.clearRecursively(trigger.subTriggers().get(1)); + + // Check that the first and all that are still finished + assertTrue(finishedSet.isFinished(trigger)); + verifyFinishedRecursively(finishedSet, trigger.subTriggers().get(0)); + verifyUnfinishedRecursively(finishedSet, trigger.subTriggers().get(1)); + } + + private static void setFinishedRecursively( + FinishedTriggers finishedSet, ExecutableTrigger trigger) { + finishedSet.setFinished(trigger, true); + for (ExecutableTrigger subTrigger : trigger.subTriggers()) { + setFinishedRecursively(finishedSet, subTrigger); + } + } + + private static void verifyFinishedRecursively( + FinishedTriggers finishedSet, ExecutableTrigger trigger) { + assertTrue(finishedSet.isFinished(trigger)); + for (ExecutableTrigger subTrigger : trigger.subTriggers()) { + verifyFinishedRecursively(finishedSet, subTrigger); + } + } + + private static void verifyUnfinishedRecursively( + FinishedTriggers finishedSet, ExecutableTrigger trigger) { + assertFalse(finishedSet.isFinished(trigger)); + for (ExecutableTrigger subTrigger : trigger.subTriggers()) { + verifyUnfinishedRecursively(finishedSet, subTrigger); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersSetTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersSetTest.java new file mode 100644 index 000000000000..60384deebdd8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/FinishedTriggersSetTest.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.theInstance; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.HashSet; + +/** + * Tests for {@link FinishedTriggersSet}. + */ +@RunWith(JUnit4.class) +public class FinishedTriggersSetTest { + /** + * Tests that after a trigger is set to finished, it reads back as finished. + */ + @Test + public void testSetGet() { + FinishedTriggersProperties.verifyGetAfterSet( + FinishedTriggersSet.fromSet(new HashSet>())); + } + + /** + * Tests that clearing a trigger recursively clears all of that triggers subTriggers, but no + * others. + */ + @Test + public void testClearRecursively() { + FinishedTriggersProperties.verifyClearRecursively( + FinishedTriggersSet.fromSet(new HashSet>())); + } + + @Test + public void testCopy() throws Exception { + FinishedTriggersSet finishedSet = + FinishedTriggersSet.fromSet(new HashSet>()); + assertThat(finishedSet.copy().getFinishedTriggers(), + not(theInstance(finishedSet.getFinishedTriggers()))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactoryTest.java new file mode 100644 index 000000000000..6e4605f7de75 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsIOChannelFactoryTest.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link GcsIOChannelFactoryTest}. */ +@RunWith(JUnit4.class) +public class GcsIOChannelFactoryTest { + private GcsIOChannelFactory factory; + + @Before + public void setUp() { + factory = new GcsIOChannelFactory(PipelineOptionsFactory.as(GcsOptions.class)); + } + + @Test + public void testResolve() throws Exception { + assertEquals("gs://bucket/object", factory.resolve("gs://bucket", "object")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsUtilTest.java new file mode 100644 index 000000000000..e7cd7d7c22c8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GcsUtilTest.java @@ -0,0 +1,490 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.when; + +import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpStatusCodes; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.LowLevelHttpRequest; +import com.google.api.client.json.GenericJson; +import com.google.api.client.json.Json; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.testing.http.HttpTesting; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import com.google.api.client.util.BackOff; +import com.google.api.services.storage.Storage; +import com.google.api.services.storage.model.Bucket; +import com.google.api.services.storage.model.Objects; +import com.google.api.services.storage.model.StorageObject; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.FastNanoClockAndSleeper; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.hadoop.gcsio.GoogleCloudStorageReadChannel; +import com.google.cloud.hadoop.util.ClientRequestHelper; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.math.BigInteger; +import java.net.SocketTimeoutException; +import java.nio.channels.SeekableByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +/** Test case for {@link GcsUtil}. */ +@RunWith(JUnit4.class) +public class GcsUtilTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testGlobTranslation() { + assertEquals("foo", GcsUtil.globToRegexp("foo")); + assertEquals("fo[^/]*o", GcsUtil.globToRegexp("fo*o")); + assertEquals("f[^/]*o\\.[^/]", GcsUtil.globToRegexp("f*o.?")); + assertEquals("foo-[0-9][^/]*", GcsUtil.globToRegexp("foo-[0-9]*")); + } + + private static GcsOptions gcsOptionsWithTestCredential() { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + pipelineOptions.setGcpCredential(new TestCredential()); + return pipelineOptions; + } + + @Test + public void testCreationWithDefaultOptions() { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + assertNotNull(pipelineOptions.getGcpCredential()); + } + + @Test + public void testUploadBufferSizeDefault() { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil util = pipelineOptions.getGcsUtil(); + assertNull(util.getUploadBufferSizeBytes()); + } + + @Test + public void testUploadBufferSizeUserSpecified() { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + pipelineOptions.setGcsUploadBufferSizeBytes(12345); + GcsUtil util = pipelineOptions.getGcsUtil(); + assertEquals((Integer) 12345, util.getUploadBufferSizeBytes()); + } + + @Test + public void testCreationWithExecutorServiceProvided() { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + pipelineOptions.setExecutorService(Executors.newCachedThreadPool()); + assertSame(pipelineOptions.getExecutorService(), pipelineOptions.getGcsUtil().executorService); + } + + @Test + public void testCreationWithGcsUtilProvided() { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + GcsUtil gcsUtil = Mockito.mock(GcsUtil.class); + pipelineOptions.setGcsUtil(gcsUtil); + assertSame(gcsUtil, pipelineOptions.getGcsUtil()); + } + + @Test + public void testMultipleThreadsCanCompleteOutOfOrderWithDefaultThreadPool() throws Exception { + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + ExecutorService executorService = pipelineOptions.getExecutorService(); + + int numThreads = 100; + final CountDownLatch[] countDownLatches = new CountDownLatch[numThreads]; + for (int i = 0; i < numThreads; i++) { + final int currentLatch = i; + countDownLatches[i] = new CountDownLatch(1); + executorService.execute(new Runnable() { + @Override + public void run() { + // Wait for latch N and then release latch N - 1 + try { + countDownLatches[currentLatch].await(); + if (currentLatch > 0) { + countDownLatches[currentLatch - 1].countDown(); + } + } catch (InterruptedException e) { + throw Throwables.propagate(e); + } + } + }); + } + + // Release the last latch starting the chain reaction. + countDownLatches[countDownLatches.length - 1].countDown(); + executorService.shutdown(); + assertTrue("Expected tasks to complete", + executorService.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testGlobExpansion() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Objects mockStorageObjects = Mockito.mock(Storage.Objects.class); + Storage.Objects.Get mockStorageGet = Mockito.mock(Storage.Objects.Get.class); + Storage.Objects.List mockStorageList = Mockito.mock(Storage.Objects.List.class); + + Objects modelObjects = new Objects(); + List items = new ArrayList<>(); + // A directory + items.add(new StorageObject().setBucket("testbucket").setName("testdirectory/")); + + // Files within the directory + items.add(new StorageObject().setBucket("testbucket").setName("testdirectory/file1name")); + items.add(new StorageObject().setBucket("testbucket").setName("testdirectory/file2name")); + items.add(new StorageObject().setBucket("testbucket").setName("testdirectory/file3name")); + items.add(new StorageObject().setBucket("testbucket").setName("testdirectory/otherfile")); + items.add(new StorageObject().setBucket("testbucket").setName("testdirectory/anotherfile")); + + modelObjects.setItems(items); + + when(mockStorage.objects()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get("testbucket", "testdirectory/otherfile")).thenReturn( + mockStorageGet); + when(mockStorageObjects.list("testbucket")).thenReturn(mockStorageList); + when(mockStorageGet.execute()).thenReturn( + new StorageObject().setBucket("testbucket").setName("testdirectory/otherfile")); + when(mockStorageList.execute()).thenReturn(modelObjects); + + // Test a single file. + { + GcsPath pattern = GcsPath.fromUri("gs://testbucket/testdirectory/otherfile"); + List expectedFiles = + ImmutableList.of(GcsPath.fromUri("gs://testbucket/testdirectory/otherfile")); + + assertThat(expectedFiles, contains(gcsUtil.expand(pattern).toArray())); + } + + // Test patterns. + { + GcsPath pattern = GcsPath.fromUri("gs://testbucket/testdirectory/file*"); + List expectedFiles = ImmutableList.of( + GcsPath.fromUri("gs://testbucket/testdirectory/file1name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file2name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file3name")); + + assertThat(expectedFiles, contains(gcsUtil.expand(pattern).toArray())); + } + + { + GcsPath pattern = GcsPath.fromUri("gs://testbucket/testdirectory/file[1-3]*"); + List expectedFiles = ImmutableList.of( + GcsPath.fromUri("gs://testbucket/testdirectory/file1name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file2name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file3name")); + + assertThat(expectedFiles, contains(gcsUtil.expand(pattern).toArray())); + } + + { + GcsPath pattern = GcsPath.fromUri("gs://testbucket/testdirectory/file?name"); + List expectedFiles = ImmutableList.of( + GcsPath.fromUri("gs://testbucket/testdirectory/file1name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file2name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file3name")); + + assertThat(expectedFiles, contains(gcsUtil.expand(pattern).toArray())); + } + + { + GcsPath pattern = GcsPath.fromUri("gs://testbucket/test*ectory/fi*name"); + List expectedFiles = ImmutableList.of( + GcsPath.fromUri("gs://testbucket/testdirectory/file1name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file2name"), + GcsPath.fromUri("gs://testbucket/testdirectory/file3name")); + + assertThat(expectedFiles, contains(gcsUtil.expand(pattern).toArray())); + } + } + + // Patterns that contain recursive wildcards ('**') are not supported. + @Test + public void testRecursiveGlobExpansionFails() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + GcsPath pattern = GcsPath.fromUri("gs://testbucket/test**"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Unsupported wildcard usage"); + gcsUtil.expand(pattern); + } + + // GCSUtil.expand() should fail when matching a single object when that object does not exist. + // We should return the empty result since GCS get object is strongly consistent. + @Test + public void testNonExistentObjectReturnsEmptyResult() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Objects mockStorageObjects = Mockito.mock(Storage.Objects.class); + Storage.Objects.Get mockStorageGet = Mockito.mock(Storage.Objects.Get.class); + + GcsPath pattern = GcsPath.fromUri("gs://testbucket/testdirectory/nonexistentfile"); + GoogleJsonResponseException expectedException = + googleJsonResponseException(HttpStatusCodes.STATUS_CODE_NOT_FOUND, + "It don't exist", "Nothing here to see"); + + when(mockStorage.objects()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get(pattern.getBucket(), pattern.getObject())).thenReturn( + mockStorageGet); + when(mockStorageGet.execute()).thenThrow(expectedException); + + assertEquals(Collections.EMPTY_LIST, gcsUtil.expand(pattern)); + } + + // GCSUtil.expand() should fail for other errors such as access denied. + @Test + public void testAccessDeniedObjectThrowsIOException() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Objects mockStorageObjects = Mockito.mock(Storage.Objects.class); + Storage.Objects.Get mockStorageGet = Mockito.mock(Storage.Objects.Get.class); + + GcsPath pattern = GcsPath.fromUri("gs://testbucket/testdirectory/accessdeniedfile"); + GoogleJsonResponseException expectedException = + googleJsonResponseException(HttpStatusCodes.STATUS_CODE_FORBIDDEN, + "Waves hand mysteriously", "These aren't the buckets your looking for"); + + when(mockStorage.objects()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get(pattern.getBucket(), pattern.getObject())).thenReturn( + mockStorageGet); + when(mockStorageGet.execute()).thenThrow(expectedException); + + thrown.expect(IOException.class); + thrown.expectMessage("Unable to match files for pattern"); + gcsUtil.expand(pattern); + } + + @Test + public void testGetSizeBytes() throws Exception { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Objects mockStorageObjects = Mockito.mock(Storage.Objects.class); + Storage.Objects.Get mockStorageGet = Mockito.mock(Storage.Objects.Get.class); + + when(mockStorage.objects()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get("testbucket", "testobject")).thenReturn(mockStorageGet); + when(mockStorageGet.execute()).thenReturn( + new StorageObject().setSize(BigInteger.valueOf(1000))); + + assertEquals(1000, gcsUtil.fileSize(GcsPath.fromComponents("testbucket", "testobject"))); + } + + @Test + public void testGetSizeBytesWhenFileNotFound() throws Exception { + MockLowLevelHttpResponse notFoundResponse = new MockLowLevelHttpResponse(); + notFoundResponse.setContent(""); + notFoundResponse.setStatusCode(HttpStatusCodes.STATUS_CODE_NOT_FOUND); + + MockHttpTransport mockTransport = + new MockHttpTransport.Builder().setLowLevelHttpResponse(notFoundResponse).build(); + + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + gcsUtil.setStorageClient(new Storage(mockTransport, Transport.getJsonFactory(), null)); + + thrown.expect(FileNotFoundException.class); + gcsUtil.fileSize(GcsPath.fromComponents("testbucket", "testobject")); + } + + @Test + public void testRetryFileSize() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Objects mockStorageObjects = Mockito.mock(Storage.Objects.class); + Storage.Objects.Get mockStorageGet = Mockito.mock(Storage.Objects.Get.class); + + BackOff mockBackOff = new AttemptBoundedExponentialBackOff(3, 200); + + when(mockStorage.objects()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get("testbucket", "testobject")).thenReturn(mockStorageGet); + when(mockStorageGet.execute()) + .thenThrow(new SocketTimeoutException("SocketException")) + .thenThrow(new SocketTimeoutException("SocketException")) + .thenReturn(new StorageObject().setSize(BigInteger.valueOf(1000))); + + assertEquals(1000, gcsUtil.fileSize(GcsPath.fromComponents("testbucket", "testobject"), + mockBackOff, new FastNanoClockAndSleeper())); + assertEquals(mockBackOff.nextBackOffMillis(), BackOff.STOP); + } + + @Test + public void testBucketExists() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Buckets mockStorageObjects = Mockito.mock(Storage.Buckets.class); + Storage.Buckets.Get mockStorageGet = Mockito.mock(Storage.Buckets.Get.class); + + BackOff mockBackOff = new AttemptBoundedExponentialBackOff(3, 200); + + when(mockStorage.buckets()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get("testbucket")).thenReturn(mockStorageGet); + when(mockStorageGet.execute()) + .thenThrow(new SocketTimeoutException("SocketException")) + .thenReturn(new Bucket()); + + assertTrue(gcsUtil.bucketExists(GcsPath.fromComponents("testbucket", "testobject"), + mockBackOff, new FastNanoClockAndSleeper())); + } + + @Test + public void testBucketDoesNotExistBecauseOfAccessError() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Buckets mockStorageObjects = Mockito.mock(Storage.Buckets.class); + Storage.Buckets.Get mockStorageGet = Mockito.mock(Storage.Buckets.Get.class); + + BackOff mockBackOff = new AttemptBoundedExponentialBackOff(3, 200); + GoogleJsonResponseException expectedException = + googleJsonResponseException(HttpStatusCodes.STATUS_CODE_FORBIDDEN, + "Waves hand mysteriously", "These aren't the buckets your looking for"); + + when(mockStorage.buckets()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get("testbucket")).thenReturn(mockStorageGet); + when(mockStorageGet.execute()) + .thenThrow(expectedException); + + assertFalse(gcsUtil.bucketExists(GcsPath.fromComponents("testbucket", "testobject"), + mockBackOff, new FastNanoClockAndSleeper())); + } + + @Test + public void testBucketDoesNotExist() throws IOException { + GcsOptions pipelineOptions = gcsOptionsWithTestCredential(); + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + + Storage mockStorage = Mockito.mock(Storage.class); + gcsUtil.setStorageClient(mockStorage); + + Storage.Buckets mockStorageObjects = Mockito.mock(Storage.Buckets.class); + Storage.Buckets.Get mockStorageGet = Mockito.mock(Storage.Buckets.Get.class); + + BackOff mockBackOff = new AttemptBoundedExponentialBackOff(3, 200); + + when(mockStorage.buckets()).thenReturn(mockStorageObjects); + when(mockStorageObjects.get("testbucket")).thenReturn(mockStorageGet); + when(mockStorageGet.execute()) + .thenThrow(googleJsonResponseException(HttpStatusCodes.STATUS_CODE_NOT_FOUND, + "It don't exist", "Nothing here to see")); + + assertFalse(gcsUtil.bucketExists(GcsPath.fromComponents("testbucket", "testobject"), + mockBackOff, new FastNanoClockAndSleeper())); + } + + @Test + public void testGCSChannelCloseIdempotent() throws IOException { + SeekableByteChannel channel = + new GoogleCloudStorageReadChannel(null, "dummybucket", "dummyobject", null, + new ClientRequestHelper()); + channel.close(); + channel.close(); + } + + /** + * Builds a fake GoogleJsonResponseException for testing API error handling. + */ + private static GoogleJsonResponseException googleJsonResponseException( + final int status, final String reason, final String message) throws IOException { + final JsonFactory jsonFactory = new JacksonFactory(); + HttpTransport transport = new MockHttpTransport() { + @Override + public LowLevelHttpRequest buildRequest(String method, String url) throws IOException { + ErrorInfo errorInfo = new ErrorInfo(); + errorInfo.setReason(reason); + errorInfo.setMessage(message); + errorInfo.setFactory(jsonFactory); + GenericJson error = new GenericJson(); + error.set("code", status); + error.set("errors", Arrays.asList(errorInfo)); + error.setFactory(jsonFactory); + GenericJson errorResponse = new GenericJson(); + errorResponse.set("error", error); + errorResponse.setFactory(jsonFactory); + return new MockLowLevelHttpRequest().setResponse( + new MockLowLevelHttpResponse().setContent(errorResponse.toPrettyString()) + .setContentType(Json.MEDIA_TYPE).setStatusCode(status)); + } + }; + HttpRequest request = + transport.createRequestFactory().buildGetRequest(HttpTesting.SIMPLE_GENERIC_URL); + request.setThrowExceptionOnExecuteError(false); + HttpResponse response = request.execute(); + return GoogleJsonResponseException.from(jsonFactory, response); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsProperties.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsProperties.java new file mode 100644 index 000000000000..ce06299d61cd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsProperties.java @@ -0,0 +1,718 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.TestUtils.KvMatcher; +import com.google.cloud.dataflow.sdk.WindowMatchers; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Properties of {@link GroupAlsoByWindowsDoFn}. + * + *

    Some properties may not hold of some implementations, due to restrictions on the context + * in which the implementation is applicable. For example, + * {@link GroupAlsoByWindowsViaIteratorsDoFn} does not support merging window functions. + */ +public class GroupAlsoByWindowsProperties { + + /** + * A factory of {@link GroupAlsoByWindowsDoFn} so that the various properties can provide + * the appropriate windowing strategy under test. + */ + public interface GroupAlsoByWindowsDoFnFactory { + GroupAlsoByWindowsDoFn + forStrategy(WindowingStrategy strategy); + } + + /** + * Tests that for empty input and the given {@link WindowingStrategy}, the provided GABW + * implementation produces no output. + * + *

    The input type is deliberately left as a wildcard, since it is not relevant. + */ + public static void emptyInputEmptyOutput( + GroupAlsoByWindowsDoFnFactory gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))); + + List result = runGABW( + gabwFactory, + windowingStrategy, + (K) null, // key should never be used + Collections.>emptyList()); + + assertThat(result.size(), equalTo(0)); + } + + /** + * Tests that for a simple sequence of elements on the same key, the given GABW implementation + * correctly groups them according to fixed windows. + */ + public static void groupsElementsIntoFixedWindows( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "key", + WindowedValue.of( + "v1", + new Instant(1), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(2), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(13), + Arrays.asList(window(10, 20)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getTimestamp(), equalTo(new Instant(1))); + assertThat(item0.getWindows(), contains(window(0, 10))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getTimestamp(), equalTo(new Instant(13))); + assertThat(item1.getWindows(), + contains(window(10, 20))); + } + + /** + * Tests that for a simple sequence of elements on the same key, the given GABW implementation + * correctly groups them into sliding windows. + * + *

    In the input here, each element occurs in multiple windows. + */ + public static void groupsElementsIntoSlidingWindows( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = WindowingStrategy.of( + SlidingWindows.of(Duration.millis(20)).every(Duration.millis(10))); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "key", + WindowedValue.of( + "v1", + new Instant(5), + Arrays.asList(window(-10, 10), window(0, 20)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(15), + Arrays.asList(window(0, 20), window(10, 30)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(3)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), contains("v1")); + assertThat(item0.getTimestamp(), equalTo(new Instant(5))); + assertThat(item0.getWindows(), + contains(window(-10, 10))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item1.getTimestamp(), equalTo(new Instant(10))); + assertThat(item1.getWindows(), + contains(window(0, 20))); + + WindowedValue>> item2 = result.get(2); + assertThat(item2.getValue().getValue(), contains("v2")); + assertThat(item2.getTimestamp(), equalTo(new Instant(20))); + assertThat(item2.getWindows(), + contains(window(10, 30))); + } + + /** + * Tests that for a simple sequence of elements on the same key, the given GABW implementation + * correctly groups and combines them according to sliding windows. + * + *

    In the input here, each element occurs in multiple windows. + */ + public static void combinesElementsInSlidingWindows( + GroupAlsoByWindowsDoFnFactory gabwFactory, + CombineFn combineFn) + throws Exception { + + WindowingStrategy windowingStrategy = WindowingStrategy.of( + SlidingWindows.of(Duration.millis(20)).every(Duration.millis(10))); + + List>> result = + runGABW(gabwFactory, windowingStrategy, "k", + WindowedValue.of( + 1L, + new Instant(5), + Arrays.asList(window(-10, 10), window(0, 20)), + PaneInfo.NO_FIRING), + WindowedValue.of( + 2L, + new Instant(15), + Arrays.asList(window(0, 20), window(10, 30)), + PaneInfo.NO_FIRING), + WindowedValue.of( + 4L, + new Instant(18), + Arrays.asList(window(0, 20), window(10, 30)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(3)); + + assertThat(result, contains( + WindowMatchers.isSingleWindowedValue( + KvMatcher.isKv( + equalTo("k"), + equalTo(combineFn.apply(ImmutableList.of(1L)))), + 5, // aggregate timestamp + -10, // window start + 10), // window end + WindowMatchers.isSingleWindowedValue( + KvMatcher.isKv( + equalTo("k"), + equalTo(combineFn.apply(ImmutableList.of(1L, 2L, 4L)))), + 10, // aggregate timestamp + 0, // window start + 20), // window end + WindowMatchers.isSingleWindowedValue( + KvMatcher.isKv( + equalTo("k"), + equalTo(combineFn.apply(ImmutableList.of(2L, 4L)))), + 20, // aggregate timestamp + 10, // window start + 30))); // window end + } + + /** + * Tests that the given GABW implementation correctly groups elements that fall into overlapping + * windows that are not merged. + */ + public static void groupsIntoOverlappingNonmergingWindows( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "key", + WindowedValue.of( + "v1", + new Instant(1), + Arrays.asList(window(0, 5)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(4), + Arrays.asList(window(1, 5)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(4), + Arrays.asList(window(0, 5)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v3")); + assertThat(item0.getTimestamp(), equalTo(new Instant(1))); + assertThat(item0.getWindows(), + contains(window(0, 5))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v2")); + assertThat(item1.getTimestamp(), equalTo(new Instant(4))); + assertThat(item1.getWindows(), + contains(window(1, 5))); + } + + /** + * Tests that the given GABW implementation correctly groups elements into merged sessions. + */ + public static void groupsElementsInMergedSessions( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(Sessions.withGapDuration(Duration.millis(10))); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "key", + WindowedValue.of( + "v1", + new Instant(0), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(5), + Arrays.asList(window(5, 15)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(15), + Arrays.asList(window(15, 25)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getTimestamp(), equalTo(new Instant(0))); + assertThat(item0.getWindows(), + contains(window(0, 15))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getTimestamp(), equalTo(new Instant(15))); + assertThat(item1.getWindows(), + contains(window(15, 25))); + } + + /** + * Tests that the given {@link GroupAlsoByWindowsDoFn} implementation combines elements per + * session window correctly according to the provided {@link CombineFn}. + */ + public static void combinesElementsPerSession( + GroupAlsoByWindowsDoFnFactory gabwFactory, + CombineFn combineFn) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(Sessions.withGapDuration(Duration.millis(10))); + + List>> result = + runGABW(gabwFactory, windowingStrategy, "k", + WindowedValue.of( + 1L, + new Instant(0), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + 2L, + new Instant(5), + Arrays.asList(window(5, 15)), + PaneInfo.NO_FIRING), + WindowedValue.of( + 4L, + new Instant(15), + Arrays.asList(window(15, 25)), + PaneInfo.NO_FIRING)); + + assertThat(result, contains( + WindowMatchers.isSingleWindowedValue( + KvMatcher.isKv( + equalTo("k"), + equalTo(combineFn.apply(ImmutableList.of(1L, 2L)))), + 0, // aggregate timestamp + 0, // window start + 15), // window end + WindowMatchers.isSingleWindowedValue( + KvMatcher.isKv( + equalTo("k"), + equalTo(combineFn.apply(ImmutableList.of(4L)))), + 15, // aggregate timestamp + 15, // window start + 25))); // window end + } + + /** + * Tests that for a simple sequence of elements on the same key, the given GABW implementation + * correctly groups them according to fixed windows and also sets the output timestamp + * according to the policy {@link OutputTimeFns#outputAtEndOfWindow()}. + */ + public static void groupsElementsIntoFixedWindowsWithEndOfWindowTimestamp( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))) + .withOutputTimeFn(OutputTimeFns.outputAtEndOfWindow()); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "key", + WindowedValue.of( + "v1", + new Instant(1), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(2), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(13), + Arrays.asList(window(10, 20)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getTimestamp(), equalTo(window(0, 10).maxTimestamp())); + assertThat(item0.getTimestamp(), + equalTo(Iterables.getOnlyElement(item0.getWindows()).maxTimestamp())); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getTimestamp(), equalTo(window(10, 20).maxTimestamp())); + assertThat(item1.getTimestamp(), + equalTo(Iterables.getOnlyElement(item1.getWindows()).maxTimestamp())); + } + + /** + * Tests that for a simple sequence of elements on the same key, the given GABW implementation + * correctly groups them according to fixed windows and also sets the output timestamp + * according to a custom {@link OutputTimeFn}. + */ + public static void groupsElementsIntoFixedWindowsWithCustomTimestamp( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))) + .withOutputTimeFn(new OutputTimeFn.Defaults() { + @Override + public Instant assignOutputTime(Instant inputTimestamp, IntervalWindow window) { + return inputTimestamp.isBefore(window.maxTimestamp()) + ? inputTimestamp.plus(1) : window.maxTimestamp(); + } + + @Override + public Instant combine(Instant outputTime, Instant otherOutputTime) { + return outputTime.isBefore(otherOutputTime) ? outputTime : otherOutputTime; + } + + @Override + public boolean dependsOnlyOnEarliestInputTimestamp() { + return true; + } + }); + + List>>> result = runGABW(gabwFactory, + windowingStrategy, "key", + WindowedValue.of("v1", new Instant(1), Arrays.asList(window(0, 10)), PaneInfo.NO_FIRING), + WindowedValue.of("v2", new Instant(2), Arrays.asList(window(0, 10)), PaneInfo.NO_FIRING), + WindowedValue.of("v3", new Instant(13), Arrays.asList(window(10, 20)), PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getWindows(), contains(window(0, 10))); + assertThat(item0.getTimestamp(), equalTo(new Instant(2))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getWindows(), contains(window(10, 20))); + assertThat(item1.getTimestamp(), equalTo(new Instant(14))); + } + + /** + * Tests that for a simple sequence of elements on the same key, the given GABW implementation + * correctly groups them according to fixed windows and also sets the output timestamp + * according to the policy {@link OutputTimeFns#outputAtLatestInputTimestamp()}. + */ + public static void groupsElementsIntoFixedWindowsWithLatestTimestamp( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))) + .withOutputTimeFn(OutputTimeFns.outputAtLatestInputTimestamp()); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "k", + WindowedValue.of( + "v1", + new Instant(1), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(2), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(13), + Arrays.asList(window(10, 20)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getWindows(), contains(window(0, 10))); + assertThat(item0.getTimestamp(), equalTo(new Instant(2))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getWindows(), contains(window(10, 20))); + assertThat(item1.getTimestamp(), equalTo(new Instant(13))); + } + + /** + * Tests that the given GABW implementation correctly groups elements into merged sessions + * with output timestamps at the end of the merged window. + */ + public static void groupsElementsInMergedSessionsWithEndOfWindowTimestamp( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(Sessions.withGapDuration(Duration.millis(10))) + .withOutputTimeFn(OutputTimeFns.outputAtEndOfWindow()); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "k", + WindowedValue.of( + "v1", + new Instant(0), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(5), + Arrays.asList(window(5, 15)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(15), + Arrays.asList(window(15, 25)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getWindows(), contains(window(0, 15))); + assertThat(item0.getTimestamp(), + equalTo(Iterables.getOnlyElement(item0.getWindows()).maxTimestamp())); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getWindows(), contains(window(15, 25))); + assertThat(item1.getTimestamp(), + equalTo(Iterables.getOnlyElement(item1.getWindows()).maxTimestamp())); + } + + /** + * Tests that the given GABW implementation correctly groups elements into merged sessions + * with output timestamps at the end of the merged window. + */ + public static void groupsElementsInMergedSessionsWithLatestTimestamp( + GroupAlsoByWindowsDoFnFactory> gabwFactory) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(Sessions.withGapDuration(Duration.millis(10))) + .withOutputTimeFn(OutputTimeFns.outputAtLatestInputTimestamp()); + + List>>> result = + runGABW(gabwFactory, windowingStrategy, "k", + WindowedValue.of( + "v1", + new Instant(0), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v2", + new Instant(5), + Arrays.asList(window(5, 15)), + PaneInfo.NO_FIRING), + WindowedValue.of( + "v3", + new Instant(15), + Arrays.asList(window(15, 25)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue>> item0 = result.get(0); + assertThat(item0.getValue().getValue(), containsInAnyOrder("v1", "v2")); + assertThat(item0.getWindows(), contains(window(0, 15))); + assertThat(item0.getTimestamp(), equalTo(new Instant(5))); + + WindowedValue>> item1 = result.get(1); + assertThat(item1.getValue().getValue(), contains("v3")); + assertThat(item1.getWindows(), contains(window(15, 25))); + assertThat(item1.getTimestamp(), equalTo(new Instant(15))); + } + + /** + * Tests that the given {@link GroupAlsoByWindowsDoFn} implementation combines elements per + * session window correctly according to the provided {@link CombineFn}. + */ + public static void combinesElementsPerSessionWithEndOfWindowTimestamp( + GroupAlsoByWindowsDoFnFactory gabwFactory, + CombineFn combineFn) + throws Exception { + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(Sessions.withGapDuration(Duration.millis(10))) + .withOutputTimeFn(OutputTimeFns.outputAtEndOfWindow()); + + + List>> result = + runGABW(gabwFactory, windowingStrategy, "k", + WindowedValue.of( + 1L, + new Instant(0), + Arrays.asList(window(0, 10)), + PaneInfo.NO_FIRING), + WindowedValue.of( + 2L, + new Instant(5), + Arrays.asList(window(5, 15)), + PaneInfo.NO_FIRING), + WindowedValue.of( + 4L, + new Instant(15), + Arrays.asList(window(15, 25)), + PaneInfo.NO_FIRING)); + + assertThat(result.size(), equalTo(2)); + + WindowedValue> item0 = result.get(0); + assertThat(item0.getValue().getValue(), equalTo(combineFn.apply(ImmutableList.of(1L, 2L)))); + assertThat(item0.getWindows(), contains(window(0, 15))); + assertThat(item0.getTimestamp(), + equalTo(Iterables.getOnlyElement(item0.getWindows()).maxTimestamp())); + + WindowedValue> item1 = result.get(1); + assertThat(item1.getValue().getValue(), equalTo(combineFn.apply(ImmutableList.of(4L)))); + assertThat(item1.getWindows(), contains(window(15, 25))); + assertThat(item1.getTimestamp(), + equalTo(Iterables.getOnlyElement(item1.getWindows()).maxTimestamp())); + } + + @SafeVarargs + private static + List>> runGABW( + GroupAlsoByWindowsDoFnFactory gabwFactory, + WindowingStrategy windowingStrategy, + K key, + WindowedValue... values) { + return runGABW(gabwFactory, windowingStrategy, key, Arrays.asList(values)); + } + + private static + List>> runGABW( + GroupAlsoByWindowsDoFnFactory gabwFactory, + WindowingStrategy windowingStrategy, + K key, + Collection> values) { + + TupleTag> outputTag = new TupleTag<>(); + DoFnRunnerBase.ListOutputManager outputManager = new DoFnRunnerBase.ListOutputManager(); + + DoFnRunner>>, KV> runner = + makeRunner( + gabwFactory.forStrategy(windowingStrategy), + windowingStrategy, + outputTag, + outputManager); + + runner.startBundle(); + + if (values.size() > 0) { + runner.processElement(WindowedValue.valueInEmptyWindows( + KV.of(key, (Iterable>) values))); + } + + runner.finishBundle(); + + List>> result = outputManager.getOutput(outputTag); + + // Sanity check for corruption + for (WindowedValue> elem : result) { + assertThat(elem.getValue().getKey(), equalTo(key)); + } + + return result; + } + + private static + DoFnRunner>>, KV> + makeRunner( + GroupAlsoByWindowsDoFn fn, + WindowingStrategy windowingStrategy, + TupleTag> outputTag, + DoFnRunners.OutputManager outputManager) { + + ExecutionContext executionContext = DirectModeExecutionContext.create(); + CounterSet counters = new CounterSet(); + + return DoFnRunners.simpleRunner( + PipelineOptionsFactory.create(), + fn, + NullSideInputReader.empty(), + outputManager, + outputTag, + new ArrayList>(), + executionContext.getOrCreateStepContext("GABWStep", "GABWTransform", null), + counters.getAddCounterMutator(), + windowingStrategy); + } + + private static BoundedWindow window(long start, long end) { + return new IntervalWindow(new Instant(start), new Instant(end)); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFnTest.java new file mode 100644 index 000000000000..1e8458de76ca --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFnTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowsProperties.GroupAlsoByWindowsDoFnFactory; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link GroupAlsoByWindowsViaOutputBufferDoFn}. + */ +@RunWith(JUnit4.class) +public class GroupAlsoByWindowsViaOutputBufferDoFnTest { + + private class BufferingGABWViaOutputBufferDoFnFactory + implements GroupAlsoByWindowsDoFnFactory> { + + private final Coder inputCoder; + + public BufferingGABWViaOutputBufferDoFnFactory(Coder inputCoder) { + this.inputCoder = inputCoder; + } + + @Override + public GroupAlsoByWindowsDoFn, W> + forStrategy(WindowingStrategy windowingStrategy) { + return new GroupAlsoByWindowsViaOutputBufferDoFn, W>( + windowingStrategy, + SystemReduceFn.buffering(inputCoder)); + } + } + + @Test + public void testEmptyInputEmptyOutput() throws Exception { + GroupAlsoByWindowsProperties.emptyInputEmptyOutput( + new BufferingGABWViaOutputBufferDoFnFactory<>(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoFixedWindows() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsIntoFixedWindows( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoSlidingWindows() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsIntoSlidingWindows( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsIntoOverlappingNonmergingWindows() throws Exception { + GroupAlsoByWindowsProperties.groupsIntoOverlappingNonmergingWindows( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsIntoSessions() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsInMergedSessions( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoFixedWindowsWithEndOfWindowTimestamp() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsIntoFixedWindowsWithEndOfWindowTimestamp( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoFixedWindowsWithLatestTimestamp() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsIntoFixedWindowsWithLatestTimestamp( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoFixedWindowsWithCustomTimestamp() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsIntoFixedWindowsWithCustomTimestamp( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoSessionsWithEndOfWindowTimestamp() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsInMergedSessionsWithEndOfWindowTimestamp( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } + + @Test + public void testGroupsElementsIntoSessionsWithLatestTimestamp() throws Exception { + GroupAlsoByWindowsProperties.groupsElementsInMergedSessionsWithLatestTimestamp( + new BufferingGABWViaOutputBufferDoFnFactory(StringUtf8Coder.of())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOChannelUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOChannelUtilsTest.java new file mode 100644 index 000000000000..cb9e9785e784 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IOChannelUtilsTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.common.io.Files; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.nio.charset.StandardCharsets; + +/** + * Tests for IOChannelUtils. + */ +@RunWith(JUnit4.class) +public class IOChannelUtilsTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testShardFormatExpansion() { + assertEquals("output-001-of-123.txt", + IOChannelUtils.constructName("output", "-SSS-of-NNN", + ".txt", + 1, 123)); + + assertEquals("out.txt/part-00042", + IOChannelUtils.constructName("out.txt", "/part-SSSSS", "", + 42, 100)); + + assertEquals("out.txt", + IOChannelUtils.constructName("ou", "t.t", "xt", 1, 1)); + + assertEquals("out0102shard.txt", + IOChannelUtils.constructName("out", "SSNNshard", ".txt", 1, 2)); + + assertEquals("out-2/1.part-1-of-2.txt", + IOChannelUtils.constructName("out", "-N/S.part-S-of-N", + ".txt", 1, 2)); + } + + @Test(expected = IllegalArgumentException.class) + public void testShardNameCollision() throws Exception { + File outFolder = tmpFolder.newFolder(); + String filename = outFolder.toPath().resolve("output").toString(); + + IOChannelUtils.create(filename, "", "", 2, "text").close(); + fail("IOChannelUtils.create expected to fail due " + + "to filename collision"); + } + + @Test + public void testLargeShardCount() { + Assert.assertEquals("out-100-of-5000.txt", + IOChannelUtils.constructName("out", "-SS-of-NN", ".txt", + 100, 5000)); + } + + @Test + public void testGetSizeBytes() throws Exception { + String data = "TestData"; + File file = tmpFolder.newFile(); + Files.write(data, file, StandardCharsets.UTF_8); + assertEquals(data.length(), IOChannelUtils.getSizeBytes(file.getPath())); + } + + @Test + public void testResolve() throws Exception { + String expected = tmpFolder.getRoot().toPath().resolve("aa").toString(); + assertEquals(expected, IOChannelUtils.resolve(tmpFolder.getRoot().toString(), "aa")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/InstanceBuilderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/InstanceBuilderTest.java new file mode 100644 index 000000000000..21245ed74bc9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/InstanceBuilderTest.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests of InstanceBuilder. + */ +@RunWith(JUnit4.class) +@SuppressWarnings("rawtypes") +public class InstanceBuilderTest { + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @SuppressWarnings("unused") + private static TupleTag createTag(String id) { + return new TupleTag(id); + } + + @Test + public void testFullNameLookup() throws Exception { + TupleTag tag = InstanceBuilder.ofType(TupleTag.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("createTag") + .withArg(String.class, "hello world!") + .build(); + + Assert.assertEquals("hello world!", tag.getId()); + } + + @Test + public void testConstructor() throws Exception { + TupleTag tag = InstanceBuilder.ofType(TupleTag.class) + .withArg(String.class, "hello world!") + .build(); + + Assert.assertEquals("hello world!", tag.getId()); + } + + @Test + public void testBadMethod() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("Unable to find factory method")); + + InstanceBuilder.ofType(String.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("nonexistantFactoryMethod") + .withArg(String.class, "hello") + .withArg(String.class, " world!") + .build(); + } + + @Test + public void testBadArgs() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("Unable to find factory method")); + + InstanceBuilder.ofType(TupleTag.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("createTag") + .withArg(String.class, "hello") + .withArg(Integer.class, 42) + .build(); + } + + @Test + public void testBadReturnType() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("must be assignable to String")); + + InstanceBuilder.ofType(String.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .fromFactoryMethod("createTag") + .withArg(String.class, "hello") + .build(); + } + + @Test + public void testWrongType() throws Exception { + expectedEx.expect(RuntimeException.class); + expectedEx.expectMessage( + Matchers.containsString("must be assignable to TupleTag")); + + InstanceBuilder.ofType(TupleTag.class) + .fromClassName(InstanceBuilderTest.class.getName()) + .build(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IntervalBoundedExponentialBackOffTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IntervalBoundedExponentialBackOffTest.java new file mode 100644 index 000000000000..8ad7aa6592fe --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/IntervalBoundedExponentialBackOffTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link IntervalBoundedExponentialBackOff}. */ +@RunWith(JUnit4.class) +public class IntervalBoundedExponentialBackOffTest { + @Rule public ExpectedException exception = ExpectedException.none(); + + + @Test + public void testUsingInvalidInitialInterval() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Initial interval must be greater than zero."); + new IntervalBoundedExponentialBackOff(1000, 0L); + } + + @Test + public void testUsingInvalidMaximumInterval() throws Exception { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Maximum interval must be greater than zero."); + new IntervalBoundedExponentialBackOff(-1, 10L); + } + + @Test + public void testThatcertainNumberOfAttemptsReachesMaxInterval() throws Exception { + IntervalBoundedExponentialBackOff backOff = new IntervalBoundedExponentialBackOff(1000, 500); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(374L), lessThan(1126L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + } + + @Test + public void testThatResettingAllowsReuse() throws Exception { + IntervalBoundedExponentialBackOff backOff = new IntervalBoundedExponentialBackOff(1000, 500); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(374L), lessThan(1126L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + backOff.reset(); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(249L), lessThan(751L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(374L), lessThan(1126L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + } + + @Test + public void testAtMaxInterval() throws Exception { + IntervalBoundedExponentialBackOff backOff = new IntervalBoundedExponentialBackOff(1000, 500); + assertFalse(backOff.atMaxInterval()); + backOff.nextBackOffMillis(); + assertFalse(backOff.atMaxInterval()); + backOff.nextBackOffMillis(); + assertTrue(backOff.atMaxInterval()); + assertThat(backOff.nextBackOffMillis(), allOf(greaterThanOrEqualTo(500L), + lessThanOrEqualTo(1500L))); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItemCoderTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItemCoderTest.java new file mode 100644 index 000000000000..e6cd454fbec3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/KeyedWorkItemCoderTest.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link KeyedWorkItems}. + */ +@RunWith(JUnit4.class) +public class KeyedWorkItemCoderTest { + @Test + public void testCoderProperties() throws Exception { + CoderProperties.coderSerializable( + KeyedWorkItemCoder.of(StringUtf8Coder.of(), VarIntCoder.of(), GlobalWindow.Coder.INSTANCE)); + } + + @Test + public void testEncodeDecodeEqual() throws Exception { + Iterable timers = + ImmutableList.of( + TimerData.of(StateNamespaces.global(), new Instant(500L), TimeDomain.EVENT_TIME)); + Iterable> elements = + ImmutableList.of( + WindowedValue.valueInGlobalWindow(1), + WindowedValue.valueInGlobalWindow(4), + WindowedValue.valueInGlobalWindow(8)); + + KeyedWorkItemCoder coder = + KeyedWorkItemCoder.of(StringUtf8Coder.of(), VarIntCoder.of(), GlobalWindow.Coder.INSTANCE); + + CoderProperties.coderDecodeEncodeEqual(coder, KeyedWorkItems.workItem("foo", timers, elements)); + CoderProperties.coderDecodeEncodeEqual(coder, KeyedWorkItems.elementsWorkItem("foo", elements)); + CoderProperties.coderDecodeEncodeEqual( + coder, KeyedWorkItems.timersWorkItem("foo", timers)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/LateDataDroppingDoFnRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/LateDataDroppingDoFnRunnerTest.java new file mode 100644 index 000000000000..c951d4c9a174 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/LateDataDroppingDoFnRunnerTest.java @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.LateDataDroppingDoFnRunner.LateDataFilter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.Arrays; + +/** + * Unit tests for {@link LateDataDroppingDoFnRunner}. + */ +@RunWith(JUnit4.class) +public class LateDataDroppingDoFnRunnerTest { + private static final FixedWindows WINDOW_FN = FixedWindows.of(Duration.millis(10)); + + @Mock private TimerInternals mockTimerInternals; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testLateDataFilter() throws Exception { + when(mockTimerInternals.currentInputWatermarkTime()).thenReturn(new Instant(15L)); + + InMemoryLongSumAggregator droppedDueToLateness = + new InMemoryLongSumAggregator("droppedDueToLateness"); + LateDataFilter lateDataFilter = new LateDataFilter( + WindowingStrategy.of(WINDOW_FN), mockTimerInternals, droppedDueToLateness); + + Iterable> actual = lateDataFilter.filter( + "a", + ImmutableList.of( + createDatum(13, 13L), + createDatum(5, 5L), // late element, earlier than 4L. + createDatum(16, 16L), + createDatum(18, 18L))); + + Iterable> expected = ImmutableList.of( + createDatum(13, 13L), + createDatum(16, 16L), + createDatum(18, 18L)); + assertThat(expected, containsInAnyOrder(Iterables.toArray(actual, WindowedValue.class))); + assertEquals(1, droppedDueToLateness.sum); + } + + private WindowedValue createDatum(T element, long timestampMillis) { + Instant timestamp = new Instant(timestampMillis); + return WindowedValue.of( + element, + timestamp, + Arrays.asList(WINDOW_FN.assignWindow(timestamp)), + PaneInfo.NO_FIRING); + } + + private static class InMemoryLongSumAggregator implements Aggregator { + private final String name; + private long sum = 0; + + public InMemoryLongSumAggregator(String name) { + this.name = name; + } + + @Override + public void addValue(Long value) { + sum += value; + } + + @Override + public String getName() { + return name; + } + + @Override + public CombineFn getCombineFn() { + return new Sum.SumLongFn(); + } + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MergingActiveWindowSetTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MergingActiveWindowSetTest.java new file mode 100644 index 000000000000..8f66e5a41c32 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MergingActiveWindowSetTest.java @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.util.state.InMemoryStateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collection; + +/** + * Test NonMergingActiveWindowSet. + */ +@RunWith(JUnit4.class) +public class MergingActiveWindowSetTest { + private Sessions windowFn; + private StateInternals state; + private MergingActiveWindowSet set; + + @Before + public void before() { + windowFn = Sessions.withGapDuration(Duration.millis(10)); + state = InMemoryStateInternals.forKey("dummyKey"); + set = new MergingActiveWindowSet<>(windowFn, state); + } + + @After + public void after() { + set = null; + state = null; + windowFn = null; + } + + private void add(final long instant) { + System.out.println("ADD " + instant); + final Object element = new Long(instant); + Sessions.AssignContext context = windowFn.new AssignContext() { + @Override + public Object element() { + return element; + } + + @Override + public Instant timestamp() { + return new Instant(instant); + } + + @Override + public Collection windows() { + return ImmutableList.of(); + } + }; + + for (IntervalWindow window : windowFn.assignWindows(context)) { + set.addNew(window); + } + } + + private void merge(ActiveWindowSet.MergeCallback callback) throws Exception { + System.out.println("MERGE"); + set.merge(callback); + set.checkInvariants(); + System.out.println(set); + } + + private void pruneAndPersist() { + System.out.println("PRUNE"); + set.removeEphemeralWindows(); + set.checkInvariants(); + System.out.println(set); + set.persist(); + } + + private IntervalWindow window(long start, long size) { + return new IntervalWindow(new Instant(start), new Duration(size)); + } + + @Test + public void test() throws Exception { + @SuppressWarnings("unchecked") + ActiveWindowSet.MergeCallback callback = + mock(ActiveWindowSet.MergeCallback.class); + + // NEW 1+10 + // NEW 2+10 + // NEW 15+10 + // => + // ACTIVE 1+11 (target 1+11) + // EPHEMERAL 1+10 -> 1+11 + // EPHEMERAL 2+10 -> 1+11 + // ACTIVE 15+10 (target 15+10) + add(1); + add(2); + add(15); + merge(callback); + verify(callback).onMerge(ImmutableList.of(window(1, 10), window(2, 10)), + ImmutableList.of(), window(1, 11)); + assertEquals(ImmutableSet.of(window(1, 11), window(15, 10)), set.getActiveWindows()); + assertEquals(window(1, 11), set.representative(window(1, 10))); + assertEquals(window(1, 11), set.representative(window(2, 10))); + assertEquals(window(1, 11), set.representative(window(1, 11))); + assertEquals(window(15, 10), set.representative(window(15, 10))); + assertEquals( + ImmutableSet.of(window(1, 11)), set.readStateAddresses(window(1, 11))); + assertEquals( + ImmutableSet.of(window(15, 10)), set.readStateAddresses(window(15, 10))); + + // NEW 3+10 + // => + // ACTIVE 1+12 (target 1+11) + // EPHEMERAL 3+10 -> 1+12 + // ACTIVE 15+10 (target 15+10) + add(3); + merge(callback); + verify(callback).onMerge(ImmutableList.of(window(1, 11), window(3, 10)), + ImmutableList.of(window(1, 11)), window(1, 12)); + assertEquals(ImmutableSet.of(window(1, 12), window(15, 10)), set.getActiveWindows()); + assertEquals(window(1, 12), set.representative(window(3, 10))); + + // NEW 8+10 + // => + // ACTIVE 1+24 (target 1+11, 15+10) + // MERGED 1+11 -> 1+24 + // MERGED 15+10 -> 1+24 + // EPHEMERAL 1+12 -> 1+24 + add(8); + merge(callback); + verify(callback).onMerge(ImmutableList.of(window(1, 12), window(8, 10), window(15, 10)), + ImmutableList.of(window(1, 12), window(15, 10)), window(1, 24)); + assertEquals(ImmutableSet.of(window(1, 24)), set.getActiveWindows()); + assertEquals(window(1, 24), set.representative(window(1, 12))); + assertEquals(window(1, 24), set.representative(window(1, 11))); + assertEquals(window(1, 24), set.representative(window(15, 10))); + + // NEW 9+10 + // => + // ACTIVE 1+24 (target 1+11, 15+10) + add(9); + merge(callback); + verify(callback).onMerge(ImmutableList.of(window(1, 24), window(9, 10)), + ImmutableList.of(window(1, 24)), window(1, 24)); + + pruneAndPersist(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MonitoringUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MonitoringUtilTest.java new file mode 100644 index 000000000000..c94450d6d204 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MonitoringUtilTest.java @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.Dataflow; +import com.google.api.services.dataflow.model.JobMessage; +import com.google.api.services.dataflow.model.ListJobMessagesResponse; +import com.google.cloud.dataflow.sdk.PipelineResult.State; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for MonitoringUtil. + */ +@RunWith(JUnit4.class) +public class MonitoringUtilTest { + private static final String PROJECT_ID = "someProject"; + private static final String JOB_ID = "1234"; + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testGetJobMessages() throws IOException { + Dataflow.Projects.Jobs.Messages mockMessages = mock(Dataflow.Projects.Jobs.Messages.class); + + // Two requests are needed to get all the messages. + Dataflow.Projects.Jobs.Messages.List firstRequest = + mock(Dataflow.Projects.Jobs.Messages.List.class); + Dataflow.Projects.Jobs.Messages.List secondRequest = + mock(Dataflow.Projects.Jobs.Messages.List.class); + + when(mockMessages.list(PROJECT_ID, JOB_ID)).thenReturn(firstRequest).thenReturn(secondRequest); + + ListJobMessagesResponse firstResponse = new ListJobMessagesResponse(); + firstResponse.setJobMessages(new ArrayList()); + for (int i = 0; i < 100; ++i) { + JobMessage message = new JobMessage(); + message.setId("message_" + i); + message.setTime(TimeUtil.toCloudTime(new Instant(i))); + firstResponse.getJobMessages().add(message); + } + String pageToken = "page_token"; + firstResponse.setNextPageToken(pageToken); + + ListJobMessagesResponse secondResponse = new ListJobMessagesResponse(); + secondResponse.setJobMessages(new ArrayList()); + for (int i = 100; i < 150; ++i) { + JobMessage message = new JobMessage(); + message.setId("message_" + i); + message.setTime(TimeUtil.toCloudTime(new Instant(i))); + secondResponse.getJobMessages().add(message); + } + + when(firstRequest.execute()).thenReturn(firstResponse); + when(secondRequest.execute()).thenReturn(secondResponse); + + MonitoringUtil util = new MonitoringUtil(PROJECT_ID, mockMessages); + + List messages = util.getJobMessages(JOB_ID, -1); + + verify(secondRequest).setPageToken(pageToken); + + assertEquals(150, messages.size()); + } + + @Test + public void testToStateCreatesState() { + String stateName = "JOB_STATE_DONE"; + + State result = MonitoringUtil.toState(stateName); + + assertEquals(State.DONE, result); + } + + @Test + public void testToStateWithNullReturnsUnknown() { + String stateName = null; + + State result = MonitoringUtil.toState(stateName); + + assertEquals(State.UNKNOWN, result); + } + + @Test + public void testToStateWithOtherValueReturnsUnknown() { + String stateName = "FOO_BAR_BAZ"; + + State result = MonitoringUtil.toState(stateName); + + assertEquals(State.UNKNOWN, result); + } + + @Test + public void testDontOverrideEndpointWithDefaultApi() { + DataflowPipelineOptions options = + PipelineOptionsFactory.create().as(DataflowPipelineOptions.class); + options.setProject(PROJECT_ID); + options.setGcpCredential(new TestCredential()); + String cancelCommand = MonitoringUtil.getGcloudCancelCommand(options, JOB_ID); + assertEquals("gcloud alpha dataflow jobs --project=someProject cancel 1234", cancelCommand); + } + + @Test + public void testOverridesEndpointWithStagedDataflowEndpoint() { + DataflowPipelineOptions options = + PipelineOptionsFactory.create().as(DataflowPipelineOptions.class); + options.setProject(PROJECT_ID); + options.setGcpCredential(new TestCredential()); + String stagingDataflowEndpoint = "v0neverExisted"; + options.setDataflowEndpoint(stagingDataflowEndpoint); + String cancelCommand = MonitoringUtil.getGcloudCancelCommand(options, JOB_ID); + assertEquals( + "CLOUDSDK_API_ENDPOINT_OVERRIDES_DATAFLOW=https://dataflow.googleapis.com/v0neverExisted/ " + + "gcloud alpha dataflow jobs --project=someProject cancel 1234", + cancelCommand); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MutationDetectorsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MutationDetectorsTest.java new file mode 100644 index 000000000000..0f77679158d1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MutationDetectorsTest.java @@ -0,0 +1,148 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +/** + * Tests for {@link MutationDetectors}. + */ +@RunWith(JUnit4.class) +public class MutationDetectorsTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} detects a mutation to a list. + */ + @Test + public void testMutatingList() throws Exception { + List value = Arrays.asList(1, 2, 3, 4); + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, ListCoder.of(VarIntCoder.of())); + value.set(0, 37); + + thrown.expect(IllegalMutationException.class); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} does not false positive on a + * {@link LinkedList} that will clone as an {@code ArrayList}. + */ + @Test + public void testUnmodifiedLinkedList() throws Exception { + List value = Lists.newLinkedList(Arrays.asList(1, 2, 3, 4)); + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, ListCoder.of(VarIntCoder.of())); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} does not false positive on a + * {@link LinkedList} coded as an {@link Iterable}. + */ + @Test + public void testImmutableList() throws Exception { + List value = Lists.newLinkedList(Arrays.asList(1, 2, 3, 4)); + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, IterableCoder.of(VarIntCoder.of())); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} does not false positive on a + * {@link Set} coded as an {@link Iterable}. + */ + @Test + public void testImmutableSet() throws Exception { + Set value = Sets.newHashSet(Arrays.asList(1, 2, 3, 4)); + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, IterableCoder.of(VarIntCoder.of())); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} does not false positive on an + * {@link Iterable} that is not known to be bounded; after coder-based cloning the bound + * will be known and it will be a {@link List} so it will encode more compactly the second + * time around. + */ + @Test + public void testImmutableIterable() throws Exception { + Iterable value = FluentIterable.from(Arrays.asList(1, 2, 3, 4)).cycle().limit(50); + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, IterableCoder.of(VarIntCoder.of())); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} detects a mutation to a byte array. + */ + @Test + public void testMutatingArray() throws Exception { + byte[] value = new byte[]{0x1, 0x2, 0x3, 0x4}; + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, ByteArrayCoder.of()); + value[0] = 0xa; + thrown.expect(IllegalMutationException.class); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} does not false positive on an + * array, even though it will decode is another array which Java will not say is {@code equals}. + */ + @Test + public void testUnmodifiedArray() throws Exception { + byte[] value = new byte[]{0x1, 0x2, 0x3, 0x4}; + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, ByteArrayCoder.of()); + detector.verifyUnmodified(); + } + + /** + * Tests that {@link MutationDetectors#forValueWithCoder} does not false positive on an + * list of arrays, even when some array is set to a deeply equal array that is not {@code equals}. + */ + @Test + public void testEquivalentListOfArrays() throws Exception { + List value = Arrays.asList(new byte[]{0x1}, new byte[]{0x2, 0x3}, new byte[]{0x4}); + MutationDetector detector = + MutationDetectors.forValueWithCoder(value, ListCoder.of(ByteArrayCoder.of())); + value.set(0, new byte[]{0x1}); + detector.verifyUnmodified(); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PTupleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PTupleTest.java new file mode 100644 index 000000000000..51b9938e1ea6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PTupleTest.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link PTuple}. */ +@RunWith(JUnit4.class) +public final class PTupleTest { + @Test + public void accessingNullVoidValuesShouldNotCauseExceptions() { + TupleTag tag = new TupleTag() {}; + PTuple tuple = PTuple.of(tag, null); + assertTrue(tuple.has(tag)); + assertThat(tuple.get(tag), is(nullValue())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PackageUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PackageUtilTest.java new file mode 100644 index 000000000000..e051219b78ef --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PackageUtilTest.java @@ -0,0 +1,482 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpStatusCodes; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.LowLevelHttpRequest; +import com.google.api.client.json.GenericJson; +import com.google.api.client.json.Json; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.testing.http.HttpTesting; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import com.google.api.services.dataflow.model.DataflowPackage; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; +import com.google.cloud.dataflow.sdk.testing.FastNanoClockAndSleeper; +import com.google.cloud.dataflow.sdk.util.PackageUtil.PackageAttributes; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.io.Files; +import com.google.common.io.LineReader; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.Pipe; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.regex.Pattern; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; + +/** Tests for PackageUtil. */ +@RunWith(JUnit4.class) +public class PackageUtilTest { + @Rule public ExpectedLogs logged = ExpectedLogs.none(PackageUtil.class); + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule + public FastNanoClockAndSleeper fastNanoClockAndSleeper = new FastNanoClockAndSleeper(); + + @Mock + GcsUtil mockGcsUtil; + + // 128 bits, base64 encoded is 171 bits, rounds to 22 bytes + private static final String HASH_PATTERN = "[a-zA-Z0-9+-]{22}"; + + // Hamcrest matcher to assert a string matches a pattern + private static class RegexMatcher extends BaseMatcher { + private final Pattern pattern; + + public RegexMatcher(String regex) { + this.pattern = Pattern.compile(regex); + } + + @Override + public boolean matches(Object o) { + if (!(o instanceof String)) { + return false; + } + return pattern.matcher((String) o).matches(); + } + + @Override + public void describeTo(Description description) { + description.appendText(String.format("matches regular expression %s", pattern)); + } + + public static RegexMatcher matches(String regex) { + return new RegexMatcher(regex); + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + pipelineOptions.setGcsUtil(mockGcsUtil); + + IOChannelUtils.registerStandardIOFactories(pipelineOptions); + } + + private File makeFileWithContents(String name, String contents) throws Exception { + File tmpFile = tmpFolder.newFile(name); + Files.write(contents, tmpFile, StandardCharsets.UTF_8); + tmpFile.setLastModified(0); // required for determinism with directories + return tmpFile; + } + + static final String STAGING_PATH = GcsPath.fromComponents("somebucket", "base/path").toString(); + private static PackageAttributes makePackageAttributes(File file, String overridePackageName) { + return PackageUtil.createPackageAttributes(file, STAGING_PATH, overridePackageName); + } + + @Test + public void testFileWithExtensionPackageNamingAndSize() throws Exception { + String contents = "This is a test!"; + File tmpFile = makeFileWithContents("file.txt", contents); + PackageAttributes attr = makePackageAttributes(tmpFile, null); + DataflowPackage target = attr.getDataflowPackage(); + + assertThat(target.getName(), RegexMatcher.matches("file-" + HASH_PATTERN + ".txt")); + assertThat(target.getLocation(), equalTo(STAGING_PATH + '/' + target.getName())); + assertThat(attr.getSize(), equalTo((long) contents.length())); + } + + @Test + public void testPackageNamingWithFileNoExtension() throws Exception { + File tmpFile = makeFileWithContents("file", "This is a test!"); + DataflowPackage target = makePackageAttributes(tmpFile, null).getDataflowPackage(); + + assertThat(target.getName(), RegexMatcher.matches("file-" + HASH_PATTERN)); + assertThat(target.getLocation(), equalTo(STAGING_PATH + '/' + target.getName())); + } + + @Test + public void testPackageNamingWithDirectory() throws Exception { + File tmpDirectory = tmpFolder.newFolder("folder"); + DataflowPackage target = makePackageAttributes(tmpDirectory, null).getDataflowPackage(); + + assertThat(target.getName(), RegexMatcher.matches("folder-" + HASH_PATTERN + ".jar")); + assertThat(target.getLocation(), equalTo(STAGING_PATH + '/' + target.getName())); + } + + @Test + public void testPackageNamingWithFilesHavingSameContentsAndSameNames() throws Exception { + File tmpDirectory1 = tmpFolder.newFolder("folder1", "folderA"); + makeFileWithContents("folder1/folderA/sameName", "This is a test!"); + DataflowPackage target1 = makePackageAttributes(tmpDirectory1, null).getDataflowPackage(); + + File tmpDirectory2 = tmpFolder.newFolder("folder2", "folderA"); + makeFileWithContents("folder2/folderA/sameName", "This is a test!"); + DataflowPackage target2 = makePackageAttributes(tmpDirectory2, null).getDataflowPackage(); + + assertEquals(target1.getName(), target2.getName()); + assertEquals(target1.getLocation(), target2.getLocation()); + } + + @Test + public void testPackageNamingWithFilesHavingSameContentsButDifferentNames() throws Exception { + File tmpDirectory1 = tmpFolder.newFolder("folder1", "folderA"); + makeFileWithContents("folder1/folderA/uniqueName1", "This is a test!"); + DataflowPackage target1 = makePackageAttributes(tmpDirectory1, null).getDataflowPackage(); + + File tmpDirectory2 = tmpFolder.newFolder("folder2", "folderA"); + makeFileWithContents("folder2/folderA/uniqueName2", "This is a test!"); + DataflowPackage target2 = makePackageAttributes(tmpDirectory2, null).getDataflowPackage(); + + assertNotEquals(target1.getName(), target2.getName()); + assertNotEquals(target1.getLocation(), target2.getLocation()); + } + + @Test + public void testPackageNamingWithDirectoriesHavingSameContentsButDifferentNames() + throws Exception { + File tmpDirectory1 = tmpFolder.newFolder("folder1", "folderA"); + tmpFolder.newFolder("folder1", "folderA", "uniqueName1"); + DataflowPackage target1 = makePackageAttributes(tmpDirectory1, null).getDataflowPackage(); + + File tmpDirectory2 = tmpFolder.newFolder("folder2", "folderA"); + tmpFolder.newFolder("folder2", "folderA", "uniqueName2"); + DataflowPackage target2 = makePackageAttributes(tmpDirectory2, null).getDataflowPackage(); + + assertNotEquals(target1.getName(), target2.getName()); + assertNotEquals(target1.getLocation(), target2.getLocation()); + } + + @Test + public void testPackageUploadWithLargeClasspathLogsWarning() throws Exception { + File tmpFile = makeFileWithContents("file.txt", "This is a test!"); + // all files will be present and cached so no upload needed. + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(tmpFile.length()); + + List classpathElements = Lists.newLinkedList(); + for (int i = 0; i < 1005; ++i) { + String eltName = "element" + i; + classpathElements.add(eltName + '=' + tmpFile.getAbsolutePath()); + } + + PackageUtil.stageClasspathElements(classpathElements, STAGING_PATH); + + logged.verifyWarn("Your classpath contains 1005 elements, which Google Cloud Dataflow"); + } + + @Test + public void testPackageUploadWithFileSucceeds() throws Exception { + Pipe pipe = Pipe.open(); + String contents = "This is a test!"; + File tmpFile = makeFileWithContents("file.txt", contents); + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + List targets = PackageUtil.stageClasspathElements( + ImmutableList.of(tmpFile.getAbsolutePath()), STAGING_PATH); + DataflowPackage target = Iterables.getOnlyElement(targets); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + assertThat(target.getName(), RegexMatcher.matches("file-" + HASH_PATTERN + ".txt")); + assertThat(target.getLocation(), equalTo(STAGING_PATH + '/' + target.getName())); + assertThat(new LineReader(Channels.newReader(pipe.source(), "UTF-8")).readLine(), + equalTo(contents)); + } + + @Test + public void testPackageUploadWithDirectorySucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpDirectory = tmpFolder.newFolder("folder"); + tmpFolder.newFolder("folder", "empty_directory"); + tmpFolder.newFolder("folder", "directory"); + makeFileWithContents("folder/file.txt", "This is a test!"); + makeFileWithContents("folder/directory/file.txt", "This is also a test!"); + + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + PackageUtil.stageClasspathElements( + ImmutableList.of(tmpDirectory.getAbsolutePath()), STAGING_PATH); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + ZipInputStream inputStream = new ZipInputStream(Channels.newInputStream(pipe.source())); + List zipEntryNames = new ArrayList<>(); + for (ZipEntry entry = inputStream.getNextEntry(); entry != null; + entry = inputStream.getNextEntry()) { + zipEntryNames.add(entry.getName()); + } + + assertThat(zipEntryNames, + containsInAnyOrder("directory/file.txt", "empty_directory/", "file.txt")); + } + + @Test + public void testPackageUploadWithEmptyDirectorySucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpDirectory = tmpFolder.newFolder("folder"); + + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + List targets = PackageUtil.stageClasspathElements( + ImmutableList.of(tmpDirectory.getAbsolutePath()), STAGING_PATH); + DataflowPackage target = Iterables.getOnlyElement(targets); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + assertThat(target.getName(), RegexMatcher.matches("folder-" + HASH_PATTERN + ".jar")); + assertThat(target.getLocation(), equalTo(STAGING_PATH + '/' + target.getName())); + assertNull(new ZipInputStream(Channels.newInputStream(pipe.source())).getNextEntry()); + } + + @Test(expected = RuntimeException.class) + public void testPackageUploadFailsWhenIOExceptionThrown() throws Exception { + File tmpFile = makeFileWithContents("file.txt", "This is a test!"); + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())) + .thenThrow(new IOException("Fake Exception: Upload error")); + + try { + PackageUtil.stageClasspathElements( + ImmutableList.of(tmpFile.getAbsolutePath()), + STAGING_PATH, fastNanoClockAndSleeper); + } finally { + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil, times(5)).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + } + + @Test + public void testPackageUploadFailsWithPermissionsErrorGivesDetailedMessage() throws Exception { + File tmpFile = makeFileWithContents("file.txt", "This is a test!"); + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())) + .thenThrow(new IOException("Failed to write to GCS path " + STAGING_PATH, + googleJsonResponseException( + HttpStatusCodes.STATUS_CODE_FORBIDDEN, "Permission denied", "Test message"))); + + try { + PackageUtil.stageClasspathElements( + ImmutableList.of(tmpFile.getAbsolutePath()), + STAGING_PATH, fastNanoClockAndSleeper); + fail("Expected RuntimeException"); + } catch (RuntimeException e) { + assertTrue("Expected IOException containing detailed message.", + e.getCause() instanceof IOException); + assertThat(e.getCause().getMessage(), + Matchers.allOf( + Matchers.containsString("Uploaded failed due to permissions error"), + Matchers.containsString( + "Stale credentials can be resolved by executing 'gcloud auth login'"))); + } finally { + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + } + + @Test + public void testPackageUploadEventuallySucceeds() throws Exception { + Pipe pipe = Pipe.open(); + File tmpFile = makeFileWithContents("file.txt", "This is a test!"); + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())) + .thenThrow(new IOException("Fake Exception: 410 Gone")) // First attempt fails + .thenReturn(pipe.sink()); // second attempt succeeds + + try { + PackageUtil.stageClasspathElements( + ImmutableList.of(tmpFile.getAbsolutePath()), + STAGING_PATH, + fastNanoClockAndSleeper); + } finally { + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil, times(2)).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + } + + @Test + public void testPackageUploadIsSkippedWhenFileAlreadyExists() throws Exception { + File tmpFile = makeFileWithContents("file.txt", "This is a test!"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(tmpFile.length()); + + PackageUtil.stageClasspathElements( + ImmutableList.of(tmpFile.getAbsolutePath()), STAGING_PATH); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verifyNoMoreInteractions(mockGcsUtil); + } + + @Test + public void testPackageUploadIsNotSkippedWhenSizesAreDifferent() throws Exception { + Pipe pipe = Pipe.open(); + File tmpDirectory = tmpFolder.newFolder("folder"); + tmpFolder.newFolder("folder", "empty_directory"); + tmpFolder.newFolder("folder", "directory"); + makeFileWithContents("folder/file.txt", "This is a test!"); + makeFileWithContents("folder/directory/file.txt", "This is also a test!"); + when(mockGcsUtil.fileSize(any(GcsPath.class))).thenReturn(Long.MAX_VALUE); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + PackageUtil.stageClasspathElements( + ImmutableList.of(tmpDirectory.getAbsolutePath()), STAGING_PATH); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + } + + @Test + public void testPackageUploadWithExplicitPackageName() throws Exception { + Pipe pipe = Pipe.open(); + File tmpFile = makeFileWithContents("file.txt", "This is a test!"); + final String overriddenName = "alias.txt"; + + when(mockGcsUtil.fileSize(any(GcsPath.class))) + .thenThrow(new FileNotFoundException("some/path")); + when(mockGcsUtil.create(any(GcsPath.class), anyString())).thenReturn(pipe.sink()); + + List targets = PackageUtil.stageClasspathElements( + ImmutableList.of(overriddenName + "=" + tmpFile.getAbsolutePath()), STAGING_PATH); + DataflowPackage target = Iterables.getOnlyElement(targets); + + verify(mockGcsUtil).fileSize(any(GcsPath.class)); + verify(mockGcsUtil).create(any(GcsPath.class), anyString()); + verifyNoMoreInteractions(mockGcsUtil); + + assertThat(target.getName(), equalTo(overriddenName)); + assertThat(target.getLocation(), + RegexMatcher.matches(STAGING_PATH + "/file-" + HASH_PATTERN + ".txt")); + } + + @Test + public void testPackageUploadIsSkippedWithNonExistentResource() throws Exception { + String nonExistentFile = + IOChannelUtils.resolve(tmpFolder.getRoot().getPath(), "non-existent-file"); + assertEquals(Collections.EMPTY_LIST, PackageUtil.stageClasspathElements( + ImmutableList.of(nonExistentFile), STAGING_PATH)); + } + + /** + * Builds a fake GoogleJsonResponseException for testing API error handling. + */ + private static GoogleJsonResponseException googleJsonResponseException( + final int status, final String reason, final String message) throws IOException { + final JsonFactory jsonFactory = new JacksonFactory(); + HttpTransport transport = new MockHttpTransport() { + @Override + public LowLevelHttpRequest buildRequest(String method, String url) throws IOException { + ErrorInfo errorInfo = new ErrorInfo(); + errorInfo.setReason(reason); + errorInfo.setMessage(message); + errorInfo.setFactory(jsonFactory); + GenericJson error = new GenericJson(); + error.set("code", status); + error.set("errors", Arrays.asList(errorInfo)); + error.setFactory(jsonFactory); + GenericJson errorResponse = new GenericJson(); + errorResponse.set("error", error); + errorResponse.setFactory(jsonFactory); + return new MockLowLevelHttpRequest().setResponse( + new MockLowLevelHttpResponse().setContent(errorResponse.toPrettyString()) + .setContentType(Json.MEDIA_TYPE).setStatusCode(status)); + } + }; + HttpRequest request = + transport.createRequestFactory().buildGetRequest(HttpTesting.SIMPLE_GENERIC_URL); + request.setThrowExceptionOnExecuteError(false); + HttpResponse response = request.execute(); + return GoogleJsonResponseException.from(jsonFactory, response); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RandomAccessDataTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RandomAccessDataTest.java new file mode 100644 index 000000000000..6d8830506878 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RandomAccessDataTest.java @@ -0,0 +1,205 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.coders.Coder.Context; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.util.RandomAccessData.RandomAccessDataCoder; +import com.google.common.primitives.UnsignedBytes; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.util.Arrays; + +/** + * Tests for {@link RandomAccessData}. + */ +@RunWith(JUnit4.class) +public class RandomAccessDataTest { + private static final byte[] TEST_DATA_A = new byte[]{ 0x01, 0x02, 0x03 }; + private static final byte[] TEST_DATA_B = new byte[]{ 0x06, 0x05, 0x04, 0x03 }; + private static final byte[] TEST_DATA_C = new byte[]{ 0x06, 0x05, 0x03, 0x03 }; + + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testCoder() throws Exception { + RandomAccessData streamA = new RandomAccessData(); + streamA.asOutputStream().write(TEST_DATA_A); + RandomAccessData streamB = new RandomAccessData(); + streamB.asOutputStream().write(TEST_DATA_A); + CoderProperties.coderDecodeEncodeEqual(RandomAccessDataCoder.of(), streamA); + CoderProperties.coderDeterministic(RandomAccessDataCoder.of(), streamA, streamB); + CoderProperties.coderConsistentWithEquals(RandomAccessDataCoder.of(), streamA, streamB); + CoderProperties.coderSerializable(RandomAccessDataCoder.of()); + CoderProperties.structuralValueConsistentWithEquals( + RandomAccessDataCoder.of(), streamA, streamB); + assertTrue(RandomAccessDataCoder.of().isRegisterByteSizeObserverCheap(streamA, Context.NESTED)); + assertTrue(RandomAccessDataCoder.of().isRegisterByteSizeObserverCheap(streamA, Context.OUTER)); + assertEquals(4, RandomAccessDataCoder.of().getEncodedElementByteSize(streamA, Context.NESTED)); + assertEquals(3, RandomAccessDataCoder.of().getEncodedElementByteSize(streamA, Context.OUTER)); + } + + @Test + public void testCoderWithPositiveInfinityIsError() throws Exception { + expectedException.expect(CoderException.class); + expectedException.expectMessage("Positive infinity can not be encoded"); + RandomAccessDataCoder.of().encode( + RandomAccessData.POSITIVE_INFINITY, new ByteArrayOutputStream(), Context.OUTER); + } + + @Test + public void testLexicographicalComparator() throws Exception { + RandomAccessData streamA = new RandomAccessData(); + streamA.asOutputStream().write(TEST_DATA_A); + RandomAccessData streamB = new RandomAccessData(); + streamB.asOutputStream().write(TEST_DATA_B); + RandomAccessData streamC = new RandomAccessData(); + streamC.asOutputStream().write(TEST_DATA_C); + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + streamA, streamB) < 0); + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + streamB, streamA) > 0); + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + streamB, streamB) == 0); + // Check common prefix length. + assertEquals(2, RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.commonPrefixLength( + streamB, streamC)); + // Check that we honor the start offset. + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + streamB, streamC, 3) == 0); + // Test positive infinity comparisons. + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + streamA, RandomAccessData.POSITIVE_INFINITY) < 0); + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + RandomAccessData.POSITIVE_INFINITY, RandomAccessData.POSITIVE_INFINITY) == 0); + assertTrue(RandomAccessData.UNSIGNED_LEXICOGRAPHICAL_COMPARATOR.compare( + RandomAccessData.POSITIVE_INFINITY, streamA) > 0); + } + + @Test + public void testEqualsAndHashCode() throws Exception { + // Test that equality by reference works + RandomAccessData streamA = new RandomAccessData(); + streamA.asOutputStream().write(TEST_DATA_A); + assertEquals(streamA, streamA); + assertEquals(streamA.hashCode(), streamA.hashCode()); + + // Test different objects containing the same data are the same + RandomAccessData streamACopy = new RandomAccessData(); + streamACopy.asOutputStream().write(TEST_DATA_A); + assertEquals(streamA, streamACopy); + assertEquals(streamA.hashCode(), streamACopy.hashCode()); + + // Test same length streams with different data differ + RandomAccessData streamB = new RandomAccessData(); + streamB.asOutputStream().write(new byte[]{ 0x01, 0x02, 0x04 }); + assertNotEquals(streamA, streamB); + assertNotEquals(streamA.hashCode(), streamB.hashCode()); + + // Test different length streams differ + streamB.asOutputStream().write(TEST_DATA_B); + assertNotEquals(streamA, streamB); + assertNotEquals(streamA.hashCode(), streamB.hashCode()); + } + + @Test + public void testResetTo() throws Exception { + RandomAccessData stream = new RandomAccessData(); + stream.asOutputStream().write(TEST_DATA_A); + stream.resetTo(1); + assertEquals(1, stream.size()); + stream.asOutputStream().write(TEST_DATA_A); + assertArrayEquals(new byte[]{ 0x01, 0x01, 0x02, 0x03 }, + Arrays.copyOf(stream.array(), stream.size())); + } + + @Test + public void testAsInputStream() throws Exception { + RandomAccessData stream = new RandomAccessData(); + stream.asOutputStream().write(TEST_DATA_A); + InputStream in = stream.asInputStream(1, 1); + assertEquals(0x02, in.read()); + assertEquals(-1, in.read()); + in.close(); + } + + @Test + public void testReadFrom() throws Exception { + ByteArrayInputStream bais = new ByteArrayInputStream(TEST_DATA_A); + RandomAccessData stream = new RandomAccessData(); + stream.readFrom(bais, 3, 2); + assertArrayEquals(new byte[]{ 0x00, 0x00, 0x00, 0x01, 0x02 }, + Arrays.copyOf(stream.array(), stream.size())); + bais.close(); + } + + @Test + public void testWriteTo() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + RandomAccessData stream = new RandomAccessData(); + stream.asOutputStream().write(TEST_DATA_B); + stream.writeTo(baos, 1, 2); + assertArrayEquals(new byte[]{ 0x05, 0x04 }, baos.toByteArray()); + baos.close(); + } + + @Test + public void testThatRandomAccessDataGrowsWhenResettingToPositionBeyondEnd() throws Exception { + RandomAccessData stream = new RandomAccessData(0); + assertArrayEquals(new byte[0], stream.array()); + stream.resetTo(3); // force resize + assertArrayEquals(new byte[]{ 0x00, 0x00, 0x00 }, stream.array()); + } + + @Test + public void testThatRandomAccessDataGrowsWhenReading() throws Exception { + RandomAccessData stream = new RandomAccessData(0); + assertArrayEquals(new byte[0], stream.array()); + stream.readFrom(new ByteArrayInputStream(TEST_DATA_A), 0, TEST_DATA_A.length); + assertArrayEquals(TEST_DATA_A, + Arrays.copyOf(stream.array(), TEST_DATA_A.length)); + } + + @Test + public void testIncrement() throws Exception { + assertEquals(new RandomAccessData(new byte[]{ 0x00, 0x01 }), + new RandomAccessData(new byte[]{ 0x00, 0x00 }).increment()); + assertEquals(new RandomAccessData(new byte[]{ 0x01, UnsignedBytes.MAX_VALUE }), + new RandomAccessData(new byte[]{ 0x00, UnsignedBytes.MAX_VALUE }).increment()); + + // Test for positive infinity + assertSame(RandomAccessData.POSITIVE_INFINITY, new RandomAccessData(new byte[0]).increment()); + assertSame(RandomAccessData.POSITIVE_INFINITY, + new RandomAccessData(new byte[]{ UnsignedBytes.MAX_VALUE }).increment()); + assertSame(RandomAccessData.POSITIVE_INFINITY, RandomAccessData.POSITIVE_INFINITY.increment()); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java new file mode 100644 index 000000000000..c85b1ca4a5b6 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java @@ -0,0 +1,1011 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.WindowMatchers.isSingleWindowedValue; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.cloud.dataflow.sdk.WindowMatchers; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterEach; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterFirst; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterPane; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterProcessingTime; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; +import com.google.cloud.dataflow.sdk.transforms.windowing.Repeatedly; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.common.base.Preconditions; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Iterator; +import java.util.List; + +/** + * Tests for {@link ReduceFnRunner}. These tests instantiate a full "stack" of + * {@link ReduceFnRunner} with enclosed {@link ReduceFn}, down to the installed {@link Trigger} + * (sometimes mocked). They proceed by injecting elements and advancing watermark and + * processing time, then verifying produced panes and counters. + */ +@RunWith(JUnit4.class) +public class ReduceFnRunnerTest { + @Mock private SideInputReader mockSideInputReader; + private Trigger mockTrigger; + private PCollectionView mockView; + + private IntervalWindow firstWindow; + + private static Trigger.TriggerContext anyTriggerContext() { + return Mockito..TriggerContext>any(); + } + private static Trigger.OnElementContext anyElementContext() { + return Mockito..OnElementContext>any(); + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + @SuppressWarnings("unchecked") + Trigger mockTriggerUnchecked = + mock(Trigger.class, withSettings().serializable()); + mockTrigger = mockTriggerUnchecked; + when(mockTrigger.buildTrigger()).thenReturn(mockTrigger); + + @SuppressWarnings("unchecked") + PCollectionView mockViewUnchecked = + mock(PCollectionView.class, withSettings().serializable()); + mockView = mockViewUnchecked; + firstWindow = new IntervalWindow(new Instant(0), new Instant(10)); + } + + private void injectElement(ReduceFnTester tester, int element) + throws Exception { + doNothing().when(mockTrigger).onElement(anyElementContext()); + tester.injectElements(TimestampedValue.of(element, new Instant(element))); + } + + private void triggerShouldFinish(Trigger mockTrigger) throws Exception { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Exception { + @SuppressWarnings("unchecked") + Trigger.TriggerContext context = + (Trigger.TriggerContext) invocation.getArguments()[0]; + context.trigger().setFinished(true); + return null; + } + }) + .when(mockTrigger).onFire(anyTriggerContext()); + } + + @Test + public void testOnElementBufferingDiscarding() throws Exception { + // Test basic execution of a trigger using a non-combining window set and discarding mode. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.DISCARDING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + // Pane of {1, 2} + injectElement(tester, 1); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 2); + assertThat(tester.extractOutput(), + contains(isSingleWindowedValue(containsInAnyOrder(1, 2), 1, 0, 10))); + + // Pane of just 3, and finish + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 3); + assertThat(tester.extractOutput(), + contains(isSingleWindowedValue(containsInAnyOrder(3), 3, 0, 10))); + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + + // This element shouldn't be seen, because the trigger has finished + injectElement(tester, 4); + + assertEquals(1, tester.getElementsDroppedDueToClosedWindow()); + } + + @Test + public void testOnElementBufferingAccumulating() throws Exception { + // Test basic execution of a trigger using a non-combining window set and accumulating mode. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.ACCUMULATING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + injectElement(tester, 1); + + // Fires {1, 2} + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 2); + + // Fires {1, 2, 3} because we are in accumulating mode + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 3); + + // This element shouldn't be seen, because the trigger has finished + injectElement(tester, 4); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue(containsInAnyOrder(1, 2), 1, 0, 10), + isSingleWindowedValue(containsInAnyOrder(1, 2, 3), 3, 0, 10))); + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + } + + @Test + public void testOnElementCombiningDiscarding() throws Exception { + // Test basic execution of a trigger using a non-combining window set and discarding mode. + ReduceFnTester tester = ReduceFnTester.combining( + FixedWindows.of(Duration.millis(10)), mockTrigger, AccumulationMode.DISCARDING_FIRED_PANES, + new Sum.SumIntegerFn().asKeyedFn(), VarIntCoder.of(), Duration.millis(100)); + + injectElement(tester, 2); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 3); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 4); + + // This element shouldn't be seen, because the trigger has finished + injectElement(tester, 6); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue(equalTo(5), 2, 0, 10), + isSingleWindowedValue(equalTo(4), 4, 0, 10))); + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + } + + @Test + public void testOnElementCombiningAccumulating() throws Exception { + // Test basic execution of a trigger using a non-combining window set and accumulating mode. + ReduceFnTester tester = + ReduceFnTester.combining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.ACCUMULATING_FIRED_PANES, new Sum.SumIntegerFn().asKeyedFn(), + VarIntCoder.of(), Duration.millis(100)); + + injectElement(tester, 1); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 2); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 3); + + // This element shouldn't be seen, because the trigger has finished + injectElement(tester, 4); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue(equalTo(3), 1, 0, 10), + isSingleWindowedValue(equalTo(6), 3, 0, 10))); + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + } + + @Test + public void testOnElementCombiningWithContext() throws Exception { + Integer expectedValue = 5; + WindowingStrategy windowingStrategy = WindowingStrategy + .of(FixedWindows.of(Duration.millis(10))) + .withTrigger(mockTrigger) + .withMode(AccumulationMode.DISCARDING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)); + + TestOptions options = PipelineOptionsFactory.as(TestOptions.class); + options.setValue(5); + + when(mockSideInputReader.contains(Matchers.>any())).thenReturn(true); + when(mockSideInputReader.get( + Matchers.>any(), any(BoundedWindow.class))).thenReturn(5); + + @SuppressWarnings({"rawtypes", "unchecked", "unused"}) + Object suppressWarningsVar = when(mockView.getWindowingStrategyInternal()) + .thenReturn((WindowingStrategy) windowingStrategy); + + SumAndVerifyContextFn combineFn = new SumAndVerifyContextFn(mockView, expectedValue); + // Test basic execution of a trigger using a non-combining window set and discarding mode. + ReduceFnTester tester = ReduceFnTester.combining( + windowingStrategy, combineFn.asKeyedFn(), + VarIntCoder.of(), options, mockSideInputReader); + + injectElement(tester, 2); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 3); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 4); + + // This element shouldn't be seen, because the trigger has finished + injectElement(tester, 6); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue(equalTo(5), 2, 0, 10), + isSingleWindowedValue(equalTo(4), 4, 0, 10))); + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + } + + @Test + public void testWatermarkHoldAndLateData() throws Exception { + // Test handling of late data. Specifically, ensure the watermark hold is correct. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.ACCUMULATING_FIRED_PANES, Duration.millis(10), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + // Input watermark -> null + assertEquals(null, tester.getWatermarkHold()); + assertEquals(null, tester.getOutputWatermark()); + + // All on time data, verify watermark hold. + injectElement(tester, 1); + injectElement(tester, 3); + assertEquals(new Instant(1), tester.getWatermarkHold()); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 2); + List>> output = tester.extractOutput(); + assertThat(output, contains( + isSingleWindowedValue(containsInAnyOrder(1, 2, 3), + 1, // timestamp + 0, // window start + 10))); // window end + assertThat(output.get(0).getPane(), + equalTo(PaneInfo.createPane(true, false, Timing.EARLY, 0, -1))); + + // Holding for the end-of-window transition. + assertEquals(new Instant(9), tester.getWatermarkHold()); + // Nothing dropped. + assertEquals(0, tester.getElementsDroppedDueToClosedWindow()); + + // Input watermark -> 4, output watermark should advance that far as well + tester.advanceInputWatermark(new Instant(4)); + assertEquals(new Instant(4), tester.getOutputWatermark()); + + // Some late, some on time. Verify that we only hold to the minimum of on-time. + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(false); + tester.advanceInputWatermark(new Instant(4)); + injectElement(tester, 2); + injectElement(tester, 3); + assertEquals(new Instant(9), tester.getWatermarkHold()); + injectElement(tester, 5); + assertEquals(new Instant(5), tester.getWatermarkHold()); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 4); + output = tester.extractOutput(); + assertThat(output, + contains( + isSingleWindowedValue(containsInAnyOrder( + 1, 2, 3, // earlier firing + 2, 3, 4, 5), // new elements + 4, // timestamp + 0, // window start + 10))); // window end + assertThat(output.get(0).getPane(), + equalTo(PaneInfo.createPane(false, false, Timing.EARLY, 1, -1))); + + // All late -- output at end of window timestamp. + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(false); + tester.advanceInputWatermark(new Instant(8)); + injectElement(tester, 6); + injectElement(tester, 5); + assertEquals(new Instant(9), tester.getWatermarkHold()); + injectElement(tester, 4); + + // Fire the ON_TIME pane + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + tester.advanceInputWatermark(new Instant(10)); + + // Output time is end of the window, because all the new data was late, but the pane + // is the ON_TIME pane. + output = tester.extractOutput(); + assertThat(output, + contains(isSingleWindowedValue( + containsInAnyOrder(1, 2, 3, // earlier firing + 2, 3, 4, 5, // earlier firing + 4, 5, 6), // new elements + 9, // timestamp + 0, // window start + 10))); // window end + assertThat(output.get(0).getPane(), + equalTo(PaneInfo.createPane(false, false, Timing.ON_TIME, 2, 0))); + + // This is "pending" at the time the watermark makes it way-late. + // Because we're about to expire the window, we output it. + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(false); + injectElement(tester, 8); + assertEquals(0, tester.getElementsDroppedDueToClosedWindow()); + + // Exceed the GC limit, triggering the last pane to be fired + tester.advanceInputWatermark(new Instant(50)); + output = tester.extractOutput(); + // Output time is still end of the window, because the new data (8) was behind + // the output watermark. + assertThat(output, + contains(isSingleWindowedValue( + containsInAnyOrder(1, 2, 3, // earlier firing + 2, 3, 4, 5, // earlier firing + 4, 5, 6, // earlier firing + 8), // new element prior to window becoming expired + 9, // timestamp + 0, // window start + 10))); // window end + assertThat( + output.get(0).getPane(), + equalTo(PaneInfo.createPane(false, true, Timing.LATE, 3, 1))); + assertEquals(new Instant(50), tester.getOutputWatermark()); + assertEquals(null, tester.getWatermarkHold()); + + // Late timers are ignored + tester.fireTimer(new IntervalWindow(new Instant(0), new Instant(10)), new Instant(12), + TimeDomain.EVENT_TIME); + + // And because we're past the end of window + allowed lateness, everything should be cleaned up. + assertFalse(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(); + } + + @Test + public void dontSetHoldIfTooLateForEndOfWindowTimer() throws Exception { + // Make sure holds are only set if they are accompanied by an end-of-window timer. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.ACCUMULATING_FIRED_PANES, Duration.millis(10), + ClosingBehavior.FIRE_ALWAYS); + tester.setAutoAdvanceOutputWatermark(false); + + // Case: Unobservably late + tester.advanceInputWatermark(new Instant(15)); + tester.advanceOutputWatermark(new Instant(11)); + injectElement(tester, 14); + // Hold was applied, waiting for end-of-window timer. + assertEquals(new Instant(14), tester.getWatermarkHold()); + assertEquals(new Instant(19), tester.getNextTimer(TimeDomain.EVENT_TIME)); + + // Trigger the end-of-window timer. + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + tester.advanceInputWatermark(new Instant(20)); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(false); + // Hold has been replaced with garbage collection hold. Waiting for garbage collection. + assertEquals(new Instant(29), tester.getWatermarkHold()); + assertEquals(new Instant(29), tester.getNextTimer(TimeDomain.EVENT_TIME)); + + // Case: Maybe late 1 + injectElement(tester, 13); + // No change to hold or timers. + assertEquals(new Instant(29), tester.getWatermarkHold()); + assertEquals(new Instant(29), tester.getNextTimer(TimeDomain.EVENT_TIME)); + + // Trigger the garbage collection timer. + tester.advanceInputWatermark(new Instant(30)); + + // Everything should be cleaned up. + assertFalse(tester.isMarkedFinished(new IntervalWindow(new Instant(10), new Instant(20)))); + tester.assertHasOnlyGlobalAndFinishedSetsFor(); + } + + @Test + public void testPaneInfoAllStates() throws Exception { + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.DISCARDING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + tester.advanceInputWatermark(new Instant(0)); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 1); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.EARLY)))); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 2); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.EARLY, 1, -1)))); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(false); + tester.advanceInputWatermark(new Instant(15)); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 3); + assertThat(tester.extractOutput(), contains( + // This is late, because the trigger wasn't waiting for AfterWatermark + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.EARLY, 2, -1)))); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 4); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.EARLY, 3, -1)))); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 5); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.EARLY, 4, -1)))); + } + + @Test + public void testPaneInfoAllStatesAfterWatermark() throws Exception { + ReduceFnTester, IntervalWindow> tester = ReduceFnTester.nonCombining( + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))) + .withTrigger(Repeatedly.forever(AfterFirst.of( + AfterPane.elementCountAtLeast(2), + AfterWatermark.pastEndOfWindow()))) + .withMode(AccumulationMode.DISCARDING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)) + .withClosingBehavior(ClosingBehavior.FIRE_ALWAYS)); + + tester.advanceInputWatermark(new Instant(0)); + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), TimestampedValue.of(2, new Instant(2))); + + List>> output = tester.extractOutput(); + assertThat( + output, + contains(WindowMatchers.valueWithPaneInfo( + PaneInfo.createPane(true, false, Timing.EARLY, 0, -1)))); + assertThat( + output, + contains( + WindowMatchers.isSingleWindowedValue(containsInAnyOrder(1, 2), 1, 0, 10))); + + tester.advanceInputWatermark(new Instant(50)); + + // We should get the ON_TIME pane even though it is empty, + // because we have an AfterWatermark.pastEndOfWindow() trigger. + output = tester.extractOutput(); + assertThat( + output, + contains(WindowMatchers.valueWithPaneInfo( + PaneInfo.createPane(false, false, Timing.ON_TIME, 1, 0)))); + assertThat( + output, + contains( + WindowMatchers.isSingleWindowedValue(emptyIterable(), 9, 0, 10))); + + // We should get the final pane even though it is empty. + tester.advanceInputWatermark(new Instant(150)); + output = tester.extractOutput(); + assertThat( + output, + contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.LATE, 2, 1)))); + assertThat( + output, + contains( + WindowMatchers.isSingleWindowedValue(emptyIterable(), 9, 0, 10))); + } + + @Test + public void testPaneInfoAllStatesAfterWatermarkAccumulating() throws Exception { + ReduceFnTester, IntervalWindow> tester = ReduceFnTester.nonCombining( + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))) + .withTrigger(Repeatedly.forever(AfterFirst.of( + AfterPane.elementCountAtLeast(2), + AfterWatermark.pastEndOfWindow()))) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)) + .withClosingBehavior(ClosingBehavior.FIRE_ALWAYS)); + + tester.advanceInputWatermark(new Instant(0)); + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), TimestampedValue.of(2, new Instant(2))); + + List>> output = tester.extractOutput(); + assertThat( + output, + contains(WindowMatchers.valueWithPaneInfo( + PaneInfo.createPane(true, false, Timing.EARLY, 0, -1)))); + assertThat( + output, + contains( + WindowMatchers.isSingleWindowedValue(containsInAnyOrder(1, 2), 1, 0, 10))); + + tester.advanceInputWatermark(new Instant(50)); + + // We should get the ON_TIME pane even though it is empty, + // because we have an AfterWatermark.pastEndOfWindow() trigger. + output = tester.extractOutput(); + assertThat( + output, + contains(WindowMatchers.valueWithPaneInfo( + PaneInfo.createPane(false, false, Timing.ON_TIME, 1, 0)))); + assertThat( + output, + contains( + WindowMatchers.isSingleWindowedValue(containsInAnyOrder(1, 2), 9, 0, 10))); + + // We should get the final pane even though it is empty. + tester.advanceInputWatermark(new Instant(150)); + output = tester.extractOutput(); + assertThat( + output, + contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.LATE, 2, 1)))); + assertThat( + output, + contains( + WindowMatchers.isSingleWindowedValue(containsInAnyOrder(1, 2), 9, 0, 10))); + } + + @Test + public void testPaneInfoFinalAndOnTime() throws Exception { + ReduceFnTester, IntervalWindow> tester = ReduceFnTester.nonCombining( + WindowingStrategy.of(FixedWindows.of(Duration.millis(10))) + .withTrigger( + Repeatedly.forever(AfterPane.elementCountAtLeast(2)) + .orFinally(AfterWatermark.pastEndOfWindow())) + .withMode(AccumulationMode.DISCARDING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)) + .withClosingBehavior(ClosingBehavior.FIRE_ALWAYS)); + + tester.advanceInputWatermark(new Instant(0)); + + // Should trigger due to element count + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), TimestampedValue.of(2, new Instant(2))); + + assertThat( + tester.extractOutput(), + contains(WindowMatchers.valueWithPaneInfo( + PaneInfo.createPane(true, false, Timing.EARLY, 0, -1)))); + + tester.advanceInputWatermark(new Instant(150)); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.ON_TIME, 1, 0)))); + } + + @Test + public void testPaneInfoSkipToFinish() throws Exception { + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.DISCARDING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + tester.advanceInputWatermark(new Instant(0)); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 1); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, true, Timing.EARLY)))); + } + + @Test + public void testPaneInfoSkipToNonSpeculativeAndFinish() throws Exception { + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.DISCARDING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + tester.advanceInputWatermark(new Instant(15)); + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 1); + assertThat(tester.extractOutput(), contains( + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, true, Timing.LATE)))); + } + + @Test + public void testMergeBeforeFinalizing() throws Exception { + // Verify that we merge windows before producing output so users don't see undesired + // unmerged windows. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(Sessions.withGapDuration(Duration.millis(10)), mockTrigger, + AccumulationMode.DISCARDING_FIRED_PANES, Duration.millis(0), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + // All on time data, verify watermark hold. + // These two windows should pre-merge immediately to [1, 20) + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), // in [1, 11) + TimestampedValue.of(10, new Instant(10))); // in [10, 20) + + // And this should fire the end-of-window timer + tester.advanceInputWatermark(new Instant(100)); + + List>> output = tester.extractOutput(); + assertThat(output.size(), equalTo(1)); + assertThat(output.get(0), + isSingleWindowedValue(containsInAnyOrder(1, 10), + 1, // timestamp + 1, // window start + 20)); // window end + assertThat( + output.get(0).getPane(), + equalTo(PaneInfo.createPane(true, true, Timing.ON_TIME, 0, 0))); + } + + /** + * Tests that when data is assigned to multiple windows but some of those windows have + * had their triggers finish, then the data is dropped and counted accurately. + */ + @Test + public void testDropDataMultipleWindowsFinishedTrigger() throws Exception { + ReduceFnTester tester = ReduceFnTester.combining( + WindowingStrategy.of( + SlidingWindows.of(Duration.millis(100)).every(Duration.millis(30))) + .withTrigger(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.millis(1000)), + new Sum.SumIntegerFn().asKeyedFn(), VarIntCoder.of()); + + tester.injectElements( + // assigned to [-60, 40), [-30, 70), [0, 100) + TimestampedValue.of(10, new Instant(23)), + // assigned to [-30, 70), [0, 100), [30, 130) + TimestampedValue.of(12, new Instant(40))); + + assertEquals(0, tester.getElementsDroppedDueToClosedWindow()); + + tester.advanceInputWatermark(new Instant(70)); + tester.injectElements( + // assigned to [-30, 70), [0, 100), [30, 130) + // but [-30, 70) is closed by the trigger + TimestampedValue.of(14, new Instant(60))); + + assertEquals(1, tester.getElementsDroppedDueToClosedWindow()); + + tester.advanceInputWatermark(new Instant(130)); + // assigned to [-30, 70), [0, 100), [30, 130) + // but they are all closed + tester.injectElements(TimestampedValue.of(16, new Instant(40))); + + assertEquals(4, tester.getElementsDroppedDueToClosedWindow()); + } + + @Test + public void testIdempotentEmptyPanesDiscarding() throws Exception { + // Test uninteresting (empty) panes don't increment the index or otherwise + // modify PaneInfo. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.DISCARDING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + // Inject a couple of on-time elements and fire at the window end. + injectElement(tester, 1); + injectElement(tester, 2); + tester.advanceInputWatermark(new Instant(12)); + + // Fire the on-time pane + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + tester.fireTimer(firstWindow, new Instant(9), TimeDomain.EVENT_TIME); + + // Fire another timer (with no data, so it's an uninteresting pane that should not be output). + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + tester.fireTimer(firstWindow, new Instant(9), TimeDomain.EVENT_TIME); + + // Finish it off with another datum. + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 3); + + // The intermediate trigger firing shouldn't result in any output. + List>> output = tester.extractOutput(); + assertThat(output.size(), equalTo(2)); + + // The on-time pane is as expected. + assertThat(output.get(0), isSingleWindowedValue(containsInAnyOrder(1, 2), 1, 0, 10)); + + // The late pane has the correct indices. + assertThat(output.get(1).getValue(), contains(3)); + assertThat( + output.get(1).getPane(), equalTo(PaneInfo.createPane(false, true, Timing.LATE, 1, 1))); + + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + + assertEquals(0, tester.getElementsDroppedDueToClosedWindow()); + } + + @Test + public void testIdempotentEmptyPanesAccumulating() throws Exception { + // Test uninteresting (empty) panes don't increment the index or otherwise + // modify PaneInfo. + ReduceFnTester, IntervalWindow> tester = + ReduceFnTester.nonCombining(FixedWindows.of(Duration.millis(10)), mockTrigger, + AccumulationMode.ACCUMULATING_FIRED_PANES, Duration.millis(100), + ClosingBehavior.FIRE_IF_NON_EMPTY); + + // Inject a couple of on-time elements and fire at the window end. + injectElement(tester, 1); + injectElement(tester, 2); + tester.advanceInputWatermark(new Instant(12)); + + // Trigger the on-time pane + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + tester.fireTimer(firstWindow, new Instant(9), TimeDomain.EVENT_TIME); + List>> output = tester.extractOutput(); + assertThat(output.size(), equalTo(1)); + assertThat(output.get(0), isSingleWindowedValue(containsInAnyOrder(1, 2), 1, 0, 10)); + assertThat(output.get(0).getPane(), + equalTo(PaneInfo.createPane(true, false, Timing.ON_TIME, 0, 0))); + + // Fire another timer with no data; the empty pane should not be output even though the + // trigger is ready to fire + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + tester.fireTimer(firstWindow, new Instant(9), TimeDomain.EVENT_TIME); + assertThat(tester.extractOutput().size(), equalTo(0)); + + // Finish it off with another datum, which is late + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 3); + output = tester.extractOutput(); + assertThat(output.size(), equalTo(1)); + + // The late pane has the correct indices. + assertThat(output.get(0).getValue(), containsInAnyOrder(1, 2, 3)); + assertThat(output.get(0).getPane(), + equalTo(PaneInfo.createPane(false, true, Timing.LATE, 1, 1))); + + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + + assertEquals(0, tester.getElementsDroppedDueToClosedWindow()); + } + + /** + * Test that we receive an empty on-time pane when an or-finally waiting for the watermark fires. + * Specifically, verify the proper triggerings and pane-info of a typical speculative/on-time/late + * when the on-time pane is empty. + */ + @Test + public void testEmptyOnTimeFromOrFinally() throws Exception { + ReduceFnTester tester = + ReduceFnTester.combining(FixedWindows.of(Duration.millis(10)), + AfterEach.inOrder( + Repeatedly + .forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf( + new Duration(5))) + .orFinally(AfterWatermark.pastEndOfWindow()), + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf( + new Duration(25)))), + AccumulationMode.ACCUMULATING_FIRED_PANES, new Sum.SumIntegerFn().asKeyedFn(), + VarIntCoder.of(), Duration.millis(100)); + + tester.advanceInputWatermark(new Instant(0)); + tester.advanceProcessingTime(new Instant(0)); + + // Processing time timer for 5 + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(1, new Instant(3)), + TimestampedValue.of(1, new Instant(7)), + TimestampedValue.of(1, new Instant(5))); + + // Should fire early pane + tester.advanceProcessingTime(new Instant(6)); + + // Should fire empty on time pane + tester.advanceInputWatermark(new Instant(11)); + List> output = tester.extractOutput(); + assertEquals(2, output.size()); + + assertThat(output.get(0), WindowMatchers.isSingleWindowedValue(4, 1, 0, 10)); + assertThat(output.get(1), WindowMatchers.isSingleWindowedValue(4, 9, 0, 10)); + + assertThat( + output.get(0), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.EARLY, 0, -1))); + assertThat( + output.get(1), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.ON_TIME, 1, 0))); + } + + /** + * Tests for processing time firings after the watermark passes the end of the window. + * Specifically, verify the proper triggerings and pane-info of a typical speculative/on-time/late + * when the on-time pane is non-empty. + */ + @Test + public void testProcessingTime() throws Exception { + ReduceFnTester tester = + ReduceFnTester.combining(FixedWindows.of(Duration.millis(10)), + AfterEach.inOrder( + Repeatedly + .forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf( + new Duration(5))) + .orFinally(AfterWatermark.pastEndOfWindow()), + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf( + new Duration(25)))), + AccumulationMode.ACCUMULATING_FIRED_PANES, new Sum.SumIntegerFn().asKeyedFn(), + VarIntCoder.of(), Duration.millis(100)); + + tester.advanceInputWatermark(new Instant(0)); + tester.advanceProcessingTime(new Instant(0)); + + tester.injectElements(TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(1, new Instant(3)), TimestampedValue.of(1, new Instant(7)), + TimestampedValue.of(1, new Instant(5))); + // 4 elements all at processing time 0 + + tester.advanceProcessingTime(new Instant(6)); // fire [1,3,7,5] since 6 > 0 + 5 + tester.injectElements( + TimestampedValue.of(1, new Instant(8)), + TimestampedValue.of(1, new Instant(4))); + // 6 elements + + tester.advanceInputWatermark(new Instant(11)); // fire [1,3,7,5,8,4] since 11 > 9 + tester.injectElements( + TimestampedValue.of(1, new Instant(8)), + TimestampedValue.of(1, new Instant(4)), + TimestampedValue.of(1, new Instant(5))); + // 9 elements + + tester.advanceInputWatermark(new Instant(12)); + tester.injectElements( + TimestampedValue.of(1, new Instant(3))); + // 10 elements + + tester.advanceProcessingTime(new Instant(15)); + tester.injectElements( + TimestampedValue.of(1, new Instant(5))); + // 11 elements + tester.advanceProcessingTime(new Instant(32)); // fire since 32 > 6 + 25 + + tester.injectElements( + TimestampedValue.of(1, new Instant(3))); + // 12 elements + // fire [1,3,7,5,8,4,8,4,5,3,5,3] since 125 > 6 + 25 + tester.advanceInputWatermark(new Instant(125)); + + List> output = tester.extractOutput(); + assertEquals(4, output.size()); + + assertThat(output.get(0), WindowMatchers.isSingleWindowedValue(4, 1, 0, 10)); + assertThat(output.get(1), WindowMatchers.isSingleWindowedValue(6, 4, 0, 10)); + assertThat(output.get(2), WindowMatchers.isSingleWindowedValue(11, 9, 0, 10)); + assertThat(output.get(3), WindowMatchers.isSingleWindowedValue(12, 9, 0, 10)); + + assertThat( + output.get(0), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.EARLY, 0, -1))); + assertThat( + output.get(1), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.ON_TIME, 1, 0))); + assertThat( + output.get(2), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.LATE, 2, 1))); + assertThat( + output.get(3), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.LATE, 3, 2))); + } + + private static class SumAndVerifyContextFn extends CombineFnWithContext { + + private final PCollectionView view; + private final int expectedValue; + + private SumAndVerifyContextFn(PCollectionView view, int expectedValue) { + this.view = view; + this.expectedValue = expectedValue; + } + @Override + public int[] createAccumulator(Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + return wrap(0); + } + + @Override + public int[] addInput(int[] accumulator, Integer input, Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + accumulator[0] += input.intValue(); + return accumulator; + } + + @Override + public int[] mergeAccumulators(Iterable accumulators, Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(c); + } else { + int[] running = iter.next(); + while (iter.hasNext()) { + running[0] += iter.next()[0]; + } + return running; + } + } + + @Override + public Integer extractOutput(int[] accumulator, Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + return accumulator[0]; + } + + private int[] wrap(int value) { + return new int[] { value }; + } + } + + /** + * A {@link PipelineOptions} to test combining with context. + */ + public interface TestOptions extends PipelineOptions { + Integer getValue(); + void setValue(Integer value); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java new file mode 100644 index 000000000000..d4620a7827c5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java @@ -0,0 +1,776 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.TriggerBuilder; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.util.state.InMemoryStateInternals; +import com.google.cloud.dataflow.sdk.util.state.State; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.WatermarkHoldState; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Function; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.PriorityQueue; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Test utility that runs a {@link ReduceFn}, {@link WindowFn}, {@link Trigger} using in-memory stub + * implementations to provide the {@link TimerInternals} and {@link WindowingInternals} needed to + * run {@code Trigger}s and {@code ReduceFn}s. + * + * @param The element types. + * @param The final type for elements in the window (for instance, + * {@code Iterable}) + * @param The type of windows being used. + */ +public class ReduceFnTester { + private static final String KEY = "TEST_KEY"; + + private final TestInMemoryStateInternals stateInternals = + new TestInMemoryStateInternals<>(KEY); + private final TestTimerInternals timerInternals = new TestTimerInternals(); + + private final WindowFn windowFn; + private final TestWindowingInternals windowingInternals; + private final Coder outputCoder; + private final WindowingStrategy objectStrategy; + private final ReduceFn reduceFn; + private final PipelineOptions options; + + /** + * If true, the output watermark is automatically advanced to the latest possible + * point when the input watermark is advanced. This is the default for most tests. + * If false, the output watermark must be explicitly advanced by the test, which can + * be used to exercise some of the more subtle behavior of WatermarkHold. + */ + private boolean autoAdvanceOutputWatermark; + + private ExecutableTrigger executableTrigger; + + private final InMemoryLongSumAggregator droppedDueToClosedWindow = + new InMemoryLongSumAggregator(GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER); + + public static ReduceFnTester, W> + nonCombining(WindowingStrategy windowingStrategy) throws Exception { + return new ReduceFnTester, W>( + windowingStrategy, + SystemReduceFn.buffering(VarIntCoder.of()), + IterableCoder.of(VarIntCoder.of()), + PipelineOptionsFactory.create(), + NullSideInputReader.empty()); + } + + public static ReduceFnTester, W> + nonCombining(WindowFn windowFn, TriggerBuilder trigger, AccumulationMode mode, + Duration allowedDataLateness, ClosingBehavior closingBehavior) throws Exception { + WindowingStrategy strategy = + WindowingStrategy.of(windowFn) + .withTrigger(trigger.buildTrigger()) + .withMode(mode) + .withAllowedLateness(allowedDataLateness) + .withClosingBehavior(closingBehavior); + return nonCombining(strategy); + } + + public static ReduceFnTester + combining(WindowingStrategy strategy, + KeyedCombineFn combineFn, + Coder outputCoder) throws Exception { + + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + AppliedCombineFn fn = + AppliedCombineFn.withInputCoder( + combineFn, registry, KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); + + return new ReduceFnTester( + strategy, + SystemReduceFn.combining(StringUtf8Coder.of(), fn), + outputCoder, + PipelineOptionsFactory.create(), + NullSideInputReader.empty()); + } + + public static ReduceFnTester + combining(WindowingStrategy strategy, + KeyedCombineFnWithContext combineFn, + Coder outputCoder, + PipelineOptions options, + SideInputReader sideInputReader) throws Exception { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + AppliedCombineFn fn = + AppliedCombineFn.withInputCoder( + combineFn, registry, KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); + + return new ReduceFnTester( + strategy, + SystemReduceFn.combining(StringUtf8Coder.of(), fn), + outputCoder, + options, + sideInputReader); + } + public static ReduceFnTester + combining(WindowFn windowFn, Trigger trigger, AccumulationMode mode, + KeyedCombineFn combineFn, Coder outputCoder, + Duration allowedDataLateness) throws Exception { + + WindowingStrategy strategy = + WindowingStrategy.of(windowFn).withTrigger(trigger).withMode(mode).withAllowedLateness( + allowedDataLateness); + + return combining(strategy, combineFn, outputCoder); + } + + private ReduceFnTester(WindowingStrategy wildcardStrategy, + ReduceFn reduceFn, Coder outputCoder, + PipelineOptions options, SideInputReader sideInputReader) throws Exception { + @SuppressWarnings("unchecked") + WindowingStrategy objectStrategy = (WindowingStrategy) wildcardStrategy; + + this.objectStrategy = objectStrategy; + this.reduceFn = reduceFn; + this.windowFn = objectStrategy.getWindowFn(); + this.windowingInternals = new TestWindowingInternals(sideInputReader); + this.outputCoder = outputCoder; + this.autoAdvanceOutputWatermark = true; + this.executableTrigger = wildcardStrategy.getTrigger(); + this.options = options; + } + + public void setAutoAdvanceOutputWatermark(boolean autoAdvanceOutputWatermark) { + this.autoAdvanceOutputWatermark = autoAdvanceOutputWatermark; + } + + @Nullable + public Instant getNextTimer(TimeDomain domain) { + return timerInternals.getNextTimer(domain); + } + + ReduceFnRunner createRunner() { + return new ReduceFnRunner<>( + KEY, + objectStrategy, + stateInternals, + timerInternals, + windowingInternals, + droppedDueToClosedWindow, + reduceFn, + options); + } + + public ExecutableTrigger getTrigger() { + return executableTrigger; + } + + public boolean isMarkedFinished(W window) { + return createRunner().isFinished(window); + } + + @SafeVarargs + public final void assertHasOnlyGlobalAndFinishedSetsFor(W... expectedWindows) { + assertHasOnlyGlobalAndAllowedTags( + ImmutableSet.copyOf(expectedWindows), + ImmutableSet.>of(TriggerRunner.FINISHED_BITS_TAG)); + } + + @SafeVarargs + public final void assertHasOnlyGlobalAndFinishedSetsAndPaneInfoFor(W... expectedWindows) { + assertHasOnlyGlobalAndAllowedTags( + ImmutableSet.copyOf(expectedWindows), + ImmutableSet.>of( + TriggerRunner.FINISHED_BITS_TAG, PaneInfoTracker.PANE_INFO_TAG, + WatermarkHold.watermarkHoldTagForOutputTimeFn(objectStrategy.getOutputTimeFn()), + WatermarkHold.EXTRA_HOLD_TAG)); + } + + public final void assertHasOnlyGlobalState() { + assertHasOnlyGlobalAndAllowedTags( + Collections.emptySet(), Collections.>emptySet()); + } + + @SafeVarargs + public final void assertHasOnlyGlobalAndPaneInfoFor(W... expectedWindows) { + assertHasOnlyGlobalAndAllowedTags( + ImmutableSet.copyOf(expectedWindows), + ImmutableSet.>of( + PaneInfoTracker.PANE_INFO_TAG, + WatermarkHold.watermarkHoldTagForOutputTimeFn(objectStrategy.getOutputTimeFn()), + WatermarkHold.EXTRA_HOLD_TAG)); + } + + /** + * Verifies that the the set of windows that have any state stored is exactly + * {@code expectedWindows} and that each of these windows has only tags from {@code allowedTags}. + */ + private void assertHasOnlyGlobalAndAllowedTags( + Set expectedWindows, Set> allowedTags) { + Set expectedWindowsSet = new HashSet<>(); + for (W expectedWindow : expectedWindows) { + expectedWindowsSet.add(windowNamespace(expectedWindow)); + } + Map>> actualWindows = new HashMap<>(); + + for (StateNamespace namespace : stateInternals.getNamespacesInUse()) { + if (namespace instanceof StateNamespaces.GlobalNamespace) { + continue; + } else if (namespace instanceof StateNamespaces.WindowNamespace) { + Set> tagsInUse = stateInternals.getTagsInUse(namespace); + if (tagsInUse.isEmpty()) { + continue; + } + actualWindows.put(namespace, tagsInUse); + Set> unexpected = Sets.difference(tagsInUse, allowedTags); + if (unexpected.isEmpty()) { + continue; + } else { + fail(namespace + " has unexpected states: " + tagsInUse); + } + } else if (namespace instanceof StateNamespaces.WindowAndTriggerNamespace) { + Set> tagsInUse = stateInternals.getTagsInUse(namespace); + assertTrue(namespace + " contains " + tagsInUse, tagsInUse.isEmpty()); + } else { + fail("Unrecognized namespace " + namespace); + } + } + + assertEquals("Still in use: " + actualWindows.toString(), expectedWindowsSet, + actualWindows.keySet()); + } + + private StateNamespace windowNamespace(W window) { + return StateNamespaces.window(windowFn.windowCoder(), window); + } + + public Instant getWatermarkHold() { + return stateInternals.earliestWatermarkHold(); + } + + public Instant getOutputWatermark() { + return timerInternals.currentOutputWatermarkTime(); + } + + public long getElementsDroppedDueToClosedWindow() { + return droppedDueToClosedWindow.getSum(); + } + + /** + * How many panes do we have in the output? + */ + public int getOutputSize() { + return windowingInternals.outputs.size(); + } + + /** + * Retrieve the values that have been output to this time, and clear out the output accumulator. + */ + public List> extractOutput() { + ImmutableList> result = + FluentIterable.from(windowingInternals.outputs) + .transform(new Function>, WindowedValue>() { + @Override + public WindowedValue apply(WindowedValue> input) { + return input.withValue(input.getValue().getValue()); + } + }) + .toList(); + windowingInternals.outputs.clear(); + return result; + } + + /** + * Advance the input watermark to the specified time, firing any timers that should + * fire. Then advance the output watermark as far as possible. + */ + public void advanceInputWatermark(Instant newInputWatermark) throws Exception { + ReduceFnRunner runner = createRunner(); + timerInternals.advanceInputWatermark(runner, newInputWatermark); + runner.persist(); + } + + /** + * If {@link #autoAdvanceOutputWatermark} is {@literal false}, advance the output watermark + * to the given value. Otherwise throw. + */ + public void advanceOutputWatermark(Instant newOutputWatermark) throws Exception { + timerInternals.advanceOutputWatermark(newOutputWatermark); + } + + /** Advance the processing time to the specified time, firing any timers that should fire. */ + public void advanceProcessingTime(Instant newProcessingTime) throws Exception { + ReduceFnRunner runner = createRunner(); + timerInternals.advanceProcessingTime(runner, newProcessingTime); + runner.persist(); + } + + /** + * Advance the synchronized processing time to the specified time, + * firing any timers that should fire. + */ + public void advanceSynchronizedProcessingTime(Instant newProcessingTime) throws Exception { + ReduceFnRunner runner = createRunner(); + timerInternals.advanceSynchronizedProcessingTime(runner, newProcessingTime); + runner.persist(); + } + + /** + * Inject all the timestamped values (after passing through the window function) as if they + * arrived in a single chunk of a bundle (or work-unit). + */ + @SafeVarargs + public final void injectElements(TimestampedValue... values) throws Exception { + for (TimestampedValue value : values) { + WindowTracing.trace("TriggerTester.injectElements: {}", value); + } + ReduceFnRunner runner = createRunner(); + runner.processElements(Iterables.transform( + Arrays.asList(values), new Function, WindowedValue>() { + @Override + public WindowedValue apply(TimestampedValue input) { + try { + InputT value = input.getValue(); + Instant timestamp = input.getTimestamp(); + Collection windows = windowFn.assignWindows(new TestAssignContext( + windowFn, value, timestamp, Arrays.asList(GlobalWindow.INSTANCE))); + return WindowedValue.of(value, timestamp, windows, PaneInfo.NO_FIRING); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + })); + + // Persist after each bundle. + runner.persist(); + } + + public void fireTimer(W window, Instant timestamp, TimeDomain domain) throws Exception { + ReduceFnRunner runner = createRunner(); + runner.onTimer( + TimerData.of(StateNamespaces.window(windowFn.windowCoder(), window), timestamp, domain)); + runner.persist(); + } + + /** + * Simulate state. + */ + private static class TestInMemoryStateInternals extends InMemoryStateInternals { + + public TestInMemoryStateInternals(K key) { + super(key); + } + + public Set> getTagsInUse(StateNamespace namespace) { + Set> inUse = new HashSet<>(); + for (Entry, State> entry : + inMemoryState.getTagsInUse(namespace).entrySet()) { + if (!isEmptyForTesting(entry.getValue())) { + inUse.add(entry.getKey()); + } + } + return inUse; + } + + public Set getNamespacesInUse() { + return inMemoryState.getNamespacesInUse(); + } + + /** Return the earliest output watermark hold in state, or null if none. */ + public Instant earliestWatermarkHold() { + Instant minimum = null; + for (State storage : inMemoryState.values()) { + if (storage instanceof WatermarkHoldState) { + Instant hold = ((WatermarkHoldState) storage).read(); + if (minimum == null || (hold != null && hold.isBefore(minimum))) { + minimum = hold; + } + } + } + return minimum; + } + } + + /** + * Convey the simulated state and implement {@link #outputWindowedValue} to capture all output + * elements. + */ + private class TestWindowingInternals implements WindowingInternals> { + private List>> outputs = new ArrayList<>(); + private SideInputReader sideInputReader; + + private TestWindowingInternals(SideInputReader sideInputReader) { + this.sideInputReader = sideInputReader; + } + + @Override + public void outputWindowedValue(KV output, Instant timestamp, + Collection windows, PaneInfo pane) { + // Copy the output value (using coders) before capturing it. + KV copy = SerializableUtils.>ensureSerializableByCoder( + KvCoder.of(StringUtf8Coder.of(), outputCoder), output, "outputForWindow"); + WindowedValue> value = WindowedValue.of(copy, timestamp, windows, pane); + outputs.add(value); + } + + @Override + public TimerInternals timerInternals() { + throw new UnsupportedOperationException( + "Testing triggers should not use timers from WindowingInternals."); + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException( + "Testing triggers should not use windows from WindowingInternals."); + } + + @Override + public PaneInfo pane() { + throw new UnsupportedOperationException( + "Testing triggers should not use pane from WindowingInternals."); + } + + @Override + public void writePCollectionViewData( + TupleTag tag, Iterable> data, Coder elemCoder) throws IOException { + throw new UnsupportedOperationException( + "Testing triggers should not use writePCollectionViewData from WindowingInternals."); + } + + @Override + public StateInternals stateInternals() { + // Safe for testing only + @SuppressWarnings({"unchecked", "rawtypes"}) + TestInMemoryStateInternals untypedStateInternals = + (TestInMemoryStateInternals) stateInternals; + return untypedStateInternals; + } + + @Override + public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { + if (!sideInputReader.contains(view)) { + throw new IllegalArgumentException("calling sideInput() with unknown view"); + } + BoundedWindow sideInputWindow = + view.getWindowingStrategyInternal().getWindowFn().getSideInputWindow(mainInputWindow); + return sideInputReader.get(view, sideInputWindow); + } + } + + private static class TestAssignContext + extends WindowFn.AssignContext { + private Object element; + private Instant timestamp; + private Collection windows; + + public TestAssignContext(WindowFn windowFn, Object element, Instant timestamp, + Collection windows) { + windowFn.super(); + this.element = element; + this.timestamp = timestamp; + this.windows = windows; + } + + @Override + public Object element() { + return element; + } + + @Override + public Instant timestamp() { + return timestamp; + } + + @Override + public Collection windows() { + return windows; + } + } + + private static class InMemoryLongSumAggregator implements Aggregator { + private final String name; + private long sum = 0; + + public InMemoryLongSumAggregator(String name) { + this.name = name; + } + + @Override + public void addValue(Long value) { + sum += value; + } + + @Override + public String getName() { + return name; + } + + @Override + public CombineFn getCombineFn() { + return new Sum.SumLongFn(); + } + + public long getSum() { + return sum; + } + } + + /** + * Simulate the firing of timers and progression of input and output watermarks for a + * single computation and key in a Windmill-like streaming environment. Similar to + * {@link BatchTimerInternals}, but also tracks the output watermark. + */ + private class TestTimerInternals implements TimerInternals { + /** At most one timer per timestamp is kept. */ + private Set existingTimers = new HashSet<>(); + + /** Pending input watermark timers, in timestamp order. */ + private PriorityQueue watermarkTimers = new PriorityQueue<>(11); + + /** Pending processing time timers, in timestamp order. */ + private PriorityQueue processingTimers = new PriorityQueue<>(11); + + /** Current input watermark. */ + @Nullable + private Instant inputWatermarkTime = null; + + /** Current output watermark. */ + @Nullable + private Instant outputWatermarkTime = null; + + /** Current processing time. */ + private Instant processingTime = BoundedWindow.TIMESTAMP_MIN_VALUE; + + /** Current synchronized processing time. */ + @Nullable + private Instant synchronizedProcessingTime = null; + + @Nullable + public Instant getNextTimer(TimeDomain domain) { + TimerData data = null; + switch (domain) { + case EVENT_TIME: + data = watermarkTimers.peek(); + break; + case PROCESSING_TIME: + case SYNCHRONIZED_PROCESSING_TIME: + data = processingTimers.peek(); + break; + } + Preconditions.checkNotNull(data); // cases exhaustive + return data == null ? null : data.getTimestamp(); + } + + private PriorityQueue queue(TimeDomain domain) { + switch (domain) { + case EVENT_TIME: + return watermarkTimers; + case PROCESSING_TIME: + case SYNCHRONIZED_PROCESSING_TIME: + return processingTimers; + } + throw new RuntimeException(); // cases exhaustive + } + + @Override + public void setTimer(TimerData timer) { + WindowTracing.trace("TestTimerInternals.setTimer: {}", timer); + if (existingTimers.add(timer)) { + queue(timer.getDomain()).add(timer); + } + } + + @Override + public void deleteTimer(TimerData timer) { + WindowTracing.trace("TestTimerInternals.deleteTimer: {}", timer); + existingTimers.remove(timer); + queue(timer.getDomain()).remove(timer); + } + + @Override + public Instant currentProcessingTime() { + return processingTime; + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return synchronizedProcessingTime; + } + + @Override + @Nullable + public Instant currentInputWatermarkTime() { + return inputWatermarkTime; + } + + @Override + @Nullable + public Instant currentOutputWatermarkTime() { + return outputWatermarkTime; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("watermarkTimers", watermarkTimers) + .add("processingTimers", processingTimers) + .add("inputWatermarkTime", inputWatermarkTime) + .add("outputWatermarkTime", outputWatermarkTime) + .add("processingTime", processingTime) + .toString(); + } + + public void advanceInputWatermark( + ReduceFnRunner runner, Instant newInputWatermark) throws Exception { + Preconditions.checkNotNull(newInputWatermark); + Preconditions.checkState( + inputWatermarkTime == null || !newInputWatermark.isBefore(inputWatermarkTime), + "Cannot move input watermark time backwards from %s to %s", inputWatermarkTime, + newInputWatermark); + WindowTracing.trace("TestTimerInternals.advanceInputWatermark: from {} to {}", + inputWatermarkTime, newInputWatermark); + inputWatermarkTime = newInputWatermark; + advanceAndFire(runner, newInputWatermark, TimeDomain.EVENT_TIME); + + Instant hold = stateInternals.earliestWatermarkHold(); + if (hold == null) { + WindowTracing.trace("TestTimerInternals.advanceInputWatermark: no holds, " + + "so output watermark = input watermark"); + hold = inputWatermarkTime; + } + if (autoAdvanceOutputWatermark) { + advanceOutputWatermark(hold); + } + } + + public void advanceOutputWatermark(Instant newOutputWatermark) { + Preconditions.checkNotNull(newOutputWatermark); + Preconditions.checkNotNull(inputWatermarkTime); + if (newOutputWatermark.isAfter(inputWatermarkTime)) { + WindowTracing.trace( + "TestTimerInternals.advanceOutputWatermark: clipping output watermark from {} to {}", + newOutputWatermark, inputWatermarkTime); + newOutputWatermark = inputWatermarkTime; + } + Preconditions.checkState( + outputWatermarkTime == null || !newOutputWatermark.isBefore(outputWatermarkTime), + "Cannot move output watermark time backwards from %s to %s", outputWatermarkTime, + newOutputWatermark); + WindowTracing.trace("TestTimerInternals.advanceOutputWatermark: from {} to {}", + outputWatermarkTime, newOutputWatermark); + outputWatermarkTime = newOutputWatermark; + } + + public void advanceProcessingTime( + ReduceFnRunner runner, Instant newProcessingTime) throws Exception { + Preconditions.checkState(!newProcessingTime.isBefore(processingTime), + "Cannot move processing time backwards from %s to %s", processingTime, newProcessingTime); + WindowTracing.trace("TestTimerInternals.advanceProcessingTime: from {} to {}", processingTime, + newProcessingTime); + processingTime = newProcessingTime; + advanceAndFire(runner, newProcessingTime, TimeDomain.PROCESSING_TIME); + } + + public void advanceSynchronizedProcessingTime( + ReduceFnRunner runner, Instant newSynchronizedProcessingTime) throws Exception { + Preconditions.checkState(!newSynchronizedProcessingTime.isBefore(synchronizedProcessingTime), + "Cannot move processing time backwards from %s to %s", processingTime, + newSynchronizedProcessingTime); + WindowTracing.trace("TestTimerInternals.advanceProcessingTime: from {} to {}", + synchronizedProcessingTime, newSynchronizedProcessingTime); + synchronizedProcessingTime = newSynchronizedProcessingTime; + advanceAndFire( + runner, newSynchronizedProcessingTime, TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + } + + private void advanceAndFire( + ReduceFnRunner runner, Instant currentTime, TimeDomain domain) + throws Exception { + PriorityQueue queue = queue(domain); + boolean shouldFire = false; + + do { + TimerData timer = queue.peek(); + // Timers fire when the current time progresses past the timer time. + shouldFire = timer != null && currentTime.isAfter(timer.getTimestamp()); + if (shouldFire) { + WindowTracing.trace( + "TestTimerInternals.advanceAndFire: firing {} at {}", timer, currentTime); + // Remove before firing, so that if the trigger adds another identical + // timer we don't remove it. + queue.remove(); + + runner.onTimer(timer); + } + } while (shouldFire); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReshuffleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReshuffleTest.java new file mode 100644 index 000000000000..c2a4fc40fc52 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReshuffleTest.java @@ -0,0 +1,208 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; + +/** + * Tests for {@link Reshuffle}. + */ +@RunWith(JUnit4.class) +public class ReshuffleTest { + + private static final List> ARBITRARY_KVS = ImmutableList.of( + KV.of("k1", 3), + KV.of("k5", Integer.MAX_VALUE), + KV.of("k5", Integer.MIN_VALUE), + KV.of("k2", 66), + KV.of("k1", 4), + KV.of("k2", -33), + KV.of("k3", 0)); + + // TODO: test with more than one value per key + private static final List> GBK_TESTABLE_KVS = ImmutableList.of( + KV.of("k1", 3), + KV.of("k2", 4)); + + private static final List>> GROUPED_TESTABLE_KVS = ImmutableList.of( + KV.of("k1", (Iterable) ImmutableList.of(3)), + KV.of("k2", (Iterable) ImmutableList.of(4))); + + @Test + @Category(RunnableOnService.class) + public void testJustReshuffle() { + Pipeline pipeline = TestPipeline.create(); + + PCollection> input = pipeline + .apply(Create.of(ARBITRARY_KVS) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))); + + PCollection> output = input + .apply(Reshuffle.of()); + + DataflowAssert.that(output).containsInAnyOrder(ARBITRARY_KVS); + + assertEquals( + input.getWindowingStrategy(), + output.getWindowingStrategy()); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testReshuffleAfterSessionsAndGroupByKey() { + Pipeline pipeline = TestPipeline.create(); + + PCollection>> input = pipeline + .apply(Create.of(GBK_TESTABLE_KVS) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(Window.>into( + Sessions.withGapDuration(Duration.standardMinutes(10)))) + .apply(GroupByKey.create()); + + PCollection>> output = input + .apply(Reshuffle.>of()); + + DataflowAssert.that(output).containsInAnyOrder(GROUPED_TESTABLE_KVS); + + assertEquals( + input.getWindowingStrategy(), + output.getWindowingStrategy()); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testReshuffleAfterFixedWindowsAndGroupByKey() { + Pipeline pipeline = TestPipeline.create(); + + PCollection>> input = pipeline + .apply(Create.of(GBK_TESTABLE_KVS) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(Window.>into( + FixedWindows.of(Duration.standardMinutes(10L)))) + .apply(GroupByKey.create()); + + PCollection>> output = input + .apply(Reshuffle.>of()); + + DataflowAssert.that(output).containsInAnyOrder(GROUPED_TESTABLE_KVS); + + assertEquals( + input.getWindowingStrategy(), + output.getWindowingStrategy()); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testReshuffleAfterSlidingWindowsAndGroupByKey() { + Pipeline pipeline = TestPipeline.create(); + + PCollection>> input = pipeline + .apply(Create.of(GBK_TESTABLE_KVS) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(Window.>into( + FixedWindows.of(Duration.standardMinutes(10L)))) + .apply(GroupByKey.create()); + + PCollection>> output = input + .apply(Reshuffle.>of()); + + DataflowAssert.that(output).containsInAnyOrder(GROUPED_TESTABLE_KVS); + + assertEquals( + input.getWindowingStrategy(), + output.getWindowingStrategy()); + + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testReshuffleAfterFixedWindows() { + Pipeline pipeline = TestPipeline.create(); + + PCollection> input = pipeline + .apply(Create.of(ARBITRARY_KVS) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(Window.>into( + FixedWindows.of(Duration.standardMinutes(10L)))); + + PCollection> output = input + .apply(Reshuffle.of()); + + DataflowAssert.that(output).containsInAnyOrder(ARBITRARY_KVS); + + assertEquals( + input.getWindowingStrategy(), + output.getWindowingStrategy()); + + pipeline.run(); + } + + + @Test + @Category(RunnableOnService.class) + public void testReshuffleAfterSlidingWindows() { + Pipeline pipeline = TestPipeline.create(); + + PCollection> input = pipeline + .apply(Create.of(ARBITRARY_KVS) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))) + .apply(Window.>into( + FixedWindows.of(Duration.standardMinutes(10L)))); + + PCollection> output = input + .apply(Reshuffle.of()); + + DataflowAssert.that(output).containsInAnyOrder(ARBITRARY_KVS); + + assertEquals( + input.getWindowingStrategy(), + output.getWindowingStrategy()); + + pipeline.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReshuffleTriggerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReshuffleTriggerTest.java new file mode 100644 index 000000000000..4b3a77ce61c1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReshuffleTriggerTest.java @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ReshuffleTrigger}. + */ +@RunWith(JUnit4.class) +public class ReshuffleTriggerTest { + + /** Public so that other tests can instantiate ReshufleTrigger. */ + public static ReshuffleTrigger forTest() { + return new ReshuffleTrigger<>(); + } + + @Test + public void testShouldFire() throws Exception { + TriggerTester tester = TriggerTester.forTrigger( + new ReshuffleTrigger(), FixedWindows.of(Duration.millis(100))); + IntervalWindow arbitraryWindow = new IntervalWindow(new Instant(300), new Instant(400)); + assertTrue(tester.shouldFire(arbitraryWindow)); + } + + @Test + public void testOnTimer() throws Exception { + TriggerTester tester = TriggerTester.forTrigger( + new ReshuffleTrigger(), FixedWindows.of(Duration.millis(100))); + IntervalWindow arbitraryWindow = new IntervalWindow(new Instant(100), new Instant(200)); + tester.fireIfShouldFire(arbitraryWindow); + assertFalse(tester.isMarkedFinished(arbitraryWindow)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializerTest.java new file mode 100644 index 000000000000..709719080258 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/RetryHttpRequestInitializerTest.java @@ -0,0 +1,296 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.client.auth.oauth2.Credential; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpResponseException; +import com.google.api.client.http.HttpResponseInterceptor; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.LowLevelHttpRequest; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.util.NanoClock; +import com.google.api.client.util.Sleeper; +import com.google.api.services.bigquery.Bigquery; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.storage.Storage; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.security.PrivateKey; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Tests for RetryHttpRequestInitializer. + */ +@RunWith(JUnit4.class) +public class RetryHttpRequestInitializerTest { + + @Mock private Credential mockCredential; + @Mock private PrivateKey mockPrivateKey; + @Mock private LowLevelHttpRequest mockLowLevelRequest; + @Mock private LowLevelHttpResponse mockLowLevelResponse; + @Mock private HttpResponseInterceptor mockHttpResponseInterceptor; + + private final JsonFactory jsonFactory = JacksonFactory.getDefaultInstance(); + private Storage storage; + + // Used to test retrying a request more than the default 10 times. + static class MockNanoClock implements NanoClock { + private int timesMs[] = {500, 750, 1125, 1688, 2531, 3797, 5695, 8543, + 12814, 19222, 28833, 43249, 64873, 97310, 145965, 218945, 328420}; + private int i = 0; + + @Override + public long nanoTime() { + return timesMs[i++ / 2] * 1000000; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + HttpTransport lowLevelTransport = new HttpTransport() { + @Override + protected LowLevelHttpRequest buildRequest(String method, String url) + throws IOException { + return mockLowLevelRequest; + } + }; + + // Retry initializer will pass through to credential, since we can have + // only a single HttpRequestInitializer, and we use multiple Credential + // types in the SDK, not all of which allow for retry configuration. + RetryHttpRequestInitializer initializer = new RetryHttpRequestInitializer( + mockCredential, new MockNanoClock(), new Sleeper() { + @Override + public void sleep(long millis) throws InterruptedException {} + }, Arrays.asList(418 /* I'm a teapot */), mockHttpResponseInterceptor); + storage = new Storage.Builder(lowLevelTransport, jsonFactory, initializer) + .setApplicationName("test").build(); + } + + @After + public void tearDown() { + verifyNoMoreInteractions(mockPrivateKey); + verifyNoMoreInteractions(mockLowLevelRequest); + verifyNoMoreInteractions(mockCredential); + verifyNoMoreInteractions(mockHttpResponseInterceptor); + } + + @Test + public void testBasicOperation() throws IOException { + when(mockLowLevelRequest.execute()) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(200); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest).execute(); + verify(mockLowLevelResponse).getStatusCode(); + } + + /** + * Tests that a non-retriable error is not retried. + */ + @Test + public void testErrorCodeForbidden() throws IOException { + when(mockLowLevelRequest.execute()) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(403) // Non-retryable error. + .thenReturn(200); // Shouldn't happen. + + try { + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + } catch (HttpResponseException e) { + Assert.assertThat(e.getMessage(), Matchers.containsString("403")); + } + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest).execute(); + verify(mockLowLevelResponse).getStatusCode(); + } + + /** + * Tests that a retriable error is retried. + */ + @Test + public void testRetryableError() throws IOException { + when(mockLowLevelRequest.execute()) + .thenReturn(mockLowLevelResponse) + .thenReturn(mockLowLevelResponse) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(503) // Retryable + .thenReturn(429) // We also retry on 429 Too Many Requests. + .thenReturn(200); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest, times(3)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(3)).execute(); + verify(mockLowLevelResponse, times(3)).getStatusCode(); + } + + /** + * Tests that an IOException is retried. + */ + @Test + public void testThrowIOException() throws IOException { + when(mockLowLevelRequest.execute()) + .thenThrow(new IOException("Fake Error")) + .thenReturn(mockLowLevelResponse); + when(mockLowLevelResponse.getStatusCode()) + .thenReturn(200); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); + verify(mockLowLevelRequest, atLeastOnce()) + .addHeader(anyString(), anyString()); + verify(mockLowLevelRequest, times(2)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(2)).execute(); + verify(mockLowLevelResponse).getStatusCode(); + } + + /** + * Tests that a retryable error is retried enough times. + */ + @Test + public void testRetryableErrorRetryEnoughTimes() throws IOException { + when(mockLowLevelRequest.execute()).thenReturn(mockLowLevelResponse); + final int retries = 10; + when(mockLowLevelResponse.getStatusCode()).thenAnswer(new Answer(){ + int n = 0; + @Override + public Integer answer(InvocationOnMock invocation) { + return (n++ < retries - 1) ? 503 : 200; + }}); + + Storage.Buckets.Get result = storage.buckets().get("test"); + HttpResponse response = result.executeUnparsed(); + assertNotNull(response); + + verify(mockCredential).initialize(any(HttpRequest.class)); + verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); + verify(mockLowLevelRequest, atLeastOnce()).addHeader(anyString(), + anyString()); + verify(mockLowLevelRequest, times(retries)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(retries)).execute(); + verify(mockLowLevelResponse, times(retries)).getStatusCode(); + } + + /** + * Tests that when RPCs fail with {@link SocketTimeoutException}, the IO exception handler + * is invoked. + */ + @Test + public void testIOExceptionHandlerIsInvokedOnTimeout() throws Exception { + // Counts the number of calls to execute the HTTP request. + final AtomicLong executeCount = new AtomicLong(); + + // 10 is a private internal constant in the Google API Client library. See + // com.google.api.client.http.HttpRequest#setNumberOfRetries + // TODO: update this test once the private internal constant is public. + final int defaultNumberOfRetries = 10; + + // A mock HTTP request that always throws SocketTimeoutException. + MockHttpTransport transport = + new MockHttpTransport.Builder().setLowLevelHttpRequest(new MockLowLevelHttpRequest() { + @Override + public LowLevelHttpResponse execute() throws IOException { + executeCount.incrementAndGet(); + throw new SocketTimeoutException("Fake forced timeout exception"); + } + }).build(); + + // A sample HTTP request to BigQuery that uses both default Transport and default + // RetryHttpInitializer. + Bigquery b = new Bigquery.Builder( + transport, Transport.getJsonFactory(), new RetryHttpRequestInitializer()).build(); + BigQueryTableInserter inserter = new BigQueryTableInserter(b); + TableReference t = new TableReference() + .setProjectId("project").setDatasetId("dataset").setTableId("table"); + + try { + inserter.insertAll(t, ImmutableList.of(new TableRow())); + fail(); + } catch (Throwable e) { + assertThat(e, Matchers.instanceOf(RuntimeException.class)); + assertThat(e.getCause(), Matchers.instanceOf(SocketTimeoutException.class)); + assertEquals(1 + defaultNumberOfRetries, executeCount.get()); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializableUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializableUtilsTest.java new file mode 100644 index 000000000000..5d52c31134e3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializableUtilsTest.java @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.DeterministicStandardCoder; +import com.google.common.collect.ImmutableList; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.List; + +/** Tests for {@link SerializableUtils}. */ +@RunWith(JUnit4.class) +public class SerializableUtilsTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + /** A class that is serializable by Java. */ + private static class SerializableByJava implements Serializable { + final String stringValue; + final int intValue; + + public SerializableByJava(String stringValue, int intValue) { + this.stringValue = stringValue; + this.intValue = intValue; + } + } + + @Test + public void testTranscode() { + String stringValue = "hi bob"; + int intValue = 42; + + SerializableByJava testObject = new SerializableByJava(stringValue, intValue); + SerializableByJava testCopy = SerializableUtils.ensureSerializable(testObject); + + assertEquals(stringValue, testCopy.stringValue); + assertEquals(intValue, testCopy.intValue); + } + + @Test + public void testDeserializationError() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("unable to deserialize a bogus string"); + SerializableUtils.deserializeFromByteArray( + "this isn't legal".getBytes(), + "a bogus string"); + } + + /** A class that is not serializable by Java. */ + private static class UnserializableByJava implements Serializable { + @SuppressWarnings("unused") + private Object unserializableField = new Object(); + } + + @Test + public void testSerializationError() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("unable to serialize"); + SerializableUtils.serializeToByteArray(new UnserializableByJava()); + } + + /** A {@link Coder} that is not serializable by Java. */ + private static class UnserializableCoderByJava extends DeterministicStandardCoder { + private final Object unserializableField = new Object(); + + @Override + public void encode(Object value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public Object decode(InputStream inStream, Context context) + throws CoderException, IOException { + return unserializableField; + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(); + } + } + + @Test + public void testEnsureSerializableWithUnserializableCoderByJava() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("unable to serialize"); + SerializableUtils.ensureSerializable(new UnserializableCoderByJava()); + } + + /** A {@link Coder} that is not serializable by Jackson. */ + private static class UnserializableCoderByJackson extends DeterministicStandardCoder { + private final SerializableByJava unserializableField; + + public UnserializableCoderByJackson(SerializableByJava unserializableField) { + this.unserializableField = unserializableField; + } + + @JsonCreator + public static UnserializableCoderByJackson of( + @JsonProperty("unserializableField") SerializableByJava unserializableField) { + return new UnserializableCoderByJackson(unserializableField); + } + + @Override + public CloudObject asCloudObject() { + CloudObject result = super.asCloudObject(); + result.put("unserializableField", unserializableField); + return result; + } + + @Override + public void encode(Object value, OutputStream outStream, Context context) + throws CoderException, IOException { + } + + @Override + public Object decode(InputStream inStream, Context context) + throws CoderException, IOException { + return unserializableField; + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(); + } + } + + @Test + public void testEnsureSerializableWithUnserializableCoderByJackson() throws Exception { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Unable to deserialize Coder:"); + SerializableUtils.ensureSerializable( + new UnserializableCoderByJackson(new SerializableByJava("TestData", 5))); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializerTest.java new file mode 100644 index 000000000000..80c797a23151 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SerializerTest.java @@ -0,0 +1,162 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDouble; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests Serializer implementation. + */ +@RunWith(JUnit4.class) +public class SerializerTest { + /** + * A POJO to use for testing serialization. + */ + @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, + property = PropertyNames.OBJECT_TYPE_NAME) + public static class TestRecord { + // TODO: When we apply property name typing to all non-final classes, the + // annotation on this class should be removed. + public String name; + public boolean ok; + public int value; + public double dValue; + } + + @Test + public void testStatefulDeserialization() { + CloudObject object = CloudObject.forClass(TestRecord.class); + + addString(object, "name", "foobar"); + addBoolean(object, "ok", true); + addLong(object, "value", 42L); + addDouble(object, "dValue", .25); + + TestRecord record = Serializer.deserialize(object, TestRecord.class); + Assert.assertEquals("foobar", record.name); + Assert.assertEquals(true, record.ok); + Assert.assertEquals(42L, record.value); + Assert.assertEquals(0.25, record.dValue, 0.0001); + } + + private static class InjectedTestRecord { + private final String n; + private final int v; + + @SuppressWarnings("unused") // used for JSON serialization + public InjectedTestRecord( + @JsonProperty("name") String name, + @JsonProperty("value") int value) { + this.n = name; + this.v = value; + } + + public String getName() { + return n; + } + public int getValue() { + return v; + } + } + + @Test + public void testDeserializationInjection() { + CloudObject object = CloudObject.forClass(InjectedTestRecord.class); + addString(object, "name", "foobar"); + addLong(object, "value", 42L); + + InjectedTestRecord record = + Serializer.deserialize(object, InjectedTestRecord.class); + + Assert.assertEquals("foobar", record.getName()); + Assert.assertEquals(42L, record.getValue()); + } + + private static class FactoryInjectedTestRecord { + @JsonCreator + public static FactoryInjectedTestRecord of( + @JsonProperty("name") String name, + @JsonProperty("value") int value) { + return new FactoryInjectedTestRecord(name, value); + } + + private final String n; + private final int v; + + private FactoryInjectedTestRecord(String name, int value) { + this.n = name; + this.v = value; + } + + public String getName() { + return n; + } + public int getValue() { + return v; + } + } + + @Test + public void testDeserializationFactoryInjection() { + CloudObject object = CloudObject.forClass(FactoryInjectedTestRecord.class); + addString(object, "name", "foobar"); + addLong(object, "value", 42L); + + FactoryInjectedTestRecord record = + Serializer.deserialize(object, FactoryInjectedTestRecord.class); + Assert.assertEquals("foobar", record.getName()); + Assert.assertEquals(42L, record.getValue()); + } + + private static class DerivedTestRecord extends TestRecord { + public String derived; + } + + @Test + public void testSubclassDeserialization() { + CloudObject object = CloudObject.forClass(DerivedTestRecord.class); + + addString(object, "name", "foobar"); + addBoolean(object, "ok", true); + addLong(object, "value", 42L); + addDouble(object, "dValue", .25); + addString(object, "derived", "baz"); + + TestRecord result = Serializer.deserialize(object, TestRecord.class); + Assert.assertThat(result, Matchers.instanceOf(DerivedTestRecord.class)); + + DerivedTestRecord record = (DerivedTestRecord) result; + Assert.assertEquals("foobar", record.name); + Assert.assertEquals(true, record.ok); + Assert.assertEquals(42L, record.value); + Assert.assertEquals(0.25, record.dValue, 0.0001); + Assert.assertEquals("baz", record.derived); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SimpleDoFnRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SimpleDoFnRunnerTest.java new file mode 100644 index 000000000000..4af60dbe9125 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/SimpleDoFnRunnerTest.java @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.util.BaseExecutionContext.StepContext; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Tests for base @{link DoFnRunnerBase} functionality. + */ +@RunWith(JUnit4.class) +public class SimpleDoFnRunnerTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testExceptionsWrappedAsUserCodeException() { + ThrowingDoFn fn = new ThrowingDoFn(); + DoFnRunner runner = createRunner(fn); + + thrown.expect(UserCodeException.class); + thrown.expectCause(is(fn.exceptionToThrow)); + + runner.processElement(WindowedValue.valueInGlobalWindow("anyValue")); + } + + @Test + public void testSystemDoFnInternalExceptionsNotWrapped() { + ThrowingSystemDoFn fn = new ThrowingSystemDoFn(); + DoFnRunner runner = createRunner(fn); + + thrown.expect(is(fn.exceptionToThrow)); + + runner.processElement(WindowedValue.valueInGlobalWindow("anyValue")); + } + + private DoFnRunner createRunner(DoFn fn) { + // Pass in only necessary parameters for the test + List> sideOutputTags = Arrays.asList(); + StepContext context = mock(StepContext.class); + return DoFnRunners.simpleRunner( + null, fn, null, null, null, sideOutputTags, context, null, null); + } + + static class ThrowingDoFn extends DoFn { + final Exception exceptionToThrow = + new UnsupportedOperationException("Expected exception"); + + @Override + public void processElement(ProcessContext c) throws Exception { + throw exceptionToThrow; + } + } + + @SystemDoFnInternal + static class ThrowingSystemDoFn extends ThrowingDoFn { + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StreamUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StreamUtilsTest.java new file mode 100644 index 000000000000..595dbdc155bf --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StreamUtilsTest.java @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.BufferedInputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; + +/** Unit tests for {@link ExposedByteArrayInputStream}. */ +@RunWith(JUnit4.class) +public class StreamUtilsTest { + + private byte[] testData = null; + + @Before + public void setUp() { + testData = new byte[60 * 1024]; + Arrays.fill(testData, (byte) 32); + } + + @Test + public void testGetBytesFromExposedByteArrayInputStream() throws IOException { + InputStream stream = new ExposedByteArrayInputStream(testData); + byte[] bytes = StreamUtils.getBytes(stream); + assertArrayEquals(testData, bytes); + assertSame(testData, bytes); + assertEquals(0, stream.available()); + } + + @Test + public void testGetBytesFromByteArrayInputStream() throws IOException { + InputStream stream = new ByteArrayInputStream(testData); + byte[] bytes = StreamUtils.getBytes(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, stream.available()); + } + + @Test + public void testGetBytesFromInputStream() throws IOException { + // Any stream which is not a ByteArrayInputStream. + InputStream stream = + new BufferedInputStream(new ByteArrayInputStream(testData)); + byte[] bytes = StreamUtils.getBytes(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, stream.available()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StringUtilsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StringUtilsTest.java new file mode 100644 index 000000000000..49d8688ace72 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StringUtilsTest.java @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PDone; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for StringUtils. + */ +@RunWith(JUnit4.class) +public class StringUtilsTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testTranscodeEmptyByteArray() { + byte[] bytes = { }; + String string = ""; + assertEquals(string, StringUtils.byteArrayToJsonString(bytes)); + assertArrayEquals(bytes, StringUtils.jsonStringToByteArray(string)); + } + + @Test + public void testTranscodeMixedByteArray() { + byte[] bytes = { + 0, 5, 12, 16, 31, 32, 65, 66, 126, 127, (byte) 128, (byte) 255, 67, 0 }; + String string = "%00%05%0c%10%1f AB~%7f%80%ffC%00"; + assertEquals(string, StringUtils.byteArrayToJsonString(bytes)); + assertArrayEquals(bytes, StringUtils.jsonStringToByteArray(string)); + } + + /** + * Inner class for simple name test. + */ + private class EmbeddedDoFn { + + private class DeeperEmbeddedDoFn extends EmbeddedDoFn {} + + private EmbeddedDoFn getEmbedded() { + return new DeeperEmbeddedDoFn(); + } + } + + private class EmbeddedPTransform extends PTransform { + private class Bound extends PTransform {} + + private Bound getBound() { + return new Bound(); + } + } + + private interface AnonymousClass { + Object getInnerClassInstance(); + } + + @Test + public void testSimpleName() { + assertEquals("Embedded", + StringUtils.approximateSimpleName(EmbeddedDoFn.class)); + } + + @Test + public void testAnonSimpleName() throws Exception { + thrown.expect(IllegalArgumentException.class); + + EmbeddedDoFn anon = new EmbeddedDoFn(){}; + + StringUtils.approximateSimpleName(anon.getClass()); + } + + @Test + public void testNestedSimpleName() { + EmbeddedDoFn fn = new EmbeddedDoFn(); + EmbeddedDoFn inner = fn.getEmbedded(); + + assertEquals("DeeperEmbedded", StringUtils.approximateSimpleName(inner.getClass())); + } + + @Test + public void testPTransformName() { + EmbeddedPTransform transform = new EmbeddedPTransform(); + assertEquals( + "StringUtilsTest.EmbeddedPTransform", + StringUtils.approximatePTransformName(transform.getClass())); + assertEquals( + "StringUtilsTest.EmbeddedPTransform", + StringUtils.approximatePTransformName(transform.getBound().getClass())); + assertEquals("TextIO.Write", StringUtils.approximatePTransformName(TextIO.Write.Bound.class)); + } + + @Test + public void testPTransformNameWithAnonOuterClass() throws Exception { + AnonymousClass anonymousClassObj = new AnonymousClass() { + class NamedInnerClass extends PTransform {} + + @Override + public Object getInnerClassInstance() { + return new NamedInnerClass(); + } + }; + + assertEquals("NamedInnerClass", + StringUtils.approximateSimpleName(anonymousClassObj.getInnerClassInstance().getClass())); + assertEquals("StringUtilsTest.NamedInnerClass", + StringUtils.approximatePTransformName( + anonymousClassObj.getInnerClassInstance().getClass())); + } + + @Test + public void testLevenshteinDistance() { + assertEquals(0, StringUtils.getLevenshteinDistance("", "")); // equal + assertEquals(3, StringUtils.getLevenshteinDistance("", "abc")); // first empty + assertEquals(3, StringUtils.getLevenshteinDistance("abc", "")); // second empty + assertEquals(5, StringUtils.getLevenshteinDistance("abc", "12345")); // completely different + assertEquals(1, StringUtils.getLevenshteinDistance("abc", "ac")); // deletion + assertEquals(1, StringUtils.getLevenshteinDistance("abc", "ab1c")); // insertion + assertEquals(1, StringUtils.getLevenshteinDistance("abc", "a1c")); // modification + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StructsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StructsTest.java new file mode 100644 index 000000000000..bc0bdd82ec1c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/StructsTest.java @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.Structs.addBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.addDouble; +import static com.google.cloud.dataflow.sdk.util.Structs.addList; +import static com.google.cloud.dataflow.sdk.util.Structs.addLong; +import static com.google.cloud.dataflow.sdk.util.Structs.addLongs; +import static com.google.cloud.dataflow.sdk.util.Structs.addNull; +import static com.google.cloud.dataflow.sdk.util.Structs.addString; +import static com.google.cloud.dataflow.sdk.util.Structs.addStringList; +import static com.google.cloud.dataflow.sdk.util.Structs.getBoolean; +import static com.google.cloud.dataflow.sdk.util.Structs.getDictionary; +import static com.google.cloud.dataflow.sdk.util.Structs.getInt; +import static com.google.cloud.dataflow.sdk.util.Structs.getListOfMaps; +import static com.google.cloud.dataflow.sdk.util.Structs.getLong; +import static com.google.cloud.dataflow.sdk.util.Structs.getObject; +import static com.google.cloud.dataflow.sdk.util.Structs.getString; +import static com.google.cloud.dataflow.sdk.util.Structs.getStrings; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Tests for Structs. + */ +@RunWith(JUnit4.class) +public class StructsTest { + private List> makeCloudObjects() { + List> objects = new ArrayList<>(); + { + CloudObject o = CloudObject.forClassName("string"); + addString(o, "singletonStringKey", "stringValue"); + objects.add(o); + } + { + CloudObject o = CloudObject.forClassName("long"); + addLong(o, "singletonLongKey", 42L); + objects.add(o); + } + return objects; + } + + private Map makeCloudDictionary() { + Map o = new HashMap<>(); + addList(o, "emptyKey", Collections.>emptyList()); + addNull(o, "noStringsKey"); + addString(o, "singletonStringKey", "stringValue"); + addStringList(o, "multipleStringsKey", Arrays.asList("hi", "there", "bob")); + addLongs(o, "multipleLongsKey", 47L, 1L << 42, -5L); + addLong(o, "singletonLongKey", 42L); + addDouble(o, "singletonDoubleKey", 3.14); + addBoolean(o, "singletonBooleanKey", true); + addNull(o, "noObjectsKey"); + addList(o, "multipleObjectsKey", makeCloudObjects()); + return o; + } + + @Test + public void testGetStringParameter() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + "stringValue", + getString(o, "singletonStringKey")); + Assert.assertEquals( + "stringValue", + getString(o, "singletonStringKey", "defaultValue")); + Assert.assertEquals( + "defaultValue", + getString(o, "missingKey", "defaultValue")); + + try { + getString(o, "missingKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString( + "didn't find required parameter missingKey")); + } + + try { + getString(o, "noStringsKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a string")); + } + + Assert.assertThat(getStrings(o, "noStringsKey", null), Matchers.emptyIterable()); + Assert.assertThat(getObject(o, "noStringsKey").keySet(), Matchers.emptyIterable()); + Assert.assertThat(getDictionary(o, "noStringsKey").keySet(), Matchers.emptyIterable()); + Assert.assertThat(getDictionary(o, "noStringsKey", null).keySet(), + Matchers.emptyIterable()); + + try { + getString(o, "multipleStringsKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a string")); + } + + try { + getString(o, "emptyKey"); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a string")); + } + } + + @Test + public void testGetBooleanParameter() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + true, + getBoolean(o, "singletonBooleanKey", false)); + Assert.assertEquals( + false, + getBoolean(o, "missingKey", false)); + + try { + getBoolean(o, "emptyKey", false); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a boolean")); + } + } + + @Test + public void testGetLongParameter() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + (Long) 42L, + getLong(o, "singletonLongKey", 666L)); + Assert.assertEquals( + (Integer) 42, + getInt(o, "singletonLongKey", 666)); + Assert.assertEquals( + (Long) 666L, + getLong(o, "missingKey", 666L)); + + try { + getLong(o, "emptyKey", 666L); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a long")); + } + try { + getInt(o, "emptyKey", 666); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not an int")); + } + } + + @Test + public void testGetListOfMaps() throws Exception { + Map o = makeCloudDictionary(); + + Assert.assertEquals( + makeCloudObjects(), + getListOfMaps(o, "multipleObjectsKey", null)); + + try { + getListOfMaps(o, "singletonLongKey", null); + Assert.fail("should have thrown an exception"); + } catch (Exception exn) { + Assert.assertThat(exn.toString(), + Matchers.containsString("not a list")); + } + } + + // TODO: Test builder operations. +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimeUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimeUtilTest.java new file mode 100644 index 000000000000..d7c3384be114 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimeUtilTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.fromCloudTime; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudDuration; +import static com.google.cloud.dataflow.sdk.util.TimeUtil.toCloudTime; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TimeUtil}. */ +@RunWith(JUnit4.class) +public final class TimeUtilTest { + @Test + public void toCloudTimeShouldPrintTimeStrings() { + assertEquals("1970-01-01T00:00:00Z", toCloudTime(new Instant(0))); + assertEquals("1970-01-01T00:00:00.001Z", toCloudTime(new Instant(1))); + } + + @Test + public void fromCloudTimeShouldParseTimeStrings() { + assertEquals(new Instant(0), fromCloudTime("1970-01-01T00:00:00Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001001Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000000Z")); + assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000001Z")); + assertNull(fromCloudTime("")); + assertNull(fromCloudTime("1970-01-01T00:00:00")); + } + + @Test + public void toCloudDurationShouldPrintDurationStrings() { + assertEquals("0s", toCloudDuration(Duration.ZERO)); + assertEquals("4s", toCloudDuration(Duration.millis(4000))); + assertEquals("4.001s", toCloudDuration(Duration.millis(4001))); + } + + @Test + public void fromCloudDurationShouldParseDurationStrings() { + assertEquals(Duration.millis(4000), fromCloudDuration("4s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001000s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001001s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001000000s")); + assertEquals(Duration.millis(4001), fromCloudDuration("4.001000001s")); + assertNull(fromCloudDuration("")); + assertNull(fromCloudDuration("4")); + assertNull(fromCloudDuration("4.1")); + assertNull(fromCloudDuration("4.1s")); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimerInternalsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimerInternalsTest.java new file mode 100644 index 000000000000..68aecf0b53b3 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TimerInternalsTest.java @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.testing.CoderProperties; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerDataCoder; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link TimerInternals}. + */ +@RunWith(JUnit4.class) +public class TimerInternalsTest { + + @Test + public void testTimerDataCoder() throws Exception { + CoderProperties.coderDecodeEncodeEqual( + TimerDataCoder.of(GlobalWindow.Coder.INSTANCE), + TimerData.of(StateNamespaces.global(), new Instant(0), TimeDomain.EVENT_TIME)); + + Coder windowCoder = IntervalWindow.getCoder(); + CoderProperties.coderDecodeEncodeEqual( + TimerDataCoder.of(windowCoder), + TimerData.of( + StateNamespaces.window( + windowCoder, new IntervalWindow(new Instant(0), new Instant(100))), + new Instant(99), TimeDomain.PROCESSING_TIME)); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TriggerTester.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TriggerTester.java new file mode 100644 index 000000000000..0c7183020291 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/TriggerTester.java @@ -0,0 +1,585 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.TriggerBuilder; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.ActiveWindowSet.MergeCallback; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.util.state.InMemoryStateInternals; +import com.google.cloud.dataflow.sdk.util.state.State; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces.WindowAndTriggerNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces.WindowNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.WatermarkHoldState; +import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.common.base.MoreObjects; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Set; + +import javax.annotation.Nullable; + +/** + * Test utility that runs a {@link Trigger}, using in-memory stub implementation to provide + * the {@link StateInternals}. + * + * @param The type of windows being used. + */ +public class TriggerTester { + + /** + * A {@link TriggerTester} specialized to {@link Integer} values, so elements and timestamps + * can be conflated. Today, triggers should not observed the element type, so this is the + * only trigger tester that needs to be used. + */ + public static class SimpleTriggerTester + extends TriggerTester { + + private SimpleTriggerTester(WindowingStrategy windowingStrategy) throws Exception { + super(windowingStrategy); + } + + public void injectElements(int... values) throws Exception { + List> timestampedValues = + Lists.newArrayListWithCapacity(values.length); + for (int value : values) { + timestampedValues.add(TimestampedValue.of(value, new Instant(value))); + } + injectElements(timestampedValues); + } + + public SimpleTriggerTester withAllowedLateness(Duration allowedLateness) throws Exception { + return new SimpleTriggerTester<>( + windowingStrategy.withAllowedLateness(allowedLateness)); + } + } + + protected final WindowingStrategy windowingStrategy; + + private final TestInMemoryStateInternals stateInternals = + new TestInMemoryStateInternals(); + private final TestTimerInternals timerInternals = new TestTimerInternals(); + private final TriggerContextFactory contextFactory; + private final WindowFn windowFn; + private final ActiveWindowSet activeWindows; + + /** + * An {@link ExecutableTrigger} built from the {@link Trigger} or {@link TriggerBuilder} + * under test. + */ + private final ExecutableTrigger executableTrigger; + + /** + * A map from a window and trigger to whether that trigger is finished for the window. + */ + private final Map finishedSets; + + public static SimpleTriggerTester forTrigger( + TriggerBuilder trigger, WindowFn windowFn) + throws Exception { + WindowingStrategy windowingStrategy = + WindowingStrategy.of(windowFn).withTrigger(trigger.buildTrigger()) + // Merging requires accumulation mode or early firings can break up a session. + // Not currently an issue with the tester (because we never GC) but we don't want + // mystery failures due to violating this need. + .withMode(windowFn.isNonMerging() + ? AccumulationMode.DISCARDING_FIRED_PANES + : AccumulationMode.ACCUMULATING_FIRED_PANES); + + return new SimpleTriggerTester<>(windowingStrategy); + } + + public static TriggerTester forAdvancedTrigger( + TriggerBuilder trigger, WindowFn windowFn) throws Exception { + WindowingStrategy strategy = + WindowingStrategy.of(windowFn).withTrigger(trigger.buildTrigger()) + // Merging requires accumulation mode or early firings can break up a session. + // Not currently an issue with the tester (because we never GC) but we don't want + // mystery failures due to violating this need. + .withMode(windowFn.isNonMerging() + ? AccumulationMode.DISCARDING_FIRED_PANES + : AccumulationMode.ACCUMULATING_FIRED_PANES); + + return new TriggerTester<>(strategy); + } + + protected TriggerTester(WindowingStrategy windowingStrategy) throws Exception { + this.windowingStrategy = windowingStrategy; + this.windowFn = windowingStrategy.getWindowFn(); + this.executableTrigger = windowingStrategy.getTrigger(); + this.finishedSets = new HashMap<>(); + + this.activeWindows = + windowFn.isNonMerging() + ? new NonMergingActiveWindowSet() + : new MergingActiveWindowSet(windowFn, stateInternals); + + this.contextFactory = + new TriggerContextFactory<>(windowingStrategy, stateInternals, activeWindows); + } + + /** + * Instructs the trigger to clear its state for the given window. + */ + public void clearState(W window) throws Exception { + executableTrigger.invokeClear(contextFactory.base(window, + new TestTimers(windowNamespace(window)), executableTrigger, getFinishedSet(window))); + } + + /** + * Asserts that the trigger has actually cleared all of its state for the given window. Since + * the trigger under test is the root, this makes the assert for all triggers regardless + * of their position in the trigger tree. + */ + public void assertCleared(W window) { + for (StateNamespace untypedNamespace : stateInternals.getNamespacesInUse()) { + if (untypedNamespace instanceof WindowAndTriggerNamespace) { + @SuppressWarnings("unchecked") + WindowAndTriggerNamespace namespace = (WindowAndTriggerNamespace) untypedNamespace; + if (namespace.getWindow().equals(window)) { + Set tagsInUse = stateInternals.getTagsInUse(namespace); + assertTrue("Trigger has not cleared tags: " + tagsInUse, tagsInUse.isEmpty()); + } + } + } + } + + /** + * Returns {@code true} if the {@link Trigger} under test is finished for the given window. + */ + public boolean isMarkedFinished(W window) { + FinishedTriggers finishedSet = finishedSets.get(window); + if (finishedSet == null) { + return false; + } + + return finishedSet.isFinished(executableTrigger); + } + + private StateNamespace windowNamespace(W window) { + return StateNamespaces.window(windowFn.windowCoder(), checkNotNull(window)); + } + + /** + * Advance the input watermark to the specified time, firing any timers that should + * fire. Then advance the output watermark as far as possible. + */ + public void advanceInputWatermark(Instant newInputWatermark) throws Exception { + timerInternals.advanceInputWatermark(newInputWatermark); + } + + /** Advance the processing time to the specified time, firing any timers that should fire. */ + public void advanceProcessingTime(Instant newProcessingTime) throws Exception { + timerInternals.advanceProcessingTime(newProcessingTime); + } + + /** Advance the processing time to the specified time, firing any timers that should fire. */ + public void advanceSynchronizedProcessingTime(Instant newSynchronizedProcessingTime) + throws Exception { + timerInternals.advanceSynchronizedProcessingTime(newSynchronizedProcessingTime); + } + + /** + * Inject all the timestamped values (after passing through the window function) as if they + * arrived in a single chunk of a bundle (or work-unit). + */ + @SafeVarargs + public final void injectElements(TimestampedValue... values) throws Exception { + injectElements(Arrays.asList(values)); + } + + public final void injectElements(Collection> values) throws Exception { + for (TimestampedValue value : values) { + WindowTracing.trace("TriggerTester.injectElements: {}", value); + } + + List> windowedValues = Lists.newArrayListWithCapacity(values.size()); + + for (TimestampedValue input : values) { + try { + InputT value = input.getValue(); + Instant timestamp = input.getTimestamp(); + Collection assignedWindows = windowFn.assignWindows(new TestAssignContext( + windowFn, value, timestamp, Arrays.asList(GlobalWindow.INSTANCE))); + + for (W window : assignedWindows) { + activeWindows.addActive(window); + + // Today, triggers assume onTimer firing at the watermark time, whether or not they + // explicitly set the timer themselves. So this tester must set it. + timerInternals.setTimer( + TimerData.of(windowNamespace(window), window.maxTimestamp(), TimeDomain.EVENT_TIME)); + } + + windowedValues.add(WindowedValue.of(value, timestamp, assignedWindows, PaneInfo.NO_FIRING)); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + for (WindowedValue windowedValue : windowedValues) { + for (BoundedWindow untypedWindow : windowedValue.getWindows()) { + // SDK is responsible for type safety + @SuppressWarnings("unchecked") + W window = activeWindows.representative((W) untypedWindow); + + Trigger.OnElementContext context = contextFactory.createOnElementContext(window, + new TestTimers(windowNamespace(window)), windowedValue.getTimestamp(), + executableTrigger, getFinishedSet(window)); + + if (!context.trigger().isFinished()) { + executableTrigger.invokeOnElement(context); + } + } + } + } + + public boolean shouldFire(W window) throws Exception { + Trigger.TriggerContext context = contextFactory.base( + window, + new TestTimers(windowNamespace(window)), + executableTrigger, getFinishedSet(window)); + executableTrigger.getSpec().prefetchShouldFire(context.state()); + return executableTrigger.invokeShouldFire(context); + } + + public void fireIfShouldFire(W window) throws Exception { + Trigger.TriggerContext context = contextFactory.base( + window, + new TestTimers(windowNamespace(window)), + executableTrigger, getFinishedSet(window)); + + executableTrigger.getSpec().prefetchShouldFire(context.state()); + if (executableTrigger.invokeShouldFire(context)) { + executableTrigger.getSpec().prefetchOnFire(context.state()); + executableTrigger.invokeOnFire(context); + if (context.trigger().isFinished()) { + activeWindows.remove(context.window()); + executableTrigger.invokeClear(context); + } + } + } + + public void setSubTriggerFinishedForWindow(int subTriggerIndex, W window, boolean value) { + getFinishedSet(window).setFinished(executableTrigger.subTriggers().get(subTriggerIndex), value); + } + + /** + * Invokes merge from the {@link WindowFn} a single time and passes the resulting merge + * events on to the trigger under test. Does not persist the fact that merging happened, + * since it is just to test the trigger's {@code OnMerge} method. + */ + public final void mergeWindows() throws Exception { + activeWindows.merge(new MergeCallback() { + @Override + public void prefetchOnMerge(Collection toBeMerged, Collection activeToBeMerged, + W mergeResult) throws Exception {} + + @Override + public void onMerge(Collection toBeMerged, Collection activeToBeMerged, W mergeResult) + throws Exception { + Map mergingFinishedSets = + Maps.newHashMapWithExpectedSize(activeToBeMerged.size()); + for (W oldWindow : activeToBeMerged) { + mergingFinishedSets.put(oldWindow, getFinishedSet(oldWindow)); + } + executableTrigger.invokeOnMerge(contextFactory.createOnMergeContext(mergeResult, + new TestTimers(windowNamespace(mergeResult)), executableTrigger, + getFinishedSet(mergeResult), mergingFinishedSets)); + timerInternals.setTimer(TimerData.of( + windowNamespace(mergeResult), mergeResult.maxTimestamp(), TimeDomain.EVENT_TIME)); + } + }); + } + + private FinishedTriggers getFinishedSet(W window) { + FinishedTriggers finishedSet = finishedSets.get(window); + if (finishedSet == null) { + finishedSet = FinishedTriggersSet.fromSet(new HashSet>()); + finishedSets.put(window, finishedSet); + } + return finishedSet; + } + + /** + * Simulate state. + */ + private static class TestInMemoryStateInternals extends InMemoryStateInternals { + + public TestInMemoryStateInternals() { + super(null); + } + + public Set> getTagsInUse(StateNamespace namespace) { + Set> inUse = new HashSet<>(); + for (Map.Entry, State> entry : + inMemoryState.getTagsInUse(namespace).entrySet()) { + if (!isEmptyForTesting(entry.getValue())) { + inUse.add(entry.getKey()); + } + } + return inUse; + } + + public Set getNamespacesInUse() { + return inMemoryState.getNamespacesInUse(); + } + + /** Return the earliest output watermark hold in state, or null if none. */ + public Instant earliestWatermarkHold() { + Instant minimum = null; + for (State storage : inMemoryState.values()) { + if (storage instanceof WatermarkHoldState) { + @SuppressWarnings("unchecked") + Instant hold = ((WatermarkHoldState) storage).read(); + if (minimum == null || (hold != null && hold.isBefore(minimum))) { + minimum = hold; + } + } + } + return minimum; + } + } + + private static class TestAssignContext + extends WindowFn.AssignContext { + private Object element; + private Instant timestamp; + private Collection windows; + + public TestAssignContext(WindowFn windowFn, Object element, Instant timestamp, + Collection windows) { + windowFn.super(); + this.element = element; + this.timestamp = timestamp; + this.windows = windows; + } + + @Override + public Object element() { + return element; + } + + @Override + public Instant timestamp() { + return timestamp; + } + + @Override + public Collection windows() { + return windows; + } + } + + /** + * Simulate the firing of timers and progression of input and output watermarks for a + * single computation and key in a Windmill-like streaming environment. Similar to + * {@link BatchTimerInternals}, but also tracks the output watermark. + */ + private class TestTimerInternals implements TimerInternals { + /** At most one timer per timestamp is kept. */ + private Set existingTimers = new HashSet<>(); + + /** Pending input watermark timers, in timestamp order. */ + private PriorityQueue watermarkTimers = new PriorityQueue<>(11); + + /** Pending processing time timers, in timestamp order. */ + private PriorityQueue processingTimers = new PriorityQueue<>(11); + + /** Current input watermark. */ + @Nullable + private Instant inputWatermarkTime = null; + + /** Current output watermark. */ + @Nullable + private Instant outputWatermarkTime = null; + + /** Current processing time. */ + private Instant processingTime = BoundedWindow.TIMESTAMP_MIN_VALUE; + + /** Current processing time. */ + private Instant synchronizedProcessingTime = null; + + private PriorityQueue queue(TimeDomain domain) { + return TimeDomain.EVENT_TIME.equals(domain) ? watermarkTimers : processingTimers; + } + + @Override + public void setTimer(TimerData timer) { + WindowTracing.trace("TestTimerInternals.setTimer: {}", timer); + if (existingTimers.add(timer)) { + queue(timer.getDomain()).add(timer); + } + } + + @Override + public void deleteTimer(TimerData timer) { + WindowTracing.trace("TestTimerInternals.deleteTimer: {}", timer); + existingTimers.remove(timer); + queue(timer.getDomain()).remove(timer); + } + + @Override + public Instant currentProcessingTime() { + return processingTime; + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return synchronizedProcessingTime; + } + + @Override + @Nullable + public Instant currentInputWatermarkTime() { + return inputWatermarkTime; + } + + @Override + @Nullable + public Instant currentOutputWatermarkTime() { + return outputWatermarkTime; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("watermarkTimers", watermarkTimers) + .add("processingTimers", processingTime) + .add("inputWatermarkTime", inputWatermarkTime) + .add("outputWatermarkTime", outputWatermarkTime) + .add("processingTime", processingTime) + .toString(); + } + + public void advanceInputWatermark(Instant newInputWatermark) throws Exception { + checkNotNull(newInputWatermark); + checkState(inputWatermarkTime == null || !newInputWatermark.isBefore(inputWatermarkTime), + "Cannot move input watermark time backwards from %s to %s", inputWatermarkTime, + newInputWatermark); + WindowTracing.trace("TestTimerInternals.advanceInputWatermark: from {} to {}", + inputWatermarkTime, newInputWatermark); + inputWatermarkTime = newInputWatermark; + + Instant hold = stateInternals.earliestWatermarkHold(); + if (hold == null) { + WindowTracing.trace("TestTimerInternals.advanceInputWatermark: no holds, " + + "so output watermark = input watermark"); + hold = inputWatermarkTime; + } + advanceOutputWatermark(hold); + } + + private void advanceOutputWatermark(Instant newOutputWatermark) throws Exception { + checkNotNull(newOutputWatermark); + checkNotNull(inputWatermarkTime); + if (newOutputWatermark.isAfter(inputWatermarkTime)) { + WindowTracing.trace( + "TestTimerInternals.advanceOutputWatermark: clipping output watermark from {} to {}", + newOutputWatermark, inputWatermarkTime); + newOutputWatermark = inputWatermarkTime; + } + checkState(outputWatermarkTime == null || !newOutputWatermark.isBefore(outputWatermarkTime), + "Cannot move output watermark time backwards from %s to %s", outputWatermarkTime, + newOutputWatermark); + WindowTracing.trace("TestTimerInternals.advanceOutputWatermark: from {} to {}", + outputWatermarkTime, newOutputWatermark); + outputWatermarkTime = newOutputWatermark; + } + + public void advanceProcessingTime(Instant newProcessingTime) throws Exception { + checkState(!newProcessingTime.isBefore(processingTime), + "Cannot move processing time backwards from %s to %s", processingTime, newProcessingTime); + WindowTracing.trace("TestTimerInternals.advanceProcessingTime: from {} to {}", processingTime, + newProcessingTime); + processingTime = newProcessingTime; + } + + public void advanceSynchronizedProcessingTime(Instant newSynchronizedProcessingTime) + throws Exception { + checkState(!newSynchronizedProcessingTime.isBefore(synchronizedProcessingTime), + "Cannot move processing time backwards from %s to %s", synchronizedProcessingTime, + newSynchronizedProcessingTime); + WindowTracing.trace("TestTimerInternals.advanceProcessingTime: from {} to {}", + synchronizedProcessingTime, newSynchronizedProcessingTime); + synchronizedProcessingTime = newSynchronizedProcessingTime; + } + } + + private class TestTimers implements Timers { + private final StateNamespace namespace; + + public TestTimers(StateNamespace namespace) { + checkArgument(namespace instanceof WindowNamespace); + this.namespace = namespace; + } + + @Override + public void setTimer(Instant timestamp, TimeDomain timeDomain) { + timerInternals.setTimer(TimerData.of(namespace, timestamp, timeDomain)); + } + + @Override + public void deleteTimer(Instant timestamp, TimeDomain timeDomain) { + timerInternals.deleteTimer(TimerData.of(namespace, timestamp, timeDomain)); + } + + @Override + public Instant currentProcessingTime() { + return timerInternals.currentProcessingTime(); + } + + @Override + @Nullable + public Instant currentSynchronizedProcessingTime() { + return timerInternals.currentSynchronizedProcessingTime(); + } + + @Override + @Nullable + public Instant currentEventTime() { + return timerInternals.currentInputWatermarkTime(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UnownedInputStreamTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UnownedInputStreamTest.java new file mode 100644 index 000000000000..30da6ae7238b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UnownedInputStreamTest.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; + +/** Unit tests for {@link UnownedInputStream}. */ +@RunWith(JUnit4.class) +public class UnownedInputStreamTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + private ByteArrayInputStream bais; + private UnownedInputStream os; + + @Before + public void setup() { + bais = new ByteArrayInputStream(new byte[]{ 1, 2, 3 }); + os = new UnownedInputStream(bais); + } + + @Test + public void testHashCodeEqualsAndToString() throws Exception { + assertEquals(bais.hashCode(), os.hashCode()); + assertEquals("UnownedInputStream{in=" + bais + "}", os.toString()); + assertEquals(new UnownedInputStream(bais), os); + } + + @Test + public void testClosingThrows() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + expectedException.expectMessage("close()"); + os.close(); + } + + @Test + public void testMarkThrows() throws Exception { + assertFalse(os.markSupported()); + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + expectedException.expectMessage("mark()"); + os.mark(1); + } + + @Test + public void testResetThrows() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + expectedException.expectMessage("reset()"); + os.reset(); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UnownedOutputStreamTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UnownedOutputStreamTest.java new file mode 100644 index 000000000000..eea70fe6cb9c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UnownedOutputStreamTest.java @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.junit.Assert.assertEquals; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; + +/** Unit tests for {@link UnownedOutputStream}. */ +@RunWith(JUnit4.class) +public class UnownedOutputStreamTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + private ByteArrayOutputStream baos; + private UnownedOutputStream os; + + @Before + public void setup() { + baos = new ByteArrayOutputStream(); + os = new UnownedOutputStream(baos); + } + + @Test + public void testHashCodeEqualsAndToString() throws Exception { + assertEquals(baos.hashCode(), os.hashCode()); + assertEquals("UnownedOutputStream{out=" + baos + "}", os.toString()); + assertEquals(new UnownedOutputStream(baos), os); + } + + @Test + public void testClosingThrows() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Caller does not own the underlying"); + os.close(); + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UploadIdResponseInterceptorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UploadIdResponseInterceptorTest.java new file mode 100644 index 000000000000..698d0cb5a517 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UploadIdResponseInterceptorTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpResponse; + +import com.google.api.client.testing.http.HttpTesting; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +/** + * A test for {@link UploadIdResponseInterceptor}. + */ + +@RunWith(JUnit4.class) +public class UploadIdResponseInterceptorTest { + + @Rule public ExpectedException expectedException = ExpectedException.none(); + // Note that expected logs also turns on debug logging. + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(UploadIdResponseInterceptor.class); + + /** + * Builds a HttpResponse with the given string response. + * + * @param header header value to provide or null if none. + * @param uploadId upload id to provide in the url upload id param or null if none. + * @param uploadType upload type to provide in url upload type param or null if none. + * @return HttpResponse with the given parameters + * @throws IOException + */ + private HttpResponse buildHttpResponse(String header, String uploadId, String uploadType) + throws IOException { + MockHttpTransport.Builder builder = new MockHttpTransport.Builder(); + MockLowLevelHttpResponse resp = new MockLowLevelHttpResponse(); + builder.setLowLevelHttpResponse(resp); + resp.setStatusCode(200); + GenericUrl url = new GenericUrl(HttpTesting.SIMPLE_URL); + if (header != null) { + resp.addHeader("X-GUploader-UploadID", header); + } + if (uploadId != null) { + url.put("upload_id", uploadId); + } + if (uploadType != null) { + url.put("uploadType", uploadType); + } + return builder.build().createRequestFactory().buildGetRequest(url).execute(); + } + + /** + * Tests the responses that should not log. + */ + @Test + public void testResponseNoLogging() throws IOException { + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse(null, null, null)); + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse("hh", "a", null)); + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse(null, "h", null)); + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse("hh", null, null)); + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse(null, null, "type")); + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse("hh", "a", "type")); + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse(null, "h", "type")); + expectedLogs.verifyNotLogged(""); + } + + /** + * Check that a response logs with the correct log. + */ + @Test + public void testResponseLogs() throws IOException { + new UploadIdResponseInterceptor().interceptResponse(buildHttpResponse("abc", null, "type")); + GenericUrl url = new GenericUrl(HttpTesting.SIMPLE_URL); + url.put("uploadType", "type"); + String worker = System.getProperty("worker_id"); + expectedLogs.verifyDebug("Upload ID for url " + url + " on worker " + worker + " is abc"); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UserCodeExceptionTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UserCodeExceptionTest.java new file mode 100644 index 000000000000..5cf385cf31bd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/UserCodeExceptionTest.java @@ -0,0 +1,176 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.hamcrest.Description; +import org.hamcrest.FeatureMatcher; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +/** + * Tests for @{link UserCodeException} functionality. + */ +@RunWith(JUnit4.class) +public class UserCodeExceptionTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void existingUserCodeExceptionsNotWrapped() { + UserCodeException existing = UserCodeException.wrap(new IOException()); + UserCodeException wrapped = UserCodeException.wrap(existing); + + assertEquals(existing, wrapped); + } + + @Test + public void testCauseIsSet() { + thrown.expectCause(isA(IOException.class)); + throwUserCodeException(); + } + + @Test + public void testStackTraceIsTruncatedToUserCode() { + thrown.expectCause(hasBottomStackFrame(method("userCode"))); + throwUserCodeException(); + } + + @Test + public void testStackTraceIsTruncatedProperlyFromHelperMethod() { + thrown.expectCause(hasBottomStackFrame(method("userCode"))); + throwUserCodeExceptionFromHelper(); + } + + @Test + public void testWrapIfOnlyWrapsWhenTrue() { + IOException cause = new IOException(); + RuntimeException wrapped = UserCodeException.wrapIf(true, cause); + + assertThat(wrapped, is(instanceOf(UserCodeException.class))); + } + + @Test + public void testWrapIfReturnsRuntimeExceptionWhenFalse() { + IOException cause = new IOException(); + RuntimeException wrapped = UserCodeException.wrapIf(false, cause); + + assertThat(wrapped, is(not(instanceOf(UserCodeException.class)))); + assertEquals(cause, wrapped.getCause()); + } + + @Test + public void testWrapIfReturnsSourceRuntimeExceptionWhenFalse() { + RuntimeException runtimeException = new RuntimeException("oh noes!"); + RuntimeException wrapped = UserCodeException.wrapIf(false, runtimeException); + + assertEquals(runtimeException, wrapped); + } + + + private void throwUserCodeException() { + try { + userCode(); + } catch (Exception ex) { + throw UserCodeException.wrap(ex); + } + } + + private void throwUserCodeExceptionFromHelper() { + try { + userCode(); + } catch (Exception ex) { + throw wrap(ex); + } + } + + private UserCodeException wrap(Throwable t) { + throw UserCodeException.wrap(t); + } + + private void userCode() throws IOException { + userCode2(); + } + + private void userCode2() throws IOException { + userCode3(); + } + + private void userCode3() throws IOException { + IOException ex = new IOException("User processing error!"); + throw ex; + } + + private static ThrowableBottomStackFrameMethodMatcher hasBottomStackFrame( + Matcher frameMatcher) { + return new ThrowableBottomStackFrameMethodMatcher(frameMatcher); + } + + private static StackFrameMethodMatcher method(String methodName) { + return new StackFrameMethodMatcher(is(methodName)); + } + + static class ThrowableBottomStackFrameMethodMatcher + extends FeatureMatcher { + + public ThrowableBottomStackFrameMethodMatcher(Matcher subMatcher) { + super(subMatcher, "Throwable with bottom stack frame:", "stack frame"); + } + + @Override + protected StackTraceElement featureValueOf(Throwable actual) { + StackTraceElement[] stackTrace = actual.getStackTrace(); + return stackTrace[stackTrace.length - 1]; + } + } + + static class StackFrameMethodMatcher extends TypeSafeMatcher { + + private Matcher methodNameMatcher; + + public StackFrameMethodMatcher(Matcher methodNameMatcher) { + this.methodNameMatcher = methodNameMatcher; + } + + @Override + public void describeTo(Description description) { + description.appendText("stack frame where method name "); + methodNameMatcher.describeTo(description); + } + + @Override + protected boolean matchesSafely(StackTraceElement item) { + return methodNameMatcher.matches(item.getMethodName()); + } + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/VarIntTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/VarIntTest.java new file mode 100644 index 000000000000..ca233ed12440 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/VarIntTest.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.IOException; + +/** Unit tests for {@link VarInt}. */ +@RunWith(JUnit4.class) +public class VarIntTest { + @Rule public final ExpectedException thrown = ExpectedException.none(); + + // Long values to check for boundary cases. + private static final long[] LONG_VALUES = { + 0, + 1, + 127, + 128, + 16383, + 16384, + 2097151, + 2097152, + 268435455, + 268435456, + 34359738367L, + 34359738368L, + 9223372036854775807L, + -9223372036854775808L, + -1, + }; + + // VarInt encoding of the above VALUES. + private static final byte[][] LONG_ENCODED = { + // 0 + { 0x00 }, + // 1 + { 0x01 }, + // 127 + { 0x7f }, + // 128 + { (byte) 0x80, 0x01 }, + // 16383 + { (byte) 0xff, 0x7f }, + // 16834 + { (byte) 0x80, (byte) 0x80, 0x01 }, + // 2097151 + { (byte) 0xff, (byte) 0xff, 0x7f }, + // 2097152 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x01 }, + // 268435455 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x7f }, + // 268435456 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x01 }, + // 34359738367 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x7f }, + // 34359738368 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, + 0x01 }, + // 9223372036854775807 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, + (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x7f }, + // -9223372036854775808L + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, + (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x01 }, + // -1 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, + (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x01 } + }; + + // Integer values to check for boundary cases. + private static final int[] INT_VALUES = { + 0, + 1, + 127, + 128, + 16383, + 16384, + 2097151, + 2097152, + 268435455, + 268435456, + 2147483647, + -2147483648, + -1, + }; + + // VarInt encoding of the above VALUES. + private static final byte[][] INT_ENCODED = { + // 0 + { (byte) 0x00 }, + // 1 + { (byte) 0x01 }, + // 127 + { (byte) 0x7f }, + // 128 + { (byte) 0x80, (byte) 0x01 }, + // 16383 + { (byte) 0xff, (byte) 0x7f }, + // 16834 + { (byte) 0x80, (byte) 0x80, (byte) 0x01 }, + // 2097151 + { (byte) 0xff, (byte) 0xff, (byte) 0x7f }, + // 2097152 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x01 }, + // 268435455 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x7f }, + // 268435456 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x01 }, + // 2147483647 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x07 }, + // -2147483648 + { (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x08 }, + // -1 + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0x0f } + }; + + private static byte[] encodeInt(int v) throws IOException { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + VarInt.encode(v, stream); + return stream.toByteArray(); + } + + private static byte[] encodeLong(long v) throws IOException { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + VarInt.encode(v, stream); + return stream.toByteArray(); + } + + private static int decodeInt(byte[] encoded) throws IOException { + ByteArrayInputStream stream = new ByteArrayInputStream(encoded); + return VarInt.decodeInt(stream); + } + + private static long decodeLong(byte[] encoded) throws IOException { + ByteArrayInputStream stream = new ByteArrayInputStream(encoded); + return VarInt.decodeLong(stream); + } + + @Test + public void decodeValues() throws IOException { + assertEquals(LONG_VALUES.length, LONG_ENCODED.length); + for (int i = 0; i < LONG_ENCODED.length; ++i) { + ByteArrayInputStream stream = new ByteArrayInputStream(LONG_ENCODED[i]); + long parsed = VarInt.decodeLong(stream); + assertEquals(LONG_VALUES[i], parsed); + assertEquals(-1, stream.read()); + } + + assertEquals(INT_VALUES.length, INT_ENCODED.length); + for (int i = 0; i < INT_ENCODED.length; ++i) { + ByteArrayInputStream stream = new ByteArrayInputStream(INT_ENCODED[i]); + int parsed = VarInt.decodeInt(stream); + assertEquals(INT_VALUES[i], parsed); + assertEquals(-1, stream.read()); + } + } + + @Test + public void encodeValuesAndGetLength() throws IOException { + assertEquals(LONG_VALUES.length, LONG_ENCODED.length); + for (int i = 0; i < LONG_VALUES.length; ++i) { + byte[] encoded = encodeLong(LONG_VALUES[i]); + assertThat(encoded, equalTo(LONG_ENCODED[i])); + assertEquals(LONG_ENCODED[i].length, VarInt.getLength(LONG_VALUES[i])); + } + + assertEquals(INT_VALUES.length, INT_ENCODED.length); + for (int i = 0; i < INT_VALUES.length; ++i) { + byte[] encoded = encodeInt(INT_VALUES[i]); + assertThat(encoded, equalTo(INT_ENCODED[i])); + assertEquals(INT_ENCODED[i].length, VarInt.getLength(INT_VALUES[i])); + } + } + + @Test + public void decodeThrowsExceptionForOverflow() throws IOException { + final byte[] tooLargeNumber = + { (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, + (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x02 }; + + thrown.expect(IOException.class); + decodeLong(tooLargeNumber); + } + + @Test + public void decodeThrowsExceptionForIntOverflow() throws IOException { + byte[] encoded = encodeLong(1L << 32); + + thrown.expect(IOException.class); + decodeInt(encoded); + } + + @Test + public void decodeThrowsExceptionForIntUnderflow() throws IOException { + byte[] encoded = encodeLong(-1); + + thrown.expect(IOException.class); + decodeInt(encoded); + } + + @Test + public void decodeThrowsExceptionForNonterminated() throws IOException { + final byte[] nonTerminatedNumber = + { (byte) 0xff, (byte) 0xff }; + + thrown.expect(IOException.class); + decodeLong(nonTerminatedNumber); + } + + @Test + public void decodeParsesEncodedValues() throws IOException { + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + for (int i = 10; i < Integer.MAX_VALUE; i = (int) (i * 1.1)) { + VarInt.encode(i, outStream); + VarInt.encode(-i, outStream); + } + for (long i = 10; i < Long.MAX_VALUE; i = (long) (i * 1.1)) { + VarInt.encode(i, outStream); + VarInt.encode(-i, outStream); + } + + ByteArrayInputStream inStream = + new ByteArrayInputStream(outStream.toByteArray()); + for (int i = 10; i < Integer.MAX_VALUE; i = (int) (i * 1.1)) { + assertEquals(i, VarInt.decodeInt(inStream)); + assertEquals(-i, VarInt.decodeInt(inStream)); + } + for (long i = 10; i < Long.MAX_VALUE; i = (long) (i * 1.1)) { + assertEquals(i, VarInt.decodeLong(inStream)); + assertEquals(-i, VarInt.decodeLong(inStream)); + } + } + + @Test + public void endOfFileThrowsException() throws Exception { + ByteArrayInputStream inStream = + new ByteArrayInputStream(new byte[0]); + thrown.expect(EOFException.class); + VarInt.decodeInt(inStream); + } + + @Test + public void unterminatedThrowsException() throws Exception { + byte[] e = encodeLong(Long.MAX_VALUE); + byte[] s = new byte[1]; + s[0] = e[0]; + ByteArrayInputStream inStream = new ByteArrayInputStream(s); + thrown.expect(IOException.class); + VarInt.decodeInt(inStream); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/WindowedValueTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/WindowedValueTest.java new file mode 100644 index 000000000000..01b31ad6e6ba --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/WindowedValueTest.java @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; + +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** Test case for {@link WindowedValue}. */ +@RunWith(JUnit4.class) +public class WindowedValueTest { + @Test + public void testWindowedValueCoder() throws CoderException { + Instant timestamp = new Instant(1234); + WindowedValue value = WindowedValue.of( + "abc", + new Instant(1234), + Arrays.asList(new IntervalWindow(timestamp, timestamp.plus(1000)), + new IntervalWindow(timestamp.plus(1000), timestamp.plus(2000))), + PaneInfo.NO_FIRING); + + Coder> windowedValueCoder = + WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder()); + + byte[] encodedValue = CoderUtils.encodeToByteArray(windowedValueCoder, value); + WindowedValue decodedValue = + CoderUtils.decodeFromByteArray(windowedValueCoder, encodedValue); + + Assert.assertEquals(value.getValue(), decodedValue.getValue()); + Assert.assertEquals(value.getTimestamp(), decodedValue.getTimestamp()); + Assert.assertArrayEquals(value.getWindows().toArray(), decodedValue.getWindows().toArray()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ZipFilesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ZipFilesTest.java new file mode 100644 index 000000000000..8c079a6bd50b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ZipFilesTest.java @@ -0,0 +1,311 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util; + +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.io.ByteSource; +import com.google.common.io.CharSource; +import com.google.common.io.Files; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; + +/** + * Tests for the {@link ZipFiles} class. These tests make sure that the handling + * of zip-files works fine. + */ +@RunWith(JUnit4.class) +public class ZipFilesTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + private File tmpDir; + + @Rule + public TemporaryFolder tmpOutputFolder = new TemporaryFolder(); + private File zipFile; + + @Before + public void setUp() throws Exception { + tmpDir = tmpFolder.getRoot(); + zipFile = createZipFileHandle(); // the file is not actually created + } + + /** + * Verify that zipping and unzipping works fine. We zip a directory having + * some subdirectories, unzip it again and verify the structure to be in + * place. + */ + @Test + public void testZipWithSubdirectories() throws Exception { + File zipDir = new File(tmpDir, "zip"); + File subDir1 = new File(zipDir, "subDir1"); + File subDir2 = new File(subDir1, "subdir2"); + assertTrue(subDir2.mkdirs()); + createFileWithContents(subDir2, "myTextFile.txt", "Simple Text"); + + assertZipAndUnzipOfDirectoryMatchesOriginal(tmpDir); + } + + /** + * An empty subdirectory must have its own zip-entry. + */ + @Test + public void testEmptySubdirectoryHasZipEntry() throws Exception { + File zipDir = new File(tmpDir, "zip"); + File subDirEmpty = new File(zipDir, "subDirEmpty"); + assertTrue(subDirEmpty.mkdirs()); + + ZipFiles.zipDirectory(tmpDir, zipFile); + assertZipOnlyContains("zip/subDirEmpty/"); + } + + /** + * A directory with contents should not have a zip entry. + */ + @Test + public void testSubdirectoryWithContentsHasNoZipEntry() throws Exception { + File zipDir = new File(tmpDir, "zip"); + File subDirContent = new File(zipDir, "subdirContent"); + assertTrue(subDirContent.mkdirs()); + createFileWithContents(subDirContent, "myTextFile.txt", "Simple Text"); + + ZipFiles.zipDirectory(tmpDir, zipFile); + assertZipOnlyContains("zip/subdirContent/myTextFile.txt"); + } + + @Test + public void testZipDirectoryToOutputStream() throws Exception { + createFileWithContents(tmpDir, "myTextFile.txt", "Simple Text"); + File[] sourceFiles = tmpDir.listFiles(); + Arrays.sort(sourceFiles); + assertThat(sourceFiles, not(arrayWithSize(0))); + + try (FileOutputStream outputStream = new FileOutputStream(zipFile)) { + ZipFiles.zipDirectory(tmpDir, outputStream); + } + File outputDir = Files.createTempDir(); + ZipFiles.unzipFile(zipFile, outputDir); + File[] outputFiles = outputDir.listFiles(); + Arrays.sort(outputFiles); + + assertThat(outputFiles, arrayWithSize(sourceFiles.length)); + for (int i = 0; i < sourceFiles.length; i++) { + compareFileContents(sourceFiles[i], outputFiles[i]); + } + + removeRecursive(outputDir.toPath()); + assertTrue(zipFile.delete()); + } + + @Test + public void testEntries() throws Exception { + File zipDir = new File(tmpDir, "zip"); + File subDir1 = new File(zipDir, "subDir1"); + File subDir2 = new File(subDir1, "subdir2"); + assertTrue(subDir2.mkdirs()); + createFileWithContents(subDir2, "myTextFile.txt", "Simple Text"); + + ZipFiles.zipDirectory(tmpDir, zipFile); + + ZipFile zip = new ZipFile(zipFile); + try { + Enumeration entries = zip.entries(); + for (ZipEntry entry : ZipFiles.entries(zip)) { + assertTrue(entries.hasMoreElements()); + // ZipEntry doesn't override equals + assertEquals(entry.getName(), entries.nextElement().getName()); + } + assertFalse(entries.hasMoreElements()); + } finally { + zip.close(); + } + } + + @Test + public void testAsByteSource() throws Exception { + File zipDir = new File(tmpDir, "zip"); + assertTrue(zipDir.mkdirs()); + createFileWithContents(zipDir, "myTextFile.txt", "Simple Text"); + + ZipFiles.zipDirectory(tmpDir, zipFile); + + try (ZipFile zip = new ZipFile(zipFile)) { + ZipEntry entry = zip.getEntry("zip/myTextFile.txt"); + ByteSource byteSource = ZipFiles.asByteSource(zip, entry); + if (entry.getSize() != -1) { + assertEquals(entry.getSize(), byteSource.size()); + } + assertArrayEquals("Simple Text".getBytes(StandardCharsets.UTF_8), byteSource.read()); + } + } + + @Test + public void testAsCharSource() throws Exception { + File zipDir = new File(tmpDir, "zip"); + assertTrue(zipDir.mkdirs()); + createFileWithContents(zipDir, "myTextFile.txt", "Simple Text"); + + ZipFiles.zipDirectory(tmpDir, zipFile); + + try (ZipFile zip = new ZipFile(zipFile)) { + ZipEntry entry = zip.getEntry("zip/myTextFile.txt"); + CharSource charSource = ZipFiles.asCharSource(zip, entry, StandardCharsets.UTF_8); + assertEquals("Simple Text", charSource.read()); + } + } + + private void assertZipOnlyContains(String zipFileEntry) throws IOException { + try (ZipFile zippedFile = new ZipFile(zipFile)) { + assertEquals(1, zippedFile.size()); + ZipEntry entry = zippedFile.entries().nextElement(); + assertEquals(zipFileEntry, entry.getName()); + } + } + + /** + * try to unzip to a non-existent directory and make sure that it fails. + */ + @Test + public void testInvalidTargetDirectory() throws IOException { + File zipDir = new File(tmpDir, "zipdir"); + assertTrue(zipDir.mkdir()); + ZipFiles.zipDirectory(tmpDir, zipFile); + File invalidDirectory = new File("/foo/bar"); + assertTrue(!invalidDirectory.exists()); + try { + ZipFiles.unzipFile(zipFile, invalidDirectory); + fail("We expect the IllegalArgumentException, but it never occured"); + } catch (IllegalArgumentException e) { + // This is the expected exception - we passed the test. + } + } + + /** + * Try to unzip to an existing directory, but failing to create directories. + */ + @Test + public void testDirectoryCreateFailed() throws IOException { + File zipDir = new File(tmpDir, "zipdir"); + assertTrue(zipDir.mkdir()); + ZipFiles.zipDirectory(tmpDir, zipFile); + File targetDirectory = Files.createTempDir(); + // Touch a file where the directory should be. + Files.touch(new File(targetDirectory, "zipdir")); + try { + ZipFiles.unzipFile(zipFile, targetDirectory); + fail("We expect the IOException, but it never occured"); + } catch (IOException e) { + // This is the expected exception - we passed the test. + } + } + + /** + * zip and unzip a certain directory, and verify the content afterward to be + * identical. + * @param sourceDir the directory to zip + */ + private void assertZipAndUnzipOfDirectoryMatchesOriginal(File sourceDir) throws IOException { + File[] sourceFiles = sourceDir.listFiles(); + Arrays.sort(sourceFiles); + + File zipFile = createZipFileHandle(); + ZipFiles.zipDirectory(sourceDir, zipFile); + File outputDir = Files.createTempDir(); + ZipFiles.unzipFile(zipFile, outputDir); + File[] outputFiles = outputDir.listFiles(); + Arrays.sort(outputFiles); + + assertThat(outputFiles, arrayWithSize(sourceFiles.length)); + for (int i = 0; i < sourceFiles.length; i++) { + compareFileContents(sourceFiles[i], outputFiles[i]); + } + + removeRecursive(outputDir.toPath()); + assertTrue(zipFile.delete()); + } + + /** + * Compare the content of two files or directories recursively. + * @param expected the expected directory or file content + * @param actual the actual directory or file content + */ + private void compareFileContents(File expected, File actual) throws IOException { + assertEquals(expected.isDirectory(), actual.isDirectory()); + assertEquals(expected.getName(), actual.getName()); + if (expected.isDirectory()) { + // Go through the children step by step. + File[] expectedChildren = expected.listFiles(); + Arrays.sort(expectedChildren); + File[] actualChildren = actual.listFiles(); + Arrays.sort(actualChildren); + assertThat(actualChildren, arrayWithSize(expectedChildren.length)); + for (int i = 0; i < expectedChildren.length; i++) { + compareFileContents(expectedChildren[i], actualChildren[i]); + } + } else { + // Compare the file content itself. + assertTrue(Files.equal(expected, actual)); + } + } + + /** + * Create a File object to which we can safely zip a file. + */ + private File createZipFileHandle() throws IOException { + File zipFile = File.createTempFile("test", "zip", tmpOutputFolder.getRoot()); + assertTrue(zipFile.delete()); + return zipFile; + } + + // This is not generally safe as it does not handle symlinks, etc. However it is safe + // enough for these tests. + private static void removeRecursive(Path path) throws IOException { + Iterable files = Files.fileTreeTraverser().postOrderTraversal(path.toFile()); + for (File f : files) { + java.nio.file.Files.delete(f.toPath()); + } + } + + /** Create file dir/fileName with contents fileContents. */ + private void createFileWithContents(File dir, String fileName, String fileContents) + throws IOException { + File txtFile = new File(dir, fileName); + Files.asCharSink(txtFile, StandardCharsets.UTF_8).write(fileContents); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterSetTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterSetTest.java new file mode 100644 index 000000000000..55ed714ad7e1 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterSetTest.java @@ -0,0 +1,225 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link CounterSet}. + */ +@RunWith(JUnit4.class) +public class CounterSetTest { + private CounterSet set; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Before + public void setup() { + set = new CounterSet(); + } + + @Test + public void testAddWithDifferentNamesAddsAll() { + Counter c1 = Counter.longs("c1", SUM); + Counter c2 = Counter.ints("c2", MAX); + + boolean c1Add = set.add(c1); + boolean c2Add = set.add(c2); + + assertTrue(c1Add); + assertTrue(c2Add); + assertThat(set, containsInAnyOrder(c1, c2)); + } + + @Test + public void testAddWithAlreadyPresentNameReturnsFalse() { + Counter c1 = Counter.longs("c1", SUM); + Counter c1Dup = Counter.longs("c1", SUM); + + boolean c1Add = set.add(c1); + boolean c1DupAdd = set.add(c1Dup); + + assertTrue(c1Add); + assertFalse(c1DupAdd); + assertThat(set, containsInAnyOrder((Counter) c1)); + } + + @Test + public void testAddOrReuseWithAlreadyPresentReturnsPresent() { + Counter c1 = Counter.longs("c1", SUM); + Counter c1Dup = Counter.longs("c1", SUM); + + Counter c1AddResult = set.addOrReuseCounter(c1); + Counter c1DupAddResult = set.addOrReuseCounter(c1Dup); + + assertSame(c1, c1AddResult); + assertSame(c1AddResult, c1DupAddResult); + assertThat(set, containsInAnyOrder((Counter) c1)); + } + + @Test + public void testAddOrReuseWithNoCounterReturnsProvided() { + Counter c1 = Counter.longs("c1", SUM); + + Counter c1AddResult = set.addOrReuseCounter(c1); + + assertSame(c1, c1AddResult); + assertThat(set, containsInAnyOrder((Counter) c1)); + } + + @Test + public void testAddOrReuseWithIncompatibleTypesThrowsException() { + Counter c1 = Counter.longs("c1", SUM); + Counter c1Incompatible = Counter.ints("c1", MAX); + + set.addOrReuseCounter(c1); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Counter " + c1Incompatible + + " duplicates incompatible counter " + c1 + " in " + set); + + set.addOrReuseCounter(c1Incompatible); + } + + @Test + public void testMergeWithDifferentNamesAddsAll() { + Counter c1 = Counter.longs("c1", SUM); + Counter c2 = Counter.ints("c2", MAX); + + set.add(c1); + set.add(c2); + + CounterSet newSet = new CounterSet(); + newSet.merge(set); + + assertThat(newSet, containsInAnyOrder(c1, c2)); + } + + @SuppressWarnings("unchecked") + @Test + public void testMergeWithSameNamesMerges() { + Counter c1 = Counter.longs("c1", SUM); + Counter c2 = Counter.ints("c2", MAX); + + set.add(c1); + set.add(c2); + + c1.addValue(3L); + c2.addValue(22); + + CounterSet newSet = new CounterSet(); + Counter c1Prime = Counter.longs("c1", SUM); + Counter c2Prime = Counter.ints("c2", MAX); + + c1Prime.addValue(7L); + c2Prime.addValue(14); + + newSet.add(c1Prime); + newSet.add(c2Prime); + + newSet.merge(set); + + assertThat((Counter) newSet.getExistingCounter("c1"), equalTo(c1Prime)); + assertThat((Long) newSet.getExistingCounter("c1").getAggregate(), equalTo(10L)); + + assertThat((Counter) newSet.getExistingCounter("c2"), equalTo(c2Prime)); + assertThat((Integer) newSet.getExistingCounter("c2").getAggregate(), equalTo(22)); + } + + @SuppressWarnings("unchecked") + @Test + public void testMergeWithIncompatibleTypesThrowsException() { + Counter c1 = Counter.longs("c1", SUM); + + set.add(c1); + + CounterSet newSet = new CounterSet(); + Counter c1Prime = Counter.longs("c1", MAX); + + newSet.add(c1Prime); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("c1"); + thrown.expectMessage("incompatible counters with the same name"); + + newSet.merge(set); + } + + @Test + public void testAddCounterMutatorAddCounterAddsCounter() { + Counter c1 = Counter.longs("c1", SUM); + + Counter addC1Result = set.getAddCounterMutator().addCounter(c1); + + assertSame(c1, addC1Result); + assertThat(set, containsInAnyOrder((Counter) c1)); + } + + @Test + public void testAddCounterMutatorAddEqualCounterReusesCounter() { + Counter c1 = Counter.longs("c1", SUM); + Counter c1dup = Counter.longs("c1", SUM); + + Counter addC1Result = set.getAddCounterMutator().addCounter(c1); + Counter addC1DupResult = set.getAddCounterMutator().addCounter(c1dup); + + assertThat(set, containsInAnyOrder((Counter) c1)); + assertSame(c1, addC1Result); + assertSame(c1, addC1DupResult); + } + + @Test + public void testAddCounterMutatorIncompatibleTypesThrowsException() { + Counter c1 = Counter.longs("c1", SUM); + Counter c1Incompatible = Counter.longs("c1", MAX); + + set.getAddCounterMutator().addCounter(c1); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Counter " + c1Incompatible + + " duplicates incompatible counter " + c1 + " in " + set); + + set.getAddCounterMutator().addCounter(c1Incompatible); + } + + @Test + public void testAddCounterMutatorAddMultipleCounters() { + Counter c1 = Counter.longs("c1", SUM); + Counter c2 = Counter.longs("c2", MAX); + + set.getAddCounterMutator().addCounter(c1); + set.getAddCounterMutator().addCounter(c2); + + assertThat(set, containsInAnyOrder(c1, c2)); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTest.java new file mode 100644 index 000000000000..619f52344552 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTest.java @@ -0,0 +1,589 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.Values.asDouble; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.AND; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MAX; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MIN; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.OR; +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.SUM; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.util.common.Counter.CounterMean; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for the {@link Counter} API. + */ +@RunWith(JUnit4.class) +public class CounterTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private static void flush(Counter c) { + switch (c.getKind()) { + case SUM: + case MAX: + case MIN: + case AND: + case OR: + c.getAndResetDelta(); + break; + case MEAN: + c.getAndResetMeanDelta(); + break; + default: + throw new IllegalArgumentException("Unknown counter kind " + c.getKind()); + } + } + + private static final double EPSILON = 0.00000000001; + + @Test + public void testCompatibility() { + // Equal counters are compatible, of all kinds. + assertTrue( + Counter.longs("c", SUM).isCompatibleWith(Counter.longs("c", SUM))); + assertTrue( + Counter.ints("c", SUM).isCompatibleWith(Counter.ints("c", SUM))); + assertTrue( + Counter.doubles("c", SUM).isCompatibleWith(Counter.doubles("c", SUM))); + assertTrue( + Counter.booleans("c", OR).isCompatibleWith( + Counter.booleans("c", OR))); + + // The name, kind, and type of the counter must match. + assertFalse( + Counter.longs("c", SUM).isCompatibleWith(Counter.longs("c2", SUM))); + assertFalse( + Counter.longs("c", SUM).isCompatibleWith(Counter.longs("c", MAX))); + assertFalse( + Counter.longs("c", SUM).isCompatibleWith(Counter.ints("c", SUM))); + + // The value of the counters are ignored. + assertTrue( + Counter.longs("c", SUM).resetToValue(666L).isCompatibleWith( + Counter.longs("c", SUM).resetToValue(42L))); + } + + + private void assertOK(long total, long delta, Counter c) { + assertEquals(total, c.getAggregate().longValue()); + assertEquals(delta, c.getAndResetDelta().longValue()); + } + + private void assertOK(double total, double delta, Counter c) { + assertEquals(total, asDouble(c.getAggregate()), EPSILON); + assertEquals(delta, asDouble(c.getAndResetDelta()), EPSILON); + } + + + // Tests for SUM. + + @Test + public void testSumLong() { + Counter c = Counter.longs("sum-long", SUM); + long expectedTotal = 0; + long expectedDelta = 0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(0L); + expectedTotal += 55; + expectedDelta += 55; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = 174; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = 0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(15L).addValue(42L); + expectedTotal += 57; + expectedDelta += 57; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(17L).addValue(49L); + expectedTotal = expectedDelta = 166; + assertOK(expectedTotal, expectedDelta, c); + + Counter other = Counter.longs("sum-long", SUM); + other.addValue(12L); + expectedDelta = 12L; + expectedTotal += 12L; + c.merge(other); + assertOK(expectedTotal, expectedDelta, c); + } + + @Test + public void testSumDouble() { + Counter c = Counter.doubles("sum-double", SUM); + double expectedTotal = 0.0; + double expectedDelta = 0.0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(0.0); + expectedTotal += Math.E + Math.PI; + expectedDelta += Math.E + Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(2)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = expectedDelta = Math.sqrt(2) + 2 * Math.PI + 3 * Math.E; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = 0.0; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedTotal += 7 * Math.PI + 5 * Math.E; + expectedDelta += 7 * Math.PI + 5 * Math.E; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(17.0).addValue(49.0); + expectedTotal = expectedDelta = Math.sqrt(17.0) + 17.0 + 49.0; + assertOK(expectedTotal, expectedDelta, c); + + Counter other = Counter.doubles("sum-double", SUM); + other.addValue(12 * Math.PI); + expectedDelta = 12 * Math.PI; + expectedTotal += 12 * Math.PI; + c.merge(other); + assertOK(expectedTotal, expectedDelta, c); + } + + + // Tests for MAX. + + @Test + public void testMaxLong() { + Counter c = Counter.longs("max-long", MAX); + long expectedTotal = Long.MIN_VALUE; + long expectedDelta = Long.MIN_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(Long.MIN_VALUE); + expectedTotal = expectedDelta = 42; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = 120; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Long.MIN_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(42L).addValue(15L); + expectedDelta = 42; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(171L).addValue(49L); + expectedTotal = expectedDelta = 171; + assertOK(expectedTotal, expectedDelta, c); + + Counter other = Counter.longs("max-long", MAX); + other.addValue(12L); + expectedDelta = 12L; + c.merge(other); + assertOK(expectedTotal, expectedDelta, c); + } + + @Test + public void testMaxDouble() { + Counter c = Counter.doubles("max-double", MAX); + double expectedTotal = Double.NEGATIVE_INFINITY; + double expectedDelta = Double.NEGATIVE_INFINITY; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(Double.NEGATIVE_INFINITY); + expectedTotal = expectedDelta = Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(12345)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = expectedDelta = Math.sqrt(12345); + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Double.NEGATIVE_INFINITY; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedDelta = 7 * Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(171.0).addValue(49.0); + expectedTotal = expectedDelta = 171.0; + assertOK(expectedTotal, expectedDelta, c); + + Counter other = Counter.doubles("max-double", MAX); + other.addValue(12 * Math.PI); + expectedDelta = 12 * Math.PI; + c.merge(other); + assertOK(expectedTotal, expectedDelta, c); + } + + + // Tests for MIN. + + @Test + public void testMinLong() { + Counter c = Counter.longs("min-long", MIN); + long expectedTotal = Long.MAX_VALUE; + long expectedDelta = Long.MAX_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(13L).addValue(42L).addValue(Long.MAX_VALUE); + expectedTotal = expectedDelta = 13; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(120L).addValue(17L).addValue(37L); + expectedTotal = expectedDelta = 17; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Long.MAX_VALUE; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(42L).addValue(18L); + expectedDelta = 18; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(100L).addValue(171L).addValue(49L); + expectedTotal = expectedDelta = 49; + assertOK(expectedTotal, expectedDelta, c); + + Counter other = Counter.longs("min-long", MIN); + other.addValue(42L); + expectedTotal = expectedDelta = 42L; + c.merge(other); + assertOK(expectedTotal, expectedDelta, c); + } + + @Test + public void testMinDouble() { + Counter c = Counter.doubles("min-double", MIN); + double expectedTotal = Double.POSITIVE_INFINITY; + double expectedDelta = Double.POSITIVE_INFINITY; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(Double.POSITIVE_INFINITY); + expectedTotal = expectedDelta = Math.E; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(12345)).addValue(2 * Math.PI).addValue(3 * Math.E); + expectedTotal = expectedDelta = 2 * Math.PI; + assertOK(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = Double.POSITIVE_INFINITY; + assertOK(expectedTotal, expectedDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expectedDelta = 5 * Math.E; + assertOK(expectedTotal, expectedDelta, c); + + c.resetToValue(Math.sqrt(17)).addValue(171.0).addValue(0.0); + expectedTotal = expectedDelta = 0.0; + assertOK(expectedTotal, expectedDelta, c); + + Counter other = Counter.doubles("min-double", MIN); + other.addValue(42 * Math.E); + expectedDelta = 42 * Math.E; + c.merge(other); + assertOK(expectedTotal, expectedDelta, c); + } + + + // Tests for MEAN. + + private void assertMean(long s, long sd, long c, long cd, Counter cn) { + CounterMean mean = cn.getMean(); + CounterMean deltaMean = cn.getAndResetMeanDelta(); + assertEquals(s, mean.getAggregate().longValue()); + assertEquals(sd, deltaMean.getAggregate().longValue()); + assertEquals(c, mean.getCount()); + assertEquals(cd, deltaMean.getCount()); + } + + private void assertMean(double s, double sd, long c, long cd, + Counter cn) { + CounterMean mean = cn.getMean(); + CounterMean deltaMean = cn.getAndResetMeanDelta(); + assertEquals(s, mean.getAggregate().doubleValue(), EPSILON); + assertEquals(sd, deltaMean.getAggregate().doubleValue(), EPSILON); + assertEquals(c, mean.getCount()); + assertEquals(cd, deltaMean.getCount()); + } + + @Test + public void testMeanLong() { + Counter c = Counter.longs("mean-long", MEAN); + long expTotal = 0; + long expDelta = 0; + long expCountTotal = 0; + long expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(13L).addValue(42L).addValue(0L); + expTotal += 55; + expDelta += 55; + expCountTotal += 3; + expCountDelta += 3; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetMeanToValue(1L, 120L).addValue(17L).addValue(37L); + expTotal = expDelta = 174; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + flush(c); + expDelta = 0; + expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(15L).addValue(42L); + expTotal += 57; + expDelta += 57; + expCountTotal += 2; + expCountDelta += 2; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetMeanToValue(3L, 100L).addValue(17L).addValue(49L); + expTotal = expDelta = 166; + expCountTotal = expCountDelta = 5; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + Counter other = Counter.longs("mean-long", MEAN); + other.addValue(12L).addValue(44L).addValue(-5L); + expTotal += 12L + 44L - 5L; + expDelta += 12L + 44L - 5L; + expCountTotal += 3; + expCountDelta += 3; + c.merge(other); + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + } + + @Test + public void testMeanDouble() { + Counter c = Counter.doubles("mean-double", MEAN); + double expTotal = 0.0; + double expDelta = 0.0; + long expCountTotal = 0; + long expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(Math.E).addValue(Math.PI).addValue(0.0); + expTotal += Math.E + Math.PI; + expDelta += Math.E + Math.PI; + expCountTotal += 3; + expCountDelta += 3; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetMeanToValue(1L, Math.sqrt(2)).addValue(2 * Math.PI) + .addValue(3 * Math.E); + expTotal = expDelta = Math.sqrt(2) + 2 * Math.PI + 3 * Math.E; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + flush(c); + expDelta = 0.0; + expCountDelta = 0; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.addValue(7 * Math.PI).addValue(5 * Math.E); + expTotal += 7 * Math.PI + 5 * Math.E; + expDelta += 7 * Math.PI + 5 * Math.E; + expCountTotal += 2; + expCountDelta += 2; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + c.resetMeanToValue(3L, Math.sqrt(17)).addValue(17.0).addValue(49.0); + expTotal = expDelta = Math.sqrt(17.0) + 17.0 + 49.0; + expCountTotal = expCountDelta = 5; + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + + Counter other = Counter.doubles("mean-double", MEAN); + other.addValue(3 * Math.PI).addValue(12 * Math.E); + expTotal += 3 * Math.PI + 12 * Math.E; + expDelta += 3 * Math.PI + 12 * Math.E; + expCountTotal += 2; + expCountDelta += 2; + c.merge(other); + assertMean(expTotal, expDelta, expCountTotal, expCountDelta, c); + } + + + // Test for AND and OR. + private void assertBool(boolean total, boolean delta, Counter c) { + assertEquals(total, c.getAggregate().booleanValue()); + assertEquals(delta, c.getAndResetDelta().booleanValue()); + } + + @Test + public void testBoolAnd() { + Counter c = Counter.booleans("bool-and", AND); + boolean expectedTotal = true; + boolean expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + expectedTotal = expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.resetToValue(true).addValue(true); + expectedTotal = expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + expectedTotal = expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + } + + @Test + public void testBoolOr() { + Counter c = Counter.booleans("bool-or", OR); + boolean expectedTotal = false; + boolean expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(false); + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + expectedTotal = expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + c.resetToValue(false).addValue(false); + expectedTotal = expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + expectedTotal = expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + + flush(c); + expectedDelta = false; + assertBool(expectedTotal, expectedDelta, c); + + c.addValue(true); + expectedDelta = true; + assertBool(expectedTotal, expectedDelta, c); + } + + // Incompatibility tests. + + @Test(expected = IllegalArgumentException.class) + public void testSumBool() { + Counter.booleans("counter", SUM); + } + + @Test(expected = IllegalArgumentException.class) + public void testMinBool() { + Counter.booleans("counter", MIN); + } + + @Test(expected = IllegalArgumentException.class) + public void testMaxBool() { + Counter.booleans("counter", MAX); + } + + @Test(expected = IllegalArgumentException.class) + public void testMeanBool() { + Counter.booleans("counter", MEAN); + } + + @Test(expected = IllegalArgumentException.class) + public void testAndLong() { + Counter.longs("counter", AND); + } + + @Test(expected = IllegalArgumentException.class) + public void testAndDouble() { + Counter.doubles("counter", AND); + } + + @Test(expected = IllegalArgumentException.class) + public void testOrLong() { + Counter.longs("counter", OR); + } + + @Test(expected = IllegalArgumentException.class) + public void testOrDouble() { + Counter.doubles("counter", OR); + } + + @Test + public void testMergeIncompatibleCounters() { + Counter longSums = Counter.longs("longsums", SUM); + Counter longMean = Counter.longs("longmean", MEAN); + Counter longMin = Counter.longs("longmin", MIN); + + Counter otherLongSums = Counter.longs("othersums", SUM); + Counter otherLongMean = Counter.longs("otherlongmean", MEAN); + + Counter doubleSums = Counter.doubles("doublesums", SUM); + Counter doubleMean = Counter.doubles("doublemean", MEAN); + + Counter boolAnd = Counter.booleans("and", AND); + Counter boolOr = Counter.booleans("or", OR); + + List> longCounters = + Arrays.asList(longSums, longMean, longMin, otherLongSums, otherLongMean); + for (Counter left : longCounters) { + for (Counter right : longCounters) { + if (left != right) { + assertIncompatibleMerge(left, right); + } + } + } + + assertIncompatibleMerge(doubleSums, doubleMean); + assertIncompatibleMerge(boolAnd, boolOr); + } + + private void assertIncompatibleMerge(Counter left, Counter right) { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Counters"); + thrown.expectMessage("are incompatible"); + left.merge(right); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTestUtils.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTestUtils.java new file mode 100644 index 000000000000..faaa34110d8c --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/CounterTestUtils.java @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + ******************************************************************************/ + +package com.google.cloud.dataflow.sdk.util.common; + +import static com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind.MEAN; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.util.common.Counter.CounterMean; + +import org.junit.Assert; + +import java.io.ByteArrayOutputStream; + +/** + * Utilities for testing {@link Counter}s. + */ +public class CounterTestUtils { + /** + * A utility method that passes the given (unencoded) elements through + * coder's registerByteSizeObserver() and encode() methods, and confirms + * they are mutually consistent. This is useful for testing coder + * implementations. + */ + public static void testByteCount(Coder coder, Coder.Context context, T[] elements) + throws Exception { + Counter meanByteCount = Counter.longs("meanByteCount", MEAN); + ElementByteSizeObserver observer = new ElementByteSizeObserver(meanByteCount); + + ByteArrayOutputStream os = new ByteArrayOutputStream(); + for (T elem : elements) { + coder.registerByteSizeObserver(elem, observer, context); + coder.encode(elem, os, context); + observer.advance(); + } + long expectedLength = os.toByteArray().length; + + CounterMean mean = meanByteCount.getMean(); + + Assert.assertEquals(expectedLength, mean.getAggregate().longValue()); + Assert.assertEquals(elements.length, mean.getCount()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/ReflectHelpersTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/ReflectHelpersTest.java new file mode 100644 index 000000000000..f7d678ba8e18 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/common/ReflectHelpersTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.common; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.List; +import java.util.Map; + +/** + * Tests for {@link ReflectHelpers}. + */ +@RunWith(JUnit4.class) +public class ReflectHelpersTest { + + @Test + public void testClassName() { + assertEquals(getClass().getName(), ReflectHelpers.CLASS_NAME.apply(getClass())); + } + + @Test + public void testClassSimpleName() { + assertEquals(getClass().getSimpleName(), + ReflectHelpers.CLASS_SIMPLE_NAME.apply(getClass())); + } + + @Test + public void testMethodFormatter() throws Exception { + assertEquals("testMethodFormatter()", + ReflectHelpers.METHOD_FORMATTER.apply(getClass().getMethod("testMethodFormatter"))); + + assertEquals("oneArg(int)", + ReflectHelpers.METHOD_FORMATTER.apply(getClass().getDeclaredMethod("oneArg", int.class))); + assertEquals("twoArg(String, List)", + ReflectHelpers.METHOD_FORMATTER.apply( + getClass().getDeclaredMethod("twoArg", String.class, List.class))); + } + + @Test + public void testClassMethodFormatter() throws Exception { + assertEquals( + getClass().getName() + "#testMethodFormatter()", + ReflectHelpers.CLASS_AND_METHOD_FORMATTER + .apply(getClass().getMethod("testMethodFormatter"))); + + assertEquals( + getClass().getName() + "#oneArg(int)", + ReflectHelpers.CLASS_AND_METHOD_FORMATTER + .apply(getClass().getDeclaredMethod("oneArg", int.class))); + assertEquals( + getClass().getName() + "#twoArg(String, List)", + ReflectHelpers.CLASS_AND_METHOD_FORMATTER.apply( + getClass().getDeclaredMethod("twoArg", String.class, List.class))); + } + + @SuppressWarnings("unused") + void oneArg(int n) {} + @SuppressWarnings("unused") + void twoArg(String foo, List bar) {} + + @Test + public void testTypeFormatterOnClasses() throws Exception { + assertEquals("Integer", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(Integer.class)); + assertEquals("int", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(int.class)); + assertEquals("Map", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(Map.class)); + assertEquals(getClass().getSimpleName(), + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(getClass())); + } + + @Test + public void testTypeFormatterOnArrays() throws Exception { + assertEquals("Integer[]", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(Integer[].class)); + assertEquals("int[]", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(int[].class)); + } + + @Test + public void testTypeFormatterWithGenerics() throws Exception { + assertEquals("Map", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply( + new TypeDescriptor>() {}.getType())); + assertEquals("Map", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply( + new TypeDescriptor>() {}.getType())); + assertEquals("Map", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply( + new TypeDescriptor>() {}.getType())); + } + + @Test + public void testTypeFormatterWithWildcards() throws Exception { + assertEquals("Map", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply( + new TypeDescriptor>() {}.getType())); + } + + @Test + public void testTypeFormatterWithMultipleWildcards() throws Exception { + assertEquals("Map", + ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply( + new TypeDescriptor>() {}.getType())); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPathTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPathTest.java new file mode 100644 index 000000000000..3aefa6d0f008 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/gcsfs/GcsPathTest.java @@ -0,0 +1,333 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.util.gcsfs; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * Tests of GcsPath. + */ +@RunWith(JUnit4.class) +public class GcsPathTest { + + /** + * Test case, which tests parsing and building of GcsPaths. + */ + static final class TestCase { + + final String uri; + final String expectedBucket; + final String expectedObject; + final String[] namedComponents; + + TestCase(String uri, String... namedComponents) { + this.uri = uri; + this.expectedBucket = namedComponents[0]; + this.namedComponents = namedComponents; + this.expectedObject = uri.substring(expectedBucket.length() + 6); + } + } + + // Each test case is an expected URL, then the components used to build it. + // Empty components result in a double slash. + static final List PATH_TEST_CASES = Arrays.asList( + new TestCase("gs://bucket/then/object", "bucket", "then", "object"), + new TestCase("gs://bucket//then/object", "bucket", "", "then", "object"), + new TestCase("gs://bucket/then//object", "bucket", "then", "", "object"), + new TestCase("gs://bucket/then///object", "bucket", "then", "", "", "object"), + new TestCase("gs://bucket/then/object/", "bucket", "then", "object/"), + new TestCase("gs://bucket/then/object/", "bucket", "then/", "object/"), + new TestCase("gs://bucket/then/object//", "bucket", "then", "object", ""), + new TestCase("gs://bucket/then/object//", "bucket", "then", "object/", ""), + new TestCase("gs://bucket/", "bucket") + ); + + @Test + public void testGcsPathParsing() throws Exception { + for (TestCase testCase : PATH_TEST_CASES) { + String uriString = testCase.uri; + + GcsPath path = GcsPath.fromUri(URI.create(uriString)); + // Deconstruction - check bucket, object, and components. + assertEquals(testCase.expectedBucket, path.getBucket()); + assertEquals(testCase.expectedObject, path.getObject()); + assertEquals(testCase.uri, + testCase.namedComponents.length, path.getNameCount()); + + // Construction - check that the path can be built from components. + GcsPath built = GcsPath.fromComponents(null, null); + for (String component : testCase.namedComponents) { + built = built.resolve(component); + } + assertEquals(testCase.uri, built.toString()); + } + } + + @Test + public void testParentRelationship() throws Exception { + GcsPath path = GcsPath.fromComponents("bucket", "then/object"); + assertEquals("bucket", path.getBucket()); + assertEquals("then/object", path.getObject()); + assertEquals(3, path.getNameCount()); + assertTrue(path.endsWith("object")); + assertTrue(path.startsWith("bucket/then")); + + GcsPath parent = path.getParent(); // gs://bucket/then/ + assertEquals("bucket", parent.getBucket()); + assertEquals("then/", parent.getObject()); + assertEquals(2, parent.getNameCount()); + assertThat(path, Matchers.not(Matchers.equalTo(parent))); + assertTrue(path.startsWith(parent)); + assertFalse(parent.startsWith(path)); + assertTrue(parent.endsWith("then/")); + assertTrue(parent.startsWith("bucket/then")); + assertTrue(parent.isAbsolute()); + + GcsPath root = path.getRoot(); + assertEquals(0, root.getNameCount()); + assertEquals("gs://", root.toString()); + assertEquals("", root.getBucket()); + assertEquals("", root.getObject()); + assertTrue(root.isAbsolute()); + assertThat(root, Matchers.equalTo(parent.getRoot())); + + GcsPath grandParent = parent.getParent(); // gs://bucket/ + assertEquals(1, grandParent.getNameCount()); + assertEquals("gs://bucket/", grandParent.toString()); + assertTrue(grandParent.isAbsolute()); + assertThat(root, Matchers.equalTo(grandParent.getParent())); + assertThat(root.getParent(), Matchers.nullValue()); + + assertTrue(path.startsWith(path.getRoot())); + assertTrue(parent.startsWith(path.getRoot())); + } + + @Test + public void testRelativeParent() throws Exception { + GcsPath path = GcsPath.fromComponents(null, "a/b"); + GcsPath parent = path.getParent(); + assertEquals("a/", parent.toString()); + + GcsPath grandParent = parent.getParent(); + assertNull(grandParent); + } + + @Test + public void testUriSupport() throws Exception { + URI uri = URI.create("gs://bucket/some/path"); + + GcsPath path = GcsPath.fromUri(uri); + assertEquals("bucket", path.getBucket()); + assertEquals("some/path", path.getObject()); + + URI reconstructed = path.toUri(); + assertEquals(uri, reconstructed); + + path = GcsPath.fromUri("gs://bucket"); + assertEquals("gs://bucket/", path.toString()); + } + + @Test + public void testBucketParsing() throws Exception { + GcsPath path = GcsPath.fromUri("gs://bucket"); + GcsPath path2 = GcsPath.fromUri("gs://bucket/"); + + assertEquals(path, path2); + assertEquals(path.toString(), path2.toString()); + assertEquals(path.toUri(), path2.toUri()); + } + + @Test + public void testGcsPathToString() throws Exception { + String filename = "gs://some_bucket/some/file.txt"; + GcsPath path = GcsPath.fromUri(filename); + assertEquals(filename, path.toString()); + } + + @Test + public void testEquals() { + GcsPath a = GcsPath.fromComponents(null, "a/b/c"); + GcsPath a2 = GcsPath.fromComponents(null, "a/b/c"); + assertFalse(a.isAbsolute()); + assertFalse(a2.isAbsolute()); + + GcsPath b = GcsPath.fromComponents("bucket", "a/b/c"); + GcsPath b2 = GcsPath.fromComponents("bucket", "a/b/c"); + assertTrue(b.isAbsolute()); + assertTrue(b2.isAbsolute()); + + assertEquals(a, a); + assertThat(a, Matchers.not(Matchers.equalTo(b))); + assertThat(b, Matchers.not(Matchers.equalTo(a))); + + assertEquals(a, a2); + assertEquals(a2, a); + assertEquals(b, b2); + assertEquals(b2, b); + + assertThat(a, Matchers.not(Matchers.equalTo(Paths.get("/tmp/foo")))); + assertTrue(a != null); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidGcsPath() { + @SuppressWarnings("unused") + GcsPath filename = + GcsPath.fromUri("file://invalid/gcs/path"); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidBucket() { + GcsPath.fromComponents("invalid/", ""); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidObject_newline() { + GcsPath.fromComponents(null, "a\nb"); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidObject_cr() { + GcsPath.fromComponents(null, "a\rb"); + } + + @Test + public void testResolveUri() { + GcsPath path = GcsPath.fromComponents("bucket", "a/b/c"); + GcsPath d = path.resolve("gs://bucket2/d"); + assertEquals("gs://bucket2/d", d.toString()); + } + + @Test + public void testResolveOther() { + GcsPath a = GcsPath.fromComponents("bucket", "a"); + GcsPath b = a.resolve(Paths.get("b")); + assertEquals("a/b", b.getObject()); + } + + @Test + public void testCompareTo() { + GcsPath a = GcsPath.fromComponents("bucket", "a"); + GcsPath b = GcsPath.fromComponents("bucket", "b"); + GcsPath b2 = GcsPath.fromComponents("bucket2", "b"); + GcsPath brel = GcsPath.fromComponents(null, "b"); + GcsPath a2 = GcsPath.fromComponents("bucket", "a"); + GcsPath arel = GcsPath.fromComponents(null, "a"); + + assertThat(a.compareTo(b), Matchers.lessThan(0)); + assertThat(b.compareTo(a), Matchers.greaterThan(0)); + assertThat(a.compareTo(a2), Matchers.equalTo(0)); + + assertThat(a.hashCode(), Matchers.equalTo(a2.hashCode())); + assertThat(a.hashCode(), Matchers.not(Matchers.equalTo(b.hashCode()))); + assertThat(b.hashCode(), Matchers.not(Matchers.equalTo(brel.hashCode()))); + + assertThat(brel.compareTo(b), Matchers.lessThan(0)); + assertThat(b.compareTo(brel), Matchers.greaterThan(0)); + assertThat(arel.compareTo(brel), Matchers.lessThan(0)); + assertThat(brel.compareTo(arel), Matchers.greaterThan(0)); + + assertThat(b.compareTo(b2), Matchers.lessThan(0)); + assertThat(b2.compareTo(b), Matchers.greaterThan(0)); + } + + @Test + public void testCompareTo_ordering() { + GcsPath ab = GcsPath.fromComponents("bucket", "a/b"); + GcsPath abc = GcsPath.fromComponents("bucket", "a/b/c"); + GcsPath a1b = GcsPath.fromComponents("bucket", "a-1/b"); + + assertThat(ab.compareTo(a1b), Matchers.lessThan(0)); + assertThat(a1b.compareTo(ab), Matchers.greaterThan(0)); + + assertThat(ab.compareTo(abc), Matchers.lessThan(0)); + assertThat(abc.compareTo(ab), Matchers.greaterThan(0)); + } + + @Test + public void testCompareTo_buckets() { + GcsPath a = GcsPath.fromComponents(null, "a/b/c"); + GcsPath b = GcsPath.fromComponents("bucket", "a/b/c"); + + assertThat(a.compareTo(b), Matchers.lessThan(0)); + assertThat(b.compareTo(a), Matchers.greaterThan(0)); + } + + @Test + public void testIterator() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c"); + Iterator it = a.iterator(); + + assertTrue(it.hasNext()); + assertEquals("gs://bucket/", it.next().toString()); + assertTrue(it.hasNext()); + assertEquals("a", it.next().toString()); + assertTrue(it.hasNext()); + assertEquals("b", it.next().toString()); + assertTrue(it.hasNext()); + assertEquals("c", it.next().toString()); + assertFalse(it.hasNext()); + } + + @Test + public void testSubpath() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c/d"); + assertThat(a.subpath(0, 1).toString(), Matchers.equalTo("gs://bucket/")); + assertThat(a.subpath(0, 2).toString(), Matchers.equalTo("gs://bucket/a")); + assertThat(a.subpath(0, 3).toString(), Matchers.equalTo("gs://bucket/a/b")); + assertThat(a.subpath(0, 4).toString(), Matchers.equalTo("gs://bucket/a/b/c")); + assertThat(a.subpath(1, 2).toString(), Matchers.equalTo("a")); + assertThat(a.subpath(2, 3).toString(), Matchers.equalTo("b")); + assertThat(a.subpath(2, 4).toString(), Matchers.equalTo("b/c")); + assertThat(a.subpath(2, 5).toString(), Matchers.equalTo("b/c/d")); + } + + @Test + public void testGetName() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c/d"); + assertEquals(5, a.getNameCount()); + assertThat(a.getName(0).toString(), Matchers.equalTo("gs://bucket/")); + assertThat(a.getName(1).toString(), Matchers.equalTo("a")); + assertThat(a.getName(2).toString(), Matchers.equalTo("b")); + assertThat(a.getName(3).toString(), Matchers.equalTo("c")); + assertThat(a.getName(4).toString(), Matchers.equalTo("d")); + } + + @Test(expected = IllegalArgumentException.class) + public void testSubPathError() { + GcsPath a = GcsPath.fromComponents("bucket", "a/b/c/d"); + a.subpath(1, 1); // throws IllegalArgumentException + Assert.fail(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternalsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternalsTest.java new file mode 100644 index 000000000000..5bb0f597728b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternalsTest.java @@ -0,0 +1,553 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.theInstance; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; + +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link CopyOnAccessInMemoryStateInternals}. + */ +@RunWith(JUnit4.class) +public class CopyOnAccessInMemoryStateInternalsTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + private String key = "foo"; + @Test + public void testGetWithEmpty() { + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = internals.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + assertThat(stringBag.read(), containsInAnyOrder("baz", "bar")); + + BagState reReadStringBag = internals.state(namespace, bagTag); + assertThat(reReadStringBag.read(), containsInAnyOrder("baz", "bar")); + } + + @Test + public void testGetWithAbsentInUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = internals.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + assertThat(stringBag.read(), containsInAnyOrder("baz", "bar")); + + BagState reReadVoidBag = internals.state(namespace, bagTag); + assertThat(reReadVoidBag.read(), containsInAnyOrder("baz", "bar")); + + BagState underlyingState = underlying.state(namespace, bagTag); + assertThat(underlyingState.read(), emptyIterable()); + } + + /** + * Tests that retrieving state with an underlying StateInternals with an existing value returns + * a value that initially has equal value to the provided state but can be modified without + * modifying the existing state. + */ + @Test + public void testGetWithPresentInUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> valueTag = StateTags.value("foo", StringUtf8Coder.of()); + ValueState underlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.read(), nullValue(String.class)); + + underlyingValue.write("bar"); + assertThat(underlyingValue.read(), equalTo("bar")); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + ValueState copyOnAccessState = internals.state(namespace, valueTag); + assertThat(copyOnAccessState.read(), equalTo("bar")); + + copyOnAccessState.write("baz"); + assertThat(copyOnAccessState.read(), equalTo("baz")); + assertThat(underlyingValue.read(), equalTo("bar")); + + ValueState reReadUnderlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); + } + + @Test + public void testBagStateWithUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> valueTag = StateTags.bag("foo", VarIntCoder.of()); + BagState underlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.read(), emptyIterable()); + + underlyingValue.add(1); + assertThat(underlyingValue.read(), containsInAnyOrder(1)); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + BagState copyOnAccessState = internals.state(namespace, valueTag); + assertThat(copyOnAccessState.read(), containsInAnyOrder(1)); + + copyOnAccessState.add(4); + assertThat(copyOnAccessState.read(), containsInAnyOrder(4, 1)); + assertThat(underlyingValue.read(), containsInAnyOrder(1)); + + BagState reReadUnderlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); + } + + @Test + public void testAccumulatorCombiningStateWithUnderlying() throws CannotProvideCoderException { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CombineFn sumLongFn = new Sum.SumLongFn(); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + CoderRegistry reg = TestPipeline.create().getCoderRegistry(); + StateTag> stateTag = + StateTags.combiningValue("summer", + sumLongFn.getAccumulatorCoder(reg, reg.getDefaultCoder(Long.class)), sumLongFn); + CombiningState underlyingValue = underlying.state(namespace, stateTag); + assertThat(underlyingValue.read(), equalTo(0L)); + + underlyingValue.add(1L); + assertThat(underlyingValue.read(), equalTo(1L)); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + CombiningState copyOnAccessState = internals.state(namespace, stateTag); + assertThat(copyOnAccessState.read(), equalTo(1L)); + + copyOnAccessState.add(4L); + assertThat(copyOnAccessState.read(), equalTo(5L)); + assertThat(underlyingValue.read(), equalTo(1L)); + + CombiningState reReadUnderlyingValue = underlying.state(namespace, stateTag); + assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); + } + + @Test + public void testKeyedAccumulatorCombiningStateWithUnderlying() throws Exception { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + KeyedCombineFn sumLongFn = new Sum.SumLongFn().asKeyedFn(); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + CoderRegistry reg = TestPipeline.create().getCoderRegistry(); + StateTag> stateTag = + StateTags.keyedCombiningValue( + "summer", + sumLongFn.getAccumulatorCoder( + reg, StringUtf8Coder.of(), reg.getDefaultCoder(Long.class)), + sumLongFn); + CombiningState underlyingValue = underlying.state(namespace, stateTag); + assertThat(underlyingValue.read(), equalTo(0L)); + + underlyingValue.add(1L); + assertThat(underlyingValue.read(), equalTo(1L)); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + CombiningState copyOnAccessState = internals.state(namespace, stateTag); + assertThat(copyOnAccessState.read(), equalTo(1L)); + + copyOnAccessState.add(4L); + assertThat(copyOnAccessState.read(), equalTo(5L)); + assertThat(underlyingValue.read(), equalTo(1L)); + + CombiningState reReadUnderlyingValue = underlying.state(namespace, stateTag); + assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); + } + + @Test + public void testWatermarkHoldStateWithUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + @SuppressWarnings("unchecked") + OutputTimeFn outputTimeFn = (OutputTimeFn) + TestPipeline.create().apply(Create.of("foo")).getWindowingStrategy().getOutputTimeFn(); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> stateTag = + StateTags.watermarkStateInternal("wmstate", outputTimeFn); + WatermarkHoldState underlyingValue = underlying.state(namespace, stateTag); + assertThat(underlyingValue.read(), nullValue()); + + underlyingValue.add(new Instant(250L)); + assertThat(underlyingValue.read(), equalTo(new Instant(250L))); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + WatermarkHoldState copyOnAccessState = internals.state(namespace, stateTag); + assertThat(copyOnAccessState.read(), equalTo(new Instant(250L))); + + copyOnAccessState.add(new Instant(100L)); + assertThat(copyOnAccessState.read(), equalTo(new Instant(100L))); + assertThat(underlyingValue.read(), equalTo(new Instant(250L))); + + copyOnAccessState.add(new Instant(500L)); + assertThat(copyOnAccessState.read(), equalTo(new Instant(100L))); + + WatermarkHoldState reReadUnderlyingValue = + underlying.state(namespace, stateTag); + assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); + } + + @Test + public void testCommitWithoutUnderlying() { + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = internals.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + assertThat(stringBag.read(), containsInAnyOrder("baz", "bar")); + + internals.commit(); + + BagState reReadStringBag = internals.state(namespace, bagTag); + assertThat(reReadStringBag.read(), containsInAnyOrder("baz", "bar")); + assertThat(internals.isEmpty(), is(false)); + } + + @Test + public void testCommitWithUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = underlying.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + + internals.commit(); + BagState reReadStringBag = internals.state(namespace, bagTag); + assertThat(reReadStringBag.read(), containsInAnyOrder("baz", "bar")); + + reReadStringBag.add("spam"); + + BagState underlyingState = underlying.state(namespace, bagTag); + assertThat(underlyingState.read(), containsInAnyOrder("spam", "bar", "baz")); + assertThat(underlyingState, is(theInstance(stringBag))); + assertThat(internals.isEmpty(), is(false)); + } + + @Test + public void testCommitWithClearedInUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CopyOnAccessInMemoryStateInternals secondUnderlying = + spy(CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying)); + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, secondUnderlying); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = underlying.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + stringBag.clear(); + // We should not read through the cleared bag + secondUnderlying.commit(); + + // Should not be visible + stringBag.add("foo"); + + internals.commit(); + BagState internalsStringBag = internals.state(namespace, bagTag); + assertThat(internalsStringBag.read(), emptyIterable()); + verify(secondUnderlying, never()).state(namespace, bagTag); + assertThat(internals.isEmpty(), is(false)); + } + + @Test + public void testCommitWithOverwrittenUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = underlying.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + + BagState internalsState = internals.state(namespace, bagTag); + internalsState.add("eggs"); + internalsState.add("ham"); + internalsState.add("0x00ff00"); + internalsState.add("&"); + + internals.commit(); + + BagState reReadInternalState = internals.state(namespace, bagTag); + assertThat( + reReadInternalState.read(), + containsInAnyOrder("bar", "baz", "0x00ff00", "eggs", "&", "ham")); + BagState reReadUnderlyingState = underlying.state(namespace, bagTag); + assertThat(reReadUnderlyingState.read(), containsInAnyOrder("bar", "baz")); + } + + @Test + public void testCommitWithAddedUnderlying() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + + internals.commit(); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = underlying.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + + BagState internalState = internals.state(namespace, bagTag); + assertThat(internalState.read(), emptyIterable()); + + BagState reReadUnderlyingState = underlying.state(namespace, bagTag); + assertThat(reReadUnderlyingState.read(), containsInAnyOrder("bar", "baz")); + } + + @Test + public void testCommitWithEmptyTableIsEmpty() { + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + internals.commit(); + + assertThat(internals.isEmpty(), is(true)); + } + + @Test + public void testCommitWithOnlyClearedValuesIsEmpty() { + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = internals.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("foo"); + stringBag.clear(); + + internals.commit(); + + assertThat(internals.isEmpty(), is(true)); + } + + @Test + public void testCommitWithEmptyNewAndFullUnderlyingIsNotEmpty() { + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag> bagTag = StateTags.bag("foo", StringUtf8Coder.of()); + BagState stringBag = underlying.state(namespace, bagTag); + assertThat(stringBag.read(), emptyIterable()); + + stringBag.add("bar"); + stringBag.add("baz"); + + internals.commit(); + assertThat(internals.isEmpty(), is(false)); + } + + @Test + public void testGetEarliestWatermarkHoldAfterCommit() { + BoundedWindow first = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(2048L); + } + }; + BoundedWindow second = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(689743L); + } + }; + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying("foo", null); + + StateTag> firstHoldAddress = + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + WatermarkHoldState firstHold = + internals.state(StateNamespaces.window(null, first), firstHoldAddress); + firstHold.add(new Instant(22L)); + + StateTag> secondHoldAddress = + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + WatermarkHoldState secondHold = + internals.state(StateNamespaces.window(null, second), secondHoldAddress); + secondHold.add(new Instant(2L)); + + internals.commit(); + assertThat(internals.getEarliestWatermarkHold(), equalTo(new Instant(2L))); + } + + @Test + public void testGetEarliestWatermarkHoldWithEarliestInUnderlyingTable() { + BoundedWindow first = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(2048L); + } + }; + BoundedWindow second = new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(689743L); + } + }; + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying("foo", null); + StateTag> firstHoldAddress = + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + WatermarkHoldState firstHold = + underlying.state(StateNamespaces.window(null, first), firstHoldAddress); + firstHold.add(new Instant(22L)); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying("foo", underlying.commit()); + + StateTag> secondHoldAddress = + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + WatermarkHoldState secondHold = + internals.state(StateNamespaces.window(null, second), secondHoldAddress); + secondHold.add(new Instant(244L)); + + internals.commit(); + assertThat(internals.getEarliestWatermarkHold(), equalTo(new Instant(22L))); + } + + @Test + public void testGetEarliestWatermarkHoldWithEarliestInNewTable() { + BoundedWindow first = + new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(2048L); + } + }; + BoundedWindow second = + new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(689743L); + } + }; + CopyOnAccessInMemoryStateInternals underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying("foo", null); + StateTag> firstHoldAddress = + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + WatermarkHoldState firstHold = + underlying.state(StateNamespaces.window(null, first), firstHoldAddress); + firstHold.add(new Instant(224L)); + + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying("foo", underlying.commit()); + + StateTag> secondHoldAddress = + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + WatermarkHoldState secondHold = + internals.state(StateNamespaces.window(null, second), secondHoldAddress); + secondHold.add(new Instant(24L)); + + internals.commit(); + assertThat(internals.getEarliestWatermarkHold(), equalTo(new Instant(24L))); + } + + @Test + public void testGetEarliestHoldBeforeCommit() { + CopyOnAccessInMemoryStateInternals internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + internals + .state( + StateNamespaces.global(), + StateTags.watermarkStateInternal("foo", OutputTimeFns.outputAtEarliestInputTimestamp())) + .add(new Instant(1234L)); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage(CopyOnAccessInMemoryStateInternals.class.getSimpleName()); + thrown.expectMessage("Can't get the earliest watermark hold"); + thrown.expectMessage("before it is committed"); + + internals.getEarliestWatermarkHold(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternalsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternalsTest.java new file mode 100644 index 000000000000..0c10560e4abd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternalsTest.java @@ -0,0 +1,348 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Arrays; + +/** + * Tests for {@link InMemoryStateInternals}. + */ +@RunWith(JUnit4.class) +public class InMemoryStateInternalsTest { + private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + private static final StateTag> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), new Sum.SumIntegerFn()); + private static final StateTag> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + private static final StateTag> + WATERMARK_EARLIEST_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp()); + private static final StateTag> + WATERMARK_LATEST_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtLatestInputTimestamp()); + private static final StateTag> WATERMARK_EOW_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEndOfWindow()); + + InMemoryStateInternals underTest = InMemoryStateInternals.forKey("dummyKey"); + + @Test + public void testValue() throws Exception { + ValueState value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), Matchers.sameInstance(value)); + assertThat( + underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), + Matchers.not(Matchers.sameInstance(value))); + + assertThat(value.read(), Matchers.nullValue()); + value.write("hello"); + assertThat(value.read(), Matchers.equalTo("hello")); + value.write("world"); + assertThat(value.read(), Matchers.equalTo("world")); + + value.clear(); + assertThat(value.read(), Matchers.nullValue()); + assertThat(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), Matchers.sameInstance(value)); + } + + @Test + public void testBag() throws Exception { + BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertThat(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), Matchers.sameInstance(value)); + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeBagIntoSource() throws Exception { + BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testCombiningValue() throws Exception { + CombiningState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); + + assertThat(value.read(), Matchers.equalTo(0)); + value.add(2); + assertThat(value.read(), Matchers.equalTo(2)); + + value.add(3); + assertThat(value.read(), Matchers.equalTo(5)); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(0)); + assertThat(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), Matchers.sameInstance(value)); + } + + @Test + public void testCombiningIsEmpty() throws Exception { + CombiningState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add(5); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeCombiningValueIntoSource() throws Exception { + AccumulatorCombiningState value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + assertThat(value1.read(), Matchers.equalTo(11)); + assertThat(value2.read(), Matchers.equalTo(10)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); + + assertThat(value1.read(), Matchers.equalTo(21)); + assertThat(value2.read(), Matchers.equalTo(0)); + } + + @Test + public void testMergeCombiningValueIntoNewNamespace() throws Exception { + AccumulatorCombiningState value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + AccumulatorCombiningState value3 = + underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); + + // Merging clears the old values and updates the result value. + assertThat(value1.read(), Matchers.equalTo(0)); + assertThat(value2.read(), Matchers.equalTo(0)); + assertThat(value3.read(), Matchers.equalTo(21)); + } + + @Test + public void testWatermarkEarliestState() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(1000)); + assertThat(value.read(), Matchers.equalTo(new Instant(1000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertThat(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), Matchers.sameInstance(value)); + } + + @Test + public void testWatermarkLatestState() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + + value.add(new Instant(1000)); + assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertThat(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), Matchers.sameInstance(value)); + } + + @Test + public void testWatermarkEndOfWindowState() throws Exception { + WatermarkHoldState value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertThat(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), Matchers.sameInstance(value)); + } + + @Test + public void testWatermarkStateIsEmpty() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add(new Instant(1000)); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeEarliestWatermarkIntoSource() throws Exception { + WatermarkHoldState value1 = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + WatermarkHoldState value2 = + underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the merged value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); + + assertThat(value1.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value2.read(), Matchers.equalTo(null)); + } + + @Test + public void testMergeLatestWatermarkIntoSource() throws Exception { + WatermarkHoldState value1 = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + WatermarkHoldState value2 = + underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); + WatermarkHoldState value3 = + underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); + + // Merging clears the old values and updates the result value. + assertThat(value3.read(), Matchers.equalTo(new Instant(5000))); + assertThat(value1.read(), Matchers.equalTo(null)); + assertThat(value2.read(), Matchers.equalTo(null)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/StateNamespacesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/StateNamespacesTest.java new file mode 100644 index 000000000000..933383db40fd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/StateNamespacesTest.java @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link StateNamespaces}. + */ +@RunWith(JUnit4.class) +public class StateNamespacesTest { + + private final Coder intervalCoder = IntervalWindow.getCoder(); + + private IntervalWindow intervalWindow(long start, long end) { + return new IntervalWindow(new Instant(start), new Instant(end)); + } + + /** + * This test should not be changed. It verifies that the stringKey matches certain expectations. + * If this changes, the ability to reload any pipeline that has persisted these namespaces will + * be impacted. + */ + @Test + public void testStability() { + StateNamespace global = StateNamespaces.global(); + StateNamespace intervalWindow = + StateNamespaces.window(intervalCoder, intervalWindow(1000, 87392)); + StateNamespace intervalWindowAndTrigger = + StateNamespaces.windowAndTrigger(intervalCoder, intervalWindow(1000, 87392), 57); + StateNamespace globalWindow = StateNamespaces.window( + GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE); + StateNamespace globalWindowAndTrigger = StateNamespaces.windowAndTrigger( + GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, 12); + + assertEquals("/", global.stringKey()); + assertEquals("/gAAAAAABVWD4ogU/", intervalWindow.stringKey()); + assertEquals("/gAAAAAABVWD4ogU/1L/", intervalWindowAndTrigger.stringKey()); + assertEquals("//", globalWindow.stringKey()); + assertEquals("//C/", globalWindowAndTrigger.stringKey()); + } + + /** + * Test that WindowAndTrigger namespaces are prefixed by the related Window namespace. + */ + @Test + public void testIntervalWindowPrefixing() { + StateNamespace window = + StateNamespaces.window(intervalCoder, intervalWindow(1000, 87392)); + StateNamespace windowAndTrigger = StateNamespaces.windowAndTrigger( + intervalCoder, intervalWindow(1000, 87392), 57); + assertThat(windowAndTrigger.stringKey(), Matchers.startsWith(window.stringKey())); + assertThat(StateNamespaces.global().stringKey(), + Matchers.not(Matchers.startsWith(window.stringKey()))); + } + + /** + * Test that WindowAndTrigger namespaces are prefixed by the related Window namespace. + */ + @Test + public void testGlobalWindowPrefixing() { + StateNamespace window = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE); + StateNamespace windowAndTrigger = StateNamespaces.windowAndTrigger( + GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, 57); + assertThat(windowAndTrigger.stringKey(), Matchers.startsWith(window.stringKey())); + assertThat(StateNamespaces.global().stringKey(), + Matchers.not(Matchers.startsWith(window.stringKey()))); + } + + @Test + public void testFromStringGlobal() { + assertStringKeyRoundTrips(intervalCoder, StateNamespaces.global()); + } + + @Test + public void testFromStringIntervalWindow() { + assertStringKeyRoundTrips( + intervalCoder, StateNamespaces.window(intervalCoder, intervalWindow(1000, 8000))); + assertStringKeyRoundTrips( + intervalCoder, StateNamespaces.window(intervalCoder, intervalWindow(1000, 8000))); + + assertStringKeyRoundTrips(intervalCoder, + StateNamespaces.windowAndTrigger(intervalCoder, intervalWindow(1000, 8000), 18)); + assertStringKeyRoundTrips(intervalCoder, + StateNamespaces.windowAndTrigger(intervalCoder, intervalWindow(1000, 8000), 19)); + assertStringKeyRoundTrips(intervalCoder, + StateNamespaces.windowAndTrigger(intervalCoder, intervalWindow(2000, 8000), 19)); + } + + @Test + public void testFromStringGlobalWindow() { + assertStringKeyRoundTrips(GlobalWindow.Coder.INSTANCE, StateNamespaces.global()); + assertStringKeyRoundTrips(GlobalWindow.Coder.INSTANCE, + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE)); + assertStringKeyRoundTrips(GlobalWindow.Coder.INSTANCE, + StateNamespaces.windowAndTrigger(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, 18)); + } + + private void assertStringKeyRoundTrips( + Coder coder, StateNamespace namespace) { + assertEquals(namespace, StateNamespaces.fromString(namespace.stringKey(), coder)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/StateTagTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/StateTagTest.java new file mode 100644 index 000000000000..47f7224e9096 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/state/StateTagTest.java @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.transforms.Max; +import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.Min; +import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link StateTag}. + */ +@RunWith(JUnit4.class) +public class StateTagTest { + @Test + public void testValueEquality() { + StateTag fooVarInt1 = StateTags.value("foo", VarIntCoder.of()); + StateTag fooVarInt2 = StateTags.value("foo", VarIntCoder.of()); + StateTag fooBigEndian = StateTags.value("foo", BigEndianIntegerCoder.of()); + StateTag barVarInt = StateTags.value("bar", VarIntCoder.of()); + + assertEquals(fooVarInt1, fooVarInt2); + assertNotEquals(fooVarInt1, fooBigEndian); + assertNotEquals(fooVarInt1, barVarInt); + } + + @Test + public void testBagEquality() { + StateTag fooVarInt1 = StateTags.bag("foo", VarIntCoder.of()); + StateTag fooVarInt2 = StateTags.bag("foo", VarIntCoder.of()); + StateTag fooBigEndian = StateTags.bag("foo", BigEndianIntegerCoder.of()); + StateTag barVarInt = StateTags.bag("bar", VarIntCoder.of()); + + assertEquals(fooVarInt1, fooVarInt2); + assertNotEquals(fooVarInt1, fooBigEndian); + assertNotEquals(fooVarInt1, barVarInt); + } + + @Test + public void testWatermarkBagEquality() { + StateTag foo1 = StateTags.watermarkStateInternal( + "foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + StateTag foo2 = StateTags.watermarkStateInternal( + "foo", OutputTimeFns.outputAtEarliestInputTimestamp()); + StateTag bar = StateTags.watermarkStateInternal( + "bar", OutputTimeFns.outputAtEarliestInputTimestamp()); + + StateTag bar2 = StateTags.watermarkStateInternal( + "bar", OutputTimeFns.outputAtLatestInputTimestamp()); + + // Same id, same fn. + assertEquals(foo1, foo2); + // Different id, same fn. + assertNotEquals(foo1, bar); + // Same id, different fn. + assertEquals(bar, bar2); + } + + @Test + public void testCombiningValueEquality() { + MaxIntegerFn maxFn = new Max.MaxIntegerFn(); + Coder input1 = VarIntCoder.of(); + Coder input2 = BigEndianIntegerCoder.of(); + MinIntegerFn minFn = new Min.MinIntegerFn(); + + StateTag fooCoder1Max1 = StateTags.combiningValueFromInputInternal("foo", input1, maxFn); + StateTag fooCoder1Max2 = StateTags.combiningValueFromInputInternal("foo", input1, maxFn); + StateTag fooCoder1Min = StateTags.combiningValueFromInputInternal("foo", input1, minFn); + + StateTag fooCoder2Max = StateTags.combiningValueFromInputInternal("foo", input2, maxFn); + StateTag barCoder1Max = StateTags.combiningValueFromInputInternal("bar", input1, maxFn); + + // Same name, coder and combineFn + assertEquals(fooCoder1Max1, fooCoder1Max2); + // Different combineFn, but we treat them as equal since we only serialize the bits. + assertEquals(fooCoder1Max1, fooCoder1Min); + + // Different input coder coder. + assertNotEquals(fooCoder1Max1, fooCoder2Max); + + // These StateTags have different IDs. + assertNotEquals(fooCoder1Max1, barCoder1Max); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/KVTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/KVTest.java new file mode 100644 index 000000000000..75d4fc5a6bf9 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/KVTest.java @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.common.collect.ImmutableList; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Comparator; + +/** + * Tests for KV. + */ +@RunWith(JUnit4.class) +public class KVTest { + private static final Integer TEST_VALUES[] = + {null, Integer.MIN_VALUE, -1, 0, 1, Integer.MAX_VALUE}; + + // Wrapper around Integer.compareTo() to support null values. + private int compareInt(Integer a, Integer b) { + if (a == null) { + return b == null ? 0 : -1; + } else { + return b == null ? 1 : a.compareTo(b); + } + } + + @Test + public void testEquals() { + // Neither position are arrays + assertThat(KV.of(1, 2), equalTo(KV.of(1, 2))); + + // Key is array + assertThat(KV.of(new int[]{1, 2}, 3), equalTo(KV.of(new int[]{1, 2}, 3))); + + // Value is array + assertThat(KV.of(1, new int[]{2, 3}), equalTo(KV.of(1, new int[]{2, 3}))); + + // Both are arrays + assertThat(KV.of(new int[]{1, 2}, new int[]{3, 4}), + equalTo(KV.of(new int[]{1, 2}, new int[]{3, 4}))); + + // Unfortunately, deep equals only goes so far + assertThat(KV.of(ImmutableList.of(new int[]{1, 2}), 3), + not(equalTo(KV.of(ImmutableList.of(new int[]{1, 2}), 3)))); + assertThat(KV.of(1, ImmutableList.of(new int[]{2, 3})), + not(equalTo(KV.of(1, ImmutableList.of(new int[]{2, 3}))))); + + // Key is array and differs + assertThat(KV.of(new int[]{1, 2}, 3), not(equalTo(KV.of(new int[]{1, 37}, 3)))); + + // Key is non-array and differs + assertThat(KV.of(1, new int[]{2, 3}), not(equalTo(KV.of(37, new int[]{1, 2})))); + + // Value is array and differs + assertThat(KV.of(1, new int[]{2, 3}), not(equalTo(KV.of(1, new int[]{37, 3})))); + + // Value is non-array and differs + assertThat(KV.of(new byte[]{1, 2}, 3), not(equalTo(KV.of(new byte[]{1, 2}, 37)))); + } + + @Test + public void testOrderByKey() { + Comparator> orderByKey = new KV.OrderByKey<>(); + for (Integer key1 : TEST_VALUES) { + for (Integer val1 : TEST_VALUES) { + for (Integer key2 : TEST_VALUES) { + for (Integer val2 : TEST_VALUES) { + assertEquals(compareInt(key1, key2), + orderByKey.compare(KV.of(key1, val1), KV.of(key2, val2))); + } + } + } + } + } + + @Test + public void testOrderByValue() { + Comparator> orderByValue = new KV.OrderByValue<>(); + for (Integer key1 : TEST_VALUES) { + for (Integer val1 : TEST_VALUES) { + for (Integer key2 : TEST_VALUES) { + for (Integer val2 : TEST_VALUES) { + assertEquals(compareInt(val1, val2), + orderByValue.compare(KV.of(key1, val1), KV.of(key2, val2))); + } + } + } + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionListTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionListTest.java new file mode 100644 index 000000000000..fa163dd63419 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionListTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; + +/** + * Tests for PCollectionLists. + */ +@RunWith(JUnit4.class) +public class PCollectionListTest { + @Test + public void testEmptyListFailure() { + try { + PCollectionList.of(Collections.>emptyList()); + fail("should have failed"); + } catch (IllegalArgumentException exn) { + assertThat( + exn.toString(), + containsString( + "must either have a non-empty list of PCollections, " + + "or must first call empty(Pipeline)")); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionTupleTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionTupleTest.java new file mode 100644 index 000000000000..1017556e777e --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PCollectionTupleTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; + +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** Unit tests for {@link PCollectionTuple}. */ +@RunWith(JUnit4.class) +public final class PCollectionTupleTest implements Serializable { + @Test + public void testOfThenHas() { + Pipeline pipeline = TestPipeline.create(); + PCollection pCollection = PCollection.createPrimitiveOutputInternal( + pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED); + TupleTag tag = new TupleTag<>(); + + assertTrue(PCollectionTuple.of(tag, pCollection).has(tag)); + } + + @Test + public void testEmpty() { + Pipeline pipeline = TestPipeline.create(); + TupleTag tag = new TupleTag<>(); + assertFalse(PCollectionTuple.empty(pipeline).has(tag)); + } + + @Test + @Category(RunnableOnService.class) + public void testComposePCollectionTuple() { + Pipeline pipeline = TestPipeline.create(); + + List inputs = Arrays.asList(3, -42, 666); + + TupleTag mainOutputTag = new TupleTag("main") {}; + TupleTag emptyOutputTag = new TupleTag("empty") {}; + final TupleTag sideOutputTag = new TupleTag("side") {}; + + PCollection mainInput = pipeline + .apply(Create.of(inputs)); + + PCollectionTuple outputs = mainInput.apply(ParDo + .of(new DoFn() { + @Override + public void processElement(ProcessContext c) { + c.sideOutput(sideOutputTag, c.element()); + }}) + .withOutputTags(emptyOutputTag, TupleTagList.of(sideOutputTag))); + assertNotNull("outputs.getPipeline()", outputs.getPipeline()); + outputs = outputs.and(mainOutputTag, mainInput); + + DataflowAssert.that(outputs.get(mainOutputTag)).containsInAnyOrder(inputs); + DataflowAssert.that(outputs.get(sideOutputTag)).containsInAnyOrder(inputs); + DataflowAssert.that(outputs.get(emptyOutputTag)).empty(); + + pipeline.run(); + } + +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PDoneTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PDoneTest.java new file mode 100644 index 000000000000..4c273672cb0b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/PDoneTest.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static com.google.cloud.dataflow.sdk.TestUtils.LINES; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.File; + +/** + * Tests for PDone. + */ +@RunWith(JUnit4.class) +public class PDoneTest { + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + /** + * A PTransform that just returns a fresh PDone. + */ + static class EmptyTransform extends PTransform { + @Override + public PDone apply(PBegin begin) { + return PDone.in(begin.getPipeline()); + } + } + + /** + * A PTransform that's composed of something that returns a PDone. + */ + static class SimpleTransform extends PTransform { + private final String filename; + + public SimpleTransform(String filename) { + this.filename = filename; + } + + @Override + public PDone apply(PBegin begin) { + return + begin + .apply(Create.of(LINES)) + .apply(TextIO.Write.to(filename)); + } + } + + // TODO: This test doesn't work, because we can't handle composite + // transforms that contain no nested transforms. + @Ignore + @Test + @Category(RunnableOnService.class) + public void testEmptyTransform() { + Pipeline p = TestPipeline.create(); + + p.begin().apply(new EmptyTransform()); + + p.run(); + } + + // Cannot run on the service, unless we allocate a GCS temp file + // instead of a local temp file. Or switch to applying a different + // transform that returns PDone. + @Test + public void testSimpleTransform() throws Exception { + File tmpFile = tmpFolder.newFile("file.txt"); + String filename = tmpFile.getPath(); + + Pipeline p = TestPipeline.create(); + + p.begin().apply(new SimpleTransform(filename)); + + p.run(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TupleTagTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TupleTagTest.java new file mode 100644 index 000000000000..af6743419412 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TupleTagTest.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.startsWith; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link TupleTag}. + */ +@RunWith(JUnit4.class) +public class TupleTagTest { + + private static TupleTag staticTag = new TupleTag<>(); + private static TupleTag staticBlockTag; + private static TupleTag staticMethodTag = createTag(); + private static TupleTag instanceMethodTag = new AnotherClass().createAnotherTag(); + + static { + staticBlockTag = new TupleTag<>(); + } + + private static TupleTag createTag() { + return new TupleTag<>(); + } + + private static class AnotherClass { + private static TupleTag anotherTag = new TupleTag<>(); + private TupleTag createAnotherTag() { + return new TupleTag<>(); + } + } + + @Test + public void testStaticTupleTag() { + assertEquals("com.google.cloud.dataflow.sdk.values.TupleTagTest#0", staticTag.getId()); + assertEquals("com.google.cloud.dataflow.sdk.values.TupleTagTest#3", staticBlockTag.getId()); + assertEquals("com.google.cloud.dataflow.sdk.values.TupleTagTest#1", staticMethodTag.getId()); + assertEquals("com.google.cloud.dataflow.sdk.values.TupleTagTest#2", instanceMethodTag.getId()); + assertEquals( + "com.google.cloud.dataflow.sdk.values.TupleTagTest$AnotherClass#0", + AnotherClass.anotherTag.getId()); + } + + private TupleTag createNonstaticTupleTag() { + return new TupleTag(); + } + + @Test + public void testNonstaticTupleTag() { + assertNotEquals(new TupleTag().getId(), new TupleTag().getId()); + assertNotEquals(createNonstaticTupleTag(), createNonstaticTupleTag()); + + TupleTag tag = createNonstaticTupleTag(); + + // Check that the name is derived from the method it is created in. + assertThat(tag.getId().split("#")[0], + startsWith("com.google.cloud.dataflow.sdk.values.TupleTagTest.createNonstaticTupleTag")); + + // Check that after the name there is a ':' followed by a line number, and just make + // sure the line number is big enough to be reasonable, so superficial changes don't break + // the test. + assertThat(Integer.parseInt(tag.getId().split("#")[0].split(":")[1]), + greaterThan(15)); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TypeDescriptorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TypeDescriptorTest.java new file mode 100644 index 000000000000..a811a7cf6533 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TypeDescriptorTest.java @@ -0,0 +1,193 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.junit.Assert.assertEquals; + +import com.google.common.reflect.TypeToken; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.lang.reflect.Method; +import java.lang.reflect.TypeVariable; +import java.util.List; +import java.util.Set; + +/** + * Tests for TypeDescriptor. + */ +@RunWith(JUnit4.class) +public class TypeDescriptorTest { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + public void testTypeDescriptorOfRawType() throws Exception { + assertEquals( + TypeToken.of(String.class).getRawType(), + TypeDescriptor.of(String.class).getRawType()); + } + + @Test + public void testTypeDescriptorImmediate() throws Exception { + TypeDescriptor descriptor = new TypeDescriptor(){}; + assertEquals(String.class, descriptor.getRawType()); + } + + @Test + public void testTypeDescriptorGeneric() throws Exception { + TypeDescriptor> descriptor = new TypeDescriptor>(){}; + TypeToken> token = new TypeToken>(){}; + assertEquals(token.getType(), descriptor.getType()); + } + + private static class TypeRememberer { + public final TypeDescriptor descriptorByClass; + public final TypeDescriptor descriptorByInstance; + + public TypeRememberer() { + descriptorByClass = new TypeDescriptor(getClass()){}; + descriptorByInstance = new TypeDescriptor(this) {}; + } + } + + @Test + public void testTypeDescriptorNested() throws Exception { + TypeRememberer rememberer = new TypeRememberer(){}; + assertEquals(new TypeToken() {}.getType(), rememberer.descriptorByClass.getType()); + assertEquals(new TypeToken() {}.getType(), rememberer.descriptorByInstance.getType()); + + TypeRememberer> genericRememberer = new TypeRememberer>(){}; + assertEquals(new TypeToken>() {}.getType(), + genericRememberer.descriptorByClass.getType()); + assertEquals(new TypeToken>() {}.getType(), + genericRememberer.descriptorByInstance.getType()); + } + + private static class Id { + @SuppressWarnings("unused") // used via reflection + public T identity(T thingie) { + return thingie; + } + } + + @Test + public void testGetArgumentTypes() throws Exception { + Method identity = Id.class.getDeclaredMethod("identity", Object.class); + + TypeToken> token = new TypeToken >(){}; + TypeDescriptor> descriptor = new TypeDescriptor >(){}; + assertEquals( + token.method(identity).getParameters().get(0).getType().getType(), + descriptor.getArgumentTypes(identity).get(0).getType()); + + TypeToken>> genericToken = new TypeToken >>(){}; + TypeDescriptor>> genericDescriptor = new TypeDescriptor >>(){}; + assertEquals( + genericToken.method(identity).getParameters().get(0).getType().getType(), + genericDescriptor.getArgumentTypes(identity).get(0).getType()); + } + + private static class TypeRemembererer { + public TypeDescriptor descriptor1; + public TypeDescriptor descriptor2; + + public TypeRemembererer() { + descriptor1 = new TypeDescriptor(getClass()){}; + descriptor2 = new TypeDescriptor(getClass()){}; + } + } + + @Test + public void testTypeDescriptorNested2() throws Exception { + TypeRemembererer remembererer = new TypeRemembererer(){}; + assertEquals(new TypeToken() {}.getType(), remembererer.descriptor1.getType()); + assertEquals(new TypeToken() {}.getType(), remembererer.descriptor2.getType()); + + TypeRemembererer, Set> genericRemembererer = + new TypeRemembererer, Set>(){}; + assertEquals(new TypeToken>() {}.getType(), + genericRemembererer.descriptor1.getType()); + assertEquals(new TypeToken>() {}.getType(), + genericRemembererer.descriptor2.getType()); + } + + private static class GenericClass { } + + @Test + public void testGetTypeParameterGood() throws Exception { + @SuppressWarnings("rawtypes") + TypeVariable> bizzleT = + TypeDescriptor.of(GenericClass.class).getTypeParameter("BizzleT"); + assertEquals(GenericClass.class.getTypeParameters()[0], bizzleT); + } + + @Test + public void testGetTypeParameterBad() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("MerpleT"); // just check that the message gives actionable details + TypeDescriptor.of(GenericClass.class).getTypeParameter("MerpleT"); + } + + private static class GenericMaker { + public TypeRememberer> getRememberer() { + return new TypeRememberer>() {}; + } + } + + private static class GenericMaker2 { + public GenericMaker> getGenericMaker() { + return new GenericMaker>() {}; + } + } + + @Test + public void testEnclosing() throws Exception { + TypeRememberer> rememberer = new GenericMaker(){}.getRememberer(); + assertEquals( + new TypeToken>() {}.getType(), rememberer.descriptorByInstance.getType()); + + // descriptorByClass *is not* able to find the type of T because it comes from the enclosing + // instance of GenericMaker. + // assertEquals(new TypeToken>() {}.getType(), rememberer.descriptorByClass.getType()); + } + + @Test + public void testEnclosing2() throws Exception { + // If we don't override, the best we can get is List> + // TypeRememberer>> rememberer = + // new GenericMaker2(){}.getGenericMaker().getRememberer(); + // assertNotEquals( + // new TypeToken>>() {}.getType(), + // rememberer.descriptorByInstance.getType()); + + // If we've overridden the getGenericMaker we can determine the types. + TypeRememberer>> rememberer = new GenericMaker2() { + @Override public GenericMaker> getGenericMaker() { + return new GenericMaker>() {}; + } + }.getGenericMaker().getRememberer(); + assertEquals( + new TypeToken>>() {}.getType(), + rememberer.descriptorByInstance.getType()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TypedPValueTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TypedPValueTest.java new file mode 100644 index 000000000000..b0a13ec37d28 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/values/TypedPValueTest.java @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.values; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + +/** + * Tests for {@link TypedPValue}, primarily focusing on Coder inference. + */ +@RunWith(JUnit4.class) +public class TypedPValueTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private static class IdentityDoFn extends DoFn { + private static final long serialVersionUID = 0; + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element()); + } + } + + private static PCollectionTuple buildPCollectionTupleWithTags( + TupleTag mainOutputTag, TupleTag sideOutputTag) { + Pipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(1, 2, 3)); + PCollectionTuple tuple = input.apply( + ParDo + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) + .of(new IdentityDoFn())); + return tuple; + } + + private static TupleTag makeTagStatically() { + return new TupleTag() {}; + } + + @Test + public void testUntypedSideOutputTupleTagGivesActionableMessage() { + TupleTag mainOutputTag = new TupleTag() {}; + // untypedSideOutputTag did not use anonymous subclass. + TupleTag untypedSideOutputTag = new TupleTag(); + PCollectionTuple tuple = buildPCollectionTupleWithTags(mainOutputTag, untypedSideOutputTag); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("No Coder has been manually specified"); + thrown.expectMessage("erasure"); + thrown.expectMessage("see TupleTag Javadoc"); + + tuple.get(untypedSideOutputTag).getCoder(); + } + + @Test + public void testStaticFactorySideOutputTupleTagGivesActionableMessage() { + TupleTag mainOutputTag = new TupleTag() {}; + // untypedSideOutputTag constructed from a static factory method. + TupleTag untypedSideOutputTag = makeTagStatically(); + PCollectionTuple tuple = buildPCollectionTupleWithTags(mainOutputTag, untypedSideOutputTag); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("No Coder has been manually specified"); + thrown.expectMessage("erasure"); + thrown.expectMessage("see TupleTag Javadoc"); + + tuple.get(untypedSideOutputTag).getCoder(); + } + + @Test + public void testTypedSideOutputTupleTag() { + TupleTag mainOutputTag = new TupleTag() {}; + // typedSideOutputTag was constructed with compile-time type information. + TupleTag typedSideOutputTag = new TupleTag() {}; + PCollectionTuple tuple = buildPCollectionTupleWithTags(mainOutputTag, typedSideOutputTag); + + assertThat(tuple.get(typedSideOutputTag).getCoder(), instanceOf(VarIntCoder.class)); + } + + @Test + public void testUntypedMainOutputTagTypedSideOutputTupleTag() { + // mainOutputTag is allowed to be untyped because Coder can be inferred other ways. + TupleTag mainOutputTag = new TupleTag<>(); + TupleTag typedSideOutputTag = new TupleTag() {}; + PCollectionTuple tuple = buildPCollectionTupleWithTags(mainOutputTag, typedSideOutputTag); + + assertThat(tuple.get(typedSideOutputTag).getCoder(), instanceOf(VarIntCoder.class)); + } + + // A simple class for which there should be no obvious Coder. + static class EmptyClass { + } + + private static class EmptyClassDoFn extends DoFn { + private static final long serialVersionUID = 0; + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(new EmptyClass()); + } + } + + @Test + public void testParDoWithNoSideOutputsErrorDoesNotMentionTupleTag() { + Pipeline p = TestPipeline.create(); + PCollection input = + p.apply(Create.of(1, 2, 3)).apply(ParDo.of(new EmptyClassDoFn())); + + thrown.expect(IllegalStateException.class); + + // Output specific to ParDo TupleTag side outputs should not be present. + thrown.expectMessage(not(containsString("erasure"))); + thrown.expectMessage(not(containsString("see TupleTag Javadoc"))); + // Instead, expect output suggesting other possible fixes. + thrown.expectMessage(containsString("Building a Coder using a registered CoderFactory failed")); + thrown.expectMessage( + containsString("Building a Coder from the @DefaultCoder annotation failed")); + thrown.expectMessage(containsString("Building a Coder from the fallback CoderProvider failed")); + + input.getCoder(); + } + + @Test + public void testFinishSpecifyingShouldFailIfNoCoderInferrable() { + Pipeline p = TestPipeline.create(); + PCollection unencodable = + p.apply(Create.of(1, 2, 3)).apply(ParDo.of(new EmptyClassDoFn())); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Unable to return a default Coder"); + thrown.expectMessage("Inferring a Coder from the CoderRegistry failed"); + + unencodable.finishSpecifying(); + } +} + diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/CombineJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/CombineJava8Test.java new file mode 100644 index 000000000000..b569e49c951d --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/CombineJava8Test.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Java 8 Tests for {@link Combine}. + */ +@RunWith(JUnit4.class) +@SuppressWarnings("serial") +public class CombineJava8Test implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + /** + * Class for use in testing use of Java 8 method references. + */ + private static class Summer implements Serializable { + public int sum(Iterable integers) { + int sum = 0; + for (int i : integers) { + sum += i; + } + return sum; + } + } + + /** + * Tests creation of a global {@link Combine} via Java 8 lambda. + */ + @Test + public void testCombineGloballyLambda() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4)) + .apply(Combine.globally(integers -> { + int sum = 0; + for (int i : integers) { + sum += i; + } + return sum; + })); + + DataflowAssert.that(output).containsInAnyOrder(10); + pipeline.run(); + } + + /** + * Tests creation of a global {@link Combine} via a Java 8 method reference. + */ + @Test + public void testCombineGloballyInstanceMethodReference() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4)) + .apply(Combine.globally(new Summer()::sum)); + + DataflowAssert.that(output).containsInAnyOrder(10); + pipeline.run(); + } + + /** + * Tests creation of a per-key {@link Combine} via a Java 8 lambda. + */ + @Test + public void testCombinePerKeyLambda() { + Pipeline pipeline = TestPipeline.create(); + + PCollection> output = pipeline + .apply(Create.of(KV.of("a", 1), KV.of("b", 2), KV.of("a", 3), KV.of("c", 4))) + .apply(Combine.perKey(integers -> { + int sum = 0; + for (int i : integers) { + sum += i; + } + return sum; + })); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", 4), + KV.of("b", 2), + KV.of("c", 4)); + pipeline.run(); + } + + /** + * Tests creation of a per-key {@link Combine} via a Java 8 method reference. + */ + @Test + public void testCombinePerKeyInstanceMethodReference() { + Pipeline pipeline = TestPipeline.create(); + + PCollection> output = pipeline + .apply(Create.of(KV.of("a", 1), KV.of("b", 2), KV.of("a", 3), KV.of("c", 4))) + .apply(Combine.perKey(new Summer()::sum)); + + DataflowAssert.that(output).containsInAnyOrder( + KV.of("a", 4), + KV.of("b", 2), + KV.of("c", 4)); + pipeline.run(); + } +} diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/FilterJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/FilterJava8Test.java new file mode 100644 index 000000000000..db65932ccba4 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/FilterJava8Test.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Java 8 Tests for {@link Filter}. + */ +@RunWith(JUnit4.class) +@SuppressWarnings("serial") +public class FilterJava8Test implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(RunnableOnService.class) + public void testIdentityFilterByPredicate() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(591, 11789, 1257, 24578, 24799, 307)) + .apply(Filter.byPredicate(i -> true)); + + DataflowAssert.that(output).containsInAnyOrder(591, 11789, 1257, 24578, 24799, 307); + pipeline.run(); + } + + @Test + public void testNoFilterByPredicate() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(1, 2, 4, 5)) + .apply(Filter.byPredicate(i -> false)); + + DataflowAssert.that(output).empty(); + pipeline.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testFilterByPredicate() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.byPredicate(i -> i % 2 == 0)); + + DataflowAssert.that(output).containsInAnyOrder(2, 4, 6); + pipeline.run(); + } + + /** + * Confirms that in Java 8 style, where a lambda results in a rawtype, the output type token is + * not useful. If this test ever fails there may be simplifications available to us. + */ + @Test + public void testFilterParDoOutputTypeDescriptorRaw() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + PCollection output = pipeline + .apply(Create.of("hello")) + .apply(Filter.by(s -> true)); + + thrown.expect(CannotProvideCoderException.class); + pipeline.getCoderRegistry().getDefaultCoder(output.getTypeDescriptor()); + } + + @Test + @Category(RunnableOnService.class) + public void testFilterByMethodReference() { + Pipeline pipeline = TestPipeline.create(); + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.byPredicate(new EvenFilter()::isEven)); + + DataflowAssert.that(output).containsInAnyOrder(2, 4, 6); + pipeline.run(); + } + + private static class EvenFilter implements Serializable { + public boolean isEven(int i) { + return i % 2 == 0; + } + } +} diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/FlatMapElementsJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/FlatMapElementsJava8Test.java new file mode 100644 index 000000000000..e0b946b77f40 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/FlatMapElementsJava8Test.java @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.List; + +/** + * Java 8 Tests for {@link FlatMapElements}. + */ +@RunWith(JUnit4.class) +public class FlatMapElementsJava8Test implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + /** + * Basic test of {@link FlatMapElements} with a lambda (which is instantiated as a + * {@link SerializableFunction}). + */ + @Test + public void testFlatMapBasic() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(FlatMapElements + // Note that the input type annotation is required. + .via((Integer i) -> ImmutableList.of(i, -i)) + .withOutputType(new TypeDescriptor() {})); + + DataflowAssert.that(output).containsInAnyOrder(1, 3, -1, -3, 2, -2); + pipeline.run(); + } + + /** + * Basic test of {@link FlatMapElements} with a method reference. + */ + @Test + public void testFlatMapMethodReference() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(FlatMapElements + // Note that the input type annotation is required. + .via(new Negater()::numAndNegation) + .withOutputType(new TypeDescriptor() {})); + + DataflowAssert.that(output).containsInAnyOrder(1, 3, -1, -3, 2, -2); + pipeline.run(); + } + + private static class Negater implements Serializable { + public List numAndNegation(int input) { + return ImmutableList.of(input, -input); + } + } +} diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/MapElementsJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/MapElementsJava8Test.java new file mode 100644 index 000000000000..123e6803876c --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/MapElementsJava8Test.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Java 8 tests for {@link MapElements}. + */ +@RunWith(JUnit4.class) +public class MapElementsJava8Test implements Serializable { + + /** + * Basic test of {@link MapElements} with a lambda (which is instantiated as a + * {@link SerializableFunction}). + */ + @Test + public void testMapBasic() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(MapElements + // Note that the type annotation is required (for Java, not for Dataflow) + .via((Integer i) -> i * 2) + .withOutputType(new TypeDescriptor() {})); + + DataflowAssert.that(output).containsInAnyOrder(6, 2, 4); + pipeline.run(); + } + + /** + * Basic test of {@link MapElements} with a method reference. + */ + @Test + public void testMapMethodReference() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(MapElements + // Note that the type annotation is required (for Java, not for Dataflow) + .via(new Doubler()::doubleIt) + .withOutputType(new TypeDescriptor() {})); + + DataflowAssert.that(output).containsInAnyOrder(6, 2, 4); + pipeline.run(); + } + + private static class Doubler implements Serializable { + public int doubleIt(int val) { + return val * 2; + } + } +} diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/PartitionJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/PartitionJava8Test.java new file mode 100644 index 000000000000..c459ada0cec2 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/PartitionJava8Test.java @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Java 8 Tests for {@link Filter}. + */ +@RunWith(JUnit4.class) +@SuppressWarnings("serial") +public class PartitionJava8Test implements Serializable { + + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + + @Test + public void testModPartition() { + Pipeline pipeline = TestPipeline.create(); + + PCollectionList outputs = pipeline + .apply(Create.of(1, 2, 4, 5)) + .apply(Partition.of(3, (element, numPartitions) -> element % numPartitions)); + assertEquals(3, outputs.size()); + DataflowAssert.that(outputs.get(0)).empty(); + DataflowAssert.that(outputs.get(1)).containsInAnyOrder(1, 4); + DataflowAssert.that(outputs.get(2)).containsInAnyOrder(2, 5); + pipeline.run(); + } + + /** + * Confirms that in Java 8 style, where a lambda results in a rawtype, the output type token is + * not useful. If this test ever fails there may be simplifications available to us. + */ + @Test + public void testPartitionFnOutputTypeDescriptorRaw() throws Exception { + Pipeline pipeline = TestPipeline.create(); + + PCollectionList output = pipeline + .apply(Create.of("hello")) + .apply(Partition.of(1, (element, numPartitions) -> 0)); + + thrown.expect(CannotProvideCoderException.class); + pipeline.getCoderRegistry().getDefaultCoder(output.get(0).getTypeDescriptor()); + } +} diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesJava8Test.java new file mode 100644 index 000000000000..d9e2180b7da7 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/RemoveDuplicatesJava8Test.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.HashSet; +import java.util.Set; + +/** + * Java 8 tests for {@link RemoveDuplicates}. + */ +@RunWith(JUnit4.class) +public class RemoveDuplicatesJava8Test { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void withLambdaRepresentativeValuesFnAndTypeDescriptorShouldApplyFn() { + TestPipeline p = TestPipeline.create(); + + Multimap predupedContents = HashMultimap.create(); + predupedContents.put(3, "foo"); + predupedContents.put(4, "foos"); + predupedContents.put(6, "barbaz"); + predupedContents.put(6, "bazbar"); + PCollection dupes = + p.apply(Create.of("foo", "foos", "barbaz", "barbaz", "bazbar", "foo")); + PCollection deduped = + dupes.apply(RemoveDuplicates.withRepresentativeValueFn((String s) -> s.length()) + .withRepresentativeType(TypeDescriptor.of(Integer.class))); + + DataflowAssert.that(deduped).satisfies((Iterable strs) -> { + Set seenLengths = new HashSet<>(); + for (String s : strs) { + assertThat(predupedContents.values(), hasItem(s)); + assertThat(seenLengths, not(contains(s.length()))); + seenLengths.add(s.length()); + } + return null; + }); + + p.run(); + } + + @Test + public void withLambdaRepresentativeValuesFnNoTypeDescriptorShouldThrow() { + TestPipeline p = TestPipeline.create(); + + Multimap predupedContents = HashMultimap.create(); + predupedContents.put(3, "foo"); + predupedContents.put(4, "foos"); + predupedContents.put(6, "barbaz"); + predupedContents.put(6, "bazbar"); + PCollection dupes = + p.apply(Create.of("foo", "foos", "barbaz", "barbaz", "bazbar", "foo")); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Unable to return a default Coder for RemoveRepresentativeDupes"); + thrown.expectMessage("Cannot provide a coder for type variable K"); + thrown.expectMessage("the actual type is unknown due to erasure."); + + // Thrown when applying a transform to the internal WithKeys that withRepresentativeValueFn is + // implemented with + dupes.apply("RemoveRepresentativeDupes", + RemoveDuplicates.withRepresentativeValueFn((String s) -> s.length())); + } +} + diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/WithKeysJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/WithKeysJava8Test.java new file mode 100644 index 000000000000..c10af2903013 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/WithKeysJava8Test.java @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TypeDescriptor; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Java 8 Tests for {@link WithKeys}. + */ +@RunWith(JUnit4.class) +public class WithKeysJava8Test { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + @Category(RunnableOnService.class) + public void withLambdaAndTypeDescriptorShouldSucceed() { + TestPipeline p = TestPipeline.create(); + + PCollection values = p.apply(Create.of("1234", "3210", "0", "-12")); + PCollection> kvs = values.apply( + WithKeys.of((String s) -> Integer.valueOf(s)) + .withKeyType(TypeDescriptor.of(Integer.class))); + + DataflowAssert.that(kvs).containsInAnyOrder( + KV.of(1234, "1234"), KV.of(0, "0"), KV.of(-12, "-12"), KV.of(3210, "3210")); + + p.run(); + } + + @Test + public void withLambdaAndNoTypeDescriptorShouldThrow() { + TestPipeline p = TestPipeline.create(); + + PCollection values = p.apply(Create.of("1234", "3210", "0", "-12")); + + values.apply("ApplyKeysWithWithKeys", WithKeys.of((String s) -> Integer.valueOf(s))); + + thrown.expect(PipelineExecutionException.class); + thrown.expectMessage("Unable to return a default Coder for ApplyKeysWithWithKeys"); + thrown.expectMessage("Cannot provide a coder for type variable K"); + thrown.expectMessage("the actual type is unknown due to erasure."); + + p.run(); + } +} + diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/WithTimestampsJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/WithTimestampsJava8Test.java new file mode 100644 index 000000000000..50b5ff737f76 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/transforms/WithTimestampsJava8Test.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Java 8 tests for {@link WithTimestamps}. + */ +@RunWith(JUnit4.class) +public class WithTimestampsJava8Test implements Serializable { + @Test + @Category(RunnableOnService.class) + public void withTimestampsLambdaShouldApplyTimestamps() { + TestPipeline p = TestPipeline.create(); + + String yearTwoThousand = "946684800000"; + PCollection timestamped = + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) + .apply(WithTimestamps.of((String input) -> new Instant(Long.valueOf(yearTwoThousand)))); + + PCollection> timestampedVals = + timestamped.apply(ParDo.of(new DoFn>() { + @Override + public void processElement(DoFn>.ProcessContext c) + throws Exception { + c.output(KV.of(c.element(), c.timestamp())); + } + })); + + DataflowAssert.that(timestamped) + .containsInAnyOrder(yearTwoThousand, "0", "1234", Integer.toString(Integer.MAX_VALUE)); + DataflowAssert.that(timestampedVals) + .containsInAnyOrder( + KV.of("0", new Instant(0)), + KV.of("1234", new Instant("1234")), + KV.of(Integer.toString(Integer.MAX_VALUE), new Instant(Integer.MAX_VALUE)), + KV.of(yearTwoThousand, new Instant(Long.valueOf(yearTwoThousand)))); + } +} + diff --git a/travis/README.md b/travis/README.md new file mode 100644 index 000000000000..f7d89999b3d9 --- /dev/null +++ b/travis/README.md @@ -0,0 +1,4 @@ +# Travis Scripts + +This directory contains scripts used for [Travis CI](https://travis-ci.org/GoogleCloudPlatform/DataflowJavaSDK) +testing. diff --git a/travis/test_wordcount.sh b/travis/test_wordcount.sh new file mode 100755 index 000000000000..2e8a58b8cbe0 --- /dev/null +++ b/travis/test_wordcount.sh @@ -0,0 +1,108 @@ +#!/bin/bash + +# This script runs WordCount example locally in a few different ways. +# Specifically, all combinations of: +# a) using mvn exec, or java -cp with a bundled jar file; +# b) input filename with no directory component, with a relative directory, or +# with an absolute directory; AND +# c) input filename containing wildcards or not. +# +# The one optional parameter is a path from the directory containing the script +# to the directory containing the top-level (parent) pom.xml. If no parameter +# is provided, the script assumes that directory is equal to the directory +# containing the script itself. +# +# The exit-code of the script indicates success or a failure. + +set -e +set -o pipefail + +PASS=1 +JAR_FILE=examples/target/google-cloud-dataflow-java-examples-all-bundled-manual_build.jar + +function check_result_hash { + local name=$1 + local outfile_prefix=$2 + local expected=$3 + + local actual=$(LC_ALL=C sort $outfile_prefix-* | md5sum | awk '{print $1}' \ + || LC_ALL=C sort $outfile_prefix-* | md5 -q) || exit 2 # OSX + if [[ "$actual" != "$expected" ]] + then + echo "FAIL $name: Output hash mismatch. Got $actual, expected $expected." + PASS="" + echo "head hexdump of actual:" + head $outfile_prefix-* | hexdump -c + else + echo "pass $name" + # Output files are left behind in /tmp + fi +} + +function get_outfile_prefix { + local name=$1 + # NOTE: mktemp on OSX doesn't support --tmpdir + mktemp -u "/tmp/$name.out.XXXXXXXXXX" +} + +function run_via_mvn { + local name=$1 + local input=$2 + local expected_hash=$3 + + local outfile_prefix="$(get_outfile_prefix "$name")" || exit 2 + local cmd='mvn exec:java -f pom.xml -pl examples \ + -Dexec.mainClass=com.google.cloud.dataflow.examples.WordCount \ + -Dexec.args="--runner=DirectPipelineRunner --inputFile='"$input"' --output='"$outfile_prefix"'"' + echo "$name: Running $cmd" >&2 + sh -c "$cmd" + check_result_hash "$name" "$outfile_prefix" "$expected_hash" +} + +function run_bundled { + local name=$1 + local input=$2 + local expected_hash=$3 + + local outfile_prefix="$(get_outfile_prefix "$name")" || exit 2 + local cmd='java -cp '"$JAR_FILE"' \ + com.google.cloud.dataflow.examples.WordCount \ + --runner=DirectPipelineRunner \ + --inputFile='"'$input'"' \ + --output='"$outfile_prefix" + echo "$name: Running $cmd" >&2 + sh -c "$cmd" + check_result_hash "$name" "$outfile_prefix" "$expected_hash" +} + +function run_all_ways { + local name=$1 + local input=$2 + local expected_hash=$3 + + run_via_mvn ${name}a "$input" $expected_hash + check_for_jar_file + run_bundled ${name}b "$input" $expected_hash +} + +function check_for_jar_file { + if [[ ! -f $JAR_FILE ]] + then + echo "Jar file $JAR_FILE not created" >&2 + exit 2 + fi +} + +run_all_ways wordcount1 "LICENSE" c5350a5ad4bb51e3e018612b4b044097 +run_all_ways wordcount2 "./LICENSE" c5350a5ad4bb51e3e018612b4b044097 +run_all_ways wordcount3 "$PWD/LICENSE" c5350a5ad4bb51e3e018612b4b044097 +run_all_ways wordcount4 "L*N?E*" c5350a5ad4bb51e3e018612b4b044097 +run_all_ways wordcount5 "./LICE*N?E" c5350a5ad4bb51e3e018612b4b044097 +run_all_ways wordcount6 "$PWD/*LIC?NSE" c5350a5ad4bb51e3e018612b4b044097 + +if [[ ! "$PASS" ]] +then + echo "One or more tests FAILED." + exit 1 +fi +echo "All tests PASS"