diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index bd361b773eff..cbfa0377fab4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,10 +1,17 @@ +DESCRIPTION HERE + +------------------------ + Follow this checklist to help us incorporate your contribution quickly and easily: - [ ] Make sure there is a [JIRA issue](https://issues.apache.org/jira/projects/BEAM/issues/) filed for the change (usually before you start working on it). Trivial changes like typos do not require a JIRA issue. Your pull request should address just this issue, without pulling in other changes. - - [ ] Each commit in the pull request should have a meaningful subject line and body. - [ ] Format the pull request title like `[BEAM-XXX] Fixes bug in ApproximateQuantiles`, where you replace `BEAM-XXX` with the appropriate JIRA issue. - - [ ] Write a pull request description that is detailed enough to understand what the pull request does, how, and why. + - [ ] Write a pull request description that is detailed enough to understand: + - [ ] What the pull request does + - [ ] Why it does it + - [ ] How it does it + - [ ] Why this approach + - [ ] Each commit in the pull request should have a meaningful subject line and body. - [ ] Run `mvn clean verify` to make sure basic checks pass. A more thorough check will be performed on your pull request automatically. - [ ] If this contribution is large, please file an Apache [Individual Contributor License Agreement](https://www.apache.org/licenses/icla.pdf). ---- diff --git a/.gitignore b/.gitignore index ff2faad05fd8..e81a4e379c83 100644 --- a/.gitignore +++ b/.gitignore @@ -37,12 +37,14 @@ sdks/python/LICENSE sdks/python/NOTICE sdks/python/README.md sdks/python/apache_beam/portability/api/*pb2*.* +sdks/python/apache_beam/portability/common_urns.py # Ignore IntelliJ files. **/.idea/**/* **/*.iml **/*.ipr **/*.iws +**/out/**/* # Ignore Eclipse files. **/.classpath diff --git a/.test-infra/jenkins/job_beam_PerformanceTests_JDBC.groovy b/.test-infra/jenkins/job_beam_PerformanceTests_JDBC.groovy index ef73a261b0c4..1e5131f3ddf5 100644 --- a/.test-infra/jenkins/job_beam_PerformanceTests_JDBC.groovy +++ b/.test-infra/jenkins/job_beam_PerformanceTests_JDBC.groovy @@ -19,45 +19,63 @@ import common_job_properties // This job runs the Beam performance tests on PerfKit Benchmarker. -job('beam_PerformanceTests_JDBC'){ +job('beam_PerformanceTests_JDBC') { // Set default Beam job properties. common_job_properties.setTopLevelMainJobProperties(delegate) // Run job in postcommit every 6 hours, don't trigger every push, and // don't email individual committers. common_job_properties.setPostCommit( - delegate, - '0 */6 * * *', - false, - 'commits@beam.apache.org', - false) + delegate, + '0 */6 * * *', + false, + 'commits@beam.apache.org', + false) + + common_job_properties.enablePhraseTriggeringFromPullRequest( + delegate, + 'Java JdbcIO Performance Test', + 'Run Java JdbcIO Performance Test') def pipelineArgs = [ - tempRoot: 'gs://temp-storage-for-end-to-end-tests', - project: 'apache-beam-testing', - postgresServerName: '10.36.0.11', - postgresUsername: 'postgres', - postgresDatabaseName: 'postgres', - postgresPassword: 'uuinkks', - postgresSsl: 'false' + tempRoot : 'gs://temp-storage-for-perf-tests', + project : 'apache-beam-testing', + postgresPort : '5432', + numberOfRecords: '5000000' ] - def pipelineArgList = [] - pipelineArgs.each({ - key, value -> pipelineArgList.add("--$key=$value") - }) - def pipelineArgsJoined = pipelineArgList.join(',') - - def argMap = [ - benchmarks: 'beam_integration_benchmark', - beam_it_module: 'sdks/java/io/jdbc', - beam_it_args: pipelineArgsJoined, - beam_it_class: 'org.apache.beam.sdk.io.jdbc.JdbcIOIT', - // Profile is located in $BEAM_ROOT/sdks/java/io/pom.xml. - beam_it_profile: 'io-it' + + def testArgs = [ + kubeconfig : '"$HOME/.kube/config"', + beam_it_timeout : '1800', + benchmarks : 'beam_integration_benchmark', + beam_it_profile : 'io-it', + beam_prebuilt : 'true', + beam_sdk : 'java', + beam_it_module : 'sdks/java/io/jdbc', + beam_it_class : 'org.apache.beam.sdk.io.jdbc.JdbcIOIT', + beam_it_options : joinPipelineOptions(pipelineArgs), + beam_kubernetes_scripts : makePathAbsolute('src/.test-infra/kubernetes/postgres/postgres.yml') + + ',' + makePathAbsolute('src/.test-infra/kubernetes/postgres/postgres-service-for-local-dev.yml'), + beam_options_config_file: makePathAbsolute('src/.test-infra/kubernetes/postgres/pkb-config-local.yml'), + bigquery_table : 'beam_performance.jdbcioit_pkb_results' ] - common_job_properties.buildPerformanceTest(delegate, argMap) + steps { + // create .kube/config file for perfkit (if not exists) + shell('gcloud container clusters get-credentials io-datastores --zone=us-central1-a --verbosity=debug') + } - // [BEAM-2141] Perf tests do not pass. - disabled() + common_job_properties.buildPerformanceTest(delegate, testArgs) } + +static String joinPipelineOptions(Map pipelineArgs) { + List pipelineArgList = [] + pipelineArgs.each({ + key, value -> pipelineArgList.add("\"--$key=$value\"") + }) + return "[" + pipelineArgList.join(',') + "]" +} + +static String makePathAbsolute(String path) { + return '"$WORKSPACE/' + path + '"' +} \ No newline at end of file diff --git a/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Flink.groovy b/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Flink.groovy index 5b228bc9cb64..a0a957acbab5 100644 --- a/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Flink.groovy +++ b/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Flink.groovy @@ -39,5 +39,5 @@ mavenJob('beam_PostCommit_Java_ValidatesRunner_Flink') { 'Run Flink ValidatesRunner') // Maven goals for this job. - goals('-B -e clean verify -am -pl runners/flink -Plocal-validates-runner-tests -Pvalidates-runner-tests') + goals('-B -e clean verify -am -pl runners/flink -Plocal-validates-runner-tests') } diff --git a/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Spark.groovy b/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Spark.groovy index 2ec4cd54142b..b4a0d029db5e 100644 --- a/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Spark.groovy +++ b/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Spark.groovy @@ -40,5 +40,5 @@ mavenJob('beam_PostCommit_Java_ValidatesRunner_Spark') { 'Run Spark ValidatesRunner') // Maven goals for this job. - goals('-B -e clean verify -am -pl runners/spark -Pvalidates-runner-tests -Plocal-validates-runner-tests -Dspark.ui.enabled=false') + goals('-B -e clean verify -am -pl runners/spark -Plocal-validates-runner-tests -Dspark.ui.enabled=false') } diff --git a/.test-infra/jenkins/job_beam_PostCommit_Python_ValidatesContainer_Dataflow.groovy b/.test-infra/jenkins/job_beam_PostCommit_Python_ValidatesContainer_Dataflow.groovy new file mode 100644 index 000000000000..5a76b03fa171 --- /dev/null +++ b/.test-infra/jenkins/job_beam_PostCommit_Python_ValidatesContainer_Dataflow.groovy @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import common_job_properties + +// This job runs the suite of Python ValidatesContainer tests against the +// Dataflow runner. +job('beam_PostCommit_Python_ValidatesContainer_Dataflow') { + description('Runs Python ValidatesContainer suite on the Dataflow runner.') + + // Set common parameters. + common_job_properties.setTopLevelMainJobProperties(delegate) + + // Sets that this is a PostCommit job. + common_job_properties.setPostCommit(delegate, '30 3 * * *', false) + + // Allows triggering this build against pull requests. + common_job_properties.enablePhraseTriggeringFromPullRequest( + delegate, + 'Google Cloud Dataflow Runner Python ValidatesContainer Tests', + 'Run Python Dataflow ValidatesContainer') + + // Allow the test to only run on particular nodes + // TODO(BEAM-1817): Remove once the tests can run on all nodes + parameters { + nodeParam('TEST_HOST') { + description('select test host as either beam1, 2 or 3') + defaultNodes(['beam1', 'beam2', 'beam3']) + allowedNodes(['beam1', 'beam2', 'beam3']) + trigger('multiSelectionDisallowed') + eligibility('IgnoreOfflineNodeEligibility') + } + } + + // Execute shell command to test Python SDK. + steps { + shell('cd ' + common_job_properties.checkoutDir + ' && bash sdks/python/container/run_validatescontainer.sh') + } +} diff --git a/.test-infra/jenkins/job_beam_PostRelease_NightlySnapshot.groovy b/.test-infra/jenkins/job_beam_PostRelease_NightlySnapshot.groovy index 60abf9e6464c..1da9d2c75984 100644 --- a/.test-infra/jenkins/job_beam_PostRelease_NightlySnapshot.groovy +++ b/.test-infra/jenkins/job_beam_PostRelease_NightlySnapshot.groovy @@ -31,10 +31,10 @@ job('beam_PostRelease_NightlySnapshot') { parameters { stringParam('snapshot_version', - '2.3.0-SNAPSHOT', + '', 'Version of the repository snapshot to install') stringParam('snapshot_url', - 'https://repository.apache.org/content/repositories/snapshots', + '', 'Repository URL to install from') } @@ -42,11 +42,21 @@ job('beam_PostRelease_NightlySnapshot') { common_job_properties.setPostCommit( delegate, '0 11 * * *', - false, - 'dev@beam.apache.org') + false) + + + // Allows triggering this build against pull requests. + common_job_properties.enablePhraseTriggeringFromPullRequest( + delegate, + './gradlew :release:runQuickstartsJava', + 'Run Dataflow PostRelease') steps { - // Run a quickstart from https://beam.apache.org/get-started/quickstart-java/ - shell('cd ' + common_job_properties.checkoutDir + '/release && groovy quickstart-java-direct.groovy') + // Run a quickstart from https://beam.apache.org/get-started/quickstart-java + gradle { + rootBuildScriptDir(common_job_properties.checkoutDir) + tasks(':release:runQuickstartsJava') + switches('-Pver=$snapshot_version -Prepourl=$snapshot_url') + } } } diff --git a/.test-infra/jenkins/job_beam_PreCommit_Go_GradleBuild.groovy b/.test-infra/jenkins/job_beam_PreCommit_Go_GradleBuild.groovy new file mode 100644 index 000000000000..509a8f3857e1 --- /dev/null +++ b/.test-infra/jenkins/job_beam_PreCommit_Go_GradleBuild.groovy @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import common_job_properties + +// This is the Go precommit which runs a gradle build, and the current set +// of precommit tests. +job('beam_PreCommit_Go_GradleBuild') { + description('Runs Go PreCommit tests for the current GitHub Pull Request.') + + // Execute concurrent builds if necessary. + concurrentBuild() + + // Set common parameters. + common_job_properties.setTopLevelMainJobProperties( + delegate, + 'master', + 240) + + def gradle_switches = [ + // Gradle log verbosity enough to diagnose basic build issues + "--info", + // Continue the build even if there is a failure to show as many potential failures as possible. + '--continue', + // Until we verify the build cache is working appropriately, force rerunning all tasks + '--rerun-tasks', + ] + + def gradle_command_line = './gradlew ' + gradle_switches.join(' ') + ' :goPreCommit' + // Sets that this is a PreCommit job. + common_job_properties.setPreCommit(delegate, gradle_command_line, 'Run Go Gradle PreCommit') + steps { + gradle { + rootBuildScriptDir(common_job_properties.checkoutDir) + tasks(':goPreCommit') + for (String gradle_switch : gradle_switches) { + switches(gradle_switch) + } + } + } +} diff --git a/.test-infra/jenkins/job_beam_PreCommit_Java_GradleBuild.groovy b/.test-infra/jenkins/job_beam_PreCommit_Java_GradleBuild.groovy index a9989da477ea..f5deef1b5521 100644 --- a/.test-infra/jenkins/job_beam_PreCommit_Java_GradleBuild.groovy +++ b/.test-infra/jenkins/job_beam_PreCommit_Java_GradleBuild.groovy @@ -21,7 +21,7 @@ import common_job_properties // This is the Java precommit which runs a Gradle build, and the current set // of precommit tests. job('beam_PreCommit_Java_GradleBuild') { - description('Runs a build of the current GitHub Pull Request.') + description('Runs Java PreCommit tests for the current GitHub Pull Request.') // Execute concurrent builds if necessary. concurrentBuild() @@ -38,6 +38,8 @@ job('beam_PreCommit_Java_GradleBuild') { } def gradle_switches = [ + // Gradle log verbosity enough to diagnose basic build issues + "--info", // Continue the build even if there is a failure to show as many potential failures as possible. '--continue', // Until we verify the build cache is working appropriately, force rerunning all tasks diff --git a/assembly.xml b/assembly.xml index 42442b8e90d4..6534e3f1f5ab 100644 --- a/assembly.xml +++ b/assembly.xml @@ -69,6 +69,7 @@ **/.gogradle/** **/build/** **/vendor/** + **/out/** **/.gradletasknamecache diff --git a/build.gradle b/build.gradle index ad7fa7ee7a4e..886da78412cb 100644 --- a/build.gradle +++ b/build.gradle @@ -66,7 +66,7 @@ ext.library = [ bigdataoss_util: "com.google.cloud.bigdataoss:util:$google_cloud_bigdataoss_version", bigtable_client_core: "com.google.cloud.bigtable:bigtable-client-core:$bigtable_version", bigtable_protos: "com.google.cloud.bigtable:bigtable-protos:$bigtable_proto_version", - byte_buddy: "net.bytebuddy:byte-buddy:1.6.8", + byte_buddy: "net.bytebuddy:byte-buddy:1.7.10", commons_compress: "org.apache.commons:commons-compress:1.14", commons_csv: "org.apache.commons:commons-csv:1.4", commons_io_1x: "commons-io:commons-io:1.3.2", @@ -83,7 +83,7 @@ ext.library = [ google_api_client_jackson2: "com.google.api-client:google-api-client-jackson2:$google_clients_version", google_api_client_java6: "com.google.api-client:google-api-client-java6:$google_clients_version", google_api_common: "com.google.api:api-common:1.0.0-rc2", - google_api_services_bigquery: "com.google.apis:google-api-services-bigquery:v2-rev355-$google_clients_version", + google_api_services_bigquery: "com.google.apis:google-api-services-bigquery:v2-rev374-$google_clients_version", google_api_services_clouddebugger: "com.google.apis:google-api-services-clouddebugger:v2-rev8-$google_clients_version", google_api_services_cloudresourcemanager: "com.google.apis:google-api-services-cloudresourcemanager:v1-rev6-$google_clients_version", google_api_services_dataflow: "com.google.apis:google-api-services-dataflow:v1b3-rev221-$google_clients_version", @@ -180,7 +180,7 @@ buildscript { classpath "gradle.plugin.org.nosphere.apache:creadur-rat-gradle:0.3.1" // Enable Apache license enforcement classpath "com.commercehub.gradle.plugin:gradle-avro-plugin:0.11.0" // Enable Avro code generation classpath "com.diffplug.spotless:spotless-plugin-gradle:3.6.0" // Enable a code formatting plugin - classpath "gradle.plugin.com.github.blindpirate:gogradle:0.7.0" // Enable Go code compilation + classpath "gradle.plugin.com.github.blindpirate:gogradle:0.8.1" // Enable Go code compilation classpath "gradle.plugin.com.palantir.gradle.docker:gradle-docker:0.13.0" // Enable building Docker containers classpath "cz.malohlava:visteg:1.0.3" // Enable generating Gradle task dependencies as ".dot" files classpath "com.github.jengelman.gradle.plugins:shadow:2.0.1" // Enable shading Java dependencies @@ -208,6 +208,7 @@ rat { // Exclude files generated by the Gradle build process "**/.gradle/**/*", "**/.gogradle/**/*", + "**/gogradle.lock", "**/build/**/*", "**/vendor/**/*", "**/.gradletasknamecache", @@ -241,6 +242,7 @@ rat { "**/*.iml", "**/*.ipr", "**/*.iws", + "**/out/**/*", // .gitignore: Ignore Eclipse files. "**/.classpath", @@ -284,8 +286,11 @@ rat { } check.dependsOn rat -// Define a root Java PreCommit task simplifying what is needed +// Define root PreCommit tasks simplifying what is needed // to be specified on the commandline when executing locally. +// This indirection also makes Jenkins use the branch of the PR +// for the test definitions. + def javaPreCommitRoots = [ ":sdks:java:core", ":runners:direct-java", @@ -298,3 +303,7 @@ task javaPreCommit() { } dependsOn ":examples:java:preCommit" } + +task goPreCommit() { + dependsOn ":sdks:go:test" +} diff --git a/build_rules.gradle b/build_rules.gradle index bf466c87a0c0..14a2d580bc54 100644 --- a/build_rules.gradle +++ b/build_rules.gradle @@ -38,7 +38,7 @@ println "Applying build_rules.gradle to $project.name" // We use the project.path as the group name to make this mapping unique since // we have a few projects with the same name. group = project.path -version = "2.3.0-SNAPSHOT" +version = "2.4.0-SNAPSHOT" // Define the default set of repositories for all builds. repositories { @@ -67,6 +67,9 @@ repositories { maven { url "https://repository.apache.org/content/repositories/releases" } } +// Apply a plugin which enables configuring projects imported into Intellij. +apply plugin: "idea" + // Provide code coverage // TODO: Should this only apply to Java projects? apply plugin: "jacoco" @@ -123,6 +126,7 @@ class JavaNatureConfiguration { double javaVersion = 1.8 // Controls the JDK source language and target compatibility boolean enableFindbugs = true // Controls whether the findbugs plugin is enabled and configured boolean enableShadow = true // Controls whether the shadow plugin is enabled and configured + String artifactId = null // Sets the maven publication artifact id } // Configures a project with a default set of plugins that should apply to all Java projects. @@ -156,16 +160,16 @@ class JavaNatureConfiguration { // Will become a runtime dependency of the generated Maven pom. // * testCompile - Required during compilation or runtime of the test source set. // This must be shaded away in the shaded test jar. -// * testShadow - Required during compilation or runtime of the test source set. +// * shadowTest - Required during compilation or runtime of the test source set. // TODO: Figure out whether this should be a test scope dependency // of the generated Maven pom. // // When creating a cross-project dependency between two Java projects, one should only rely on the shaded configurations. // This allows for compilation/test execution to occur against the final artifact that will be provided to users. -// This is by done by referencing the "shadow" or "testShadow" configuration as so: +// This is by done by referencing the "shadow" or "shadowTest" configuration as so: // dependencies { // shadow project(path: "other:java:project1", configuration: "shadow") -// testShadow project(path: "other:java:project2", configuration: "testShadow") +// shadowTest project(path: "other:java:project2", configuration: "shadowTest") // } // This will ensure the correct set of transitive dependencies from those projects are correctly added to the // main and test source set runtimes. @@ -174,7 +178,7 @@ ext.applyJavaNature = { println "applyJavaNature with " + (it ? "$it" : "default configuration") + " for project $project.name" // Use the implicit it parameter of the closure to handle zero argument or one argument map calls. JavaNatureConfiguration configuration = it ? it as JavaNatureConfiguration : new JavaNatureConfiguration() - apply plugin: "maven" + apply plugin: "maven-publish" apply plugin: "java" // Configure the Java compiler source language and target compatibility levels. Also ensure that @@ -215,7 +219,7 @@ ext.applyJavaNature = { // Note that these plugins specifically use the compileOnly and testCompileOnly // configurations because they are never required to be shaded or become a // dependency of the output. - def auto_value = "com.google.auto.value:auto-value:1.5.1" + def auto_value = "com.google.auto.value:auto-value:1.5.3" def auto_service = "com.google.auto.service:auto-service:1.0-rc2" compileOnly auto_value @@ -244,6 +248,9 @@ ext.applyJavaNature = { showViolations = true maxErrors = 0 } + checkstyle { + toolVersion = "8.7" + } // Apply the eclipse and apt-eclipse plugins. This adds the "eclipse" task and // connects the apt-eclipse plugin to update the eclipse project files @@ -255,15 +262,6 @@ ext.applyJavaNature = { // Enables a plugin which can apply code formatting to source. // TODO: Should this plugin be enabled for all projects? apply plugin: "com.diffplug.gradle.spotless" - spotless { - java { - target rootProject.fileTree(rootProject.rootDir) { - include 'sdks/java/**/*.java' - } - // Code formatting disabled because style rules are out of date. - // eclipse().configFile(rootProject.file('sdks/java/build-tools/src/main/resources/beam/beam-codestyle.xml')) - } - } // Enables a plugin which performs code analysis for common bugs. // This plugin is configured to only analyze the "main" source set. @@ -323,7 +321,27 @@ ext.applyJavaNature = { // Ensure that shaded classes are part of the artifact set. artifacts.archives shadowJar - // TODO: Figure out how to create ShadowJar task for testShadowJar here + if (configuration.artifactId) { + // If a publication artifact id is supplied, publish the shadow jar. + publishing { + publications { + mavenJava(MavenPublication) { + artifact(shadowJar) { + groupId "org.apache.beam" + artifactId configuration.artifactId + // Strip the "shaded" classifier. + classifier null + // Set readable name to project description. + pom.withXml { + asNode().appendNode('name', description) + } + } + } + } + } + } + + // TODO: Figure out how to create ShadowJar task for ShadowTestJar here // that is extendable within each sub-project with any additional includes. // This could mirror the "shadowJar" configuration block. // Optionally, we could also copy the shading configuration from the main @@ -354,12 +372,40 @@ ext.applyJavaNature = { // which chooses the latest version available. // // TODO: Figure out whether we should force all dependency conflict resolution - // to occur in the "shadow" and "testShadow" configurations. + // to occur in the "shadow" and "shadowTest" configurations. configurations.all { resolutionStrategy { force library.java.values() } } + + // These directories for when build actions are delegated to Gradle + def gradleAptGeneratedMain = "${project.buildDir}/generated/source/apt/main" + def gradleAptGeneratedTest = "${project.buildDir}/generated/source/apt/test" + + // These directories for when build actions are executed by Idea + // IntelliJ does not add these source roots (that it owns!) unless hinted + def ideaRoot = "${project.projectDir}/out" + def ideaAptGeneratedMain = "${ideaRoot}/production/classes/generated" + def ideaAptGeneratedTest = "${ideaRoot}/test/classes/generated_test" + + idea { + module { + sourceDirs += file(gradleAptGeneratedMain) + testSourceDirs += file(gradleAptGeneratedTest) + + sourceDirs += file(ideaAptGeneratedMain) + testSourceDirs += file(ideaAptGeneratedTest) + + generatedSourceDirs += [ + file(gradleAptGeneratedMain), + file(gradleAptGeneratedTest), + file(ideaAptGeneratedMain), + file(ideaAptGeneratedTest) + ] + + } + } } /*************************************************************************************************/ @@ -371,30 +417,26 @@ ext.applyGoNature = { goVersion = '1.9' } - // GoGradle fails in a parallel build during dependency resolution/installation. - // Force a dependency between all GoGradle projects during dependency resolution/installation. - // TODO: Figure out how to do this by automatically figuring out the task dependency DAG - // based upon task type. - List goProjects = [ - ":sdks:go", - ":runners:gcp:gcemd", - ":runners:gcp:gcsproxy", - ":sdks:python:container", - ":sdks:java:container", - ] - if (!goProjects.contains(project.path)) { - throw new GradleException(project.path + " has not been defined within the list of well known go projects within build_rules.gradle.") + repositories { + golang { + // Gogradle doesn't like thrift: https://github.com/gogradle/gogradle/issues/183 + root 'git.apache.org/thrift.git' + emptyDir() + } + golang { + root 'github.com/apache/thrift' + emptyDir() + } } - int index = goProjects.indexOf(project.path) - if (index != 0) { - String previous = goProjects.get(index - 1) - println "Forcing: '" + previous + "' to be evaulated before '" + project.path + "'" - evaluationDependsOn(previous) - afterEvaluate { - println "Forcing: '" + project.path + ":resolveBuildDependencies' must run after '" + previous + ":installDependencies'" - tasks.getByPath(project.path + ":resolveBuildDependencies").mustRunAfter tasks.getByPath(previous + ":installDependencies") - println "Forcing: '" + project.path + ":resolveTestDependencies' must run after '" + previous + ":installDependencies'" - tasks.getByPath(project.path + ":resolveTestDependencies").mustRunAfter tasks.getByPath(previous + ":installDependencies") + + idea { + module { + // The gogradle plugin downloads all dependencies into the source tree here, + // which is a path baked into golang + excludeDirs += file("${project.path}/vendor") + + // gogradle's private working directory + excludeDirs += file("${project.path}/.gogradle") } } } @@ -440,6 +482,26 @@ ext.applyGrpcNature = { } } } + + def generatedProtoMainJavaDir = "${project.buildDir}/generated/source/proto/main/java" + def generatedProtoTestJavaDir = "${project.buildDir}/generated/source/proto/test/java" + def generatedGrpcMainJavaDir = "${project.buildDir}/generated/source/proto/main/grpc" + def generatedGrpcTestJavaDir = "${project.buildDir}/generated/source/proto/test/grpc" + idea { + module { + sourceDirs += file(generatedProtoMainJavaDir) + generatedSourceDirs += file(generatedProtoMainJavaDir) + + testSourceDirs += file(generatedProtoTestJavaDir) + generatedSourceDirs += file(generatedProtoTestJavaDir) + + sourceDirs += file(generatedGrpcMainJavaDir) + generatedSourceDirs += file(generatedGrpcMainJavaDir) + + testSourceDirs += file(generatedGrpcTestJavaDir) + generatedSourceDirs += file(generatedGrpcTestJavaDir) + } + } } /*************************************************************************************************/ @@ -451,4 +513,40 @@ ext.applyAvroNature = { apply plugin: "com.commercehub.gradle.plugin.avro" } +// A class defining the set of configurable properties for createJavaQuickstartValidationTask +class JavaQuickstartConfiguration { + // Name for the quickstart is required. + // Used both for the test name runQuickstartJava${name} + // and also for the script name, quickstart-java-${name}.toLowerCase(). + String name + + // gcpProject sets the gcpProject argument when executing the quickstart. + String gcpProject + + // gcsBucket sets the gcsProject argument when executing the quickstart. + String gcsBucket +} +// Creates a task to run the quickstart for a runner. +// Releases version and URL, can be overriden for a RC release with +// ./gradlew :release:runQuickstartJava -Pver=2.3.0 -Prepourl=https://repository.apache.org/content/repositories/orgapachebeam-1027 +ext.createJavaQuickstartValidationTask = { + JavaQuickstartConfiguration config = it as JavaQuickstartConfiguration + def taskName = "runQuickstartJava${config.name}" + println "Generating :${taskName}" + def releaseVersion = project.findProperty('ver') ?: version + def releaseRepo = project.findProperty('repourl') ?: 'https://repository.apache.org/content/repositories/snapshots' + def argsNeeded = ["--ver=${releaseVersion}", "--repourl=${releaseRepo}"] + if (config.gcpProject) { + argsNeeded.add("--gcpProject=${config.gcpProject}") + } + if (config.gcsBucket) { + argsNeeded.add("--gcsBucket=${config.gcsBucket}") + } + project.evaluationDependsOn(':release') + task "${taskName}" (dependsOn: ':release:classes', type: JavaExec) { + main = "quickstart-java-${config.name}".toLowerCase() + classpath = project(':release').sourceSets.main.runtimeClasspath + args argsNeeded + } +} diff --git a/examples/java/pom.xml b/examples/java/pom.xml index 750485768dff..46f1f8c6101c 100644 --- a/examples/java/pom.xml +++ b/examples/java/pom.xml @@ -511,7 +511,12 @@ org.hamcrest - hamcrest-all + hamcrest-core + + + + org.hamcrest + hamcrest-library @@ -539,7 +544,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java b/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java index f953b1339da9..ad1bd0cf81f5 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java +++ b/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java @@ -254,13 +254,16 @@ private static class AllPrefixes extends DoFn> { private final int minPrefix; private final int maxPrefix; + public AllPrefixes(int minPrefix) { this(minPrefix, Integer.MAX_VALUE); } + public AllPrefixes(int minPrefix, int maxPrefix) { this.minPrefix = minPrefix; this.maxPrefix = maxPrefix; } + @ProcessElement public void processElement(ProcessContext c) { String word = c.element().value; diff --git a/examples/java/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToBigQuery.java b/examples/java/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToBigQuery.java index 2ec4e5c9a133..c1b3019b0dab 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToBigQuery.java +++ b/examples/java/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToBigQuery.java @@ -91,6 +91,7 @@ FieldFn getFieldFn() { return this.fieldFn; } } + /** Convert each key/score pair into a BigQuery TableRow as specified by fieldFn. */ protected class BuildRowFn extends DoFn { diff --git a/examples/java/src/main/java/org/apache/beam/examples/cookbook/TriggerExample.java b/examples/java/src/main/java/org/apache/beam/examples/cookbook/TriggerExample.java index b62e23f3794e..6ec1702f2994 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/cookbook/TriggerExample.java +++ b/examples/java/src/main/java/org/apache/beam/examples/cookbook/TriggerExample.java @@ -367,6 +367,7 @@ static class FormatTotalFlow extends DoFn, TableRow> { public FormatTotalFlow(String triggerType) { this.triggerType = triggerType; } + @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) throws Exception { String[] values = c.element().getValue().split(","); diff --git a/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java b/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java index e2853da5ae04..55cbde1b728e 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java +++ b/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java @@ -17,10 +17,29 @@ */ package org.apache.beam.examples.snippets; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.AvroCoder; +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; +import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord; +import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.join.CoGbkResult; import org.apache.beam.sdk.transforms.join.CoGroupByKey; @@ -28,13 +47,321 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.sdk.values.ValueInSingleWindow; /** * Code snippets used in webdocs. */ public class Snippets { - /* Helper function to format results in coGroupByKeyTuple */ + @DefaultCoder(AvroCoder.class) + static class Quote { + final String source; + final String quote; + + public Quote() { + this.source = ""; + this.quote = ""; + } + + public Quote(String source, String quote) { + this.source = source; + this.quote = quote; + } + } + + @DefaultCoder(AvroCoder.class) + static class WeatherData { + final long year; + final long month; + final long day; + final double maxTemp; + + public WeatherData() { + this.year = 0; + this.month = 0; + this.day = 0; + this.maxTemp = 0.0f; + } + + public WeatherData(long year, long month, long day, double maxTemp) { + this.year = year; + this.month = month; + this.day = day; + this.maxTemp = maxTemp; + } + } + + /** Using a Read and Write transform to read/write from/to BigQuery. */ + public static void modelBigQueryIO(Pipeline p) { + modelBigQueryIO(p, "", "", ""); + } + + public static void modelBigQueryIO(Pipeline p, + String writeProject, String writeDataset, String writeTable) { + { + // [START BigQueryTableSpec] + String tableSpec = "clouddataflow-readonly:samples.weather_stations"; + // [END BigQueryTableSpec] + } + + { + // [START BigQueryTableSpecWithoutProject] + String tableSpec = "samples.weather_stations"; + // [END BigQueryTableSpecWithoutProject] + } + + { + // [START BigQueryTableSpecObject] + TableReference tableSpec = new TableReference() + .setProjectId("clouddataflow-readonly") + .setDatasetId("samples") + .setTableId("weather_stations"); + // [END BigQueryTableSpecObject] + } + + { + String tableSpec = "clouddataflow-readonly:samples.weather_stations"; + // [START BigQueryReadTable] + PCollection maxTemperatures = p + .apply(BigQueryIO.readTableRows().from(tableSpec)) + // Each row is of type TableRow + .apply(MapElements.into(TypeDescriptors.doubles()).via( + (TableRow row) -> (Double) row.get("max_temperature"))); + // [END BigQueryReadTable] + } + + { + String tableSpec = "clouddataflow-readonly:samples.weather_stations"; + // [START BigQueryReadFunction] + PCollection maxTemperatures = p + .apply(BigQueryIO.read( + (SchemaAndRecord elem) -> (Double) elem.getRecord().get("max_temperature")) + .from(tableSpec) + .withCoder(DoubleCoder.of())); + // [END BigQueryReadFunction] + } + + { + // [START BigQueryReadQuery] + PCollection maxTemperatures = p + .apply(BigQueryIO.read( + (SchemaAndRecord elem) -> (Double) elem.getRecord().get("max_temperature")) + .fromQuery( + "SELECT max_temperature FROM [clouddataflow-readonly:samples.weather_stations]") + .withCoder(DoubleCoder.of())); + // [END BigQueryReadQuery] + } + + { + // [START BigQueryReadQueryStdSQL] + PCollection maxTemperatures = p + .apply(BigQueryIO.read( + (SchemaAndRecord elem) -> (Double) elem.getRecord().get("max_temperature")) + .fromQuery( + "SELECT max_temperature FROM `clouddataflow-readonly.samples.weather_stations`") + .usingStandardSql() + .withCoder(DoubleCoder.of())); + // [END BigQueryReadQueryStdSQL] + } + + // [START BigQuerySchemaJson] + String tableSchemaJson = "" + + "{" + + " \"fields\": [" + + " {" + + " \"name\": \"source\"," + + " \"type\": \"STRING\"," + + " \"mode\": \"NULLABLE\"" + + " }," + + " {" + + " \"name\": \"quote\"," + + " \"type\": \"STRING\"," + + " \"mode\": \"REQUIRED\"" + + " }" + + " ]" + + "}"; + // [END BigQuerySchemaJson] + + { + String tableSpec = "clouddataflow-readonly:samples.weather_stations"; + if (!writeProject.isEmpty() && !writeDataset.isEmpty() && !writeTable.isEmpty()) { + tableSpec = writeProject + ":" + writeDataset + "." + writeTable; + } + + // [START BigQuerySchemaObject] + TableSchema tableSchema = new TableSchema().setFields(ImmutableList.of( + new TableFieldSchema().setName("source").setType("STRING").setMode("NULLABLE"), + new TableFieldSchema().setName("quote").setType("STRING").setMode("REQUIRED"))); + // [END BigQuerySchemaObject] + + // [START BigQueryWriteInput] + /* + @DefaultCoder(AvroCoder.class) + static class Quote { + final String source; + final String quote; + + public Quote() { + this.source = ""; + this.quote = ""; + } + public Quote(String source, String quote) { + this.source = source; + this.quote = quote; + } + } + */ + + PCollection quotes = p + .apply(Create.of( + new Quote("Mahatma Gandhi", "My life is my message."), + new Quote("Yoda", "Do, or do not. There is no 'try'.") + )); + // [END BigQueryWriteInput] + + // [START BigQueryWriteTable] + quotes + .apply(MapElements.into(TypeDescriptor.of(TableRow.class)).via( + (Quote elem) -> new TableRow().set("source", elem.source).set("quote", elem.quote) + )) + .apply(BigQueryIO.writeTableRows() + .to(tableSpec) + .withSchema(tableSchema) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE)); + // [END BigQueryWriteTable] + + // [START BigQueryWriteFunction] + quotes.apply(BigQueryIO.write() + .to(tableSpec) + .withSchema(tableSchema) + .withFormatFunction( + (Quote elem) -> new TableRow().set("source", elem.source).set("quote", elem.quote)) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE)); + // [END BigQueryWriteFunction] + + // [START BigQueryWriteJsonSchema] + quotes.apply(BigQueryIO.write() + .to(tableSpec) + .withJsonSchema(tableSchemaJson) + .withFormatFunction( + (Quote elem) -> new TableRow().set("source", elem.source).set("quote", elem.quote)) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE)); + // [END BigQueryWriteJsonSchema] + } + + { + // [START BigQueryWriteDynamicDestinations] + /* + @DefaultCoder(AvroCoder.class) + static class WeatherData { + final long year; + final long month; + final long day; + final double maxTemp; + + public WeatherData() { + this.year = 0; + this.month = 0; + this.day = 0; + this.maxTemp = 0.0f; + } + public WeatherData(long year, long month, long day, double maxTemp) { + this.year = year; + this.month = month; + this.day = day; + this.maxTemp = maxTemp; + } + } + */ + + PCollection weatherData = p + .apply(BigQueryIO.read( + (SchemaAndRecord elem) -> { + GenericRecord record = elem.getRecord(); + return new WeatherData( + (Long) record.get("year"), + (Long) record.get("month"), + (Long) record.get("day"), + (Double) record.get("max_temperature")); + }) + .fromQuery("SELECT year, month, day, max_temperature " + + "FROM [clouddataflow-readonly:samples.weather_stations] " + + "WHERE year BETWEEN 2007 AND 2009") + .withCoder(AvroCoder.of(WeatherData.class))); + + // We will send the weather data into different tables for every year. + weatherData.apply(BigQueryIO.write() + .to(new DynamicDestinations() { + @Override + public Long getDestination(ValueInSingleWindow elem) { + return elem.getValue().year; + } + + @Override + public TableDestination getTable(Long destination) { + return new TableDestination( + new TableReference() + .setProjectId(writeProject) + .setDatasetId(writeDataset) + .setTableId(writeTable + "_" + destination), + "Table for year " + destination); + } + + @Override + public TableSchema getSchema(Long destination) { + return new TableSchema().setFields(ImmutableList.of( + new TableFieldSchema().setName("year").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("month").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("day").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("maxTemp").setType("FLOAT").setMode("NULLABLE"))); + } + }) + .withFormatFunction( + (WeatherData elem) -> new TableRow() + .set("year", elem.year) + .set("month", elem.month) + .set("day", elem.day) + .set("maxTemp", elem.maxTemp)) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE)); + // [END BigQueryWriteDynamicDestinations] + + String tableSpec = "clouddataflow-readonly:samples.weather_stations"; + if (!writeProject.isEmpty() && !writeDataset.isEmpty() && !writeTable.isEmpty()) { + tableSpec = writeProject + ":" + writeDataset + "." + writeTable + "_partitioning"; + } + + TableSchema tableSchema = new TableSchema().setFields(ImmutableList.of( + new TableFieldSchema().setName("year").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("month").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("day").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("maxTemp").setType("FLOAT").setMode("NULLABLE"))); + + // [START BigQueryTimePartitioning] + weatherData.apply(BigQueryIO.write() + .to(tableSpec + "_partitioning") + .withSchema(tableSchema) + .withFormatFunction( + (WeatherData elem) -> new TableRow() + .set("year", elem.year) + .set("month", elem.month) + .set("day", elem.day) + .set("maxTemp", elem.maxTemp)) + // NOTE: an existing table without time partitioning set up will not work + .withTimePartitioning(new TimePartitioning().setType("DAY")) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE)); + // [END BigQueryTimePartitioning] + } + } + + /** Helper function to format results in coGroupByKeyTuple. */ public static String formatCoGbkResults( String name, Iterable emails, Iterable phones) { @@ -55,6 +382,7 @@ public static String formatCoGbkResults( return name + "; " + emailsStr + "; " + phonesStr; } + /** Using a CoGroupByKey transform. */ public static PCollection coGroupByKeyTuple( TupleTag emailsTag, TupleTag phonesTag, @@ -63,9 +391,10 @@ public static PCollection coGroupByKeyTuple( // [START CoGroupByKeyTuple] PCollection> results = - KeyedPCollectionTuple.of(emailsTag, emails) - .and(phonesTag, phones) - .apply(CoGroupByKey.create()); + KeyedPCollectionTuple + .of(emailsTag, emails) + .and(phonesTag, phones) + .apply(CoGroupByKey.create()); PCollection contactLines = results.apply(ParDo.of( new DoFn, String>() { diff --git a/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java b/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java index 67fc8e8b3b90..f5074ac95171 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java +++ b/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java @@ -37,7 +37,6 @@ import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.StreamingOptions; -import org.apache.beam.sdk.testing.FileChecksumMatcher; import org.apache.beam.sdk.testing.SerializableMatcher; import org.apache.beam.sdk.testing.StreamingIT; import org.apache.beam.sdk.testing.TestPipeline; @@ -59,8 +58,6 @@ import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** End-to-end integration test of {@link WindowedWordCount}. */ @RunWith(JUnit4.class) @@ -194,8 +191,6 @@ private void testWindowedWordCountPipeline(WindowedWordCountITOptions options) t private static class WordCountsMatcher extends TypeSafeMatcher implements SerializableMatcher { - private static final Logger LOG = LoggerFactory.getLogger(FileChecksumMatcher.class); - private final SortedMap expectedWordCounts; private final List outputFiles; private SortedMap actualCounts; diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest.java index fb86a1ab7cc7..9e5f2095eb60 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest.java @@ -47,8 +47,6 @@ @RunWith(JUnit4.class) public class StatefulTeamScoreTest { - private static final Duration ALLOWED_LATENESS = Duration.standardHours(1); - private static final Duration TEAM_WINDOW_DURATION = Duration.standardMinutes(20); private Instant baseTime = new Instant(0); @Rule diff --git a/examples/java/src/test/java/org/apache/beam/examples/snippets/SnippetsTest.java b/examples/java/src/test/java/org/apache/beam/examples/snippets/SnippetsTest.java index 7ebe1625cbf2..d446d4ee5313 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/snippets/SnippetsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/snippets/SnippetsTest.java @@ -22,6 +22,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -34,6 +37,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; + /** * Tests for Snippets. */ @@ -41,7 +45,32 @@ public class SnippetsTest implements Serializable { @Rule - public transient TestPipeline p = TestPipeline.create(); + public final transient TestPipeline p = TestPipeline.create(); + + @Test + public void testModelBigQueryIO() { + // We cannot test BigQueryIO functionality in unit tests, therefore we limit ourselves + // to making sure the pipeline containing BigQuery sources and sinks can be built. + // + // To run locally, set `runLocally` to `true`. You will have to set `project`, `dataset` and + // `table` to the BigQuery table the test will write into. + boolean runLocally = false; + if (runLocally) { + String project = "my-project"; + String dataset = "samples"; // this must already exist + String table = "modelBigQueryIO"; // this will be created if needed + + BigQueryOptions options = PipelineOptionsFactory.create().as(BigQueryOptions.class); + options.setProject(project); + options.setTempLocation("gs://" + project + "/samples/temp/"); + Pipeline p = Pipeline.create(options); + Snippets.modelBigQueryIO(p, project, dataset, table); + p.run(); + } else { + Pipeline p = Pipeline.create(); + Snippets.modelBigQueryIO(p); + } + } /* Tests CoGroupByKeyTuple */ @Test @@ -64,8 +93,8 @@ public void testCoGroupByKeyTuple() throws IOException { // [END CoGroupByKeyTupleInputs] // [START CoGroupByKeyTupleOutputs] - final TupleTag emailsTag = new TupleTag(); - final TupleTag phonesTag = new TupleTag(); + final TupleTag emailsTag = new TupleTag<>(); + final TupleTag phonesTag = new TupleTag<>(); final List> expectedResults = Arrays.asList( KV.of("amy", CoGbkResult @@ -95,7 +124,7 @@ public void testCoGroupByKeyTuple() throws IOException { // Make sure that both 'expectedResults' and 'actualFormattedResults' match with the // 'formattedResults'. 'expectedResults' will have to be formatted before comparing - List expectedFormattedResultsList = new ArrayList<>(expectedResults.size()); + List expectedFormattedResultsList = new ArrayList(expectedResults.size()); for (KV e : expectedResults) { String name = e.getKey(); Iterable emailsIter = e.getValue().getAll(emailsTag); diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index ed88a042a287..c44b679acd3f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 40b39ed67b77..cee0fbd5dbae 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -19,4 +19,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-4.2.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-4.5.1-bin.zip diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto index 16f7709982d1..28c755950249 100644 --- a/model/fn-execution/src/main/proto/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/beam_fn_api.proto @@ -278,6 +278,7 @@ message Metrics { // the actual watermarks. map watermarks = 3; + repeated User user = 4; // TODO: Define other transform level system metrics. } @@ -285,10 +286,7 @@ message Metrics { message User { // A key for identifying a metric at the most granular level. - message MetricKey { - // The step, if any, this metric is associated with. - string step = 1; - + message MetricName { // (Required): The namespace of this metric. string namespace = 2; @@ -309,18 +307,24 @@ message Metrics { int64 max = 4; } + // Data associated with a Gauge metric. + message GaugeData { + int64 value = 1; + google.protobuf.Timestamp timestamp = 2; + } + // (Required) The identifier for this metric. - MetricKey key = 1; + MetricName metric_name = 1; // (Required) The data for this metric. oneof data { CounterData counter_data = 1001; DistributionData distribution_data = 1002; + GaugeData gauge_data = 1003; } } map ptransforms = 1; - repeated User user = 2; } message ProcessBundleProgressResponse { diff --git a/model/pipeline/src/main/resources/org/apache/beam/model/common_urns.md b/model/pipeline/src/main/resources/org/apache/beam/model/common_urns.md new file mode 100644 index 000000000000..6e86f4e2af15 --- /dev/null +++ b/model/pipeline/src/main/resources/org/apache/beam/model/common_urns.md @@ -0,0 +1,134 @@ + + +# Apache Beam URNs + +This file serves as a central place to enumerate and document the various +URNs used in the Beam portability APIs. + + +## Core Transforms + +### urn:beam:transform:pardo:v1 + +TODO(BEAM-3595): Change this to beam:transform:pardo:v1. + +Represents Beam's parallel do operation. + +Payload: A serialized ParDoPayload proto. + +### beam:transform:group_by_key:v1 + +Represents Beam's group-by-key operation. + +Payload: None + +### beam:transform:window_into:v1 + +Payload: A windowing strategy id. + +### beam:transform:flatten:v1 + +### beam:transform:read:v1 + +### beam:transform:impulse:v1 + +## Combining + +If any of the combine operations are produced by an SDK, it is assumed that +the SDK understands the last three combine helper operations. + +### beam:transform:combine_globally:v1 + +### beam:transform:combine_per_key:v1 + +### beam:transform:combine_grouped_values:v1 + +### beam:transform:combine_pgbkcv:v1 + +### beam:transform:combine_merge_accumulators:v1 + +### beam:transform:combine_extract_outputs:v1 + + +## Other common transforms + +### beam:transform:reshuffle:v1 + +### beam:transform:map_windows:v1 + +## WindowFns + +### beam:windowfn:global_windows:v0.1 + +TODO(BEAM-3595): Change this to beam:windowfn:global_windows:v1 + +### beam:windowfn:fixed_windows:v0.1 + +TODO(BEAM-3595): Change this to beam:windowfn:fixed_windows:v1 + +### beam:windowfn:sliding_windows:v0.1 + +TODO(BEAM-3595): Change this to beam:windowfn:sliding_windows:v1 + +### beam:windowfn:session_windows:v0.1 + +TODO(BEAM-3595): Change this to beam:windowfn:session_windows:v1 + + +## Coders + +### beam:coder:bytes:v1 + +Components: None + +### beam:coder:varint:v1 + +Components: None + +### beam:coder:kv:v1 + +Components: The key and value coder, in that order. + +### beam:coder:iterable:v1 + +Encodes an iterable of elements. + +Components: Coder for a single element. + +## Internal coders + +The following coders are typically not specified by manually by the user, +but are used at runtime and must be supported by every SDK. + +### beam:coder:length_prefix:v1 + +### beam:coder:global_window:v1 + +### beam:coder:interval_window:v1 + +### beam:coder:windowed_value:v1 + + +## Side input access + +### beam:side_input:iterable:v1 + +### beam:side_input:multimap:v1 + diff --git a/pom.xml b/pom.xml index a5f27d0e464a..004f47057a19 100644 --- a/pom.xml +++ b/pom.xml @@ -109,7 +109,7 @@ 1.0.0-rc2 2.33 1.8.2 - v2-rev355-1.22.0 + v2-rev374-1.22.0 1.0.0 1.0.0-pre3 v1-rev6-1.22.0 @@ -121,7 +121,7 @@ 1.3.0 1.0.0-rc2 1.0-rc2 - 1.5.1 + 1.5.3 0.7.1 1.22.0 1.4.5 @@ -162,12 +162,15 @@ 0.12 1.5.0.Final 2.0 + 1.14 2.20.1 2.20.1 3.7.0 3.0.2 3.0.0-M1 + 1.0-beta-7 1.6.0 + 1.6 3.0.2 3.0.0-M1 1.14 @@ -175,6 +178,8 @@ 3.1.0 0.4 3.1.0 + 2.8.2 + 2.2.0 -Werror -Xpkginfo:always @@ -313,6 +318,7 @@ org.apache.maven.plugins maven-gpg-plugin + ${maven-gpg-plugin.version} sign-release-artifacts @@ -353,6 +359,38 @@ + + errorprone + + + + + org.apache.maven.plugins + maven-compiler-plugin + + javac-with-errorprone + true + true + + + + org.codehaus.plexus + plexus-compiler-javac-errorprone + ${plexus-compiler-java-errorprone.version} + + + + com.google.errorprone + error_prone_core + ${error_prone_core.version} + + + + + + + build-containers @@ -379,6 +417,61 @@ + + java-9 + + 9 + + + + true + true + true + + --add-modules java.base + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 9 + true + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + enforce + + enforce + + + + + 9 + + + + + + + + + + + javax.annotation + javax.annotation-api + 1.3.1 + + + @@ -416,26 +509,32 @@ org.apache.beam - beam-sdks-java-extensions-join-library + beam-sdks-java-extensions-google-cloud-platform-core ${project.version} org.apache.beam - beam-sdks-java-extensions-protobuf + beam-sdks-java-extensions-google-cloud-platform-core + tests ${project.version} org.apache.beam - beam-sdks-java-extensions-google-cloud-platform-core + beam-sdks-java-extensions-join-library ${project.version} org.apache.beam - beam-sdks-java-extensions-google-cloud-platform-core - tests + beam-sdks-java-extensions-json-jackson + ${project.version} + + + + org.apache.beam + beam-sdks-java-extensions-protobuf ${project.version} @@ -515,24 +614,24 @@ org.apache.beam - beam-sdks-java-io-elasticsearch-tests-common + beam-sdks-java-io-elasticsearch-tests-2 ${project.version} test - tests org.apache.beam - beam-sdks-java-io-elasticsearch-tests-2 + beam-sdks-java-io-elasticsearch-tests-5 ${project.version} test org.apache.beam - beam-sdks-java-io-elasticsearch-tests-5 + beam-sdks-java-io-elasticsearch-tests-common ${project.version} test + tests @@ -560,6 +659,12 @@ ${project.version} + + org.apache.beam + beam-sdks-java-io-hadoop-input-format + ${project.version} + + org.apache.beam beam-sdks-java-io-hbase @@ -608,6 +713,12 @@ ${project.version} + + org.apache.beam + beam-sdks-java-io-redis + ${project.version} + + org.apache.beam beam-sdks-java-io-solr @@ -616,7 +727,13 @@ org.apache.beam - beam-sdks-java-io-hadoop-input-format + beam-sdks-java-io-tika + ${project.version} + + + + org.apache.beam + beam-sdks-java-io-xml ${project.version} @@ -1229,7 +1346,7 @@ net.bytebuddy byte-buddy - 1.6.8 + 1.7.10 @@ -1312,7 +1429,7 @@ org.hamcrest - hamcrest-all + hamcrest-library ${hamcrest.version} test @@ -1349,7 +1466,7 @@ org.mockito - mockito-all + mockito-core ${mockito.version} test @@ -1411,12 +1528,12 @@ org.apache.maven.plugins maven-checkstyle-plugin - 2.17 + 3.0.0 com.puppycrawl.tools checkstyle - 6.19 + 8.7 org.apache.beam @@ -1431,6 +1548,13 @@ true false true + + + src/main/java + + + src/test/java + @@ -1489,6 +1612,7 @@ ${compiler.default.pkginfo.flag} + -parameters ${compiler.default.exclude} @@ -1589,6 +1713,7 @@ true + **/gogradle.lock .github/**/* @@ -1674,6 +1799,7 @@ false false true + random ${beamSurefireArgline} @@ -1943,6 +2069,7 @@ **/sdks/python/NOTICE **/sdks/python/README.md **/sdks/python/apache_beam/portability/api/*pb2*.* + **/sdks/python/apache_beam/portability/common_urns.py **/sdks/python/**/*.c **/sdks/python/**/*.so **/sdks/python/**/*.egg @@ -2056,6 +2183,14 @@ 1.8 + + + module-info + go -> go.out + // \ + // -> py -> py.out + // read.out can't be fused with both 'go' and 'py', so we should refuse to create this stage + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms( + "read", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .build()) + .putPcollections( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "goTransform", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "go.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections("go.out", PCollection.newBuilder().setUniqueName("go.out").build()) + .putTransforms( + "pyTransform", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "py.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("py.out", PCollection.newBuilder().setUniqueName("py.out").build()) + .putEnvironments("go", Environment.newBuilder().setUrl("go").build()) + .putEnvironments("py", Environment.newBuilder().setUrl("py").build()) + .build()); + Set differentEnvironments = + p.getPerElementConsumers( + PipelineNode.pCollection( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build())); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("go"); + thrown.expectMessage("py"); + thrown.expectMessage("same"); + GreedilyFusedExecutableStage.forGrpcPortRead( + p, + PipelineNode.pCollection( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build()), + differentEnvironments); + } + + @Test + public void noEnvironmentThrows() { + // (impulse.out) -> runnerTransform -> gbk.out + // runnerTransform can't be executed in an environment, so trying to construct it should fail + PTransform gbkTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .setSpec( + FunctionSpec.newBuilder().setUrn(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN)) + .putOutputs("output", "gbk.out") + .build(); + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("runnerTransform", gbkTransform) + .putPcollections( + "gbk.out", PCollection.newBuilder().setUniqueName("gbk.out").build()) + .build()); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Environment must be populated"); + GreedilyFusedExecutableStage.forGrpcPortRead( + p, + impulseOutputNode, + ImmutableSet.of(PipelineNode.pTransform("runnerTransform", gbkTransform))); + } + + @Test + public void fusesCompatibleEnvironments() { + // (impulse.out) -> parDo -> parDo.out -> window -> window.out + // parDo and window both have the environment "common" and can be fused together + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform windowTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "window.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms("window", windowTransform) + .putPcollections( + "window.out", PCollection.newBuilder().setUniqueName("window.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .build()); + + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, + impulseOutputNode, + ImmutableSet.of( + PipelineNode.pTransform("parDo", parDoTransform), + PipelineNode.pTransform("window", windowTransform))); + // Nothing consumes the outputs of ParDo or Window, so they don't have to be materialized + assertThat(subgraph.getOutputPCollections(), emptyIterable()); + assertThat( + subgraph.toPTransform().getSubtransformsList(), containsInAnyOrder("parDo", "window")); + } + + @Test + public void materializesWithStatefulConsumer() { + // (impulse.out) -> parDo -> (parDo.out) + // (parDo.out) -> stateful -> stateful.out + // stateful has a state spec which prevents it from fusing with an upstream ParDo + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform statefulTransform = + PTransform.newBuilder() + .putInputs("input", "parDo.out") + .putOutputs("output", "stateful.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .putStateSpecs("state", StateSpec.getDefaultInstance()) + .build() + .toByteString())) + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms("stateful", statefulTransform) + .putPcollections( + "stateful.out", PCollection.newBuilder().setUniqueName("stateful.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .build()); + + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, + impulseOutputNode, + ImmutableSet.of(PipelineNode.pTransform("parDo", parDoTransform))); + assertThat( + subgraph.getOutputPCollections(), + contains( + PipelineNode.pCollection( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()))); + assertThat( + subgraph.toPTransform().getSubtransformsList(), containsInAnyOrder("parDo")); + } + + @Test + public void materializesWithConsumerWithTimer() { + // (impulse.out) -> parDo -> (parDo.out) + // (parDo.out) -> timer -> timer.out + // timer has a timer spec which prevents it from fusing with an upstream ParDo + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform timerTransform = + PTransform.newBuilder() + .putInputs("input", "parDo.out") + .putOutputs("output", "timer.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .putTimerSpecs("timer", TimerSpec.getDefaultInstance()) + .build() + .toByteString())) + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms("timer", timerTransform) + .putPcollections( + "timer.out", PCollection.newBuilder().setUniqueName("timer.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .build()); + + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, + impulseOutputNode, + ImmutableSet.of(PipelineNode.pTransform("parDo", parDoTransform))); + assertThat( + subgraph.getOutputPCollections(), + contains( + PipelineNode.pCollection( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()))); + assertThat( + subgraph.toPTransform().getSubtransformsList(), containsInAnyOrder("parDo")); + } + + @Test + public void fusesFlatten() { + // (impulse.out) -> parDo -> parDo.out --> flatten -> flatten.out -> window -> window.out + // \ / + // -> read -> read.out - + // The flatten can be executed within the same environment as any transform; the window can + // execute in the same environment as the rest of the transforms, and can fuse with the stage + PTransform readTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform flattenTransform = + PTransform.newBuilder() + .putInputs("readInput", "read.out") + .putInputs("parDoInput", "parDo.out") + .putOutputs("output", "flatten.out") + .setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)) + .build(); + PTransform windowTransform = + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "window.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("read", readTransform) + .putPcollections( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms("flatten", flattenTransform) + .putPcollections( + "flatten.out", PCollection.newBuilder().setUniqueName("flatten.out").build()) + .putTransforms("window", windowTransform) + .putPcollections( + "window.out", PCollection.newBuilder().setUniqueName("window.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .build()); + + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, p.getPerElementConsumers(impulseOutputNode)); + assertThat(subgraph.getOutputPCollections(), emptyIterable()); + assertThat( + subgraph.toPTransform().getSubtransformsList(), + containsInAnyOrder("read", "parDo", "flatten", "window")); + } + + @Test + public void fusesFlattenWithDifferentEnvironmentInputs() { + // (impulse.out) -> read -> read.out \ -> window -> window.out + // -------> flatten -> flatten.out / + // (impulse.out) -> envRead -> envRead.out / + // fuses into + // read -> read.out -> flatten -> flatten.out -> window -> window.out + // envRead -> envRead.out -> flatten -> (flatten.out) + // (flatten.out) -> window -> window.out + PTransform readTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform otherEnvRead = + PTransform.newBuilder() + .putInputs("impulse", "impulse.out") + .putOutputs("output", "envRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("rare")) + .build() + .toByteString())) + .build(); + PTransform flattenTransform = + PTransform.newBuilder() + .putInputs("readInput", "read.out") + .putInputs("otherEnvInput", "envRead.out") + .putOutputs("output", "flatten.out") + .setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)) + .build(); + PTransform windowTransform = + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "window.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + + Components components = + partialComponents + .toBuilder() + .putTransforms("read", readTransform) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms("envRead", otherEnvRead) + .putPcollections( + "envRead.out", PCollection.newBuilder().setUniqueName("envRead.out").build()) + .putTransforms("flatten", flattenTransform) + .putPcollections( + "flatten.out", PCollection.newBuilder().setUniqueName("flatten.out").build()) + .putTransforms("window", windowTransform) + .putPcollections( + "window.out", PCollection.newBuilder().setUniqueName("window.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .putEnvironments("rare", Environment.newBuilder().setUrl("rare").build()) + .build(); + QueryablePipeline p = QueryablePipeline.fromComponents(components); + + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("read", readTransform))); + assertThat(subgraph.getOutputPCollections(), emptyIterable()); + assertThat( + subgraph.toPTransform().getSubtransformsList(), + containsInAnyOrder("read", "flatten", "window")); + + // Flatten shows up in both of these subgraphs, but elements only go through a path to the + // flatten once. + ExecutableStage readFromOtherEnv = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, + impulseOutputNode, + ImmutableSet.of(PipelineNode.pTransform("envRead", otherEnvRead))); + assertThat( + readFromOtherEnv.getOutputPCollections(), + contains( + PipelineNode.pCollection( + "flatten.out", components.getPcollectionsOrThrow("flatten.out")))); + assertThat( + readFromOtherEnv.toPTransform().getSubtransformsList(), + containsInAnyOrder("envRead", "flatten")); + } + + @Test + public void flattenWithHeterogeneousInputsAndOutputs() { + // (impulse.out) -> pyRead -> pyRead.out \ -> pyParDo -> pyParDo.out + // (impulse.out) -> -> flatten -> flatten.out | + // (impulse.out) -> goRead -> goRead.out / -> goWindow -> goWindow.out + // fuses into + // (impulse.out) -> pyRead -> pyRead.out -> flatten -> (flatten.out) + // (impulse.out) -> goRead -> goRead.out -> flatten -> (flatten.out) + // (flatten.out) -> pyParDo -> pyParDo.out + // (flatten.out) -> goWindow -> goWindow.out + PTransform pyRead = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "pyRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString()) + .build()) + .build(); + PTransform goRead = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "goRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString()) + .build()) + .build(); + + PTransform pyParDo = + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "pyParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString()) + .build()) + .build(); + PTransform goWindow = + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "goWindow.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString()) + .build()) + .build(); + + PCollection flattenPc = PCollection.newBuilder().setUniqueName("flatten.out").build(); + Components components = + partialComponents + .toBuilder() + .putTransforms("pyRead", pyRead) + .putPcollections( + "pyRead.out", PCollection.newBuilder().setUniqueName("pyRead.out").build()) + .putTransforms("goRead", goRead) + .putPcollections( + "goRead.out", PCollection.newBuilder().setUniqueName("goRead.out").build()) + .putTransforms( + "flatten", + PTransform.newBuilder() + .putInputs("py_input", "pyRead.out") + .putInputs("go_input", "goRead.out") + .putOutputs("output", "flatten.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN) + .build()) + .build()) + .putPcollections("flatten.out", flattenPc) + .putTransforms("pyParDo", pyParDo) + .putPcollections( + "pyParDo.out", PCollection.newBuilder().setUniqueName("pyParDo.out").build()) + .putTransforms("goWindow", goWindow) + .putPcollections( + "goWindow.out", PCollection.newBuilder().setUniqueName("goWindow.out").build()) + .putEnvironments("go", Environment.newBuilder().setUrl("go").build()) + .putEnvironments("py", Environment.newBuilder().setUrl("py").build()) + .build(); + QueryablePipeline p = QueryablePipeline.fromComponents(components); + + ExecutableStage readFromPy = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("pyRead", pyRead))); + ExecutableStage readFromGo = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("goRead", goRead))); + + assertThat( + readFromPy.getOutputPCollections(), + contains(PipelineNode.pCollection("flatten.out", flattenPc))); + // The stage must materialize the flatten, so the `go` stage can read it; this means that this + // parDo can't be in the stage, as it'll be a reader of that materialized PCollection. The same + // is true for the go window. + assertThat( + readFromPy.getTransforms(), not(hasItem(PipelineNode.pTransform("pyParDo", pyParDo)))); + + assertThat( + readFromGo.getOutputPCollections(), + contains(PipelineNode.pCollection("flatten.out", flattenPc))); + assertThat( + readFromGo.getTransforms(), not(hasItem(PipelineNode.pTransform("goWindow", goWindow)))); + } + + @Test + public void materializesWithDifferentEnvConsumer() { + // (impulse.out) -> parDo -> parDo.out -> window -> window.out + // Fuses into + // (impulse.out) -> parDo -> (parDo.out) + // (parDo.out) -> window -> window.out + Environment env = Environment.newBuilder().setUrl("common").build(); + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .putOutputs("out", "parDo.out") + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms( + "window", + PTransform.newBuilder() + .putInputs("input", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("rare")) + .build() + .toByteString())) + .build()) + .putEnvironments("rare", Environment.newBuilder().setUrl("rare").build()) + .putEnvironments("common", env) + .build()); + + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, p.getPerElementConsumers(impulseOutputNode)); + assertThat(subgraph.getOutputPCollections(), emptyIterable()); + assertThat(subgraph.getInputPCollection(), equalTo(impulseOutputNode)); + assertThat(subgraph.getEnvironment(), equalTo(env)); + assertThat( + subgraph.getTransforms(), contains(PipelineNode.pTransform("parDo", parDoTransform))); + } + + @Test + public void materializesWithDifferentEnvSibling() { + // (impulse.out) -> read -> read.out -> parDo -> parDo.out + // \ + // -> window -> window.out + // Fuses into + // (impulse.out) -> read -> (read.out) + // (read.out) -> parDo -> parDo.out + // (read.out) -> window -> window.out + // The window can't be fused into the stage, which forces the PCollection to be materialized. + // ParDo in this case _could_ be fused into the stage, but is not for simplicity of + // implementation + Environment env = Environment.newBuilder().setUrl("common").build(); + PTransform readTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("read", readTransform) + .putPcollections( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "parDo", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build()) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms( + "window", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "window.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("rare")) + .build() + .toByteString())) + .build()) + .putPcollections( + "window.out", PCollection.newBuilder().setUniqueName("window.out").build()) + .putEnvironments("rare", Environment.newBuilder().setUrl("rare").build()) + .putEnvironments("common", env) + .build()); + + PTransformNode readNode = PipelineNode.pTransform("read", readTransform); + PCollectionNode readOutput = getOnlyElement(p.getOutputPCollections(readNode)); + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("read", readTransform))); + assertThat(subgraph.getOutputPCollections(), contains(readOutput)); + assertThat(subgraph.getTransforms(), contains(readNode)); + } + + @Test + public void materializesWithSideInputConsumer() { + // (impulse.out) -> read -> read.out -----------> parDo -> parDo.out -> window -> window.out + // (impulse.out) -> side_read -> side_read.out / + // Where parDo takes side_read as a side input, fuses into + // (impulse.out) -> read -> (read.out) + // (impulse.out) -> side_read -> (side_read.out) + // (read.out) -> parDo -> parDo.out -> window -> window.out + // parDo doesn't have a per-element consumer from side_read.out, so it can't root a stage + // which consumes from that materialized collection. Nodes with side inputs must root a stage, + // but do not restrict fusion of consumers. + Environment env = Environment.newBuilder().setUrl("common").build(); + PTransform readTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("read", readTransform) + .putPcollections( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "side_read", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "side_read.out") + .build()) + .putPcollections( + "side_read.out", + PCollection.newBuilder().setUniqueName("side_read.out").build()) + .putTransforms( + "parDo", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putInputs("side_input", "side_read.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .putSideInputs("side_input", SideInput.getDefaultInstance()) + .build() + .toByteString())) + .build()) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms( + "window", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "window.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build()) + .putPcollections( + "window.out", PCollection.newBuilder().setUniqueName("window.out").build()) + .putEnvironments("common", env) + .build()); + + PTransformNode readNode = PipelineNode.pTransform("read", readTransform); + PCollectionNode readOutput = getOnlyElement(p.getOutputPCollections(readNode)); + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, ImmutableSet.of(readNode)); + assertThat(subgraph.getOutputPCollections(), contains(readOutput)); + assertThat(subgraph.toPTransform().getSubtransformsList(), contains(readNode.getId())); + } + + @Test + public void materializesWithGroupByKeyConsumer() { + // (impulse.out) -> read -> read.out -> gbk -> gbk.out + // Fuses to + // (impulse.out) -> read -> (read.out) + // GBK is the responsibility of the runner, so it is not included in a stage. + Environment env = Environment.newBuilder().setUrl("common").build(); + PTransform readTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + + QueryablePipeline p = + QueryablePipeline.fromComponents( + partialComponents + .toBuilder() + .putTransforms("read", readTransform) + .putPcollections( + "read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "gbk", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "gbk.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN)) + .build()) + .putPcollections( + "gbk.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putEnvironments("common", env) + .build()); + + PTransformNode readNode = PipelineNode.pTransform("read", readTransform); + PCollectionNode readOutput = getOnlyElement(p.getOutputPCollections(readNode)); + ExecutableStage subgraph = + GreedilyFusedExecutableStage.forGrpcPortRead( + p, impulseOutputNode, ImmutableSet.of(readNode)); + assertThat(subgraph.getOutputPCollections(), contains(readOutput)); + assertThat(subgraph.toPTransform().getSubtransformsList(), contains(readNode.getId())); + } +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java new file mode 100644 index 000000000000..76bdddedd0bd --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java @@ -0,0 +1,996 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction.graph; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasSize; +import static org.junit.Assert.assertThat; + +import org.apache.beam.model.pipeline.v1.RunnerApi.Components; +import org.apache.beam.model.pipeline.v1.RunnerApi.Environment; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline; +import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec; +import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput; +import org.apache.beam.model.pipeline.v1.RunnerApi.StateSpec; +import org.apache.beam.model.pipeline.v1.RunnerApi.TimerSpec; +import org.apache.beam.model.pipeline.v1.RunnerApi.WindowIntoPayload; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link GreedyPipelineFuser}. */ +@RunWith(JUnit4.class) +public class GreedyPipelineFuserTest { + // Contains the 'go' and 'py' environments, and a default 'impulse' step and output. + private Components partialComponents; + + @Before + public void setup() { + partialComponents = + Components.newBuilder() + .putTransforms( + "impulse", + PTransform.newBuilder() + .putOutputs("output", "impulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "impulse.out", PCollection.newBuilder().setUniqueName("impulse.out").build()) + .putEnvironments("go", Environment.newBuilder().setUrl("go").build()) + .putEnvironments("py", Environment.newBuilder().setUrl("py").build()) + .build(); + } + + /* + * impulse -> .out -> read -> .out -> parDo -> .out -> window -> .out + * becomes + * (impulse.out) -> read -> read.out -> parDo -> parDo.out -> window + */ + @Test + public void singleEnvironmentBecomesASingleStage() { + Components components = + partialComponents + .toBuilder() + .putTransforms( + "read", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "parDo", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms( + "window", + PTransform.newBuilder() + .putInputs("input", "parDo.out") + .putOutputs("output", "window.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "window.out", PCollection.newBuilder().setUniqueName("window.out").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + contains(PipelineNode.pTransform("impulse", components.getTransformsOrThrow("impulse")))); + assertThat( + fused.getFusedStages(), + contains( + ExecutableStageMatcher.withInput("impulse.out") + .withNoOutputs() + .withTransforms("read", "parDo", "window"))); + } + + /* + * impulse -> .out -> read -> .out -> groupByKey -> .out -> parDo -> .out + * becomes + * (impulse.out) -> read -> (read.out) + * (groupByKey.out) -> parDo + */ + @Test + public void singleEnvironmentAcrossGroupByKeyMultipleStages() { + Components components = + partialComponents + .toBuilder() + .putTransforms( + "read", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "groupByKey", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "groupByKey.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN)) + .build()) + .putPcollections( + "groupByKey.out", PCollection.newBuilder().setUniqueName("groupByKey.out").build()) + .putTransforms( + "parDo", + PTransform.newBuilder() + .putInputs("input", "groupByKey.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + contains( + PipelineNode.pTransform("impulse", components.getTransformsOrThrow("impulse")), + PipelineNode.pTransform("groupByKey", components.getTransformsOrThrow("groupByKey")))); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("impulse.out") + .withOutputs("read.out") + .withTransforms("read"), + ExecutableStageMatcher.withInput("groupByKey.out") + .withNoOutputs() + .withTransforms("parDo"))); + } + + /* + * impulse -> .out -> read -> .out --> goTransform -> .out + * \ + * -> pyTransform -> .out + * becomes (impulse.out) -> read -> (read.out) + * (read.out) -> goTransform + * (read.out) -> pyTransform + */ + @Test + public void multipleEnvironmentsBecomesMultipleStages() { + Components components = + partialComponents + .toBuilder() + .putTransforms( + "read", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "goTransform", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "go.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections("go.out", PCollection.newBuilder().setUniqueName("go.out").build()) + .putTransforms( + "pyTransform", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "py.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("py.out", PCollection.newBuilder().setUniqueName("py.out").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + // Impulse is the runner transform + assertThat(fused.getRunnerExecutedTransforms(), hasSize(1)); + assertThat(fused.getFusedStages(), hasSize(3)); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("impulse.out") + .withOutputs("read.out") + .withTransforms("read"), + ExecutableStageMatcher.withInput("read.out") + .withNoOutputs() + .withTransforms("pyTransform"), + ExecutableStageMatcher.withInput("read.out") + .withNoOutputs() + .withTransforms("goTransform"))); + } + + /* + * goImpulse -> .out -> goRead -> .out \ -> goParDo -> .out + * -> flatten -> .out | + * pyImpulse -> .out -> pyRead -> .out / -> pyParDo -> .out + * + * becomes + * (goImpulse.out) -> goRead -> goRead.out -> flatten -> (flatten.out) + * (pyImpulse.out) -> pyRead -> pyRead.out -> flatten -> (flatten.out) + * (flatten.out) -> goParDo + * (flatten.out) -> pyParDo + */ + @Test + public void flattenWithHeterogenousInputsAndOutputsEntirelyMaterialized() { + Components components = + Components.newBuilder() + .putTransforms( + "pyImpulse", + PTransform.newBuilder() + .putOutputs("output", "pyImpulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "pyImpulse.out", PCollection.newBuilder().setUniqueName("pyImpulse.out").build()) + .putTransforms( + "pyRead", + PTransform.newBuilder() + .putInputs("input", "pyImpulse.out") + .putOutputs("output", "pyRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "pyRead.out", PCollection.newBuilder().setUniqueName("pyRead.out").build()) + .putTransforms( + "goImpulse", + PTransform.newBuilder() + .putOutputs("output", "goImpulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "goImpulse.out", PCollection.newBuilder().setUniqueName("goImpulse.out").build()) + .putTransforms( + "goRead", + PTransform.newBuilder() + .putInputs("input", "goImpulse.out") + .putOutputs("output", "goRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections( + "goRead.out", PCollection.newBuilder().setUniqueName("goRead.out").build()) + .putTransforms( + "flatten", + PTransform.newBuilder() + .putInputs("goReadInput", "goRead.out") + .putInputs("pyReadInput", "pyRead.out") + .putOutputs("output", "flatten.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)) + .build()) + .putPcollections( + "flatten.out", PCollection.newBuilder().setUniqueName("flatten.out").build()) + .putTransforms( + "pyParDo", + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "pyParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "pyParDo.out", PCollection.newBuilder().setUniqueName("pyParDo.out").build()) + .putTransforms( + "goParDo", + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "goParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections( + "goParDo.out", PCollection.newBuilder().setUniqueName("goParDo.out").build()) + .putEnvironments("go", Environment.newBuilder().setUrl("go").build()) + .putEnvironments("py", Environment.newBuilder().setUrl("py").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + containsInAnyOrder( + PipelineNode.pTransform("pyImpulse", components.getTransformsOrThrow("pyImpulse")), + PipelineNode.pTransform("goImpulse", components.getTransformsOrThrow("goImpulse")))); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("goImpulse.out") + .withOutputs("flatten.out") + .withTransforms("goRead", "flatten"), + ExecutableStageMatcher.withInput("pyImpulse.out") + .withOutputs("flatten.out") + .withTransforms("pyRead", "flatten"), + ExecutableStageMatcher.withInput("flatten.out") + .withNoOutputs() + .withTransforms("goParDo"), + ExecutableStageMatcher.withInput("flatten.out") + .withNoOutputs() + .withTransforms("pyParDo"))); + } + + /* + * impulseA -> .out -> goRead -> .out \ + * -> flatten -> .out -> goParDo -> .out + * impulseB -> .out -> pyRead -> .out / + * + * becomes + * (impulseA.out) -> goRead -> goRead.out -> flatten -> flatten.out -> goParDo + * (impulseB.out) -> pyRead -> pyRead.out -> flatten -> (flatten.out) + * (flatten.out) -> goParDo + */ + @Test + public void flattenWithHeterogeneousInputsSingleEnvOutputPartiallyMaterialized() { + Components components = + Components.newBuilder() + .putTransforms( + "pyImpulse", + PTransform.newBuilder() + .putOutputs("output", "pyImpulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "pyImpulse.out", PCollection.newBuilder().setUniqueName("pyImpulse.out").build()) + .putTransforms( + "pyRead", + PTransform.newBuilder() + .putInputs("input", "pyImpulse.out") + .putOutputs("output", "pyRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "pyRead.out", PCollection.newBuilder().setUniqueName("pyRead.out").build()) + .putTransforms( + "goImpulse", + PTransform.newBuilder() + .putOutputs("output", "goImpulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "goImpulse.out", PCollection.newBuilder().setUniqueName("goImpulse.out").build()) + .putTransforms( + "goRead", + PTransform.newBuilder() + .putInputs("input", "goImpulse.out") + .putOutputs("output", "goRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections( + "goRead.out", PCollection.newBuilder().setUniqueName("goRead.out").build()) + .putTransforms( + "flatten", + PTransform.newBuilder() + .putInputs("goReadInput", "goRead.out") + .putInputs("pyReadInput", "pyRead.out") + .putOutputs("output", "flatten.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)) + .build()) + .putPcollections( + "flatten.out", PCollection.newBuilder().setUniqueName("flatten.out").build()) + .putTransforms( + "goParDo", + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "goParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections( + "goParDo.out", PCollection.newBuilder().setUniqueName("goParDo.out").build()) + .putEnvironments("go", Environment.newBuilder().setUrl("go").build()) + .putEnvironments("py", Environment.newBuilder().setUrl("py").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + containsInAnyOrder( + PipelineNode.pTransform("pyImpulse", components.getTransformsOrThrow("pyImpulse")), + PipelineNode.pTransform("goImpulse", components.getTransformsOrThrow("goImpulse")))); + + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("goImpulse.out") + .withNoOutputs() + .withTransforms("goRead", "flatten", "goParDo"), + ExecutableStageMatcher.withInput("pyImpulse.out") + .withOutputs("flatten.out") + .withTransforms("pyRead", "flatten"), + ExecutableStageMatcher.withInput("flatten.out") + .withNoOutputs() + .withTransforms("goParDo"))); + } + + /* + * impulseA -> .out -> flatten -> .out -> read -> .out -> parDo -> .out + * becomes + * (flatten.out) -> read -> parDo + * + * Flatten, specifically, doesn't fuse greedily into downstream environments or act as a sibling + * to any of those nodes, but the routing is instead handled by the Runner. + */ + @Test + public void flattenAfterNoEnvDoesNotFuse() { + Components components = partialComponents.toBuilder() + .putTransforms("flatten", + PTransform.newBuilder() + .putInputs("impulseInput", "impulse.out") + .putOutputs("output", "flatten.out") + .setSpec(FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN) + .build()) + .build()) + .putPcollections("flatten.out", + PCollection.newBuilder().setUniqueName("flatten.out").build()) + .putTransforms("read", + PTransform.newBuilder() + .putInputs("input", "flatten.out") + .putOutputs("output", "read.out") + .setSpec(FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload(ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms("parDo", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "parDo.out") + .setSpec(FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload(ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py").build()) + .build() + .toByteString())) + .build()) + .putPcollections("parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + containsInAnyOrder( + PipelineNode.pTransform("impulse", components.getTransformsOrThrow("impulse")), + PipelineNode.pTransform("flatten", components.getTransformsOrThrow("flatten")))); + assertThat( + fused.getFusedStages(), + contains( + ExecutableStageMatcher.withInput("flatten.out") + .withNoOutputs() + .withTransforms("read", "parDo"))); + } + + /* + * impulseA -> .out -> read -> .out -> leftParDo -> .out + * \ -> rightParDo -> .out + * ------> sideInputParDo -> .out + * / + * impulseB -> .out -> side_read -> .out / + * + * becomes + * (impulseA.out) -> read -> (read.out) + * (read.out) -> leftParDo + * \ + * -> rightParDo + * (read.out) -> sideInputParDo + * (impulseB.out) -> side_read + */ + @Test + public void sideInputRootsNewStage() { + Components components = + Components.newBuilder() + .putTransforms( + "mainImpulse", + PTransform.newBuilder() + .putOutputs("output", "mainImpulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "mainImpulse.out", + PCollection.newBuilder().setUniqueName("mainImpulse.out").build()) + .putTransforms( + "read", + PTransform.newBuilder() + .putInputs("input", "mainImpulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "sideImpulse", + PTransform.newBuilder() + .putOutputs("output", "sideImpulse.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)) + .build()) + .putPcollections( + "sideImpulse.out", + PCollection.newBuilder().setUniqueName("sideImpulse.out").build()) + .putTransforms( + "sideRead", + PTransform.newBuilder() + .putInputs("input", "sideImpulse.out") + .putOutputs("output", "sideRead.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections( + "sideRead.out", PCollection.newBuilder().setUniqueName("sideRead.out").build()) + .putTransforms( + "leftParDo", + PTransform.newBuilder() + .putInputs("main", "read.out") + .putOutputs("output", "leftParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString()) + .build()) + .build()) + .putPcollections( + "leftParDo.out", PCollection.newBuilder().setUniqueName("leftParDo.out").build()) + .putTransforms( + "rightParDo", + PTransform.newBuilder() + .putInputs("main", "read.out") + .putOutputs("output", "rightParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString()) + .build()) + .build()) + .putPcollections( + "rightParDo.out", PCollection.newBuilder().setUniqueName("rightParDo.out").build()) + .putTransforms( + "sideParDo", + PTransform.newBuilder() + .putInputs("main", "read.out") + .putInputs("side", "sideRead.out") + .putOutputs("output", "sideParDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .putSideInputs("side", SideInput.getDefaultInstance()) + .build() + .toByteString()) + .build()) + .build()) + .putPcollections( + "sideParDo.out", PCollection.newBuilder().setUniqueName("sideParDo.out").build()) + .putEnvironments("py", Environment.newBuilder().setUrl("py").build()) + .build(); + + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + containsInAnyOrder( + PipelineNode.pTransform("mainImpulse", components.getTransformsOrThrow("mainImpulse")), + PipelineNode.pTransform( + "sideImpulse", components.getTransformsOrThrow("sideImpulse")))); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("mainImpulse.out") + .withOutputs("read.out") + .withTransforms("read"), + ExecutableStageMatcher.withInput("read.out") + .withNoOutputs() + .withTransforms("leftParDo", "rightParDo"), + ExecutableStageMatcher.withInput("read.out") + .withNoOutputs() + .withTransforms("sideParDo"), + ExecutableStageMatcher.withInput("sideImpulse.out") + .withNoOutputs() + .withTransforms("sideRead"))); + } + + /* + * impulse -> .out -> parDo -> .out -> stateful -> .out + * becomes + * (impulse.out) -> parDo -> (parDo.out) + * (parDo.out) -> stateful + */ + @Test + public void statefulParDoRootsStage() { + // (impulse.out) -> parDo -> (parDo.out) + // (parDo.out) -> stateful -> stateful.out + // stateful has a state spec which prevents it from fusing with an upstream ParDo + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform statefulTransform = + PTransform.newBuilder() + .putInputs("input", "parDo.out") + .putOutputs("output", "stateful.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .putStateSpecs("state", StateSpec.getDefaultInstance()) + .build() + .toByteString())) + .build(); + + Components components = + partialComponents + .toBuilder() + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms("stateful", statefulTransform) + .putPcollections( + "stateful.out", PCollection.newBuilder().setUniqueName("stateful.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + containsInAnyOrder( + PipelineNode.pTransform("impulse", components.getTransformsOrThrow("impulse")))); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("impulse.out") + .withOutputs("parDo.out") + .withTransforms("parDo"), + ExecutableStageMatcher.withInput("parDo.out") + .withNoOutputs() + .withTransforms("stateful"))); + } + + /* + * impulse -> .out -> parDo -> .out -> timer -> .out + * becomes + * (impulse.out) -> parDo -> (parDo.out) + * (parDo.out) -> timer + */ + @Test + public void parDoWithTimerRootsStage() { + // (impulse.out) -> parDo -> (parDo.out) + // (parDo.out) -> timer -> timer.out + // timer has a timer spec which prevents it from fusing with an upstream ParDo + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform timerTransform = + PTransform.newBuilder() + .putInputs("input", "parDo.out") + .putOutputs("output", "timer.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .putTimerSpecs("timer", TimerSpec.getDefaultInstance()) + .build() + .toByteString())) + .build(); + + Components components = + partialComponents + .toBuilder() + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putTransforms("timer", timerTransform) + .putPcollections( + "timer.out", PCollection.newBuilder().setUniqueName("timer.out").build()) + .putEnvironments("common", Environment.newBuilder().setUrl("common").build()) + .build(); + + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + assertThat( + fused.getRunnerExecutedTransforms(), + containsInAnyOrder( + PipelineNode.pTransform("impulse", components.getTransformsOrThrow("impulse")))); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("impulse.out") + .withOutputs("parDo.out") + .withTransforms("parDo"), + ExecutableStageMatcher.withInput("parDo.out") + .withNoOutputs() + .withTransforms("timer"))); + } + + /* + * impulse -> .out -> ( read -> .out --> goTransform -> .out ) + * \ + * -> pyTransform -> .out ) + * becomes (impulse.out) -> read -> (read.out) + * (read.out) -> goTransform + * (read.out) -> pyTransform + */ + @Test + public void compositesIgnored() { + Components components = + partialComponents + .toBuilder() + .putTransforms( + "read", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("read.out", PCollection.newBuilder().setUniqueName("read.out").build()) + .putTransforms( + "goTransform", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "go.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("go")) + .build() + .toByteString())) + .build()) + .putPcollections("go.out", PCollection.newBuilder().setUniqueName("go.out").build()) + .putTransforms( + "pyTransform", + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "py.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN) + .setPayload( + WindowIntoPayload.newBuilder() + .setWindowFn( + SdkFunctionSpec.newBuilder().setEnvironmentId("py")) + .build() + .toByteString())) + .build()) + .putPcollections("py.out", PCollection.newBuilder().setUniqueName("py.out").build()) + .putTransforms( + "compositeMultiLang", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("pyOut", "py.out") + .putOutputs("goOut", "go.out") + .addSubtransforms("read") + .addSubtransforms("goTransform") + .addSubtransforms("pyTransform") + .build()) + .build(); + FusedPipeline fused = + GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build()); + + // Impulse is the runner transform + assertThat(fused.getRunnerExecutedTransforms(), hasSize(1)); + assertThat(fused.getFusedStages(), hasSize(3)); + assertThat( + fused.getFusedStages(), + containsInAnyOrder( + ExecutableStageMatcher.withInput("impulse.out") + .withOutputs("read.out") + .withTransforms("read"), + ExecutableStageMatcher.withInput("read.out") + .withNoOutputs() + .withTransforms("pyTransform"), + ExecutableStageMatcher.withInput("read.out") + .withNoOutputs() + .withTransforms("goTransform"))); + } +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/QueryablePipelineTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/QueryablePipelineTest.java new file mode 100644 index 000000000000..32ee01acb238 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/QueryablePipelineTest.java @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction.graph; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.Components; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput; +import org.apache.beam.runners.core.construction.Environments; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode; +import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.joda.time.Duration; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link QueryablePipeline}. */ +@RunWith(JUnit4.class) +public class QueryablePipelineTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + /** + * Constructing a {@link QueryablePipeline} with components that reference absent {@link + * RunnerApi.PCollection PCollections} should fail. + */ + @Test + public void fromEmptyComponents() { + // Not that it's hugely useful, but it shouldn't throw. + QueryablePipeline p = QueryablePipeline.fromComponents(Components.getDefaultInstance()); + assertThat(p.getRootTransforms(), emptyIterable()); + } + + @Test + public void fromComponentsWithMalformedComponents() { + Components components = + Components.newBuilder() + .putTransforms( + "root", PTransform.newBuilder().putOutputs("output", "output.out").build()) + .build(); + + thrown.expect(IllegalArgumentException.class); + QueryablePipeline.fromComponents(components); + } + + @Test + public void rootTransforms() { + Pipeline p = Pipeline.create(); + p.apply("UnboundedRead", Read.from(CountingSource.unbounded())) + .apply(Window.into(FixedWindows.of(Duration.millis(5L)))) + .apply(Count.perElement()); + p.apply("BoundedRead", Read.from(CountingSource.upTo(100L))); + + Components components = PipelineTranslation.toProto(p).getComponents(); + QueryablePipeline qp = QueryablePipeline.fromComponents(components); + + assertThat(qp.getRootTransforms(), hasSize(2)); + for (PTransformNode rootTransform : qp.getRootTransforms()) { + assertThat( + "Root transforms should have no inputs", + rootTransform.getTransform().getInputsCount(), + equalTo(0)); + assertThat( + "Only added source reads to the pipeline", + rootTransform.getTransform().getSpec().getUrn(), + equalTo(PTransformTranslation.READ_TRANSFORM_URN)); + } + } + + /** + * Tests that inputs that are only side inputs are not returned from {@link + * QueryablePipeline#getPerElementConsumers(PCollectionNode)} and are returned from {@link + * QueryablePipeline#getSideInputs(PTransformNode)}. + */ + @Test + public void transformWithSideAndMainInputs() { + Pipeline p = Pipeline.create(); + PCollection longs = p.apply("BoundedRead", Read.from(CountingSource.upTo(100L))); + PCollectionView view = + p.apply("Create", Create.of("foo")).apply("View", View.asSingleton()); + longs.apply( + "par_do", + ParDo.of(new TestFn()) + .withSideInputs(view) + .withOutputTags(new TupleTag<>(), TupleTagList.empty())); + + Components components = PipelineTranslation.toProto(p).getComponents(); + QueryablePipeline qp = QueryablePipeline.fromComponents(components); + + String mainInputName = + getOnlyElement( + PipelineNode.pTransform("BoundedRead", components.getTransformsOrThrow("BoundedRead")) + .getTransform() + .getOutputsMap() + .values()); + PCollectionNode mainInput = + PipelineNode.pCollection(mainInputName, components.getPcollectionsOrThrow(mainInputName)); + String sideInputName = + getOnlyElement( + components + .getTransformsOrThrow("par_do") + .getInputsMap() + .values() + .stream() + .filter(pcollectionName -> !pcollectionName.equals(mainInputName)) + .collect(Collectors.toSet())); + PCollectionNode sideInput = + PipelineNode.pCollection(sideInputName, components.getPcollectionsOrThrow(sideInputName)); + PTransformNode parDoNode = + PipelineNode.pTransform("par_do", components.getTransformsOrThrow("par_do")); + + assertThat(qp.getSideInputs(parDoNode), contains(sideInput)); + assertThat(qp.getPerElementConsumers(mainInput), contains(parDoNode)); + assertThat(qp.getPerElementConsumers(sideInput), not(contains(parDoNode))); + } + + /** + * Tests that inputs that are both side inputs and main inputs are returned from {@link + * QueryablePipeline#getPerElementConsumers(PCollectionNode)} and {@link + * QueryablePipeline#getSideInputs(PTransformNode)}. + */ + @Test + public void transformWithSameSideAndMainInput() { + Components components = + Components.newBuilder() + .putPcollections("read_pc", RunnerApi.PCollection.getDefaultInstance()) + .putPcollections("pardo_out", RunnerApi.PCollection.getDefaultInstance()) + .putTransforms("root", PTransform.newBuilder().putOutputs("out", "read_pc").build()) + .putTransforms( + "multiConsumer", + PTransform.newBuilder() + .putInputs("main_in", "read_pc") + .putInputs("side_in", "read_pc") + .putOutputs("out", "pardo_out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .putSideInputs("side_in", SideInput.getDefaultInstance()) + .build() + .toByteString()) + .build()) + .build()) + .build(); + + QueryablePipeline qp = QueryablePipeline.fromComponents(components); + PCollectionNode multiInputPc = + PipelineNode.pCollection("read_pc", components.getPcollectionsOrThrow("read_pc")); + PTransformNode multiConsumerPT = + PipelineNode.pTransform("multiConsumer", components.getTransformsOrThrow("multiConsumer")); + assertThat(qp.getPerElementConsumers(multiInputPc), contains(multiConsumerPT)); + assertThat(qp.getSideInputs(multiConsumerPT), contains(multiInputPc)); + } + + /** + * Tests that {@link QueryablePipeline#getPerElementConsumers(PCollectionNode)} returns a + * transform that consumes the node more than once. + */ + @Test + public void perElementConsumersWithConsumingMultipleTimes() { + Pipeline p = Pipeline.create(); + PCollection longs = p.apply("BoundedRead", Read.from(CountingSource.upTo(100L))); + PCollectionList.of(longs).and(longs).and(longs).apply("flatten", Flatten.pCollections()); + + Components components = PipelineTranslation.toProto(p).getComponents(); + // This breaks if the way that IDs are assigned to PTransforms changes in PipelineTranslation + String readOutput = + getOnlyElement(components.getTransformsOrThrow("BoundedRead").getOutputsMap().values()); + QueryablePipeline qp = QueryablePipeline.fromComponents(components); + Set consumers = + qp.getPerElementConsumers( + PipelineNode.pCollection(readOutput, components.getPcollectionsOrThrow(readOutput))); + + assertThat(consumers.size(), equalTo(1)); + assertThat( + getOnlyElement(consumers).getTransform().getSpec().getUrn(), + equalTo(PTransformTranslation.FLATTEN_TRANSFORM_URN)); + } + + @Test + public void getProducer() { + Pipeline p = Pipeline.create(); + PCollection longs = p.apply("BoundedRead", Read.from(CountingSource.upTo(100L))); + PCollectionList.of(longs).and(longs).and(longs).apply("flatten", Flatten.pCollections()); + + Components components = PipelineTranslation.toProto(p).getComponents(); + QueryablePipeline qp = QueryablePipeline.fromComponents(components); + + String longsOutputName = + getOnlyElement( + PipelineNode.pTransform("BoundedRead", components.getTransformsOrThrow("BoundedRead")) + .getTransform() + .getOutputsMap() + .values()); + PTransformNode longsProducer = + PipelineNode.pTransform("BoundedRead", components.getTransformsOrThrow("BoundedRead")); + PCollectionNode longsOutput = + PipelineNode.pCollection( + longsOutputName, components.getPcollectionsOrThrow(longsOutputName)); + String flattenOutputName = + getOnlyElement( + PipelineNode.pTransform("flatten", components.getTransformsOrThrow("flatten")) + .getTransform() + .getOutputsMap() + .values()); + PTransformNode flattenProducer = + PipelineNode.pTransform("flatten", components.getTransformsOrThrow("flatten")); + PCollectionNode flattenOutput = + PipelineNode.pCollection( + flattenOutputName, components.getPcollectionsOrThrow(flattenOutputName)); + + assertThat(qp.getProducer(longsOutput), equalTo(longsProducer)); + assertThat(qp.getProducer(flattenOutput), equalTo(flattenProducer)); + } + + @Test + public void getEnvironmentWithEnvironment() { + Pipeline p = Pipeline.create(); + PCollection longs = p.apply("BoundedRead", Read.from(CountingSource.upTo(100L))); + PCollectionList.of(longs).and(longs).and(longs).apply("flatten", Flatten.pCollections()); + + Components components = PipelineTranslation.toProto(p).getComponents(); + QueryablePipeline qp = QueryablePipeline.fromComponents(components); + + PTransformNode environmentalRead = + PipelineNode.pTransform("BoundedRead", components.getTransformsOrThrow("BoundedRead")); + PTransformNode nonEnvironmentalTransform = + PipelineNode.pTransform("flatten", components.getTransformsOrThrow("flatten")); + + assertThat(qp.getEnvironment(environmentalRead).isPresent(), is(true)); + assertThat( + qp.getEnvironment(environmentalRead).get(), + equalTo(Environments.JAVA_SDK_HARNESS_ENVIRONMENT)); + assertThat(qp.getEnvironment(nonEnvironmentalTransform).isPresent(), is(false)); + } + + private static class TestFn extends DoFn { + @ProcessElement + public void process(ProcessContext ctxt) {} + } + + @Test + public void retainOnlyPrimitivesWithOnlyPrimitivesUnchanged() { + Pipeline p = Pipeline.create(); + p.apply("Read", Read.from(CountingSource.unbounded())) + .apply( + "multi-do", + ParDo.of(new TestFn()).withOutputTags(new TupleTag<>(), TupleTagList.empty())); + + Components originalComponents = PipelineTranslation.toProto(p).getComponents(); + Components primitiveComponents = QueryablePipeline.retainOnlyPrimitives(originalComponents); + + assertThat(primitiveComponents, equalTo(originalComponents)); + } + + @Test + public void retainOnlyPrimitivesComposites() { + Pipeline p = Pipeline.create(); + p.apply( + new org.apache.beam.sdk.transforms.PTransform>() { + @Override + public PCollection expand(PBegin input) { + return input + .apply(GenerateSequence.from(2L)) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(5L)))) + .apply(MapElements.into(TypeDescriptors.longs()).via(l -> l + 1)); + } + }); + + Components originalComponents = PipelineTranslation.toProto(p).getComponents(); + Components primitiveComponents = QueryablePipeline.retainOnlyPrimitives(originalComponents); + + // Read, Window.Assign, ParDo. This will need to be updated if the expansions change. + assertThat(primitiveComponents.getTransformsCount(), equalTo(3)); + for (Map.Entry transformEntry : + primitiveComponents.getTransformsMap().entrySet()) { + assertThat( + originalComponents.getTransformsMap(), + hasEntry(transformEntry.getKey(), transformEntry.getValue())); + } + + // Other components should be unchanged + assertThat( + primitiveComponents.getPcollectionsCount(), + equalTo(originalComponents.getPcollectionsCount())); + assertThat( + primitiveComponents.getWindowingStrategiesCount(), + equalTo(originalComponents.getWindowingStrategiesCount())); + assertThat(primitiveComponents.getCodersCount(), equalTo(originalComponents.getCodersCount())); + assertThat( + primitiveComponents.getEnvironmentsCount(), + equalTo(originalComponents.getEnvironmentsCount())); + } + + /** This method doesn't do any pruning for reachability, but this may not require a test. */ + @Test + public void retainOnlyPrimitivesIgnoresUnreachableNodes() { + Pipeline p = Pipeline.create(); + p.apply( + new org.apache.beam.sdk.transforms.PTransform>() { + @Override + public PCollection expand(PBegin input) { + return input + .apply(GenerateSequence.from(2L)) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(5L)))) + .apply(MapElements.into(TypeDescriptors.longs()).via(l -> l + 1)); + } + }); + + Components augmentedComponents = + PipelineTranslation.toProto(p) + .getComponents() + .toBuilder() + .putCoders("extra-coder", RunnerApi.Coder.getDefaultInstance()) + .putWindowingStrategies( + "extra-windowing-strategy", RunnerApi.WindowingStrategy.getDefaultInstance()) + .putEnvironments("extra-env", RunnerApi.Environment.getDefaultInstance()) + .putPcollections("extra-pc", RunnerApi.PCollection.getDefaultInstance()) + .build(); + Components primitiveComponents = QueryablePipeline.retainOnlyPrimitives(augmentedComponents); + + // Other components should be unchanged + assertThat( + primitiveComponents.getPcollectionsCount(), + equalTo(augmentedComponents.getPcollectionsCount())); + assertThat( + primitiveComponents.getWindowingStrategiesCount(), + equalTo(augmentedComponents.getWindowingStrategiesCount())); + assertThat(primitiveComponents.getCodersCount(), equalTo(augmentedComponents.getCodersCount())); + assertThat( + primitiveComponents.getEnvironmentsCount(), + equalTo(augmentedComponents.getEnvironmentsCount())); + } +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/metrics/MetricFilteringTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/metrics/MetricFilteringTest.java index 69204fc6b293..72bbc937c8a2 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/metrics/MetricFilteringTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/metrics/MetricFilteringTest.java @@ -35,8 +35,6 @@ */ @RunWith(JUnit4.class) public class MetricFilteringTest { - private static final MetricName NAME1 = MetricName.named("ns1", "name1"); - private boolean matchesSubPath(String actualScope, String subPath) { return MetricFiltering.subPathMatches(actualScope, subPath); diff --git a/runners/core-java/build.gradle b/runners/core-java/build.gradle index 5c8842b6440d..c9295a6eded5 100644 --- a/runners/core-java/build.gradle +++ b/runners/core-java/build.gradle @@ -17,7 +17,7 @@ */ apply from: project(":").file("build_rules.gradle") -applyJavaNature() +applyJavaNature(artifactId: "runners-core-java") description = "Apache Beam :: Runners :: Core Java" diff --git a/runners/core-java/pom.xml b/runners/core-java/pom.xml index 586114bbbccd..215faf41cd34 100644 --- a/runners/core-java/pom.xml +++ b/runners/core-java/pom.xml @@ -112,13 +112,19 @@ org.hamcrest - hamcrest-all + hamcrest-core test - + + + org.hamcrest + hamcrest-library + test + + org.mockito - mockito-all + mockito-core test diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java index 32c561e48134..ebd2a8873e0b 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; @@ -49,6 +50,7 @@ import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.CombineFnUtil; import org.joda.time.Instant; @@ -126,25 +128,25 @@ public InMemoryStateBinder(StateContext c) { @Override public ValueState bindValue( StateTag> address, Coder coder) { - return new InMemoryValue<>(); + return new InMemoryValue<>(coder); } @Override public BagState bindBag( final StateTag> address, Coder elemCoder) { - return new InMemoryBag<>(); + return new InMemoryBag<>(elemCoder); } @Override public SetState bindSet(StateTag> spec, Coder elemCoder) { - return new InMemorySet<>(); + return new InMemorySet<>(elemCoder); } @Override public MapState bindMap( StateTag> spec, Coder mapKeyCoder, Coder mapValueCoder) { - return new InMemoryMap<>(); + return new InMemoryMap<>(mapKeyCoder, mapValueCoder); } @Override @@ -153,7 +155,7 @@ public MapState bindMap( StateTag> address, Coder accumCoder, final CombineFn combineFn) { - return new InMemoryCombiningState<>(combineFn); + return new InMemoryCombiningState<>(combineFn, accumCoder); } @Override @@ -178,9 +180,15 @@ public WatermarkHoldState bindWatermark( */ public static final class InMemoryValue implements ValueState, InMemoryState> { + private final Coder coder; + private boolean isCleared = true; private @Nullable T value = null; + public InMemoryValue(Coder coder) { + this.coder = coder; + } + @Override public void clear() { // Even though we're clearing we can't remove this from the in-memory state map, since @@ -207,10 +215,10 @@ public void write(T input) { @Override public InMemoryValue copy() { - InMemoryValue that = new InMemoryValue<>(); + InMemoryValue that = new InMemoryValue<>(coder); if (!this.isCleared) { that.isCleared = this.isCleared; - that.value = this.value; + that.value = uncheckedClone(coder, this.value); } return that; } @@ -305,14 +313,16 @@ public InMemoryWatermarkHold copy() { public static final class InMemoryCombiningState implements CombiningState, InMemoryState> { - private boolean isCleared = true; private final CombineFn combineFn; + private final Coder accumCoder; + private boolean isCleared = true; private AccumT accum; public InMemoryCombiningState( - CombineFn combineFn) { + CombineFn combineFn, Coder accumCoder) { this.combineFn = combineFn; accum = combineFn.createAccumulator(); + this.accumCoder = accumCoder; } @Override @@ -378,10 +388,10 @@ public boolean isCleared() { @Override public InMemoryCombiningState copy() { InMemoryCombiningState that = - new InMemoryCombiningState<>(combineFn); + new InMemoryCombiningState<>(combineFn, accumCoder); if (!this.isCleared) { that.isCleared = this.isCleared; - that.addAccum(accum); + that.addAccum(uncheckedClone(accumCoder, accum)); } return that; } @@ -391,8 +401,13 @@ public InMemoryCombiningState copy() { * An {@link InMemoryState} implementation of {@link BagState}. */ public static final class InMemoryBag implements BagState, InMemoryState> { + private final Coder elemCoder; private List contents = new ArrayList<>(); + public InMemoryBag(Coder elemCoder) { + this.elemCoder = elemCoder; + } + @Override public void clear() { // Even though we're clearing we can't remove this from the in-memory state map, since @@ -442,8 +457,10 @@ public Boolean read() { @Override public InMemoryBag copy() { - InMemoryBag that = new InMemoryBag<>(); - that.contents.addAll(this.contents); + InMemoryBag that = new InMemoryBag<>(elemCoder); + for (T elem : this.contents) { + that.contents.add(uncheckedClone(elemCoder, elem)); + } return that; } } @@ -452,8 +469,13 @@ public InMemoryBag copy() { * An {@link InMemoryState} implementation of {@link SetState}. */ public static final class InMemorySet implements SetState, InMemoryState> { + private final Coder elemCoder; private Set contents = new HashSet<>(); + public InMemorySet(Coder elemCoder) { + this.elemCoder = elemCoder; + } + @Override public void clear() { contents = new HashSet<>(); @@ -513,8 +535,10 @@ public Boolean read() { @Override public InMemorySet copy() { - InMemorySet that = new InMemorySet<>(); - that.contents.addAll(this.contents); + InMemorySet that = new InMemorySet<>(elemCoder); + for (T elem : this.contents) { + that.contents.add(uncheckedClone(elemCoder, elem)); + } return that; } } @@ -524,8 +548,16 @@ public InMemorySet copy() { */ public static final class InMemoryMap implements MapState, InMemoryState> { + private final Coder keyCoder; + private final Coder valueCoder; + private Map contents = new HashMap<>(); + public InMemoryMap(Coder keyCoder, Coder valueCoder) { + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + @Override public void clear() { contents = new HashMap<>(); @@ -600,9 +632,22 @@ public boolean isCleared() { @Override public InMemoryMap copy() { - InMemoryMap that = new InMemoryMap<>(); + InMemoryMap that = new InMemoryMap<>(keyCoder, valueCoder); + for (Map.Entry entry : this.contents.entrySet()) { + that.contents.put( + uncheckedClone(keyCoder, entry.getKey()), uncheckedClone(valueCoder, entry.getValue())); + } that.contents.putAll(this.contents); return that; } } + + /** Like {@link CoderUtils#clone} but without a checked exception. */ + private static T uncheckedClone(Coder coder, T value) { + try { + return CoderUtils.clone(coder, value); + } catch (CoderException e) { + throw new RuntimeException(e); + } + } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index c53efcc23b72..b1a3f3bdb62c 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -50,7 +50,11 @@ * outputs), or runs for the given duration. */ public class OutputAndTimeBoundedSplittableProcessElementInvoker< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, + OutputT, + RestrictionT, + PositionT, + TrackerT extends RestrictionTracker> extends SplittableProcessElementInvoker { private final DoFn fn; private final PipelineOptions pipelineOptions; @@ -71,9 +75,10 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker< * @param maxNumOutputs Maximum number of outputs, in total over all output tags, after which a * checkpoint will be requested. This is a best-effort request - the {@link DoFn} may output * more after receiving the request. - * @param maxDuration Maximum duration of the {@link DoFn.ProcessElement} call after which a - * checkpoint will be requested. This is a best-effort request - the {@link DoFn} may run for - * longer after receiving the request. + * @param maxDuration Maximum duration of the {@link DoFn.ProcessElement} call (counted from the + * first successful {@link RestrictionTracker#tryClaim} call) after which a checkpoint will be + * requested. This is a best-effort request - the {@link DoFn} may run for longer after + * receiving the request. */ public OutputAndTimeBoundedSplittableProcessElementInvoker( DoFn fn, @@ -98,6 +103,7 @@ public Result invokeProcessElement( final WindowedValue element, final TrackerT tracker) { final ProcessContext processContext = new ProcessContext(element, tracker); + tracker.setClaimObserver(processContext); DoFn.ProcessContinuation cont = invoker.invokeProcessElement( new DoFnInvoker.ArgumentProvider() { @Override @@ -107,7 +113,7 @@ public DoFn.ProcessContext processContext( } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { return tracker; } @@ -157,19 +163,39 @@ public Timer timer(String timerId) { "Access to timers not supported in Splittable DoFn"); } }); - // TODO: verify that if there was a failed tryClaim() call, then cont.shouldResume() is false. - // Currently we can't verify this because there are no hooks into tryClaim(). - // See https://issues.apache.org/jira/browse/BEAM-2607 processContext.cancelScheduledCheckpoint(); @Nullable KV residual = processContext.getTakenCheckpoint(); if (cont.shouldResume()) { + checkState( + !processContext.hasClaimFailed, + "After tryClaim() returned false, @ProcessElement must return stop(), " + + "but returned resume()"); if (residual == null) { // No checkpoint had been taken by the runner while the ProcessElement call ran, however // the call says that not the whole restriction has been processed. So we need to take // a checkpoint now: checkpoint() guarantees that the primary restriction describes exactly // the work that was done in the current ProcessElement call, and returns a residual // restriction that describes exactly the work that wasn't done in the current call. - residual = checkNotNull(processContext.takeCheckpointNow()); + if (processContext.numClaimedBlocks > 0) { + residual = checkNotNull(processContext.takeCheckpointNow()); + tracker.checkDone(); + } else { + // The call returned resume() without trying to claim any blocks, i.e. it is unaware + // of any work to be done at the moment, but more might emerge later. This is a valid + // use case: e.g. a DoFn reading from a streaming source might see that there are + // currently no new elements (hence not claim anything) and return resume() with a delay + // to check again later. + // In this case, we must simply reschedule the original restriction - checkpointing a + // tracker that hasn't claimed any work is not allowed. + // + // Note that the situation "a DoFn repeatedly says that it doesn't have any work to claim + // and asks to try again later with the same restriction" is different from the situation + // "a runner repeatedly checkpoints the DoFn before it has a chance to even attempt + // claiming work": the former is valid, and the latter would be a bug, and is addressed + // by not checkpointing the tracker until it attempts to claim some work. + residual = KV.of(tracker.currentRestriction(), processContext.getLastReportedWatermark()); + // Don't call tracker.checkDone() - it's not done. + } } else { // A checkpoint was taken by the runner, and then the ProcessElement call returned resume() // without making more tryClaim() calls (since no tryClaim() calls can succeed after @@ -180,14 +206,15 @@ public Timer timer(String timerId) { // ProcessElement call. // In other words, if we took a checkpoint *after* ProcessElement completed (like in the // branch above), it would have been equivalent to this one. + tracker.checkDone(); } } else { // The ProcessElement call returned stop() - that means the tracker's current restriction // has been fully processed by the call. A checkpoint may or may not have been taken in // "residual"; if it was, then we'll need to process it; if no, then we don't - nothing // special needs to be done. + tracker.checkDone(); } - tracker.checkDone(); if (residual == null) { // Can only be true if cont.shouldResume() is false and no checkpoint was taken. // This means the restriction has been fully processed. @@ -197,9 +224,12 @@ public Timer timer(String timerId) { return new Result(residual.getKey(), cont, residual.getValue()); } - private class ProcessContext extends DoFn.ProcessContext { + private class ProcessContext extends DoFn.ProcessContext + implements RestrictionTracker.ClaimObserver { private final WindowedValue element; private final TrackerT tracker; + private int numClaimedBlocks; + private boolean hasClaimFailed; private int numOutputs; // Checkpoint may be initiated either when the given number of outputs is reached, @@ -212,20 +242,44 @@ private class ProcessContext extends DoFn.ProcessContext { // on the output from "checkpoint". private @Nullable Instant residualWatermark; // A handle on the scheduled action to take a checkpoint. - private Future scheduledCheckpoint; + private @Nullable Future scheduledCheckpoint; private @Nullable Instant lastReportedWatermark; public ProcessContext(WindowedValue element, TrackerT tracker) { fn.super(); this.element = element; this.tracker = tracker; + } - this.scheduledCheckpoint = - executor.schedule( - (Runnable) this::takeCheckpointNow, maxDuration.getMillis(), TimeUnit.MILLISECONDS); + @Override + public void onClaimed(PositionT position) { + checkState( + !hasClaimFailed, + "Must not call tryClaim() after it has previously returned false"); + if (numClaimedBlocks == 0) { + // Claiming first block: can schedule the checkpoint now. + // We don't schedule it right away to prevent checkpointing before any blocks are claimed, + // in a state where no work has been done yet - because such a checkpoint is equivalent to + // the original restriction, i.e. pointless. + this.scheduledCheckpoint = + executor.schedule( + (Runnable) this::takeCheckpointNow, maxDuration.getMillis(), TimeUnit.MILLISECONDS); + } + ++numClaimedBlocks; + } + + @Override + public void onClaimFailed(PositionT position) { + checkState( + !hasClaimFailed, + "Must not call tryClaim() after it has previously returned false"); + hasClaimFailed = true; } void cancelScheduledCheckpoint() { + if (scheduledCheckpoint == null) { + return; + } scheduledCheckpoint.cancel(true); try { Futures.getUnchecked(scheduledCheckpoint); @@ -275,9 +329,19 @@ public PaneInfo pane() { @Override public synchronized void updateWatermark(Instant watermark) { + // Updating the watermark without any claimed blocks is allowed. + // The watermark is a promise about the timestamps of output from future claimed blocks. + // Such a promise can be made even if there are no claimed blocks. E.g. imagine reading + // from a streaming source that currently has no new data: there are no blocks to claim, but + // we may still want to advance the watermark if we have information about what timestamps + // of future elements in the source will be like. lastReportedWatermark = watermark; } + synchronized Instant getLastReportedWatermark() { + return lastReportedWatermark; + } + @Override public PipelineOptions getPipelineOptions() { return pipelineOptions; @@ -290,8 +354,8 @@ public void output(OutputT output) { @Override public void outputWithTimestamp(OutputT value, Instant timestamp) { - output.outputWindowedValue(value, timestamp, element.getWindows(), element.getPane()); noteOutput(); + output.outputWindowedValue(value, timestamp, element.getWindows(), element.getPane()); } @Override @@ -301,12 +365,14 @@ public void output(TupleTag tag, T value) { @Override public void outputWithTimestamp(TupleTag tag, T value, Instant timestamp) { + noteOutput(); output.outputWindowedValue( tag, value, timestamp, element.getWindows(), element.getPane()); - noteOutput(); } private void noteOutput() { + checkState(!hasClaimFailed, "Output is not allowed after a failed tryClaim()"); + checkState(numClaimedBlocks > 0, "Output is not allowed before tryClaim()"); ++numOutputs; if (numOutputs >= maxNumOutputs) { takeCheckpointNow(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index 6ae6754812db..d4c5775464b0 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -262,7 +262,7 @@ public DoFn.OnTimerContext onTimerContext(DoFn } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException( "Cannot access RestrictionTracker outside of @ProcessElement method."); } @@ -332,7 +332,7 @@ public DoFn.OnTimerContext onTimerContext(DoFn } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException( "Cannot access RestrictionTracker outside of @ProcessElement method."); } @@ -504,7 +504,7 @@ public DoFn.OnTimerContext onTimerContext(DoFn } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException("RestrictionTracker parameters are not supported."); } @@ -615,7 +615,7 @@ public DoFn.OnTimerContext onTimerContext(DoFn } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException("RestrictionTracker parameters are not supported."); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java index 4e490e27c267..ff238bee2f73 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java @@ -150,7 +150,7 @@ public PCollectionTuple expand(PCollection>> /** A primitive transform wrapping around {@link ProcessFn}. */ public static class ProcessElements< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> extends PTransform< PCollection>>, PCollectionTuple> { private final ProcessKeyedElements original; @@ -211,7 +211,7 @@ public PCollectionTuple expand( */ @VisibleForTesting public static class ProcessFn< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> extends DoFn>, OutputT> { /** * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java index 5b9cbf2bfe71..9d5475a8a33a 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java @@ -31,7 +31,7 @@ * DoFn}, in particular, allowing the runner to access the {@link RestrictionTracker}. */ public abstract class SplittableProcessElementInvoker< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> { + InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> { /** Specifies how to resume a splittable {@link DoFn.ProcessElement} call. */ public class Result { @Nullable diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/TimerInternals.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/TimerInternals.java index f4a12d089116..38fc35242035 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/TimerInternals.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/TimerInternals.java @@ -175,6 +175,8 @@ abstract class TimerData implements Comparable { public abstract TimeDomain getDomain(); + // When adding a new field, make sure to add it to the compareTo() method. + /** * Construct a {@link TimerData} for the given parameters, where the timer ID is automatically * generated. @@ -201,8 +203,9 @@ public static TimerData of(StateNamespace namespace, Instant timestamp, TimeDoma /** * {@inheritDoc}. * - *

The ordering of {@link TimerData} that are not in the same namespace or domain is - * arbitrary. + *

Used for sorting {@link TimerData} by timestamp. Furthermore, we compare timers by all the + * other fields so that {@code compareTo()} only returns 0 when {@code equals()} returns 0. + * This ensures consistent sort order. */ @Override public int compareTo(TimerData that) { @@ -212,7 +215,8 @@ public int compareTo(TimerData that) { ComparisonChain chain = ComparisonChain.start() .compare(this.getTimestamp(), that.getTimestamp()) - .compare(this.getDomain(), that.getDomain()); + .compare(this.getDomain(), that.getDomain()) + .compare(this.getTimerId(), that.getTimerId()); if (chain.result() == 0 && !this.getNamespace().equals(that.getNamespace())) { // Obtaining the stringKey may be expensive; only do so if required chain = chain.compare(getNamespace().stringKey(), that.getNamespace().stringKey()); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java index 959909e6690e..991b9299fb8f 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java @@ -27,8 +27,10 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.common.util.concurrent.Uninterruptibles; import java.util.Collection; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.DoFn; @@ -41,26 +43,38 @@ import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; /** Tests for {@link OutputAndTimeBoundedSplittableProcessElementInvoker}. */ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest { - private static class SomeFn extends DoFn { + @Rule + public transient ExpectedException e = ExpectedException.none(); + + private static class SomeFn extends DoFn { + private final Duration sleepBeforeFirstClaim; private final int numOutputsPerProcessCall; private final Duration sleepBeforeEachOutput; - private SomeFn(int numOutputsPerProcessCall, Duration sleepBeforeEachOutput) { + private SomeFn( + Duration sleepBeforeFirstClaim, + int numOutputsPerProcessCall, + Duration sleepBeforeEachOutput) { + this.sleepBeforeFirstClaim = sleepBeforeFirstClaim; this.numOutputsPerProcessCall = numOutputsPerProcessCall; this.sleepBeforeEachOutput = sleepBeforeEachOutput; } @ProcessElement - public ProcessContinuation process(ProcessContext context, OffsetRangeTracker tracker) - throws Exception { + public ProcessContinuation process(ProcessContext context, OffsetRangeTracker tracker) { + Uninterruptibles.sleepUninterruptibly( + sleepBeforeFirstClaim.getMillis(), TimeUnit.MILLISECONDS); for (long i = tracker.currentRestriction().getFrom(), numIterations = 1; tracker.tryClaim(i); ++i, ++numIterations) { - Thread.sleep(sleepBeforeEachOutput.getMillis()); + Uninterruptibles.sleepUninterruptibly( + sleepBeforeEachOutput.getMillis(), TimeUnit.MILLISECONDS); context.output("" + i); if (numIterations == numOutputsPerProcessCall) { return resume(); @@ -70,15 +84,25 @@ public ProcessContinuation process(ProcessContext context, OffsetRangeTracker tr } @GetInitialRestriction - public OffsetRange getInitialRestriction(Integer element) { + public OffsetRange getInitialRestriction(Void element) { throw new UnsupportedOperationException("Should not be called in this test"); } } - private SplittableProcessElementInvoker.Result - runTest(int totalNumOutputs, int numOutputsPerProcessCall, Duration sleepPerElement) { - SomeFn fn = new SomeFn(numOutputsPerProcessCall, sleepPerElement); - SplittableProcessElementInvoker invoker = + private SplittableProcessElementInvoker.Result + runTest( + int totalNumOutputs, + Duration sleepBeforeFirstClaim, + int numOutputsPerProcessCall, + Duration sleepBeforeEachOutput) { + SomeFn fn = new SomeFn(sleepBeforeFirstClaim, numOutputsPerProcessCall, sleepBeforeEachOutput); + OffsetRange initialRestriction = new OffsetRange(0, totalNumOutputs); + return runTest(fn, initialRestriction); + } + + private SplittableProcessElementInvoker.Result + runTest(DoFn fn, OffsetRange initialRestriction) { + SplittableProcessElementInvoker invoker = new OutputAndTimeBoundedSplittableProcessElementInvoker<>( fn, PipelineOptionsFactory.create(), @@ -105,14 +129,14 @@ public void outputWindowedValue( return invoker.invokeProcessElement( DoFnInvokers.invokerFor(fn), - WindowedValue.of(totalNumOutputs, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING), - new OffsetRangeTracker(new OffsetRange(0, totalNumOutputs))); + WindowedValue.of(null, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING), + new OffsetRangeTracker(initialRestriction)); } @Test public void testInvokeProcessElementOutputBounded() throws Exception { - SplittableProcessElementInvoker.Result res = - runTest(10000, Integer.MAX_VALUE, Duration.ZERO); + SplittableProcessElementInvoker.Result res = + runTest(10000, Duration.ZERO, Integer.MAX_VALUE, Duration.ZERO); assertFalse(res.getContinuation().shouldResume()); OffsetRange residualRange = res.getResidualRestriction(); // Should process the first 100 elements. @@ -122,8 +146,8 @@ public void testInvokeProcessElementOutputBounded() throws Exception { @Test public void testInvokeProcessElementTimeBounded() throws Exception { - SplittableProcessElementInvoker.Result res = - runTest(10000, Integer.MAX_VALUE, Duration.millis(100)); + SplittableProcessElementInvoker.Result res = + runTest(10000, Duration.ZERO, Integer.MAX_VALUE, Duration.millis(100)); assertFalse(res.getContinuation().shouldResume()); OffsetRange residualRange = res.getResidualRestriction(); // Should process ideally around 30 elements - but due to timing flakiness, we can't enforce @@ -133,19 +157,66 @@ public void testInvokeProcessElementTimeBounded() throws Exception { assertEquals(10000, residualRange.getTo()); } + @Test + public void testInvokeProcessElementTimeBoundedWithStartupDelay() throws Exception { + SplittableProcessElementInvoker.Result res = + runTest(10000, Duration.standardSeconds(3), Integer.MAX_VALUE, Duration.millis(100)); + assertFalse(res.getContinuation().shouldResume()); + OffsetRange residualRange = res.getResidualRestriction(); + // Same as above, but this time it counts from the time of the first tryClaim() call + assertThat(residualRange.getFrom(), greaterThan(10L)); + assertThat(residualRange.getFrom(), lessThan(100L)); + assertEquals(10000, residualRange.getTo()); + } + @Test public void testInvokeProcessElementVoluntaryReturnStop() throws Exception { - SplittableProcessElementInvoker.Result res = - runTest(5, Integer.MAX_VALUE, Duration.millis(100)); + SplittableProcessElementInvoker.Result res = + runTest(5, Duration.ZERO, Integer.MAX_VALUE, Duration.millis(100)); assertFalse(res.getContinuation().shouldResume()); assertNull(res.getResidualRestriction()); } @Test public void testInvokeProcessElementVoluntaryReturnResume() throws Exception { - SplittableProcessElementInvoker.Result res = - runTest(10, 5, Duration.millis(100)); + SplittableProcessElementInvoker.Result res = + runTest(10, Duration.ZERO, 5, Duration.millis(100)); assertTrue(res.getContinuation().shouldResume()); assertEquals(new OffsetRange(5, 10), res.getResidualRestriction()); } + + @Test + public void testInvokeProcessElementOutputDisallowedBeforeTryClaim() throws Exception { + DoFn brokenFn = new DoFn() { + @ProcessElement + public void process(ProcessContext c, OffsetRangeTracker tracker) { + c.output("foo"); + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction(Void element) { + throw new UnsupportedOperationException("Should not be called in this test"); + } + }; + e.expectMessage("Output is not allowed before tryClaim()"); + runTest(brokenFn, new OffsetRange(0, 5)); + } + + @Test + public void testInvokeProcessElementOutputDisallowedAfterFailedTryClaim() throws Exception { + DoFn brokenFn = new DoFn() { + @ProcessElement + public void process(ProcessContext c, OffsetRangeTracker tracker) { + assertFalse(tracker.tryClaim(6L)); + c.output("foo"); + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction(Void element) { + throw new UnsupportedOperationException("Should not be called in this test"); + } + }; + e.expectMessage("Output is not allowed after a failed tryClaim()"); + runTest(brokenFn, new OffsetRange(0, 5)); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java index c7bee2556fbe..b9fd0ab1f41d 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.core; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.hamcrest.Matchers.contains; @@ -83,13 +84,18 @@ public SomeRestrictionTracker newTracker() { } } - private static class SomeRestrictionTracker implements RestrictionTracker { + private static class SomeRestrictionTracker extends RestrictionTracker { private final SomeRestriction someRestriction; public SomeRestrictionTracker(SomeRestriction someRestriction) { this.someRestriction = someRestriction; } + @Override + protected boolean tryClaimImpl(Void position) { + return true; + } + @Override public SomeRestriction currentRestriction() { return someRestriction; @@ -108,11 +114,15 @@ public void checkDone() {} public TestPipeline pipeline = TestPipeline.create(); /** - * A helper for testing {@link ProcessFn} on 1 element (but - * possibly over multiple {@link DoFn.ProcessElement} calls). + * A helper for testing {@link ProcessFn} on 1 element (but possibly over multiple {@link + * DoFn.ProcessElement} calls). */ private static class ProcessFnTester< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, + OutputT, + RestrictionT, + PositionT, + TrackerT extends RestrictionTracker> implements AutoCloseable { private final DoFnTester>, OutputT> tester; private Instant currentProcessingTime; @@ -260,6 +270,7 @@ public void outputWindowedValue( private static class ToStringFn extends DoFn { @ProcessElement public void process(ProcessContext c, SomeRestrictionTracker tracker) { + checkState(tracker.tryClaim(null)); c.output(c.element().toString() + "a"); c.output(c.element().toString() + "b"); c.output(c.element().toString() + "c"); @@ -284,7 +295,7 @@ public void testTrivialProcessFnPropagatesOutputWindowAndTimestamp() throws Exce new IntervalWindow( base.minus(Duration.standardMinutes(1)), base.plus(Duration.standardMinutes(1))); - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( base, fn, @@ -332,7 +343,7 @@ public void testUpdatesWatermark() throws Exception { DoFn fn = new WatermarkUpdateFn(); Instant base = Instant.now(); - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( base, fn, @@ -358,6 +369,7 @@ public void testUpdatesWatermark() throws Exception { private static class SelfInitiatedResumeFn extends DoFn { @ProcessElement public ProcessContinuation process(ProcessContext c, SomeRestrictionTracker tracker) { + checkState(tracker.tryClaim(null)); c.output(c.element().toString()); return resume().withResumeDelay(Duration.standardSeconds(5)); } @@ -372,7 +384,7 @@ public SomeRestriction getInitialRestriction(Integer elem) { public void testResumeSetsTimer() throws Exception { DoFn fn = new SelfInitiatedResumeFn(); Instant base = Instant.now(); - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( base, fn, @@ -430,7 +442,7 @@ public OffsetRange getInitialRestriction(Integer elem) { public void testResumeCarriesOverState() throws Exception { DoFn fn = new CounterFn(1); Instant base = Instant.now(); - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( base, fn, @@ -459,7 +471,7 @@ public void testCheckpointsAfterNumOutputs() throws Exception { Instant base = Instant.now(); int baseIndex = 42; - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(OffsetRange.class), max, MAX_BUNDLE_DURATION); @@ -501,7 +513,7 @@ public void testCheckpointsAfterDuration() throws Exception { Instant base = Instant.now(); int baseIndex = 42; - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(OffsetRange.class), max, maxBundleDuration); @@ -566,7 +578,7 @@ public void finishBundle() { @Test public void testInvokesLifecycleMethods() throws Exception { DoFn fn = new LifecycleVerifyingFn(); - try (ProcessFnTester tester = + try (ProcessFnTester tester = new ProcessFnTester<>( Instant.now(), fn, diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java index 446a9f874ca8..36d7a66b539f 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java @@ -235,5 +235,5 @@ public void processElement( Integer currentValue = MoreObjects.firstNonNull(state.read(), 0); state.write(currentValue + 1); } - }; + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/TimerInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/TimerInternalsTest.java index af270d9fd075..471f8b1a32ce 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/TimerInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/TimerInternalsTest.java @@ -19,7 +19,6 @@ import static org.hamcrest.Matchers.comparesEqualTo; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; import org.apache.beam.runners.core.TimerInternals.TimerData; @@ -63,43 +62,66 @@ public void testCoderIsSerializableWithWellKnownCoderType() { } @Test - public void testCompareTo() { + public void testCompareEqual() { + Instant timestamp = new Instant(100); + StateNamespace namespace = StateNamespaces.global(); + TimerData timer = TimerData.of("id", namespace, timestamp, TimeDomain.EVENT_TIME); + + assertThat(timer, + comparesEqualTo(TimerData.of("id", namespace, timestamp, TimeDomain.EVENT_TIME))); + } + + @Test + public void testCompareByTimestamp() { Instant firstTimestamp = new Instant(100); Instant secondTimestamp = new Instant(200); - IntervalWindow firstWindow = new IntervalWindow(new Instant(0), firstTimestamp); - IntervalWindow secondWindow = new IntervalWindow(firstTimestamp, secondTimestamp); + StateNamespace namespace = StateNamespaces.global(); + + TimerData firstTimer = TimerData.of(namespace, firstTimestamp, TimeDomain.EVENT_TIME); + TimerData secondTimer = TimerData.of(namespace, secondTimestamp, TimeDomain.EVENT_TIME); + + assertThat(firstTimer, lessThan(secondTimer)); + } + + @Test + public void testCompareByDomain() { + Instant timestamp = new Instant(100); + StateNamespace namespace = StateNamespaces.global(); + + TimerData eventTimer = TimerData.of(namespace, timestamp, TimeDomain.EVENT_TIME); + TimerData procTimer = TimerData.of(namespace, timestamp, TimeDomain.PROCESSING_TIME); + TimerData synchronizedProcTimer = + TimerData.of(namespace, timestamp, TimeDomain.SYNCHRONIZED_PROCESSING_TIME); + + assertThat(eventTimer, lessThan(procTimer)); + assertThat(eventTimer, lessThan(synchronizedProcTimer)); + assertThat(procTimer, lessThan(synchronizedProcTimer)); + } + + @Test + public void testCompareByNamespace() { + Instant timestamp = new Instant(100); + IntervalWindow firstWindow = new IntervalWindow(new Instant(0), timestamp); + IntervalWindow secondWindow = new IntervalWindow(timestamp, new Instant(200)); Coder windowCoder = IntervalWindow.getCoder(); StateNamespace firstWindowNs = StateNamespaces.window(windowCoder, firstWindow); StateNamespace secondWindowNs = StateNamespaces.window(windowCoder, secondWindow); - TimerData firstEventTime = TimerData.of(firstWindowNs, firstTimestamp, TimeDomain.EVENT_TIME); - TimerData secondEventTime = TimerData.of(firstWindowNs, secondTimestamp, TimeDomain.EVENT_TIME); - TimerData thirdEventTime = TimerData.of(secondWindowNs, secondTimestamp, TimeDomain.EVENT_TIME); - - TimerData firstProcTime = - TimerData.of(firstWindowNs, firstTimestamp, TimeDomain.PROCESSING_TIME); - TimerData secondProcTime = - TimerData.of(firstWindowNs, secondTimestamp, TimeDomain.PROCESSING_TIME); - TimerData thirdProcTime = - TimerData.of(secondWindowNs, secondTimestamp, TimeDomain.PROCESSING_TIME); + TimerData secondEventTime = TimerData.of(firstWindowNs, timestamp, TimeDomain.EVENT_TIME); + TimerData thirdEventTime = TimerData.of(secondWindowNs, timestamp, TimeDomain.EVENT_TIME); - assertThat(firstEventTime, - comparesEqualTo(TimerData.of(firstWindowNs, firstTimestamp, TimeDomain.EVENT_TIME))); - assertThat(firstEventTime, lessThan(secondEventTime)); assertThat(secondEventTime, lessThan(thirdEventTime)); - assertThat(firstEventTime, lessThan(thirdEventTime)); - - assertThat(secondProcTime, - comparesEqualTo(TimerData.of(firstWindowNs, secondTimestamp, TimeDomain.PROCESSING_TIME))); - assertThat(firstProcTime, lessThan(secondProcTime)); - assertThat(secondProcTime, lessThan(thirdProcTime)); - assertThat(firstProcTime, lessThan(thirdProcTime)); - - assertThat(firstEventTime, not(comparesEqualTo(firstProcTime))); - assertThat(firstProcTime, - not(comparesEqualTo(TimerData.of(firstWindowNs, - firstTimestamp, - TimeDomain.SYNCHRONIZED_PROCESSING_TIME)))); + } + + @Test + public void testCompareByTimerId() { + Instant timestamp = new Instant(100); + StateNamespace namespace = StateNamespaces.global(); + + TimerData id0Timer = TimerData.of("id0", namespace, timestamp, TimeDomain.EVENT_TIME); + TimerData id1Timer = TimerData.of("id1", namespace, timestamp, TimeDomain.EVENT_TIME); + + assertThat(id0Timer, lessThan(id1Timer)); } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java index d0fc61b1eecc..f5f7cafedbeb 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java @@ -154,7 +154,8 @@ public void testDistributionCommittedUnsupportedInAttemptedAccumulatedMetricResu thrown.expect(UnsupportedOperationException.class); thrown.expectMessage("This runner does not currently support committed metrics results."); - assertDistribution(DISTRIBUTION_NAME, step1res, STEP1, DistributionResult.ZERO, true); + assertDistribution( + DISTRIBUTION_NAME, step1res, STEP1, DistributionResult.IDENTITY_ELEMENT, true); } @Test diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index f258048437ee..a9e553fdd97b 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -17,7 +17,7 @@ */ apply from: project(":").file("build_rules.gradle") -applyJavaNature() +applyJavaNature(artifactId: "beam-runners-direct-java") description = "Apache Beam :: Runners :: Direct Java" @@ -69,3 +69,6 @@ shadowJar { relocate "com.google.protobuf", getJavaRelocatedPath("com.google.protobuf") relocate "javax.annotation", getJavaRelocatedPath("javax.annotation") } + +// Generates :runners:direct-java:runQuickstartJavaDirect +createJavaQuickstartValidationTask(name: 'Direct') diff --git a/runners/direct-java/pom.xml b/runners/direct-java/pom.xml index 99ebd760c86e..a678e0149757 100644 --- a/runners/direct-java/pom.xml +++ b/runners/direct-java/pom.xml @@ -246,10 +246,16 @@ org.hamcrest - hamcrest-all + hamcrest-core provided + + org.hamcrest + hamcrest-library + provided + + junit junit @@ -264,7 +270,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java index 848bf712da80..1747a5372ce4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java @@ -300,7 +300,7 @@ public ValueState bindValue( underlying.get().get(namespace, address, c); return existingState.copy(); } else { - return new InMemoryValue<>(); + return new InMemoryValue<>(coder); } } @@ -317,7 +317,7 @@ CombiningState bindCombiningValue( underlying.get().get(namespace, address, c); return existingState.copy(); } else { - return new InMemoryCombiningState<>(combineFn); + return new InMemoryCombiningState<>(combineFn, accumCoder); } } @@ -331,7 +331,7 @@ public BagState bindBag( underlying.get().get(namespace, address, c); return existingState.copy(); } else { - return new InMemoryBag<>(); + return new InMemoryBag<>(elemCoder); } } @@ -345,7 +345,7 @@ public SetState bindSet( underlying.get().get(namespace, address, c); return existingState.copy(); } else { - return new InMemorySet<>(); + return new InMemorySet<>(elemCoder); } } @@ -361,7 +361,7 @@ public MapState bindMap( underlying.get().get(namespace, address, c); return existingState.copy(); } else { - return new InMemoryMap<>(); + return new InMemoryMap<>(mapKeyCoder, mapValueCoder); } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java index 574ab46fb449..af6730602c3e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.direct; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.options.ApplicationNameOptions; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; @@ -74,4 +76,10 @@ public Integer create(PipelineOptions options) { return Math.max(Runtime.getRuntime().availableProcessors(), MIN_PARALLELISM); } } + + @Experimental(Kind.CORE_RUNNERS_ONLY) + @Default.Boolean(false) + @Description("Control whether toProto/fromProto translations are applied to original Pipeline") + boolean isProtoTranslation(); + void setProtoTranslation(boolean b); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java index c12ff6763656..0f0caf6d8e26 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java @@ -30,6 +30,7 @@ */ public class DirectRegistrar { private DirectRegistrar() {} + /** * Registers the {@link DirectRunner}. */ diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 89f75a5e12ef..32ba69dc059b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -29,7 +29,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformMatchers; import org.apache.beam.runners.core.construction.PTransformTranslation; @@ -161,11 +160,15 @@ void setClockSupplier(Supplier supplier) { @Override public DirectPipelineResult run(Pipeline originalPipeline) { Pipeline pipeline; - try { - RunnerApi.Pipeline protoPipeline = PipelineTranslation.toProto(originalPipeline); - pipeline = PipelineTranslation.fromProto(protoPipeline); - } catch (IOException exception) { - throw new RuntimeException("Error preparing pipeline for direct execution.", exception); + if (getPipelineOptions().isProtoTranslation()) { + try { + pipeline = PipelineTranslation.fromProto( + PipelineTranslation.toProto(originalPipeline)); + } catch (IOException exception) { + throw new RuntimeException("Error preparing pipeline for direct execution.", exception); + } + } else { + pipeline = originalPipeline; } pipeline.replaceAll(defaultTransformOverrides()); MetricsEnvironment.setMetricsSupported(true); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index 44df845e908b..f4c489544b24 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -42,7 +42,11 @@ import org.joda.time.Instant; class SplittableProcessElementsEvaluatorFactory< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, + OutputT, + RestrictionT, + PositionT, + TrackerT extends RestrictionTracker> implements TransformEvaluatorFactory { private final ParDoEvaluatorFactory>, OutputT> delegateFactory; diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java index 708a9315dcc6..ac5b24a21849 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java @@ -19,10 +19,10 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static org.apache.beam.runners.core.construction.PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN; import static org.apache.beam.runners.core.construction.PTransformTranslation.FLATTEN_TRANSFORM_URN; import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN; import static org.apache.beam.runners.core.construction.PTransformTranslation.READ_TRANSFORM_URN; -import static org.apache.beam.runners.core.construction.PTransformTranslation.WINDOW_TRANSFORM_URN; import static org.apache.beam.runners.core.construction.SplittableParDo.SPLITTABLE_PROCESS_URN; import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GABW_URN; import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GBKO_URN; @@ -68,7 +68,7 @@ public static TransformEvaluatorRegistry defaultRegistry(EvaluationContext ctxt) ParDoEvaluator.defaultRunnerFactory(), ParDoEvaluatorFactory.basicDoFnCacheLoader())) .put(FLATTEN_TRANSFORM_URN, new FlattenEvaluatorFactory(ctxt)) - .put(WINDOW_TRANSFORM_URN, new WindowEvaluatorFactory(ctxt)) + .put(ASSIGN_WINDOWS_TRANSFORM_URN, new WindowEvaluatorFactory(ctxt)) // Runner-specific primitives .put(DIRECT_WRITE_VIEW_URN, new ViewEvaluatorFactory(ctxt)) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java index 24f29779e5a5..d3f8136e33c0 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java @@ -45,15 +45,12 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} * for the {@link Unbounded Read.Unbounded} primitive {@link PTransform}. */ class UnboundedReadEvaluatorFactory implements TransformEvaluatorFactory { - private static final Logger LOG = LoggerFactory.getLogger(UnboundedReadEvaluatorFactory.class); // Occasionally close an existing reader and resume from checkpoint, to exercise close-and-resume private static final double DEFAULT_READER_REUSE_CHANCE = 0.95; @@ -140,8 +137,7 @@ public void processElement( } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance()); Instant watermark = reader.getWatermark(); - CheckpointMarkT finishedCheckpoint = finishRead(reader, shard); - UnboundedSourceShard residual; + CheckpointMarkT finishedCheckpoint = finishRead(reader, watermark, shard); // Sometimes resume from a checkpoint even if it's not required if (ThreadLocalRandom.current().nextDouble(1.0) >= readerReuseChance) { UnboundedReader toClose = reader; @@ -150,29 +146,36 @@ public void processElement( // if the call to close throws an IOException. reader = null; toClose.close(); - residual = - UnboundedSourceShard.of( - shard.getSource(), shard.getDeduplicator(), null, finishedCheckpoint); - } else { - residual = shard.withCheckpoint(finishedCheckpoint); } + UnboundedSourceShard residual = UnboundedSourceShard.of( + shard.getSource(), shard.getDeduplicator(), reader, finishedCheckpoint); resultBuilder .addOutput(output) .addUnprocessedElements( Collections.singleton( WindowedValue.timestampedValueInGlobalWindow(residual, watermark))); - } else if (reader.getWatermark().isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) { - // If the reader had no elements available, but the shard is not done, reuse it later - resultBuilder.addUnprocessedElements( - Collections.>singleton( - WindowedValue.timestampedValueInGlobalWindow( - UnboundedSourceShard.of( - shard.getSource(), - shard.getDeduplicator(), - reader, - shard.getCheckpoint()), - reader.getWatermark()))); + } else { + Instant watermark = reader.getWatermark(); + if (watermark.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + // If the reader had no elements available, but the shard is not done, reuse it later + // Might be better to finalize old checkpoint. + resultBuilder.addUnprocessedElements( + Collections.>singleton( + WindowedValue.timestampedValueInGlobalWindow( + UnboundedSourceShard.of( + shard.getSource(), + shard.getDeduplicator(), + reader, + shard.getCheckpoint()), + watermark))); + } else { + // End of input. Close the reader after finalizing old checkpoint. + shard.getCheckpoint().finalizeCheckpoint(); + UnboundedReader toClose = reader; + reader = null; // Avoid double close below in case of an exception. + toClose.close(); + } } } catch (IOException e) { if (reader != null) { @@ -209,11 +212,13 @@ private boolean startReader( } /** - * Checkpoint the current reader, finalize the previous checkpoint, and return the residual - * {@link UnboundedSourceShard}. + * Checkpoint the current reader, finalize the previous checkpoint, and return the current + * checkpoint. */ private CheckpointMarkT finishRead( - UnboundedReader reader, UnboundedSourceShard shard) + UnboundedReader reader, + Instant watermark, + UnboundedSourceShard shard) throws IOException { final CheckpointMark oldMark = shard.getCheckpoint(); @SuppressWarnings("unchecked") @@ -224,7 +229,7 @@ private CheckpointMarkT finishRead( // If the watermark is the max value, this source may not be invoked again. Finalize after // committing the output. - if (!reader.getWatermark().isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + if (!watermark.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) { PCollection outputPc = (PCollection) Iterables.getOnlyElement(transform.getOutputs().values()); evaluationContext.scheduleAfterOutputWouldBeProduced( @@ -277,10 +282,6 @@ static UnboundedSourceShard withCheckpoint(CheckpointT newCheckpoint) { - return of(getSource(), getDeduplicator(), getExistingReader(), newCheckpoint); - } } static class InputProvider diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java index 968c5ebc2591..566e5b88f2ba 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java @@ -173,6 +173,7 @@ public void keyedBundleDecodeFailsAddFails() { } static class Record {} + static class RecordNoEncodeCoder extends AtomicCoder { @Override diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java index c23b0f03778b..f34bb0c34b68 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java @@ -101,7 +101,7 @@ public void testApplyCommittedNoFilter() { committedMetricsResult("ns1", "name2", "step1", 12L), committedMetricsResult("ns1", "name1", "step2", 7L))); assertThat(results.distributions(), contains( - attemptedMetricsResult("ns1", "name1", "step1", DistributionResult.ZERO))); + attemptedMetricsResult("ns1", "name1", "step1", DistributionResult.IDENTITY_ELEMENT))); assertThat(results.distributions(), contains( committedMetricsResult("ns1", "name1", "step1", DistributionResult.create(12, 3, 3, 5)))); assertThat(results.gauges(), contains( diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java index ab5e8c9c5495..006613801f3d 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java @@ -166,7 +166,6 @@ public void inputBundleProcessesEachElementFinishesAndCompletes() throws Excepti @Override public void processElement(WindowedValue element) throws Exception { elementsProcessed.add(element); - return; } @Override diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java index a9f106442ba3..2aa4ab1e923d 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java @@ -41,15 +41,12 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Tests for {@link FlattenEvaluatorFactory}. - */ +/** Tests for {@link FlattenEvaluatorFactory}. */ @RunWith(JUnit4.class) public class FlattenEvaluatorFactoryTest { private BundleFactory bundleFactory = ImmutableListBundleFactory.create(); - @Rule - public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); + @Rule public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); @Test public void testFlattenInMemoryEvaluator() throws Exception { @@ -59,10 +56,8 @@ public void testFlattenInMemoryEvaluator() throws Exception { PCollection flattened = list.apply(Flatten.pCollections()); - CommittedBundle leftBundle = - bundleFactory.createBundle(left).commit(Instant.now()); - CommittedBundle rightBundle = - bundleFactory.createBundle(right).commit(Instant.now()); + CommittedBundle leftBundle = bundleFactory.createBundle(left).commit(Instant.now()); + CommittedBundle rightBundle = bundleFactory.createBundle(right).commit(Instant.now()); EvaluationContext context = mock(EvaluationContext.class); @@ -82,9 +77,9 @@ public void testFlattenInMemoryEvaluator() throws Exception { rightSideEvaluator.processElement(WindowedValue.valueInGlobalWindow(-1)); leftSideEvaluator.processElement( WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1024))); - leftSideEvaluator.processElement(WindowedValue.valueInEmptyWindows(4, PaneInfo.NO_FIRING)); + leftSideEvaluator.processElement(WindowedValue.valueInGlobalWindow(4, PaneInfo.NO_FIRING)); rightSideEvaluator.processElement( - WindowedValue.valueInEmptyWindows(2, PaneInfo.ON_TIME_AND_ONLY_FIRING)); + WindowedValue.valueInGlobalWindow(2, PaneInfo.ON_TIME_AND_ONLY_FIRING)); rightSideEvaluator.processElement( WindowedValue.timestampedValueInGlobalWindow(-4, new Instant(-4096))); @@ -104,12 +99,12 @@ public void testFlattenInMemoryEvaluator() throws Exception { flattenedLeftBundle.commit(Instant.now()).getElements(), containsInAnyOrder( WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1024)), - WindowedValue.valueInEmptyWindows(4, PaneInfo.NO_FIRING), + WindowedValue.valueInGlobalWindow(4, PaneInfo.NO_FIRING), WindowedValue.valueInGlobalWindow(1))); assertThat( flattenedRightBundle.commit(Instant.now()).getElements(), containsInAnyOrder( - WindowedValue.valueInEmptyWindows(2, PaneInfo.ON_TIME_AND_ONLY_FIRING), + WindowedValue.valueInGlobalWindow(2, PaneInfo.ON_TIME_AND_ONLY_FIRING), WindowedValue.timestampedValueInGlobalWindow(-4, new Instant(-4096)), WindowedValue.valueInGlobalWindow(-1))); } @@ -141,5 +136,4 @@ public void testFlattenInMemoryEvaluatorWithEmptyPCollectionList() throws Except leftSideResult.getTransform(), Matchers.>equalTo(flattendProducer)); } - } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java index adf9bba6dd84..d4cbe64f3587 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java @@ -26,7 +26,10 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ContiguousSet; @@ -272,10 +275,13 @@ public void noElementsAvailableReaderIncludedInResidual() throws Exception { } @Test - public void evaluatorReusesReader() throws Exception { - ContiguousSet elems = ContiguousSet.create(Range.closed(0L, 20L), DiscreteDomain.longs()); + public void evaluatorReusesReaderAndClosesAtTheEnd() throws Exception { + int numElements = 1000; + ContiguousSet elems = ContiguousSet.create( + Range.openClosed(0L, Long.valueOf(numElements)), DiscreteDomain.longs()); TestUnboundedSource source = new TestUnboundedSource<>(BigEndianLongCoder.of(), elems.toArray(new Long[0])); + source.advanceWatermarkToInfinity = true; PCollection pcollection = p.apply(Read.from(source)); DirectGraph graph = DirectGraphs.getGraph(p); @@ -283,7 +289,7 @@ public void evaluatorReusesReader() throws Exception { graph.getProducer(pcollection); when(context.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); - UncommittedBundle output = bundleFactory.createBundle(pcollection); + UncommittedBundle output = mock(UncommittedBundle.class); when(context.createBundle(pcollection)).thenReturn(output); WindowedValue> shard = @@ -297,23 +303,23 @@ public void evaluatorReusesReader() throws Exception { UnboundedReadEvaluatorFactory factory = new UnboundedReadEvaluatorFactory(context, 1.0 /* Always reuse */); new UnboundedReadEvaluatorFactory.InputProvider(context).getInitialInputs(sourceTransform, 1); - TransformEvaluator> evaluator = - factory.forApplication(sourceTransform, inputBundle); - evaluator.processElement(shard); - TransformResult> result = - evaluator.finishBundle(); - CommittedBundle> residual = - inputBundle.withElements( - (Iterable>>) - result.getUnprocessedElements()); + CommittedBundle> residual = inputBundle; - TransformEvaluator> secondEvaluator = + do { + TransformEvaluator> evaluator = factory.forApplication(sourceTransform, residual); - secondEvaluator.processElement(Iterables.getOnlyElement(residual.getElements())); - secondEvaluator.finishBundle(); + evaluator.processElement(Iterables.getOnlyElement(residual.getElements())); + TransformResult> result = + evaluator.finishBundle(); + residual = inputBundle.withElements( + (Iterable>>) + result.getUnprocessedElements()); + } while (!Iterables.isEmpty(residual.getElements())); - assertThat(TestUnboundedSource.readerClosedCount, equalTo(0)); + verify(output, times((numElements))).add(any()); + assertThat(TestUnboundedSource.readerCreatedCount, equalTo(1)); + assertThat(TestUnboundedSource.readerClosedCount, equalTo(1)); } @Test @@ -412,20 +418,23 @@ public Instant apply(Long input) { private static class TestUnboundedSource extends UnboundedSource { private static int getWatermarkCalls = 0; + static int readerCreatedCount; static int readerClosedCount; static int readerAdvancedCount; private final Coder coder; private final List elems; private boolean dedupes = false; + private boolean advanceWatermarkToInfinity = false; // After reaching end of input. private boolean throwOnClose; public TestUnboundedSource(Coder coder, T... elems) { this(coder, false, Arrays.asList(elems)); } - private TestUnboundedSource(Coder coder, boolean throwOnClose, List elems) { - readerAdvancedCount = 0; + private TestUnboundedSource(Coder coder, boolean throwOnClose, List elems) { + readerCreatedCount = 0; readerClosedCount = 0; + readerAdvancedCount = 0; this.coder = coder; this.elems = elems; this.throwOnClose = throwOnClose; @@ -443,6 +452,7 @@ public UnboundedSource.UnboundedReader createReader( checkState( checkpointMark == null || checkpointMark.decoded, "Cannot resume from a checkpoint that has not been decoded"); + readerCreatedCount++; return new TestUnboundedReader(elems, checkpointMark == null ? -1 : checkpointMark.index); } @@ -494,7 +504,12 @@ public boolean advance() throws IOException { @Override public Instant getWatermark() { getWatermarkCalls++; - return new Instant(index + getWatermarkCalls); + if (index + 1 == elems.size() + && TestUnboundedSource.this.advanceWatermarkToInfinity) { + return BoundedWindow.TIMESTAMP_MAX_VALUE; + } else { + return new Instant(index + getWatermarkCalls); + } } @Override diff --git a/runners/flink/build.gradle b/runners/flink/build.gradle index 7707ffca548b..dc77bcfe5c83 100644 --- a/runners/flink/build.gradle +++ b/runners/flink/build.gradle @@ -19,7 +19,7 @@ import groovy.json.JsonOutput apply from: project(":").file("build_rules.gradle") -applyJavaNature() +applyJavaNature(artifactId: "beam-runners-flink_2.11") description = "Apache Beam :: Runners :: Flink" @@ -117,3 +117,6 @@ task validatesRunner { dependsOn validatesRunnerBatch dependsOn validatesRunnerStreaming } + +// Generates :runners:flink:runQuickstartJavaFlinkLocal +createJavaQuickstartValidationTask(name: 'FlinkLocal') diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml index ae135a331381..46917e9a9ffe 100644 --- a/runners/flink/pom.xml +++ b/runners/flink/pom.xml @@ -299,10 +299,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + junit junit @@ -311,7 +317,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 126c81e9442e..789ebdec527b 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -115,7 +115,8 @@ class FlinkBatchTransformTranslators { TRANSLATORS.put(PTransformTranslation.FLATTEN_TRANSFORM_URN, new FlattenPCollectionTranslatorBatch()); - TRANSLATORS.put(PTransformTranslation.WINDOW_TRANSFORM_URN, new WindowAssignTranslatorBatch()); + TRANSLATORS.put( + PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, new WindowAssignTranslatorBatch()); TRANSLATORS.put(PTransformTranslation.PAR_DO_TRANSFORM_URN, new ParDoTranslatorBatch()); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java index 7a6c61f8b36d..7f7281e14bd9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java @@ -93,14 +93,15 @@ public void translate(FlinkRunner flinkRunner, Pipeline pipeline) { throw new RuntimeException(e); } - pipeline.replaceAll(FlinkTransformOverrides.getDefaultOverrides(options.isStreaming())); - PipelineTranslationOptimizer optimizer = new PipelineTranslationOptimizer(TranslationMode.BATCH, options); optimizer.translate(pipeline); TranslationMode translationMode = optimizer.getTranslationMode(); + pipeline.replaceAll(FlinkTransformOverrides.getDefaultOverrides( + translationMode == TranslationMode.STREAMING)); + FlinkPipelineTranslator translator; if (translationMode == TranslationMode.STREAMING) { this.flinkStreamEnv = createStreamExecutionEnvironment(); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 811c15940c1f..d39b5c1e9672 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -58,7 +58,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; @@ -130,7 +129,8 @@ class FlinkStreamingTransformTranslators { SPLITTABLE_PROCESS_URN, new SplittableProcessElementsStreamingTranslator()); TRANSLATORS.put(SplittableParDo.SPLITTABLE_GBKIKWI_URN, new GBKIntoKeyedWorkItemsTranslator()); - TRANSLATORS.put(PTransformTranslation.WINDOW_TRANSFORM_URN, new WindowAssignTranslator()); + TRANSLATORS.put( + PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, new WindowAssignTranslator()); TRANSLATORS.put( PTransformTranslation.FLATTEN_TRANSFORM_URN, new FlattenPCollectionTranslator()); TRANSLATORS.put( @@ -253,7 +253,7 @@ void translateNode( if (context.getOutput(transform).isBounded().equals(PCollection.IsBounded.BOUNDED)) { boundedTranslator.translateNode(transform, context); } else { - unboundedTranslator.translateNode((Read.Unbounded) transform, context); + unboundedTranslator.translateNode(transform, context); } } } @@ -625,7 +625,7 @@ public void translateNode( } private static class SplittableProcessElementsStreamingTranslator< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< SplittableParDoViaKeyedWorkItems.ProcessElements> { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java index 3acc3eafca13..8877f1a044ac 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java @@ -17,9 +17,11 @@ */ package org.apache.beam.runners.flink; -import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -60,13 +62,21 @@ public void leaveCompositeTransform(TransformHierarchy.Node node) {} @Override public void visitPrimitiveTransform(TransformHierarchy.Node node) { - Class transformClass = node.getTransform().getClass(); - if (transformClass == Read.Unbounded.class) { + AppliedPTransform appliedPTransform = node.toAppliedPTransform(getPipeline()); + if (hasUnboundedOutput(appliedPTransform)) { + Class transformClass = node.getTransform().getClass(); LOG.info("Found {}. Switching to streaming execution.", transformClass); translationMode = TranslationMode.STREAMING; } } + private boolean hasUnboundedOutput(AppliedPTransform transform) { + return transform.getOutputs().values().stream() + .filter(value -> value instanceof PCollection) + .map(value -> (PCollection) value) + .anyMatch(collection -> collection.isBounded() == IsBounded.UNBOUNDED); + } + @Override public void visitValue(PValue value, TransformHierarchy.Node producer) {} } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index de3c0546b494..41a35ce05386 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -276,8 +276,10 @@ public void open() throws Exception { keyedStateInternals = new FlinkStateInternals<>((KeyedStateBackend) getKeyedStateBackend(), keyCoder); - timerService = (HeapInternalTimerService) - getInternalTimerService("beam-timer", new CoderTypeSerializer<>(timerCoder), this); + if (timerService == null) { + timerService = (HeapInternalTimerService) + getInternalTimerService("beam-timer", new CoderTypeSerializer<>(timerCoder), this); + } timerInternals = new FlinkTimerInternals(); @@ -376,12 +378,10 @@ public void close() throws Exception { nonKeyedStateInternals.state(StateNamespaces.global(), pushedBackTag); Iterable> pushedBackContents = pushedBack.read(); - if (pushedBackContents != null) { - if (!Iterables.isEmpty(pushedBackContents)) { - String pushedBackString = Joiner.on(",").join(pushedBackContents); - throw new RuntimeException( - "Leftover pushed-back data: " + pushedBackString + ". This indicates a bug."); - } + if (pushedBackContents != null && !Iterables.isEmpty(pushedBackContents)) { + String pushedBackString = Joiner.on(",").join(pushedBackContents); + throw new RuntimeException( + "Leftover pushed-back data: " + pushedBackString + ". This indicates a bug."); } } } @@ -730,11 +730,15 @@ public void initializeState(StateInitializationContext context) throws Exception // We just initialize our timerService if (keyCoder != null) { if (timerService == null) { - timerService = new HeapInternalTimerService<>( - totalKeyGroups, - localKeyGroupRange, - this, - getRuntimeContext().getProcessingTimeService()); + final HeapInternalTimerService localService = + new HeapInternalTimerService<>( + totalKeyGroups, + localKeyGroupRange, + this, + getRuntimeContext().getProcessingTimeService()); + localService.startTimerService(getKeyedStateBackend().getKeySerializer(), + new CoderTypeSerializer<>(timerCoder), this); + timerService = localService; } timerService.restoreTimersForKeyGroup(div, keyGroupIdx, getUserCodeClassloader()); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java index d0d283044ee1..1a418a036986 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java @@ -55,7 +55,7 @@ * the {@code @ProcessElement} method of a splittable {@link DoFn}. */ public class SplittableDoFnOperator< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> extends DoFnOperator>, OutputT> { private transient ScheduledExecutorService executorService; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 2a506e9d07b6..4990d70106c5 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -1067,7 +1067,7 @@ public ReadableState contains(final T t) { namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).get(t); - return ReadableStates.immediate(result != null ? result : false); + return ReadableStates.immediate(result != null && result); } catch (Exception e) { throw new RuntimeException("Error contains value from state.", e); } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java new file mode 100644 index 000000000000..0e5ce144135e --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.io.Serializable; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlinkPipelineExecutionEnvironment}. + */ +@RunWith(JUnit4.class) +public class FlinkPipelineExecutionEnvironmentTest implements Serializable { + + @Test + public void shouldRecognizeAndTranslateStreamingPipeline() { + FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("[auto]"); + + FlinkRunner flinkRunner = FlinkRunner.fromOptions(options); + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline pipeline = Pipeline.create(); + + pipeline + .apply(GenerateSequence.from(0).withRate(1, Duration.standardSeconds(1))) + .apply(ParDo.of(new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output(Long.toString(c.element())); + } + })) + .apply(Window.into(FixedWindows.of(Duration.standardHours(1)))) + .apply(TextIO.write().withNumShards(1).withWindowedWrites().to("/dummy/path")); + + flinkEnv.translate(flinkRunner, pipeline); + + // no exception should be thrown + } + +} + + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java index 6c3204767d96..73a0a08f29c6 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java @@ -58,6 +58,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -609,6 +610,110 @@ public void testSideInputs(boolean keyed) throws Exception { } + @Test + public void testTimersRestore() throws Exception { + final Instant timerTimestamp = new Instant(1000); + final String outputMessage = "Timer fired"; + + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(new Duration(10_000))); + + DoFn fn = new DoFn() { + private static final String EVENT_TIMER_ID = "eventTimer"; + + @TimerId(EVENT_TIMER_ID) + private final TimerSpec eventTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + public void processElement(ProcessContext context, @TimerId(EVENT_TIMER_ID) Timer timer) { + timer.set(timerTimestamp); + } + + @OnTimer(EVENT_TIMER_ID) + public void onEventTime(OnTimerContext context) { + assertEquals( + "Timer timestamp must match set timestamp.", timerTimestamp, context.timestamp()); + context.outputWithTimestamp(outputMessage, context.timestamp()); + } + }; + + WindowedValue.FullWindowedValueCoder inputCoder = + WindowedValue.getFullCoder( + VarIntCoder.of(), + windowingStrategy.getWindowFn().windowCoder()); + + WindowedValue.FullWindowedValueCoder outputCoder = + WindowedValue.getFullCoder( + StringUtf8Coder.of(), + windowingStrategy.getWindowFn().windowCoder()); + + + TupleTag outputTag = new TupleTag<>("main-output"); + final CoderTypeSerializer> outputSerializer = new CoderTypeSerializer<>( + outputCoder); + + OneInputStreamOperatorTestHarness, WindowedValue> testHarness = + createTestHarness(windowingStrategy, fn, inputCoder, outputCoder, outputTag); + + testHarness.setup(outputSerializer); + + testHarness.open(); + + testHarness.processWatermark(0); + + IntervalWindow window1 = new IntervalWindow(new Instant(0), Duration.millis(10_000)); + + // this should register a timer + testHarness.processElement( + new StreamRecord<>(WindowedValue.of(13, new Instant(0), window1, PaneInfo.NO_FIRING))); + + assertThat( + this.stripStreamRecordFromWindowedValue(testHarness.getOutput()), + emptyIterable()); + + // snapshot and restore + final OperatorStateHandles snapshot = testHarness.snapshot(0, 0); + testHarness.close(); + + testHarness = createTestHarness(windowingStrategy, fn, inputCoder, outputCoder, outputTag); + testHarness.setup(outputSerializer); + testHarness.initializeState(snapshot); + testHarness.open(); + + // this must fire the timer + testHarness.processWatermark(timerTimestamp.getMillis() + 1); + + assertThat( + this.stripStreamRecordFromWindowedValue(testHarness.getOutput()), + contains( + WindowedValue.of( + outputMessage, new Instant(timerTimestamp), window1, PaneInfo.NO_FIRING))); + + testHarness.close(); + } + + private OneInputStreamOperatorTestHarness, WindowedValue> + createTestHarness(WindowingStrategy windowingStrategy, + DoFn fn, FullWindowedValueCoder inputCoder, + FullWindowedValueCoder outputCoder, TupleTag outputTag) throws Exception { + DoFnOperator doFnOperator = + new DoFnOperator<>( + fn, + "stepName", + inputCoder, + outputTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>(outputTag, outputCoder), + windowingStrategy, + new HashMap<>(), /* side-input mapping */ + Collections.emptyList(), /* side inputs */ + PipelineOptionsFactory.as(FlinkPipelineOptions.class), + VarIntCoder.of() /* key coder */); + + return new KeyedOneInputStreamOperatorTestHarness<>( + doFnOperator, WindowedValue::getValue, new CoderTypeInformation<>(VarIntCoder.of())); + } + /** * {@link TwoInputStreamOperatorTestHarness} support OperatorStateBackend, * but don't support KeyedStateBackend. So we just test sideInput of normal ParDo. diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializerTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializerTest.java index b0c40dee79e4..2e4928ca6349 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializerTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializerTest.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; - import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer.CoderTypeSerializerConfigSnapshot; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.Coder; diff --git a/runners/gcp/gcemd/build.gradle b/runners/gcp/gcemd/build.gradle index a32dc770c682..36d5974dac65 100644 --- a/runners/gcp/gcemd/build.gradle +++ b/runners/gcp/gcemd/build.gradle @@ -21,6 +21,20 @@ applyGoNature() description = "Apache Beam :: Runners :: Google Cloud Platform :: GCE metadata provisioning" +// Figure out why the golang plugin does not add a build dependency between projects. +// Without the line below, we get spurious errors about not being able to resolve +// "./github.com/apache/beam/sdks/go" +resolveBuildDependencies.dependsOn ":sdks:go:build" + +dependencies { + golang { + // TODO(herohde): use "./" prefix to prevent gogradle use base github path, for now. + // TODO(herohde): get the pkg subdirectory only, if possible. We spend mins pulling cmd/beamctl deps. + build name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir + test name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir + } +} + golang { packagePath = 'github.com/apache/beam/runners/gcp/gcemd' build { diff --git a/runners/gcp/gcsproxy/build.gradle b/runners/gcp/gcsproxy/build.gradle index 1a02afa522ef..e231e2bcedb5 100644 --- a/runners/gcp/gcsproxy/build.gradle +++ b/runners/gcp/gcsproxy/build.gradle @@ -21,6 +21,20 @@ applyGoNature() description = "Apache Beam :: Runners :: Google Cloud Platform :: GCS artifact proxy" +// Figure out why the golang plugin does not add a build dependency between projects. +// Without the line below, we get spurious errors about not being able to resolve +// "./github.com/apache/beam/sdks/go" +resolveBuildDependencies.dependsOn ":sdks:go:build" + +dependencies { + golang { + // TODO(herohde): use "./" prefix to prevent gogradle use base github path, for now. + // TODO(herohde): get the pkg subdirectory only, if possible. We spend mins pulling cmd/beamctl deps. + build name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir + test name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir + } +} + golang { packagePath = 'github.com/apache/beam/runners/gcp/gcsproxy' build { diff --git a/runners/gearpump/build.gradle b/runners/gearpump/build.gradle index 701a809b5384..cd96dc4cea04 100644 --- a/runners/gearpump/build.gradle +++ b/runners/gearpump/build.gradle @@ -17,7 +17,7 @@ */ apply from: project(":").file("build_rules.gradle") -applyJavaNature() +applyJavaNature(artifactId: "beam-runners-gearpump") description = "Apache Beam :: Runners :: Gearpump" diff --git a/runners/gearpump/pom.xml b/runners/gearpump/pom.xml index 0ea758ea8a0d..54348d93311c 100644 --- a/runners/gearpump/pom.xml +++ b/runners/gearpump/pom.xml @@ -175,7 +175,12 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test @@ -202,7 +207,7 @@ org.mockito - mockito-all + mockito-core test @@ -227,17 +232,6 @@ - - - maven-compiler-plugin - - 1.8 - 1.8 - 1.8 - 1.8 - - - org.apache.maven.plugins diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java index e02cbbc01a8d..aab6a531a54c 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java @@ -19,13 +19,10 @@ package org.apache.beam.runners.gearpump; import com.fasterxml.jackson.annotation.JsonIgnore; - import java.util.Map; - import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; - import org.apache.gearpump.cluster.client.ClientContext; import org.apache.gearpump.cluster.embedded.EmbeddedCluster; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java index 5febf3ccf3c5..b395c37f3854 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java @@ -19,18 +19,14 @@ import com.typesafe.config.Config; import com.typesafe.config.ConfigValueFactory; - import java.util.HashMap; import java.util.Map; - import org.apache.beam.runners.gearpump.translators.GearpumpPipelineTranslator; import org.apache.beam.runners.gearpump.translators.TranslationContext; - import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; - import org.apache.gearpump.cluster.ClusterConfig; import org.apache.gearpump.cluster.UserConfig; import org.apache.gearpump.cluster.client.ClientContext; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java index 0a8884930ac2..6f3941009ca9 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java @@ -24,7 +24,6 @@ import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; - import org.apache.gearpump.cluster.ClusterConfig; import org.apache.gearpump.cluster.embedded.EmbeddedCluster; import org.apache.gearpump.util.Constants; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java index 559cb28dda4c..62e5e4432dcc 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.gearpump.translators; import java.util.List; - import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollectionView; import org.apache.gearpump.streaming.dsl.javaapi.JavaStream; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java index ca98aac35893..7d9379f4256a 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java @@ -19,11 +19,9 @@ package org.apache.beam.runners.gearpump.translators; import com.google.common.collect.ImmutableList; - import java.util.HashMap; import java.util.List; import java.util.Map; - import org.apache.beam.runners.core.construction.PTransformMatchers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.Read; @@ -36,7 +34,6 @@ import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PValue; - import org.apache.gearpump.util.Graph; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java index d92979b91d71..46cd99f842f1 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java @@ -22,7 +22,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; - import org.apache.beam.runners.gearpump.translators.functions.DoFnFunction; import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils; import org.apache.beam.sdk.transforms.DoFn; @@ -32,7 +31,6 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; - import org.apache.gearpump.streaming.dsl.api.functions.FilterFunction; import org.apache.gearpump.streaming.dsl.javaapi.JavaStream; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java index 0462c57e1f01..d0343f545760 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java @@ -22,7 +22,6 @@ import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.util.WindowedValue; - import org.apache.gearpump.streaming.dsl.javaapi.JavaStream; import org.apache.gearpump.streaming.source.DataSource; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java index c7becadc9eb1..4ad6cdf81583 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.gearpump.translators; import java.io.Serializable; - import org.apache.beam.sdk.transforms.PTransform; /** diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java index 42b7a536ab8c..255e4dc4ac7b 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java @@ -21,17 +21,14 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.common.collect.Iterables; - import java.util.HashMap; import java.util.Map; - import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.gearpump.GearpumpPipelineOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.values.PValue; - import org.apache.beam.sdk.values.TupleTag; import org.apache.gearpump.cluster.UserConfig; import org.apache.gearpump.streaming.dsl.javaapi.JavaStream; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java index d144b958ae85..5f0cb6107e1e 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java @@ -19,12 +19,10 @@ package org.apache.beam.runners.gearpump.translators; import com.google.common.collect.Iterables; - import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; - import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java index fde265a83d1c..cdb0e054c60d 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java @@ -19,9 +19,7 @@ package org.apache.beam.runners.gearpump.translators.functions; import com.google.common.collect.Iterables; - import com.google.common.collect.Lists; - import java.io.Serializable; import java.util.Collection; import java.util.HashSet; @@ -30,7 +28,6 @@ import java.util.List; import java.util.Map; import java.util.Set; - import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.InMemoryStateInternals; import org.apache.beam.runners.core.PushbackSideInputDoFnRunner; @@ -45,7 +42,6 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; - import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java index 2c187355cbe5..ae26769aa6fb 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.gearpump.translators.io; import java.io.IOException; - import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.options.PipelineOptions; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java index 3766195f1b2a..dd5c8f6dd4af 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.time.Instant; - import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils; import org.apache.beam.sdk.io.Source; @@ -28,7 +27,6 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; - import org.apache.gearpump.DefaultMessage; import org.apache.gearpump.Message; import org.apache.gearpump.streaming.source.DataSource; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java index cb912c12d9c2..b9aa04b7f815 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.gearpump.translators.io; import java.io.IOException; - import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java index 6557c8bb66e1..db4736677c9a 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java @@ -21,7 +21,6 @@ import java.io.Serializable; import java.util.Collection; import java.util.List; - import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.PushbackSideInputDoFnRunner; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java index b795ed989948..23ec2120aa12 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.gearpump.translators.utils; import java.io.Serializable; - import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StepContext; import org.apache.beam.runners.core.TimerInternals; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java index 2dae9555e975..935470153dac 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java @@ -19,20 +19,17 @@ package org.apache.beam.runners.gearpump.translators.utils; import com.google.common.collect.Lists; - import java.time.Instant; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; - import org.apache.beam.runners.gearpump.translators.TranslationContext; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollectionView; - import org.apache.gearpump.streaming.dsl.api.functions.FoldFunction; import org.apache.gearpump.streaming.dsl.api.functions.MapFunction; import org.apache.gearpump.streaming.dsl.javaapi.JavaStream; diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java index 994856be4a58..bf2b527e5453 100644 --- a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java +++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java @@ -19,14 +19,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; + import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Maps; import com.typesafe.config.Config; - import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Map; - import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.gearpump.cluster.ClusterConfig; diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java index 1115fad695a0..7b1a4cfeb15f 100644 --- a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java +++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java @@ -29,7 +29,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; - import org.apache.beam.runners.gearpump.GearpumpPipelineOptions; import org.apache.beam.runners.gearpump.translators.io.UnboundedSourceWrapper; import org.apache.beam.sdk.options.PipelineOptionsFactory; diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java index d5b931b78868..a5248f3b124f 100644 --- a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java +++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java @@ -23,11 +23,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; - import java.time.Instant; import java.util.Collection; import java.util.List; - import org.apache.beam.runners.gearpump.translators.GroupByKeyTranslator.GearpumpWindowFn; import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java index cc4284f54eee..6f1cdf716a12 100644 --- a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java +++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java @@ -19,11 +19,9 @@ package org.apache.beam.runners.gearpump.translators.io; import com.google.common.collect.Lists; - import java.io.IOException; import java.time.Instant; import java.util.List; - import org.apache.beam.runners.gearpump.GearpumpPipelineOptions; import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils; import org.apache.beam.sdk.coders.StringUtf8Coder; diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java index 439e1b18a0b8..e63ba78ee809 100644 --- a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java +++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java @@ -21,12 +21,10 @@ import com.google.common.collect.Sets; import com.typesafe.config.Config; import com.typesafe.config.ConfigValueFactory; - import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; - import org.apache.beam.runners.gearpump.GearpumpPipelineOptions; import org.apache.beam.runners.gearpump.GearpumpRunner; import org.apache.beam.sdk.Pipeline; diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java index 6ebe59bfc082..07e6da34dd20 100644 --- a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java +++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java @@ -22,10 +22,8 @@ import static org.junit.Assert.assertThat; import com.google.common.collect.Lists; - import java.time.Instant; import java.util.List; - import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index e0590203b1c2..3bd62c289671 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -17,7 +17,8 @@ */ apply from: project(":").file("build_rules.gradle") -applyJavaNature(enableFindbugs: false /* BEAM-925 */) +applyJavaNature(enableFindbugs: false /* BEAM-925 */, + artifactId: "beam-runners-google-cloud-dataflow-java") description = "Apache Beam :: Runners :: Google Cloud Dataflow" @@ -86,3 +87,8 @@ dependencies { test { systemProperties = [ "beamUseDummyRunner" : "true" ] } + +// Generates :runners:google-cloud-dataflow-java:runQuickstartJavaDataflow +def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' +def gcsBucket = project.findProperty('gcsBucket') ?: 'temp-storage-for-release-validation-tests/quickstart' +createJavaQuickstartValidationTask(name: 'Dataflow', gcpProject: gcpProject, gcsBucket: gcsBucket) diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index 7c427b205fb3..206bb040b9de 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -33,7 +33,7 @@ jar - beam-master-20180122 + beam-master-20180205 1 7 @@ -461,10 +461,16 @@ org.hamcrest - hamcrest-all + hamcrest-core provided - + + + org.hamcrest + hamcrest-library + provided + + junit junit @@ -479,7 +485,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index 1196a0647a80..a10472d1f451 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -120,6 +120,7 @@ public Map mapOutputs( return ReplacementOutputs.singleton(outputs, newOutput); } } + private static class MultiOutputOverrideFactory implements PTransformOverrideFactory< PCollection>, PCollectionTuple, ParDo.MultiOutput, OutputT>> { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java index 7c4df9f91dc3..87b3437ca849 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java @@ -18,9 +18,9 @@ package org.apache.beam.runners.dataflow; import static com.google.common.base.Preconditions.checkState; -import static org.apache.beam.sdk.util.WindowedValue.valueInEmptyWindows; import com.google.common.base.Function; +import com.google.common.base.MoreObjects; import com.google.common.base.Optional; import com.google.common.collect.ForwardingMap; import com.google.common.collect.HashMultimap; @@ -34,11 +34,13 @@ import java.io.OutputStream; import java.io.Serializable; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import org.apache.beam.runners.dataflow.internal.IsmFormat; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecord; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecordCoder; @@ -66,6 +68,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.SystemDoFnInternal; @@ -80,74 +83,75 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; +import org.joda.time.Instant; /** * Dataflow batch overrides for {@link CreatePCollectionView}, specialized for different view types. */ class BatchViewOverrides { /** - * Specialized implementation for - * {@link org.apache.beam.sdk.transforms.View.AsMap View.AsMap} for the - * Dataflow runner in batch mode. + * Specialized implementation for {@link org.apache.beam.sdk.transforms.View.AsMap View.AsMap} for + * the Dataflow runner in batch mode. + * + *

Creates a set of {@code Ism} files sharded by the hash of the key's byte representation. + * Each record is structured as follows: * - *

Creates a set of {@code Ism} files sharded by the hash of the key's byte - * representation. Each record is structured as follows: *

    - *
  • Key 1: User key K
  • - *
  • Key 2: Window
  • - *
  • Key 3: 0L (constant)
  • - *
  • Value: Windowed value
  • + *
  • Key 1: User key K + *
  • Key 2: Window + *
  • Key 3: 0L (constant) + *
  • Value: Windowed value *
* *

Alongside the data records, there are the following metadata records: + * *

    - *
  • Key 1: Metadata Key
  • - *
  • Key 2: Window
  • - *
  • Key 3: Index [0, size of map]
  • - *
  • Value: variable length long byte representation of size of map if index is 0, - * otherwise the byte representation of a key
  • + *
  • Key 1: Metadata Key + *
  • Key 2: Window + *
  • Key 3: Index [0, size of map] + *
  • Value: variable length long byte representation of size of map if index is 0, otherwise + * the byte representation of a key *
- * The {@code [META, Window, 0]} record stores the number of unique keys per window, while - * {@code [META, Window, i]} for {@code i} in {@code [1, size of map]} stores a the users key. - * This allows for one to access the size of the map by looking at {@code [META, Window, 0]} - * and iterate over all the keys by accessing {@code [META, Window, i]} for {@code i} in - * {@code [1, size of map]}. * - *

Note that in the case of a non-deterministic key coder, we fallback to using - * {@link org.apache.beam.sdk.transforms.View.AsSingleton View.AsSingleton} printing - * a warning to users to specify a deterministic key coder. + *

The {@code [META, Window, 0]} record stores the number of unique keys per window, while + * {@code [META, Window, i]} for {@code i} in {@code [1, size of map]} stores a the users key. + * This allows for one to access the size of the map by looking at {@code [META, Window, 0]} and + * iterate over all the keys by accessing {@code [META, Window, i]} for {@code i} in {@code [1, + * size of map]}. + * + *

Note that in the case of a non-deterministic key coder, we fallback to using {@link + * org.apache.beam.sdk.transforms.View.AsSingleton View.AsSingleton} printing a warning to users + * to specify a deterministic key coder. */ - static class BatchViewAsMap - extends PTransform>, PCollection> { + static class BatchViewAsMap extends PTransform>, PCollection> { /** - * A {@link DoFn} which groups elements by window boundaries. For each group, - * the group of elements is transformed into a {@link TransformedMap}. - * The transformed {@code Map} is backed by a {@code Map>} - * and contains a function {@code WindowedValue -> V}. + * A {@link DoFn} which groups elements by window boundaries. For each group, the group of + * elements is transformed into a {@link TransformedMap}. The transformed {@code Map} is + * backed by a {@code Map>} and contains a function {@code WindowedValue + * -> V}. * *

Outputs {@link IsmRecord}s having: + * *

    - *
  • Key 1: Window
  • - *
  • Value: Transformed map containing a transform that removes the encapsulation - * of the window around each value, - * {@code Map> -> Map}.
  • + *
  • Key 1: Window + *
  • Value: Transformed map containing a transform that removes the encapsulation of the + * window around each value, {@code Map> -> Map}. *
*/ static class ToMapDoFn - extends DoFn>>>>, - IsmRecord, - V>>>> { + extends DoFn< + KV>>>>, + IsmRecord, V>>>> { private final Coder windowCoder; + ToMapDoFn(Coder windowCoder) { this.windowCoder = windowCoder; } @ProcessElement - public void processElement(ProcessContext c) - throws Exception { + public void processElement(ProcessContext c) throws Exception { Optional previousWindowStructuralValue = Optional.absent(); Optional previousWindow = Optional.absent(); Map> map = new HashMap<>(); @@ -165,12 +169,14 @@ public void processElement(ProcessContext c) } // Verify that the user isn't trying to insert the same key multiple times. - checkState(!map.containsKey(kv.getValue().getValue().getKey()), + checkState( + !map.containsKey(kv.getValue().getValue().getKey()), "Multiple values [%s, %s] found for single key [%s] within window [%s].", map.get(kv.getValue().getValue().getKey()), kv.getValue().getValue().getValue(), kv.getKey()); - map.put(kv.getValue().getValue().getKey(), + map.put( + kv.getValue().getValue().getKey(), kv.getValue().withValue(kv.getValue().getValue().getValue())); previousWindowStructuralValue = Optional.of(currentWindowStructuralValue); previousWindow = Optional.of(kv.getKey()); @@ -201,8 +207,7 @@ public PCollection expand(PCollection> input) { return this.applyInternal(input); } - private PCollection - applyInternal(PCollection> input) { + private PCollection applyInternal(PCollection> input) { try { return BatchViewAsMultimap.applyForMapLike(runner, input, view, true /* unique keys */); } catch (NonDeterministicException e) { @@ -220,11 +225,10 @@ protected String getKindString() { } /** Transforms the input {@link PCollection} into a singleton {@link Map} per window. */ - private PCollection - applyForSingletonFallback(PCollection> input) { + private PCollection applyForSingletonFallback( + PCollection> input) { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); @SuppressWarnings({"rawtypes", "unchecked"}) KvCoder inputCoder = (KvCoder) input.getCoder(); @@ -246,54 +250,57 @@ protected String getKindString() { } /** - * Specialized implementation for - * {@link org.apache.beam.sdk.transforms.View.AsMultimap View.AsMultimap} for the - * Dataflow runner in batch mode. + * Specialized implementation for {@link org.apache.beam.sdk.transforms.View.AsMultimap + * View.AsMultimap} for the Dataflow runner in batch mode. + * + *

Creates a set of {@code Ism} files sharded by the hash of the key's byte representation. + * Each record is structured as follows: * - *

Creates a set of {@code Ism} files sharded by the hash of the key's byte - * representation. Each record is structured as follows: *

    - *
  • Key 1: User key K
  • - *
  • Key 2: Window
  • - *
  • Key 3: Index offset for a given key and window.
  • - *
  • Value: Windowed value
  • + *
  • Key 1: User key K + *
  • Key 2: Window + *
  • Key 3: Index offset for a given key and window. + *
  • Value: Windowed value *
* *

Alongside the data records, there are the following metadata records: + * *

    - *
  • Key 1: Metadata Key
  • - *
  • Key 2: Window
  • - *
  • Key 3: Index [0, size of map]
  • - *
  • Value: variable length long byte representation of size of map if index is 0, - * otherwise the byte representation of a key
  • + *
  • Key 1: Metadata Key + *
  • Key 2: Window + *
  • Key 3: Index [0, size of map] + *
  • Value: variable length long byte representation of size of map if index is 0, otherwise + * the byte representation of a key *
- * The {@code [META, Window, 0]} record stores the number of unique keys per window, while - * {@code [META, Window, i]} for {@code i} in {@code [1, size of map]} stores a the users key. - * This allows for one to access the size of the map by looking at {@code [META, Window, 0]} - * and iterate over all the keys by accessing {@code [META, Window, i]} for {@code i} in - * {@code [1, size of map]}. * - *

Note that in the case of a non-deterministic key coder, we fallback to using - * {@link org.apache.beam.sdk.transforms.View.AsSingleton View.AsSingleton} printing - * a warning to users to specify a deterministic key coder. + *

The {@code [META, Window, 0]} record stores the number of unique keys per window, while + * {@code [META, Window, i]} for {@code i} in {@code [1, size of map]} stores a the users key. + * This allows for one to access the size of the map by looking at {@code [META, Window, 0]} and + * iterate over all the keys by accessing {@code [META, Window, i]} for {@code i} in {@code [1, + * size of map]}. + * + *

Note that in the case of a non-deterministic key coder, we fallback to using {@link + * org.apache.beam.sdk.transforms.View.AsSingleton View.AsSingleton} printing a warning to users + * to specify a deterministic key coder. */ - static class BatchViewAsMultimap - extends PTransform>, PCollection> { + static class BatchViewAsMultimap extends PTransform>, PCollection> { /** - * A {@link PTransform} that groups elements by the hash of window's byte representation - * if the input {@link PCollection} is not within the global window. Otherwise by the hash - * of the window and key's byte representation. This {@link PTransform} also sorts - * the values by the combination of the window and key's byte representations. + * A {@link PTransform} that groups elements by the hash of window's byte representation if the + * input {@link PCollection} is not within the global window. Otherwise by the hash of the + * window and key's byte representation. This {@link PTransform} also sorts the values by the + * combination of the window and key's byte representations. */ private static class GroupByKeyHashAndSortByKeyAndWindow - extends PTransform>, - PCollection, WindowedValue>>>>> { + extends PTransform< + PCollection>, + PCollection, WindowedValue>>>>> { @SystemDoFnInternal private static class GroupByKeyHashAndSortByKeyAndWindowDoFn extends DoFn, KV, WindowedValue>>> { private final IsmRecordCoder coder; + private GroupByKeyHashAndSortByKeyAndWindowDoFn(IsmRecordCoder coder) { this.coder = coder; } @@ -304,38 +311,38 @@ public void processElement(ProcessContext c, BoundedWindow untypedWindow) throws W window = (W) untypedWindow; c.output( - KV.of(coder.hash(ImmutableList.of(c.element().getKey())), - KV.of(KV.of(c.element().getKey(), window), + KV.of( + coder.hash(ImmutableList.of(c.element().getKey())), + KV.of( + KV.of(c.element().getKey(), window), WindowedValue.of( - c.element().getValue(), - c.timestamp(), - untypedWindow, - c.pane())))); + c.element().getValue(), c.timestamp(), untypedWindow, c.pane())))); } } private final IsmRecordCoder coder; + public GroupByKeyHashAndSortByKeyAndWindow(IsmRecordCoder coder) { this.coder = coder; } @Override - public PCollection, WindowedValue>>>> - expand(PCollection> input) { + public PCollection, WindowedValue>>>> expand( + PCollection> input) { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); @SuppressWarnings("unchecked") KvCoder inputCoder = (KvCoder) input.getCoder(); PCollection, WindowedValue>>> keyedByHash; - keyedByHash = input.apply( - ParDo.of(new GroupByKeyHashAndSortByKeyAndWindowDoFn(coder))); + keyedByHash = + input.apply(ParDo.of(new GroupByKeyHashAndSortByKeyAndWindowDoFn(coder))); keyedByHash.setCoder( KvCoder.of( VarIntCoder.of(), - KvCoder.of(KvCoder.of(inputCoder.getKeyCoder(), windowCoder), + KvCoder.of( + KvCoder.of(inputCoder.getKeyCoder(), windowCoder), FullWindowedValueCoder.of(inputCoder.getValueCoder(), windowCoder)))); return keyedByHash.apply(new GroupByKeyAndSortValuesOnly<>()); @@ -343,13 +350,14 @@ public GroupByKeyHashAndSortByKeyAndWindow(IsmRecordCoder coder) { } /** - * A {@link DoFn} which creates {@link IsmRecord}s comparing successive elements windows - * and keys to locate window and key boundaries. The main output {@link IsmRecord}s have: + * A {@link DoFn} which creates {@link IsmRecord}s comparing successive elements windows and + * keys to locate window and key boundaries. The main output {@link IsmRecord}s have: + * *

    - *
  • Key 1: Window
  • - *
  • Key 2: User key K
  • - *
  • Key 3: Index offset for a given key and window.
  • - *
  • Value: Windowed value
  • + *
  • Key 1: Window + *
  • Key 2: User key K + *
  • Key 3: Index offset for a given key and window. + *
  • Value: Windowed value *
* *

Additionally, we output all the unique keys per window seen to {@code outputForEntrySet} @@ -359,8 +367,8 @@ public GroupByKeyHashAndSortByKeyAndWindow(IsmRecordCoder coder) { * throw an {@link IllegalStateException} if more than one key per window is found. */ static class ToIsmRecordForMapLikeDoFn - extends DoFn, WindowedValue>>>, - IsmRecord>> { + extends DoFn< + KV, WindowedValue>>>, IsmRecord>> { private final TupleTag>> outputForSize; private final TupleTag>> outputForEntrySet; @@ -368,6 +376,7 @@ static class ToIsmRecordForMapLikeDoFn private final Coder keyCoder; private final IsmRecordCoder> ismCoder; private final boolean uniqueKeysExpected; + ToIsmRecordForMapLikeDoFn( TupleTag>> outputForSize, TupleTag>> outputForEntrySet, @@ -391,15 +400,13 @@ public void processElement(ProcessContext c) throws Exception { Iterator, WindowedValue>> iterator = c.element().getValue().iterator(); KV, WindowedValue> currentValue = iterator.next(); - Object currentKeyStructuralValue = - keyCoder.structuralValue(currentValue.getKey().getKey()); + Object currentKeyStructuralValue = keyCoder.structuralValue(currentValue.getKey().getKey()); Object currentWindowStructuralValue = windowCoder.structuralValue(currentValue.getKey().getValue()); while (iterator.hasNext()) { KV, WindowedValue> nextValue = iterator.next(); - Object nextKeyStructuralValue = - keyCoder.structuralValue(nextValue.getKey().getKey()); + Object nextKeyStructuralValue = keyCoder.structuralValue(nextValue.getKey().getKey()); Object nextWindowStructuralValue = windowCoder.structuralValue(nextValue.getKey().getValue()); @@ -418,7 +425,7 @@ public void processElement(ProcessContext c) throws Exception { nextKeyIndex = 0; nextUniqueKeyCounter = 1; - } else if (!currentKeyStructuralValue.equals(nextKeyStructuralValue)){ + } else if (!currentKeyStructuralValue.equals(nextKeyStructuralValue)) { // It is a new key within the same window so output the key for the entry set, // reset the key index and increase the count of unique keys seen within this window. outputMetadataRecordForEntrySet(c, currentValue); @@ -432,12 +439,13 @@ public void processElement(ProcessContext c) throws Exception { nextKeyIndex = currentKeyIndex + 1; nextUniqueKeyCounter = currentUniqueKeyCounter; } else { - throw new IllegalStateException(String.format( - "Unique keys are expected but found key %s with values %s and %s in window %s.", - currentValue.getKey().getKey(), - currentValue.getValue().getValue(), - nextValue.getValue().getValue(), - currentValue.getKey().getValue())); + throw new IllegalStateException( + String.format( + "Unique keys are expected but found key %s with values %s and %s in window %s.", + currentValue.getKey().getKey(), + currentValue.getValue().getValue(), + nextValue.getValue().getValue(), + currentValue.getKey().getValue())); } currentValue = nextValue; @@ -457,12 +465,10 @@ public void processElement(ProcessContext c) throws Exception { /** This outputs the data record. */ private void outputDataRecord( ProcessContext c, KV, WindowedValue> value, long keyIndex) { - IsmRecord> ismRecord = IsmRecord.of( - ImmutableList.of( - value.getKey().getKey(), - value.getKey().getValue(), - keyIndex), - value.getValue()); + IsmRecord> ismRecord = + IsmRecord.of( + ImmutableList.of(value.getKey().getKey(), value.getKey().getValue(), keyIndex), + value.getValue()); c.output(ismRecord); } @@ -471,37 +477,43 @@ private void outputDataRecord( */ private void outputMetadataRecordForSize( ProcessContext c, KV, WindowedValue> value, long uniqueKeyCount) { - c.output(outputForSize, - KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), - value.getKey().getValue())), + c.output( + outputForSize, + KV.of( + ismCoder.hash( + ImmutableList.of(IsmFormat.getMetadataKey(), value.getKey().getValue())), KV.of(value.getKey().getValue(), uniqueKeyCount))); } /** This outputs records which will be used to construct the entry set. */ private void outputMetadataRecordForEntrySet( ProcessContext c, KV, WindowedValue> value) { - c.output(outputForEntrySet, - KV.of(ismCoder.hash(ImmutableList.of(IsmFormat.getMetadataKey(), - value.getKey().getValue())), + c.output( + outputForEntrySet, + KV.of( + ismCoder.hash( + ImmutableList.of(IsmFormat.getMetadataKey(), value.getKey().getValue())), KV.of(value.getKey().getValue(), value.getKey().getKey()))); } } /** * A {@link DoFn} which outputs a metadata {@link IsmRecord} per window of: + * *

    - *
  • Key 1: META key
  • - *
  • Key 2: window
  • - *
  • Key 3: 0L (constant)
  • - *
  • Value: sum of values for window
  • + *
  • Key 1: META key + *
  • Key 2: window + *
  • Key 3: 0L (constant) + *
  • Value: sum of values for window *
* - *

This {@link DoFn} is meant to be used to compute the number of unique keys - * per window for map and multimap side inputs. + *

This {@link DoFn} is meant to be used to compute the number of unique keys per window for + * map and multimap side inputs. */ static class ToIsmMetadataRecordForSizeDoFn extends DoFn>>, IsmRecord>> { private final Coder windowCoder; + ToIsmMetadataRecordForSizeDoFn(Coder windowCoder) { this.windowCoder = windowCoder; } @@ -540,21 +552,23 @@ public void processElement(ProcessContext c) throws Exception { /** * A {@link DoFn} which outputs a metadata {@link IsmRecord} per window and key pair of: + * *

    - *
  • Key 1: META key
  • - *
  • Key 2: window
  • - *
  • Key 3: index offset (1-based index)
  • - *
  • Value: key
  • + *
  • Key 1: META key + *
  • Key 2: window + *
  • Key 3: index offset (1-based index) + *
  • Value: key *
* - *

This {@link DoFn} is meant to be used to output index to key records - * per window for map and multimap side inputs. + *

This {@link DoFn} is meant to be used to output index to key records per window for map + * and multimap side inputs. */ static class ToIsmMetadataRecordForKeyDoFn extends DoFn>>, IsmRecord>> { private final Coder keyCoder; private final Coder windowCoder; + ToIsmMetadataRecordForKeyDoFn(Coder keyCoder, Coder windowCoder) { this.keyCoder = keyCoder; this.windowCoder = windowCoder; @@ -595,34 +609,33 @@ public void processElement(ProcessContext c) throws Exception { } /** - * A {@link DoFn} which partitions sets of elements by window boundaries. Within each - * partition, the set of elements is transformed into a {@link TransformedMap}. - * The transformed {@code Map>} is backed by a - * {@code Map>>} and contains a function - * {@code Iterable> -> Iterable}. + * A {@link DoFn} which partitions sets of elements by window boundaries. Within each partition, + * the set of elements is transformed into a {@link TransformedMap}. The transformed {@code + * Map>} is backed by a {@code Map>>} and contains a + * function {@code Iterable> -> Iterable}. * *

Outputs {@link IsmRecord}s having: + * *

    - *
  • Key 1: Window
  • - *
  • Value: Transformed map containing a transform that removes the encapsulation - * of the window around each value, - * {@code Map>> -> Map>}.
  • + *
  • Key 1: Window + *
  • Value: Transformed map containing a transform that removes the encapsulation of the + * window around each value, {@code Map>> -> Map>}. *
*/ static class ToMultimapDoFn - extends DoFn>>>>, - IsmRecord>, - Iterable>>>> { + extends DoFn< + KV>>>>, + IsmRecord>, Iterable>>>> { private final Coder windowCoder; + ToMultimapDoFn(Coder windowCoder) { this.windowCoder = windowCoder; } @ProcessElement - public void processElement(ProcessContext c) - throws Exception { + public void processElement(ProcessContext c) throws Exception { Optional previousWindowStructuralValue = Optional.absent(); Optional previousWindow = Optional.absent(); Multimap> multimap = HashMultimap.create(); @@ -643,7 +656,8 @@ public void processElement(ProcessContext c) multimap = HashMultimap.create(); } - multimap.put(kv.getValue().getValue().getKey(), + multimap.put( + kv.getValue().getValue().getKey(), kv.getValue().withValue(kv.getValue().getValue().getValue())); previousWindowStructuralValue = Optional.of(currentWindowStructuralValue); previousWindow = Optional.of(kv.getKey()); @@ -677,8 +691,7 @@ public PCollection expand(PCollection> input) { return this.applyInternal(input); } - private PCollection - applyInternal(PCollection> input) { + private PCollection applyInternal(PCollection> input) { try { return applyForMapLike(runner, input, view, false /* unique keys not expected */); } catch (NonDeterministicException e) { @@ -691,11 +704,10 @@ public PCollection expand(PCollection> input) { } /** Transforms the input {@link PCollection} into a singleton {@link Map} per window. */ - private PCollection - applyForSingletonFallback(PCollection> input) { + private PCollection applyForSingletonFallback( + PCollection> input) { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); @SuppressWarnings({"rawtypes", "unchecked"}) KvCoder inputCoder = (KvCoder) input.getCoder(); @@ -725,11 +737,11 @@ private static PCollection applyForMap DataflowRunner runner, PCollection> input, PCollectionView view, - boolean uniqueKeysExpected) throws NonDeterministicException { + boolean uniqueKeysExpected) + throws NonDeterministicException { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); @SuppressWarnings({"rawtypes", "unchecked"}) KvCoder inputCoder = (KvCoder) input.getCoder(); @@ -776,8 +788,7 @@ private static PCollection applyForMap // for each window. PCollection>> outputForSize = outputTuple.get(outputForSizeTag); outputForSize.setCoder( - KvCoder.of(VarIntCoder.of(), - KvCoder.of(windowCoder, VarLongCoder.of()))); + KvCoder.of(VarIntCoder.of(), KvCoder.of(windowCoder, VarLongCoder.of()))); PCollection>> windowMapSizeMetadata = outputForSize .apply("GBKaSVForSize", new GroupByKeyAndSortValuesOnly<>()) @@ -786,11 +797,9 @@ private static PCollection applyForMap // Set the coder on the metadata output destined to build the entry set and process the // entries producing a [META, Window, Index] record per window key pair storing the key. - PCollection>> outputForEntrySet = - outputTuple.get(outputForEntrySetTag); + PCollection>> outputForEntrySet = outputTuple.get(outputForEntrySetTag); outputForEntrySet.setCoder( - KvCoder.of(VarIntCoder.of(), - KvCoder.of(windowCoder, inputCoder.getKeyCoder()))); + KvCoder.of(VarIntCoder.of(), KvCoder.of(windowCoder, inputCoder.getKeyCoder()))); PCollection>> windowMapKeysMetadata = outputForEntrySet .apply("GBKaSVForKeys", new GroupByKeyAndSortValuesOnly<>()) @@ -806,8 +815,9 @@ private static PCollection applyForMap runner.addPCollectionRequiringIndexedFormat(windowMapKeysMetadata); PCollectionList>> outputs = - PCollectionList.of(ImmutableList.of( - perHashWithReifiedWindows, windowMapSizeMetadata, windowMapKeysMetadata)); + PCollectionList.of( + ImmutableList.of( + perHashWithReifiedWindows, windowMapSizeMetadata, windowMapKeysMetadata)); PCollection>> flattenedOutputs = Pipeline.applyTransform(outputs, Flatten.pCollections()); @@ -828,41 +838,38 @@ static IsmRecordCoder> coderForMapLike( 1, // We use only the key for hashing when producing value records 2, // Since the key is not present, we add the window to the hash when // producing metadata records - ImmutableList.of( - MetadataKeyCoder.of(keyCoder), - windowCoder, - BigEndianLongCoder.of()), + ImmutableList.of(MetadataKeyCoder.of(keyCoder), windowCoder, BigEndianLongCoder.of()), FullWindowedValueCoder.of(valueCoder, windowCoder)); } } /** - * Specialized implementation for - * {@link org.apache.beam.sdk.transforms.View.AsSingleton View.AsSingleton} for the - * Dataflow runner in batch mode. + * Specialized implementation for {@link org.apache.beam.sdk.transforms.View.AsSingleton + * View.AsSingleton} for the Dataflow runner in batch mode. + * + *

Creates a set of files in the {@link IsmFormat} sharded by the hash of the windows byte + * representation and with records having: * - *

Creates a set of files in the {@link IsmFormat} sharded by the hash of the windows - * byte representation and with records having: *

    - *
  • Key 1: Window
  • - *
  • Value: Windowed value
  • + *
  • Key 1: Window + *
  • Value: Windowed value *
*/ - static class BatchViewAsSingleton - extends PTransform, PCollection> { + static class BatchViewAsSingleton extends PTransform, PCollection> { /** * A {@link DoFn} that outputs {@link IsmRecord}s. These records are structured as follows: + * *
    *
  • Key 1: Window *
  • Value: Windowed value *
*/ static class IsmRecordForSingularValuePerWindowDoFn - extends DoFn>>>, - IsmRecord>> { + extends DoFn>>>, IsmRecord>> { private final Coder windowCoder; + IsmRecordForSingularValuePerWindowDoFn(Coder windowCoder) { this.windowCoder = windowCoder; } @@ -913,8 +920,8 @@ public BatchViewAsSingleton( public PCollection expand(PCollection input) { input = input.apply(Combine.globally(combineFn).withoutDefaults().withFanout(fanout)); @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = + (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); return BatchViewAsSingleton.applyForSingleton( runner, @@ -924,25 +931,23 @@ public PCollection expand(PCollection input) { view); } - static PCollection - applyForSingleton( + static PCollection applyForSingleton( DataflowRunner runner, PCollection input, - DoFn>>>, - IsmRecord>> doFn, + DoFn>>>, IsmRecord>> doFn, Coder defaultValueCoder, PCollectionView view) { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); IsmRecordCoder> ismCoder = coderForSingleton(windowCoder, defaultValueCoder); - PCollection>> reifiedPerWindowAndSorted = input - .apply(new GroupByWindowHashAsKeyAndWindowAsSortKey(ismCoder)) - .apply(ParDo.of(doFn)); + PCollection>> reifiedPerWindowAndSorted = + input + .apply(new GroupByWindowHashAsKeyAndWindowAsSortKey(ismCoder)) + .apply(ParDo.of(doFn)); reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); @@ -966,34 +971,34 @@ static IsmRecordCoder> coderForSingleton( } /** - * Specialized implementation for - * {@link org.apache.beam.sdk.transforms.View.AsList View.AsList} for the - * Dataflow runner in batch mode. + * Specialized implementation for {@link org.apache.beam.sdk.transforms.View.AsList View.AsList} + * for the Dataflow runner in batch mode. * *

Creates a set of {@code Ism} files sharded by the hash of the window's byte representation * and with records having: + * *

    - *
  • Key 1: Window
  • - *
  • Key 2: Index offset within window
  • - *
  • Value: Windowed value
  • + *
  • Key 1: Window + *
  • Key 2: Index offset within window + *
  • Value: Windowed value *
*/ - static class BatchViewAsList - extends PTransform, PCollection> { + static class BatchViewAsList extends PTransform, PCollection> { /** * A {@link DoFn} which creates {@link IsmRecord}s assuming that each element is within the * global window. Each {@link IsmRecord} has + * *
    - *
  • Key 1: Global window
  • - *
  • Key 2: Index offset within window
  • - *
  • Value: Windowed value
  • + *
  • Key 1: Global window + *
  • Key 2: Index offset within window + *
  • Value: Windowed value *
*/ @SystemDoFnInternal - static class ToIsmRecordForGlobalWindowDoFn - extends DoFn>> { + static class ToIsmRecordForGlobalWindowDoFn extends DoFn>> { long indexInBundle; + @StartBundle public void startBundle() throws Exception { indexInBundle = 0; @@ -1001,32 +1006,30 @@ public void startBundle() throws Exception { @ProcessElement public void processElement(ProcessContext c) throws Exception { - c.output(IsmRecord.of( - ImmutableList.of(GlobalWindow.INSTANCE, indexInBundle), - WindowedValue.of( - c.element(), - c.timestamp(), - GlobalWindow.INSTANCE, - c.pane()))); + c.output( + IsmRecord.of( + ImmutableList.of(GlobalWindow.INSTANCE, indexInBundle), + WindowedValue.of(c.element(), c.timestamp(), GlobalWindow.INSTANCE, c.pane()))); indexInBundle += 1; } } /** - * A {@link DoFn} which creates {@link IsmRecord}s comparing successive elements windows - * to locate the window boundaries. The {@link IsmRecord} has: + * A {@link DoFn} which creates {@link IsmRecord}s comparing successive elements windows to + * locate the window boundaries. The {@link IsmRecord} has: + * *
    - *
  • Key 1: Window
  • - *
  • Key 2: Index offset within window
  • - *
  • Value: Windowed value
  • + *
  • Key 1: Window + *
  • Key 2: Index offset within window + *
  • Value: Windowed value *
*/ @SystemDoFnInternal static class ToIsmRecordForNonGlobalWindowDoFn - extends DoFn>>>, - IsmRecord>> { + extends DoFn>>>, IsmRecord>> { private final Coder windowCoder; + ToIsmRecordForNonGlobalWindowDoFn(Coder windowCoder) { this.windowCoder = windowCoder; } @@ -1043,9 +1046,8 @@ public void processElement(ProcessContext c) throws Exception { // Reset i since we have a new window. elementsInWindow = 0; } - c.output(IsmRecord.of( - ImmutableList.of(value.getKey(), elementsInWindow), - value.getValue())); + c.output( + IsmRecord.of(ImmutableList.of(value.getKey(), elementsInWindow), value.getValue())); previousWindowStructuralValue = Optional.of(currentWindowStructuralValue); elementsInWindow += 1; } @@ -1054,9 +1056,7 @@ public void processElement(ProcessContext c) throws Exception { private final DataflowRunner runner; private final PCollectionView> view; - /** - * Builds an instance of this class from the overridden transform. - */ + /** Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() public BatchViewAsList(DataflowRunner runner, CreatePCollectionView> transform) { this.runner = runner; @@ -1069,13 +1069,10 @@ public PCollection expand(PCollection input) { } static PCollection applyForIterableLike( - DataflowRunner runner, - PCollection input, - PCollectionView view) { + DataflowRunner runner, PCollection input, PCollectionView view) { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); IsmRecordCoder> ismCoder = coderForListLike(windowCoder, input.getCoder()); @@ -1123,20 +1120,19 @@ static IsmRecordCoder> coderForListLike( } /** - * Specialized implementation for - * {@link org.apache.beam.sdk.transforms.View.AsIterable View.AsIterable} for the - * Dataflow runner in batch mode. + * Specialized implementation for {@link org.apache.beam.sdk.transforms.View.AsIterable + * View.AsIterable} for the Dataflow runner in batch mode. * *

Creates a set of {@code Ism} files sharded by the hash of the windows byte representation * and with records having: + * *

    - *
  • Key 1: Window
  • - *
  • Key 2: Index offset within window
  • - *
  • Value: Windowed value
  • + *
  • Key 1: Window + *
  • Key 2: Index offset within window + *
  • Value: Windowed value *
*/ - static class BatchViewAsIterable - extends PTransform, PCollection> { + static class BatchViewAsIterable extends PTransform, PCollection> { private final DataflowRunner runner; private final PCollectionView> view; @@ -1154,12 +1150,9 @@ public PCollection expand(PCollection input) { } } - - /** - * A {@link Function} which converts {@code WindowedValue} to {@code V}. - */ - private static class WindowedValueToValue implements - Function, V>, Serializable { + /** A {@link Function} which converts {@code WindowedValue} to {@code V}. */ + private static class WindowedValueToValue + implements Function, V>, Serializable { private static final WindowedValueToValue INSTANCE = new WindowedValueToValue<>(); @SuppressWarnings({"unchecked", "rawtypes"}) @@ -1176,8 +1169,8 @@ public V apply(WindowedValue input) { /** * A {@link Function} which converts {@code Iterable>} to {@code Iterable}. */ - private static class IterableWithWindowedValuesToIterable implements - Function>, Iterable>, Serializable { + private static class IterableWithWindowedValuesToIterable + implements Function>, Iterable>, Serializable { private static final IterableWithWindowedValuesToIterable INSTANCE = new IterableWithWindowedValuesToIterable<>(); @@ -1193,22 +1186,24 @@ public Iterable apply(Iterable> input) { } /** - * A {@link PTransform} that groups the values by a hash of the window's byte representation - * and sorts the values using the windows byte representation. + * A {@link PTransform} that groups the values by a hash of the window's byte representation and + * sorts the values using the windows byte representation. */ - private static class GroupByWindowHashAsKeyAndWindowAsSortKey extends - PTransform, PCollection>>>>> { + private static class GroupByWindowHashAsKeyAndWindowAsSortKey + extends PTransform< + PCollection, PCollection>>>>> { /** - * A {@link DoFn} that for each element outputs a {@code KV} structure suitable for - * grouping by the hash of the window's byte representation and sorting the grouped values - * using the window's byte representation. + * A {@link DoFn} that for each element outputs a {@code KV} structure suitable for grouping by + * the hash of the window's byte representation and sorting the grouped values using the + * window's byte representation. */ @SystemDoFnInternal private static class UseWindowHashAsKeyAndWindowAsSortKeyDoFn extends DoFn>>> { private final IsmRecordCoder ismCoderForHash; + private UseWindowHashAsKeyAndWindowAsSortKeyDoFn(IsmRecordCoder ismCoderForHash) { this.ismCoderForHash = ismCoderForHash; } @@ -1218,17 +1213,14 @@ public void processElement(ProcessContext c, BoundedWindow untypedWindow) throws @SuppressWarnings("unchecked") W window = (W) untypedWindow; c.output( - KV.of(ismCoderForHash.hash(ImmutableList.of(window)), - KV.of(window, - WindowedValue.of( - c.element(), - c.timestamp(), - window, - c.pane())))); + KV.of( + ismCoderForHash.hash(ImmutableList.of(window)), + KV.of(window, WindowedValue.of(c.element(), c.timestamp(), window, c.pane())))); } } private final IsmRecordCoder ismCoderForHash; + private GroupByWindowHashAsKeyAndWindowAsSortKey(IsmRecordCoder ismCoderForHash) { this.ismCoderForHash = ismCoderForHash; } @@ -1237,33 +1229,29 @@ private GroupByWindowHashAsKeyAndWindowAsSortKey(IsmRecordCoder ismCoderForHa public PCollection>>>> expand( PCollection input) { @SuppressWarnings("unchecked") - Coder windowCoder = (Coder) - input.getWindowingStrategy().getWindowFn().windowCoder(); + Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); PCollection>>> rval = - input.apply(ParDo.of( - new UseWindowHashAsKeyAndWindowAsSortKeyDoFn(ismCoderForHash))); + input.apply( + ParDo.of(new UseWindowHashAsKeyAndWindowAsSortKeyDoFn(ismCoderForHash))); rval.setCoder( KvCoder.of( VarIntCoder.of(), - KvCoder.of(windowCoder, - FullWindowedValueCoder.of(input.getCoder(), windowCoder)))); + KvCoder.of(windowCoder, FullWindowedValueCoder.of(input.getCoder(), windowCoder)))); return rval.apply(new GroupByKeyAndSortValuesOnly<>()); } } /** - * A {@link GroupByKey} transform for the {@link DataflowRunner} which sorts - * values using the secondary key {@code K2}. + * A {@link GroupByKey} transform for the {@link DataflowRunner} which sorts values using the + * secondary key {@code K2}. * - *

The {@link PCollection} created created by this {@link PTransform} will have values in - * the empty window. Care must be taken *afterwards* to either re-window - * (using {@link Window#into}) or only use {@link PTransform}s that do not depend on the - * values being within a window. + *

The {@link PCollection} created created by this {@link PTransform} will have values in the + * empty window. Care must be taken *afterwards* to either re-window (using {@link Window#into}) + * or only use {@link PTransform}s that do not depend on the values being within a window. */ static class GroupByKeyAndSortValuesOnly extends PTransform>>, PCollection>>>> { - GroupByKeyAndSortValuesOnly() { - } + GroupByKeyAndSortValuesOnly() {} @Override public PCollection>>> expand(PCollection>> input) { @@ -1277,13 +1265,11 @@ public PCollection>>> expand(PCollection} backed by a {@code Map} and a function that transforms - * {@code V1 -> V2}. + * A {@code Map} backed by a {@code Map} and a function that transforms {@code V1 -> + * V2}. */ - static class TransformedMap - extends ForwardingMap { + static class TransformedMap extends ForwardingMap { private final Function transform; private final Map originalMap; private final Map transformedMap; @@ -1300,11 +1286,8 @@ protected Map delegate() { } } - /** - * A {@link Coder} for {@link TransformedMap}s. - */ - static class TransformedMapCoder - extends StructuredCoder> { + /** A {@link Coder} for {@link TransformedMap}s. */ + static class TransformedMapCoder extends StructuredCoder> { private final Coder> transformCoder; private final Coder> originalMapCoder; @@ -1330,8 +1313,7 @@ public void encode(TransformedMap value, OutputStream outStream) public TransformedMap decode(InputStream inStream) throws CoderException, IOException { return new TransformedMap<>( - transformCoder.decode(inStream), - originalMapCoder.decode(inStream)); + transformCoder.decode(inStream), originalMapCoder.decode(inStream)); } @Override @@ -1346,4 +1328,64 @@ public void verifyDeterministic() verifyDeterministic(this, "Expected map coder to be deterministic.", originalMapCoder); } } + + /** + * A hack to put a simple value (aka globally windowed) in a place where a WindowedValue is + * expected. + * + *

This is not actually valid for Beam elements, because values in no windows do not really + * exist and may be dropped at any time without further justification. + */ + private static WindowedValue valueInEmptyWindows(T value) { + return new ValueInEmptyWindows<>(value); + } + + private static class ValueInEmptyWindows extends WindowedValue { + + private final T value; + + private ValueInEmptyWindows(T value) { + this.value = value; + } + + @Override + public WindowedValue withValue(NewT value) { + return new ValueInEmptyWindows<>(value); + } + + @Override + public T getValue() { + return value; + } + + @Override + public Instant getTimestamp() { + return BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + @Override + public Collection getWindows() { + return Collections.emptyList(); + } + + @Override + public PaneInfo getPane() { + return PaneInfo.NO_FIRING; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()).add("value", getValue()).toString(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof ValueInEmptyWindows) { + ValueInEmptyWindows that = (ValueInEmptyWindows) o; + return Objects.equals(that.getValue(), this.getValue()); + } else { + return super.equals(o); + } + } + } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java index b5ad6b396cb6..0983674b9831 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java @@ -191,7 +191,7 @@ private Long getCounterValue(com.google.api.services.dataflow.model.MetricUpdate private DistributionResult getDistributionValue( com.google.api.services.dataflow.model.MetricUpdate metricUpdate) { if (metricUpdate.getDistribution() == null) { - return DistributionResult.ZERO; + return DistributionResult.IDENTITY_ELEMENT; } ArrayMap distributionMap = (ArrayMap) metricUpdate.getDistribution(); Long count = ((Number) distributionMap.get("count")).longValue(); diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 9ff8a45aa056..a0fd99437687 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -716,6 +716,13 @@ private void translateTyped( context.addStep(transform, "CollectionToSingleton"); PCollection input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + stepContext.addInput( + PropertyNames.WINDOWING_STRATEGY, + byteArrayToJsonString(serializeWindowingStrategy(windowingStrategy))); + stepContext.addInput( + PropertyNames.IS_MERGING_WINDOW_FN, + !windowingStrategy.getWindowFn().isNonMerging()); stepContext.addCollectionToSingletonOutput( input, PropertyNames.OUTPUT, transform.getView()); } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 895c7a1a661a..b5874b24f937 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -876,6 +876,7 @@ public void visitValue(PValue value, Node producer) { } } } + BoundednessVisitor visitor = new BoundednessVisitor(); p.traverseTopologically(visitor); return visitor.boundedness == IsBounded.UNBOUNDED; @@ -1540,6 +1541,7 @@ public void process(ProcessContext c) throws Exception { .setCoder(source.getOutputCoder()); } } + /** * A marker {@link DoFn} for writing the contents of a {@link PCollection} to a streaming * {@link PCollectionView} backend implementation. diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TestDataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TestDataflowRunner.java index 1abea99fcb6f..c3108b69a87e 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TestDataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TestDataflowRunner.java @@ -369,12 +369,10 @@ public Void call() throws Exception { State jobState = job.getState(); // If we see an error, cancel and note failure - if (messageHandler.hasSeenError()) { - if (!job.getState().isTerminal()) { - job.cancel(); - LOG.info("Cancelling Dataflow job {}", job.getJobId()); - return null; - } + if (messageHandler.hasSeenError() && !job.getState().isTerminal()) { + job.cancel(); + LOG.info("Cancelling Dataflow job {}", job.getJobId()); + return null; } if (jobState.isTerminal()) { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PackageUtil.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PackageUtil.java index 479db3fcdff8..387b7e3a5900 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PackageUtil.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PackageUtil.java @@ -32,9 +32,6 @@ import com.google.common.io.ByteSource; import com.google.common.io.CountingOutputStream; import com.google.common.io.Files; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import java.io.Closeable; import java.io.File; @@ -47,7 +44,9 @@ import java.util.Collection; import java.util.Comparator; import java.util.List; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -61,6 +60,7 @@ import org.apache.beam.sdk.util.BackOffAdapter; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.sdk.util.MimeTypes; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.sdk.util.ZipFiles; import org.joda.time.Duration; import org.slf4j.Logger; @@ -95,9 +95,9 @@ class PackageUtil implements Closeable { */ private static final ApiErrorExtractor ERROR_EXTRACTOR = new ApiErrorExtractor(); - private final ListeningExecutorService executorService; + private final ExecutorService executorService; - private PackageUtil(ListeningExecutorService executorService) { + private PackageUtil(ExecutorService executorService) { this.executorService = executorService; } @@ -107,7 +107,7 @@ public static PackageUtil withDefaultThreadPool() { MoreExecutors.platformThreadFactory()))); } - public static PackageUtil withExecutorService(ListeningExecutorService executorService) { + public static PackageUtil withExecutorService(ExecutorService executorService) { return new PackageUtil(executorService); } @@ -134,10 +134,10 @@ public int compare(PackageAttributes o1, PackageAttributes o2) { } /** Asynchronously computes {@link PackageAttributes} for a single staged file. */ - private ListenableFuture computePackageAttributes( + private CompletionStage computePackageAttributes( final DataflowPackage source, final String stagingPath) { - return executorService.submit( + return MoreFutures.supplyAsync( () -> { final File file = new File(source.getLocation()); if (!file.exists()) { @@ -150,7 +150,8 @@ private ListenableFuture computePackageAttributes( attributes = attributes.withPackageName(source.getName()); } return attributes; - }); + }, + executorService); } private boolean alreadyStaged(PackageAttributes attributes) throws IOException { @@ -165,12 +166,12 @@ private boolean alreadyStaged(PackageAttributes attributes) throws IOException { } /** Stages one file ("package") if necessary. */ - public ListenableFuture stagePackage( + public CompletionStage stagePackage( final PackageAttributes attributes, final Sleeper retrySleeper, final CreateOptions createOptions) { - return executorService.submit( - () -> stagePackageSynchronously(attributes, retrySleeper, createOptions)); + return MoreFutures.supplyAsync( + () -> stagePackageSynchronously(attributes, retrySleeper, createOptions), executorService); } /** Synchronously stages a package, with retry and backoff for resiliency. */ @@ -265,7 +266,7 @@ private StagingResult tryStagePackage(PackageAttributes attributes, CreateOption /** * Transfers the classpath elements to the staging location using a default {@link Sleeper}. * - * @see {@link #stageClasspathElements(Collection, String, Sleeper, CreateOptions)} + * @see #stageClasspathElements(Collection, String, Sleeper, CreateOptions) */ List stageClasspathElements( Collection classpathElements, String stagingPath, CreateOptions createOptions) { @@ -275,7 +276,7 @@ List stageClasspathElements( /** * Transfers the classpath elements to the staging location using default settings. * - * @see {@link #stageClasspathElements(Collection, String, Sleeper, CreateOptions)} + * @see #stageClasspathElements(Collection, String, Sleeper, CreateOptions) */ List stageClasspathElements( Collection classpathElements, String stagingPath) { @@ -286,11 +287,11 @@ List stageClasspathElements( public DataflowPackage stageToFile( byte[] bytes, String target, String stagingPath, CreateOptions createOptions) { try { - return stagePackage( - PackageAttributes.forBytesToStage(bytes, target, stagingPath), - DEFAULT_SLEEPER, - createOptions) - .get() + return MoreFutures.get( + stagePackage( + PackageAttributes.forBytesToStage(bytes, target, stagingPath), + DEFAULT_SLEEPER, + createOptions)) .getPackageAttributes() .getDestination(); } catch (InterruptedException e) { @@ -331,7 +332,7 @@ List stageClasspathElements( final AtomicInteger numUploaded = new AtomicInteger(0); final AtomicInteger numCached = new AtomicInteger(0); - List> destinationPackages = new ArrayList<>(); + List> destinationPackages = new ArrayList<>(); for (String classpathElement : classpathElements) { DataflowPackage sourcePackage = new DataflowPackage(); @@ -350,15 +351,14 @@ List stageClasspathElements( continue; } - // TODO: Java 8 / Guava 23.0: FluentFuture - ListenableFuture stagingResult = - Futures.transformAsync( - computePackageAttributes(sourcePackage, stagingPath), - packageAttributes -> stagePackage(packageAttributes, retrySleeper, createOptions)); + CompletionStage stagingResult = + computePackageAttributes(sourcePackage, stagingPath) + .thenComposeAsync( + packageAttributes -> + stagePackage(packageAttributes, retrySleeper, createOptions)); - ListenableFuture stagedPackage = - Futures.transform( - stagingResult, + CompletionStage stagedPackage = + stagingResult.thenApply( stagingResult1 -> { if (stagingResult1.alreadyStaged()) { numCached.incrementAndGet(); @@ -372,19 +372,19 @@ List stageClasspathElements( } try { - ListenableFuture> stagingFutures = - Futures.allAsList(destinationPackages); + CompletionStage> stagingFutures = + MoreFutures.allAsList(destinationPackages); boolean finished = false; do { try { - stagingFutures.get(3L, TimeUnit.MINUTES); + MoreFutures.get(stagingFutures, 3L, TimeUnit.MINUTES); finished = true; } catch (TimeoutException e) { // finished will still be false LOG.info("Still staging {} files", classpathElements.size()); } } while (!finished); - List stagedPackages = stagingFutures.get(); + List stagedPackages = MoreFutures.get(stagingFutures); LOG.info( "Staging files complete: {} files cached, {} files newly uploaded", numCached.get(), numUploaded.get()); diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java index cdc87bf93434..3fefe383a622 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java @@ -62,6 +62,7 @@ public class PropertyNames { public static final String USER_NAME = "user_name"; public static final String USES_KEYED_STATE = "uses_keyed_state"; public static final String VALUE = "value"; + public static final String WINDOWING_STRATEGY = "windowing_strategy"; public static final String DISPLAY_DATA = "display_data"; public static final String RESTRICTION_CODER = "restriction_coder"; public static final String IMPULSE_ELEMENT = "impulse_element"; diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java index 613c30b1c26f..baf02114179a 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java @@ -62,8 +62,6 @@ public class DataflowMetricsTest { private static final String PROJECT_ID = "some-project"; private static final String JOB_ID = "1234"; - private static final String REGION_ID = "some-region"; - private static final String REPLACEMENT_JOB_ID = "4321"; @Mock private Dataflow mockWorkflowClient; diff --git a/runners/java-fn-execution/pom.xml b/runners/java-fn-execution/pom.xml index 7958410d516e..cd637cb5ab47 100644 --- a/runners/java-fn-execution/pom.xml +++ b/runners/java-fn-execution/pom.xml @@ -63,11 +63,6 @@ beam-runners-core-construction-java - - com.google.protobuf - protobuf-java - - io.grpc grpc-core @@ -108,13 +103,13 @@ org.hamcrest - hamcrest-all + hamcrest-library test - + - org.apache.beam - beam-runners-core-construction-java + org.hamcrest + hamcrest-core test @@ -127,7 +122,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactRetrievalService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactRetrievalService.java new file mode 100644 index 000000000000..c04e118f0d4d --- /dev/null +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactRetrievalService.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.fnexecution.artifact; + +import org.apache.beam.runners.fnexecution.FnService; + +/** An implementation of the Beam Artifact Retrieval Service. */ +public interface ArtifactRetrievalService extends FnService {} diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/package-info.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/package-info.java new file mode 100644 index 000000000000..ca148c1a360e --- /dev/null +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Pipeline execution-time artifact-management services, including abstract implementations of the + * Artifact Retrieval Service. + */ +package org.apache.beam.runners.fnexecution.artifact; diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java index 28971895bb42..81eb7281b201 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClient.java @@ -17,15 +17,18 @@ */ package org.apache.beam.runners.fnexecution.control; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; import java.io.Closeable; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.fn.stream.SynchronizedStreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,17 +43,18 @@ * *

This low-level client is responsible only for correlating requests with responses. */ -class FnApiControlClient implements Closeable { +public class FnApiControlClient implements Closeable, InstructionRequestHandler { private static final Logger LOG = LoggerFactory.getLogger(FnApiControlClient.class); // All writes to this StreamObserver need to be synchronized. private final StreamObserver requestReceiver; private final ResponseStreamObserver responseObserver = new ResponseStreamObserver(); - private final Map> outstandingRequests; - private volatile boolean isClosed; + private final ConcurrentMap> + outstandingRequests; + private AtomicBoolean isClosed = new AtomicBoolean(false); private FnApiControlClient(StreamObserver requestReceiver) { - this.requestReceiver = requestReceiver; + this.requestReceiver = SynchronizedStreamObserver.wrapping(requestReceiver); this.outstandingRequests = new ConcurrentHashMap<>(); } @@ -66,16 +70,16 @@ public static FnApiControlClient forRequestObserver( return new FnApiControlClient(requestObserver); } - public synchronized ListenableFuture handle( + public CompletionStage handle( BeamFnApi.InstructionRequest request) { LOG.debug("Sending InstructionRequest {}", request); - SettableFuture resultFuture = SettableFuture.create(); + CompletableFuture resultFuture = new CompletableFuture<>(); outstandingRequests.put(request.getInstructionId(), resultFuture); requestReceiver.onNext(request); return resultFuture; } - StreamObserver asResponseObserver() { + public StreamObserver asResponseObserver() { return responseObserver; } @@ -85,16 +89,15 @@ public void close() { } /** Closes this client and terminates any outstanding requests exceptionally. */ - private synchronized void closeAndTerminateOutstandingRequests(Throwable cause) { - if (isClosed) { + private void closeAndTerminateOutstandingRequests(Throwable cause) { + if (isClosed.getAndSet(true)) { return; } // Make a copy of the map to make the view of the outstanding requests consistent. - Map> outstandingRequestsCopy = + Map> outstandingRequestsCopy = new ConcurrentHashMap<>(outstandingRequests); outstandingRequests.clear(); - isClosed = true; if (outstandingRequestsCopy.isEmpty()) { requestReceiver.onCompleted(); @@ -107,9 +110,9 @@ private synchronized void closeAndTerminateOutstandingRequests(Throwable cause) "{} closed, clearing outstanding requests {}", FnApiControlClient.class.getSimpleName(), outstandingRequestsCopy); - for (SettableFuture outstandingRequest : + for (CompletableFuture outstandingRequest : outstandingRequestsCopy.values()) { - outstandingRequest.setException(cause); + outstandingRequest.completeExceptionally(cause); } } @@ -125,13 +128,13 @@ private class ResponseStreamObserver implements StreamObserver completableFuture = + CompletableFuture responseFuture = outstandingRequests.remove(response.getInstructionId()); - if (completableFuture != null) { + if (responseFuture != null) { if (response.getError().isEmpty()) { - completableFuture.set(response); + responseFuture.complete(response); } else { - completableFuture.setException( + responseFuture.completeExceptionally( new RuntimeException(String.format( "Error received from SDK harness for instruction %s: %s", response.getInstructionId(), diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolService.java index 2a037861b410..9d443427dc11 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolService.java @@ -39,6 +39,9 @@ private FnApiControlClientPoolService(BlockingQueue clientPo /** * Creates a new {@link FnApiControlClientPoolService} which will enqueue and vend new SDK harness * connections. + * + *

Clients placed into the {@code clientPool} are owned by whichever consumer owns the pool. + * That consumer is responsible for closing the clients when they are no longer needed. */ public static FnApiControlClientPoolService offeringClientsToPool( BlockingQueue clientPool) { @@ -68,6 +71,6 @@ public StreamObserver control( @Override public void close() throws Exception { - // TODO: terminate existing clients. + // The clients in the pool are owned by the consumer, which is responsible for closing them } } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java new file mode 100644 index 000000000000..46e2d7b11e91 --- /dev/null +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.fnexecution.control; + +import java.util.concurrent.CompletionStage; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; + +/** Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. */ +@FunctionalInterface +public interface InstructionRequestHandler { + CompletionStage handle(BeamFnApi.InstructionRequest request); +} diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java index 27c940110a49..02ed5cb3292c 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java @@ -20,14 +20,11 @@ import com.google.auto.value.AutoValue; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicLong; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse; @@ -47,10 +44,10 @@ /** * A high-level client for an SDK harness. * - *

This provides a Java-friendly wrapper around {@link FnApiControlClient} and {@link + *

This provides a Java-friendly wrapper around {@link InstructionRequestHandler} and {@link * CloseableFnDataReceiver}, which handle lower-level gRPC message wrangling. */ -public class SdkHarnessClient { +public class SdkHarnessClient implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(SdkHarnessClient.class); /** @@ -76,20 +73,20 @@ public String getId() { */ public class BundleProcessor { private final String processBundleDescriptorId; - private final Future registrationFuture; + private final CompletionStage registrationFuture; private final RemoteInputDestination> remoteInput; private BundleProcessor( String processBundleDescriptorId, - Future registrationFuture, + CompletionStage registrationFuture, RemoteInputDestination> remoteInput) { this.processBundleDescriptorId = processBundleDescriptorId; this.registrationFuture = registrationFuture; this.remoteInput = remoteInput; } - public Future getRegistrationFuture() { + public CompletionStage getRegistrationFuture() { return registrationFuture; } @@ -103,7 +100,7 @@ public ActiveBundle newBundle( Map> outputReceivers) { String bundleId = idGenerator.getId(); - final ListenableFuture genericResponse = + final CompletionStage genericResponse = fnApiControlClient.handle( BeamFnApi.InstructionRequest.newBuilder() .setInstructionId(bundleId) @@ -118,8 +115,8 @@ public ActiveBundle newBundle( ProcessBundleDescriptor.class.getSimpleName(), processBundleDescriptorId); - ListenableFuture specificResponse = - Futures.transform(genericResponse, InstructionResponse::getProcessBundle); + CompletionStage specificResponse = + genericResponse.thenApply(InstructionResponse::getProcessBundle); Map outputClients = new HashMap<>(); for (Map.Entry> targetReceiver : outputReceivers.entrySet()) { @@ -152,14 +149,14 @@ private InboundDataClient attachReceiver( public abstract static class ActiveBundle { public abstract String getBundleId(); - public abstract Future getBundleResponse(); + public abstract CompletionStage getBundleResponse(); public abstract CloseableFnDataReceiver> getInputReceiver(); public abstract Map getOutputClients(); public static ActiveBundle create( String bundleId, - Future response, + CompletionStage response, CloseableFnDataReceiver> dataReceiver, Map outputClients) { return new AutoValue_SdkHarnessClient_ActiveBundle<>( @@ -168,14 +165,14 @@ public static ActiveBundle create( } private final IdGenerator idGenerator; - private final FnApiControlClient fnApiControlClient; + private final InstructionRequestHandler fnApiControlClient; private final FnDataService fnApiDataService; private final Cache clientProcessors = CacheBuilder.newBuilder().build(); private SdkHarnessClient( - FnApiControlClient fnApiControlClient, + InstructionRequestHandler fnApiControlClient, FnDataService fnApiDataService, IdGenerator idGenerator) { this.fnApiDataService = fnApiDataService; @@ -189,7 +186,7 @@ private SdkHarnessClient( * correctly associated. */ public static SdkHarnessClient usingFnApiClient( - FnApiControlClient fnApiControlClient, FnDataService fnApiDataService) { + InstructionRequestHandler fnApiControlClient, FnDataService fnApiDataService) { return new SdkHarnessClient(fnApiControlClient, fnApiDataService, new CountingIdGenerator()); } @@ -225,7 +222,7 @@ public Map register( LOG.debug("Registering {}", processBundleDescriptors.keySet()); // TODO: validate that all the necessary data endpoints are known - ListenableFuture genericResponse = + CompletionStage genericResponse = fnApiControlClient.handle( BeamFnApi.InstructionRequest.newBuilder() .setInstructionId(idGenerator.getId()) @@ -235,10 +232,9 @@ public Map register( .build()) .build()); - ListenableFuture registerResponseFuture = - Futures.transform( - genericResponse, InstructionResponse::getRegister, - MoreExecutors.directExecutor()); + CompletionStage registerResponseFuture = + genericResponse.thenApply(InstructionResponse::getRegister); + for (Map.Entry>> descriptorInputEntry : processBundleDescriptors.entrySet()) { clientProcessors.put( @@ -252,6 +248,9 @@ public Map register( return clientProcessors.asMap(); } + @Override + public void close() throws Exception {} + /** * A pair of {@link Coder} and {@link BeamFnApi.Target} which can be handled by the remote SDK * harness to receive elements sent from the runner. diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientControlService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientControlService.java new file mode 100644 index 000000000000..20ad69a97595 --- /dev/null +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientControlService.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.fnexecution.control; + +import io.grpc.ServerServiceDefinition; +import java.util.Collection; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.SynchronousQueue; +import java.util.function.Supplier; +import org.apache.beam.runners.fnexecution.FnService; +import org.apache.beam.runners.fnexecution.data.FnDataService; + +/** + * A service providing {@link SdkHarnessClient} based on an internally managed {@link + * FnApiControlClientPoolService}. + */ +public class SdkHarnessClientControlService implements FnService { + private final FnApiControlClientPoolService clientPoolService; + private final BlockingQueue pendingClients; + + private final Supplier dataService; + + private final Collection activeClients; + + public static SdkHarnessClientControlService create(Supplier dataService) { + return new SdkHarnessClientControlService(dataService); + } + + private SdkHarnessClientControlService(Supplier dataService) { + this.dataService = dataService; + activeClients = new ConcurrentLinkedQueue<>(); + pendingClients = new SynchronousQueue<>(); + clientPoolService = FnApiControlClientPoolService.offeringClientsToPool(pendingClients); + } + + public SdkHarnessClient getClient() { + try { + // Block until a client is available. + FnApiControlClient getClient = pendingClients.take(); + return SdkHarnessClient.usingFnApiClient(getClient, dataService.get()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for client", e); + } + } + + @Override + public void close() throws Exception { + for (SdkHarnessClient client : activeClients) { + client.close(); + } + } + + @Override + public ServerServiceDefinition bindService() { + return clientPoolService.bindService(); + } +} diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/graph/LengthPrefixUnknownCoders.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/graph/LengthPrefixUnknownCoders.java index ac7e745b939f..228dad4bcfdb 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/graph/LengthPrefixUnknownCoders.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/graph/LengthPrefixUnknownCoders.java @@ -29,21 +29,22 @@ /** * Utilities for replacing or wrapping unknown coders with {@link LengthPrefixCoder}. * - *

TODO: Support a dynamic list of well known coders using either registration or manual listing. + *

TODO: Support a dynamic list of well known coders using either registration or manual listing, + * possibly from ModelCoderRegistrar. */ public class LengthPrefixUnknownCoders { - private static final String BYTES_CODER_TYPE = "urn:beam:coders:bytes:0.1"; - private static final String LENGTH_PREFIX_CODER_TYPE = "urn:beam:coders:length_prefix:0.1"; + private static final String BYTES_CODER_TYPE = "beam:coder:bytes:v1"; + private static final String LENGTH_PREFIX_CODER_TYPE = "beam:coder:length_prefix:v1"; private static final Set WELL_KNOWN_CODER_URNS = ImmutableSet.of( BYTES_CODER_TYPE, - "urn:beam:coders:kv:0.1", - "urn:beam:coders:varint:0.1", - "urn:beam:coders:interval_window:0.1", - "urn:beam:coders:stream:0.1", + "beam:coder:kv:v1", + "beam:coder:varint:v1", + "beam:coder:interval_window:v1", + "beam:coder:iterable:v1", LENGTH_PREFIX_CODER_TYPE, - "urn:beam:coders:global_window:0.1", - "urn:beam:coders:windowed_value:0.1"); + "beam:coder:global_window:v1", + "beam:coder:windowed_value:v1"); /** * Recursively traverse the coder tree and wrap the first unknown coder in every branch with a diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolServiceTest.java index 9392ee0ff335..8f4a09f28fbc 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolServiceTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientPoolServiceTest.java @@ -23,11 +23,12 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import com.google.common.util.concurrent.ListenableFuture; import io.grpc.stub.StreamObserver; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletionStage; import java.util.concurrent.LinkedBlockingQueue; import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.util.MoreFutures; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,14 +53,14 @@ public void testIncomingConnection() throws Exception { // Check that the client is wired up to the request channel String id = "fakeInstruction"; - ListenableFuture responseFuture = + CompletionStage responseFuture = client.handle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId(id).build()); verify(requestObserver).onNext(any(BeamFnApi.InstructionRequest.class)); - assertThat(responseFuture.isDone(), is(false)); + assertThat(MoreFutures.isDone(responseFuture), is(false)); // Check that the response channel really came from the client responseObserver.onNext( BeamFnApi.InstructionResponse.newBuilder().setInstructionId(id).build()); - responseFuture.get(); + MoreFutures.get(responseFuture); } } diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientTest.java index 31a9c0a6825c..e26e426e6058 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/FnApiControlClientTest.java @@ -24,13 +24,13 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.verify; -import com.google.common.util.concurrent.ListenableFuture; import io.grpc.stub.StreamObserver; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse; +import org.apache.beam.sdk.util.MoreFutures; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -67,13 +67,13 @@ public void testRequestSent() { public void testRequestSuccess() throws Exception { String id = "successfulInstruction"; - Future responseFuture = + CompletionStage responseFuture = client.handle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId(id).build()); client .asResponseObserver() .onNext(BeamFnApi.InstructionResponse.newBuilder().setInstructionId(id).build()); - BeamFnApi.InstructionResponse response = responseFuture.get(); + BeamFnApi.InstructionResponse response = MoreFutures.get(responseFuture); assertThat(response.getInstructionId(), equalTo(id)); } @@ -81,7 +81,7 @@ public void testRequestSuccess() throws Exception { @Test public void testRequestError() throws Exception { String id = "instructionId"; - ListenableFuture responseFuture = + CompletionStage responseFuture = client.handle(InstructionRequest.newBuilder().setInstructionId(id).build()); String error = "Oh no an error!"; client @@ -94,7 +94,7 @@ public void testRequestError() throws Exception { thrown.expectCause(isA(RuntimeException.class)); thrown.expectMessage(error); - responseFuture.get(); + MoreFutures.get(responseFuture); } @Test @@ -102,22 +102,22 @@ public void testUnknownResponseIgnored() throws Exception { String id = "actualInstruction"; String unknownId = "unknownInstruction"; - ListenableFuture responseFuture = + CompletionStage responseFuture = client.handle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId(id).build()); client .asResponseObserver() .onNext(BeamFnApi.InstructionResponse.newBuilder().setInstructionId(unknownId).build()); - assertThat(responseFuture.isDone(), is(false)); - assertThat(responseFuture.isCancelled(), is(false)); + assertThat(MoreFutures.isDone(responseFuture), is(false)); + assertThat(MoreFutures.isCancelled(responseFuture), is(false)); } @Test public void testOnCompletedCancelsOutstanding() throws Exception { String id = "clientHangUpInstruction"; - Future responseFuture = + CompletionStage responseFuture = client.handle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId(id).build()); client.asResponseObserver().onCompleted(); @@ -125,29 +125,30 @@ public void testOnCompletedCancelsOutstanding() throws Exception { thrown.expect(ExecutionException.class); thrown.expectCause(isA(IllegalStateException.class)); thrown.expectMessage("closed"); - responseFuture.get(); + MoreFutures.get(responseFuture); } @Test public void testOnErrorCancelsOutstanding() throws Exception { String id = "errorInstruction"; - Future responseFuture = + CompletionStage responseFuture = client.handle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId(id).build()); class FrazzleException extends Exception {} + client.asResponseObserver().onError(new FrazzleException()); thrown.expect(ExecutionException.class); thrown.expectCause(isA(FrazzleException.class)); - responseFuture.get(); + MoreFutures.get(responseFuture); } @Test public void testCloseCancelsOutstanding() throws Exception { String id = "serverCloseInstruction"; - Future responseFuture = + CompletionStage responseFuture = client.handle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId(id).build()); client.close(); @@ -155,6 +156,6 @@ public void testCloseCancelsOutstanding() throws Exception { thrown.expect(ExecutionException.class); thrown.expectCause(isA(IllegalStateException.class)); thrown.expectMessage("closed"); - responseFuture.get(); + MoreFutures.get(responseFuture); } } diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java index 53aed4164b85..0a18ff6844fe 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java @@ -26,7 +26,6 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; -import com.google.common.util.concurrent.SettableFuture; import io.grpc.ManagedChannel; import io.grpc.inprocess.InProcessChannelBuilder; import java.io.IOException; @@ -34,6 +33,7 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -76,6 +76,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow.Coder; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.TupleTag; @@ -107,7 +108,7 @@ public void testRegisterDoesNotCrash() throws Exception { String descriptorId1 = "descriptor1"; String descriptorId2 = "descriptor2"; - SettableFuture registerResponseFuture = SettableFuture.create(); + CompletableFuture registerResponseFuture = new CompletableFuture<>(); when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))) .thenReturn(registerResponseFuture); @@ -146,10 +147,10 @@ public void testNewBundleNoDataDoesNotCrash() throws Exception { ProcessBundleDescriptor descriptor = ProcessBundleDescriptor.newBuilder().setId(descriptorId1).build(); - SettableFuture processBundleResponseFuture = - SettableFuture.create(); + CompletableFuture processBundleResponseFuture = + new CompletableFuture<>(); when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))) - .thenReturn(SettableFuture.create()) + .thenReturn(new CompletableFuture<>()) .thenReturn(processBundleResponseFuture); FullWindowedValueCoder coder = @@ -169,9 +170,9 @@ public void testNewBundleNoDataDoesNotCrash() throws Exception { // Currently there are no fields so there's nothing to check. This test is formulated // to match the pattern it should have if/when the response is meaningful. BeamFnApi.ProcessBundleResponse response = BeamFnApi.ProcessBundleResponse.getDefaultInstance(); - processBundleResponseFuture.set( + processBundleResponseFuture.complete( BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response).build()); - activeBundle.getBundleResponse().get(); + MoreFutures.get(activeBundle.getBundleResponse()); } @Test @@ -248,7 +249,7 @@ public ManagedChannel forDescriptor(ApiServiceDescriptor apiServiceDescriptor) { bundleInputReceiver.accept(WindowedValue.valueInGlobalWindow("bar")); bundleInputReceiver.accept(WindowedValue.valueInGlobalWindow("baz")); } - activeBundle.getBundleResponse().get(); + MoreFutures.get(activeBundle.getBundleResponse()); for (InboundDataClient outputClient : activeBundle.getOutputClients().values()) { outputClient.awaitCompletion(); } diff --git a/runners/local-artifact-service-java/build.gradle b/runners/local-artifact-service-java/build.gradle index 3637462322ec..205cd8e494b3 100644 --- a/runners/local-artifact-service-java/build.gradle +++ b/runners/local-artifact-service-java/build.gradle @@ -24,12 +24,14 @@ description = "Apache Beam :: Runners :: Java Local Artifact Service" dependencies { compile library.java.guava shadow project(path: ":model:job-management", configuration: "shadow") + shadow project(path: ":model:pipeline", configuration: "shadow") shadow project(path: ":runners:java-fn-execution", configuration: "shadow") shadow library.java.findbugs_jsr305 shadow library.java.grpc_core shadow library.java.grpc_stub shadow library.java.protobuf_java shadow library.java.slf4j_api + testCompile project(path: ":runners:core-construction-java") testCompile library.java.hamcrest_core testCompile library.java.hamcrest_library testCompile library.java.mockito_core diff --git a/runners/local-artifact-service-java/pom.xml b/runners/local-artifact-service-java/pom.xml index 7e10ad82f2bb..a7174ee35a4d 100644 --- a/runners/local-artifact-service-java/pom.xml +++ b/runners/local-artifact-service-java/pom.xml @@ -56,6 +56,11 @@ beam-model-job-management + + org.apache.beam + beam-model-pipeline + + org.apache.beam beam-runners-java-fn-execution @@ -95,13 +100,19 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + + org.hamcrest + hamcrest-library test org.mockito - mockito-all + mockito-core test @@ -116,5 +127,11 @@ slf4j-jdk14 test + + + org.apache.beam + beam-runners-core-construction-java + test + diff --git a/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalArtifactStagingLocation.java b/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalArtifactStagingLocation.java new file mode 100644 index 000000000000..e11125f75a36 --- /dev/null +++ b/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalArtifactStagingLocation.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.artifact.local; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +import java.io.File; +import java.io.IOException; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.Manifest; + +/** + * A location where the results of an {@link LocalFileSystemArtifactStagerService} are stored and + * where the retrieval service retrieves them from. + */ +public class LocalArtifactStagingLocation { + /** + * Create a new {@link LocalArtifactStagingLocation} rooted at the specified location, creating + * any directories or subdirectories as necessary. + */ + public static LocalArtifactStagingLocation createAt(File rootDirectory) { + return new LocalArtifactStagingLocation(rootDirectory).createDirectories(); + } + + /** + * Create a {@link LocalArtifactStagingLocation} for an existing directory. The directory must + * contain a manifest and an artifact directory. + */ + public static LocalArtifactStagingLocation forExistingDirectory(File rootDirectory) { + return new LocalArtifactStagingLocation(rootDirectory).verifyExistence(); + } + + private final File rootDirectory; + private final File artifactsDirectory; + + private LocalArtifactStagingLocation(File base) { + this.rootDirectory = base; + this.artifactsDirectory = new File(base, "artifacts"); + } + + private LocalArtifactStagingLocation createDirectories() { + if (((rootDirectory.exists() && rootDirectory.isDirectory()) || rootDirectory.mkdirs()) + && rootDirectory.canWrite()) { + checkState( + ((artifactsDirectory.exists() && artifactsDirectory.isDirectory()) + || artifactsDirectory.mkdir()) + && artifactsDirectory.canWrite(), + "Could not create artifact staging directory at %s", + artifactsDirectory); + } else { + throw new IllegalStateException( + String.format("Could not create staging directory structure at root %s", rootDirectory)); + } + return this; + } + + private LocalArtifactStagingLocation verifyExistence() { + checkArgument(rootDirectory.exists(), "Nonexistent staging location root %s", rootDirectory); + checkArgument( + rootDirectory.isDirectory(), "Staging location %s is not a directory", rootDirectory); + checkArgument( + artifactsDirectory.exists(), "Nonexistent artifact directory %s", artifactsDirectory); + checkArgument( + artifactsDirectory.isDirectory(), + "Artifact location %s is not a directory", + artifactsDirectory); + checkArgument(getManifestFile().exists(), "No Manifest in existing location %s", rootDirectory); + return this; + } + + /** + * Returns the {@link File} which contains the artifact with the provided name. + * + *

The file may not exist. + */ + public File getArtifactFile(String artifactName) { + return new File(artifactsDirectory, artifactName); + } + + /** + * Returns the {@link File} which contains the {@link Manifest}. + * + *

The file may not exist. + */ + public File getManifestFile() { + return new File(rootDirectory, "MANIFEST"); + } + + /** + * Returns the local location of this {@link LocalArtifactStagingLocation}. + * + *

This can be used to refer to the staging location when creating a retrieval service. + */ + public String getRootPath() { + try { + return rootDirectory.getCanonicalPath(); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } +} diff --git a/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactRetrievalService.java b/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactRetrievalService.java new file mode 100644 index 000000000000..73afcdcba966 --- /dev/null +++ b/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactRetrievalService.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.artifact.local; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactChunk; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestResponse; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.Manifest; +import org.apache.beam.model.jobmanagement.v1.ArtifactRetrievalServiceGrpc; +import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService; + +/** An {@code ArtifactRetrievalService} which stages files to a local temp directory. */ +public class LocalFileSystemArtifactRetrievalService + extends ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceImplBase + implements ArtifactRetrievalService { + private static final int DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024; + + public static LocalFileSystemArtifactRetrievalService forRootDirectory(File base) { + return new LocalFileSystemArtifactRetrievalService(base); + } + + private final LocalArtifactStagingLocation location; + private final Manifest manifest; + + private LocalFileSystemArtifactRetrievalService(File rootDirectory) { + this.location = LocalArtifactStagingLocation.forExistingDirectory(rootDirectory); + try (FileInputStream manifestStream = new FileInputStream(location.getManifestFile())) { + this.manifest = ArtifactApi.Manifest.parseFrom(manifestStream); + } catch (FileNotFoundException e) { + throw new IllegalArgumentException( + String.format( + "No %s in root directory %s", Manifest.class.getSimpleName(), rootDirectory), + e); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public final void getManifest( + ArtifactApi.GetManifestRequest request, + StreamObserver responseObserver) { + try { + responseObserver.onNext(GetManifestResponse.newBuilder().setManifest(manifest).build()); + responseObserver.onCompleted(); + } catch (Exception e) { + responseObserver.onError(Status.INTERNAL.withCause(e).asException()); + } + } + + /** Get the artifact with the provided name as a sequence of bytes. */ + private ByteBuffer getArtifact(String name) throws IOException { + File artifact = location.getArtifactFile(name); + if (!artifact.exists()) { + throw new FileNotFoundException(String.format("No such artifact %s", name)); + } + FileChannel input = new FileInputStream(artifact).getChannel(); + return input.map(MapMode.READ_ONLY, 0L, input.size()); + } + + @Override + public void getArtifact( + ArtifactApi.GetArtifactRequest request, + StreamObserver responseObserver) { + try { + ByteBuffer artifact = getArtifact(request.getName()); + do { + responseObserver.onNext( + ArtifactChunk.newBuilder() + .setData( + ByteString.copyFrom( + artifact, Math.min(artifact.remaining(), DEFAULT_CHUNK_SIZE))) + .build()); + } while (artifact.hasRemaining()); + responseObserver.onCompleted(); + } catch (FileNotFoundException e) { + responseObserver.onError( + Status.INVALID_ARGUMENT + .withDescription(String.format("No such artifact %s", request.getName())) + .withCause(e) + .asException()); + } catch (Exception e) { + responseObserver.onError( + Status.INTERNAL + .withDescription( + String.format("Could not retrieve artifact with name %s", request.getName())) + .withCause(e) + .asException()); + } + } + + @Override + public void close() throws Exception {} +} diff --git a/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerService.java b/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerService.java index 03c03276fc58..049d6147633b 100644 --- a/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerService.java +++ b/runners/local-artifact-service-java/src/main/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerService.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; import io.grpc.Status; import io.grpc.StatusException; @@ -48,22 +49,10 @@ public static LocalFileSystemArtifactStagerService withRootDirectory(File base) return new LocalFileSystemArtifactStagerService(base); } - private final File stagingBase; - private final File artifactsBase; + private final LocalArtifactStagingLocation location; private LocalFileSystemArtifactStagerService(File stagingBase) { - this.stagingBase = stagingBase; - if (((stagingBase.exists() && stagingBase.isDirectory()) || stagingBase.mkdirs()) - && stagingBase.canWrite()) { - artifactsBase = new File(stagingBase, "artifacts"); - checkState( - (artifactsBase.mkdir() || artifactsBase.exists()) && artifactsBase.canWrite(), - "Could not create artifact staging directory at %s", - artifactsBase); - } else { - throw new IllegalStateException( - String.format("Could not create staging directory structure at root %s", stagingBase)); - } + this.location = LocalArtifactStagingLocation.createAt(stagingBase); } @Override @@ -98,7 +87,7 @@ private void commitManifestOrThrow( Collection missing = new ArrayList<>(); for (ArtifactApi.ArtifactMetadata artifact : request.getManifest().getArtifactList()) { // TODO: Validate the checksums on the server side, to fail more aggressively if require - if (!getArtifactFile(artifact.getName()).exists()) { + if (!location.getArtifactFile(artifact.getName()).exists()) { missing.add(artifact); } } @@ -108,27 +97,28 @@ private void commitManifestOrThrow( String.format("Attempted to commit manifest with missing Artifacts: [%s]", missing)) .asRuntimeException(); } - File mf = new File(stagingBase, "MANIFEST"); + File mf = location.getManifestFile(); checkState(mf.createNewFile(), "Could not create file to store manifest"); try (OutputStream mfOut = new FileOutputStream(mf)) { request.getManifest().writeTo(mfOut); } responseObserver.onNext( ArtifactApi.CommitManifestResponse.newBuilder() - .setStagingToken(stagingBase.getCanonicalPath()) + .setStagingToken(location.getRootPath()) .build()); responseObserver.onCompleted(); } - File getArtifactFile(String artifactName) { - return new File(artifactsBase, artifactName); - } - @Override public void close() throws Exception { // TODO: Close all active staging calls, signalling errors to the caller. } + @VisibleForTesting + LocalArtifactStagingLocation getLocation() { + return location; + } + private class CreateAndWriteFileObserver implements StreamObserver { private final StreamObserver responseObserver; @@ -169,7 +159,7 @@ public void onNext(ArtifactApi.PutArtifactRequest value) { private FileWritingObserver createFile(ArtifactApi.ArtifactMetadata metadata) throws IOException { - File destination = getArtifactFile(metadata.getName()); + File destination = location.getArtifactFile(metadata.getName()); if (!destination.createNewFile()) { throw Status.ALREADY_EXISTS .withDescription(String.format("Artifact with name %s already exists", metadata)) diff --git a/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalArtifactStagingLocationTest.java b/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalArtifactStagingLocationTest.java new file mode 100644 index 000000000000..6da704559190 --- /dev/null +++ b/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalArtifactStagingLocationTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.artifact.local; + +import static com.google.common.base.Preconditions.checkState; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import java.io.File; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link LocalArtifactStagingLocation}. + */ +@RunWith(JUnit4.class) +public class LocalArtifactStagingLocationTest { + @Rule public TemporaryFolder tmp = new TemporaryFolder(); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void createAtWithAbsentDirectory() throws Exception { + File baseFolder = tmp.newFolder(); + File root = new File(baseFolder, "foo"); + + checkState(!root.exists()); + LocalArtifactStagingLocation.createAt(root); + + assertThat(root.exists(), is(true)); + assertThat(root.listFiles().length, equalTo(1)); + } + + @Test + public void createAtWithExistingDirectory() throws Exception { + File root = tmp.newFolder(); + checkState(root.exists(), "root directory must exist"); + + assertThat(root.exists(), is(true)); + assertThat(root.listFiles().length, equalTo(0)); + LocalArtifactStagingLocation.createAt(root); + + assertThat(root.exists(), is(true)); + assertThat(root.listFiles().length, equalTo(1)); + } + + @Test + public void createAtWithUnwritableDirectory() throws Exception { + File baseFolder = tmp.newFolder(); + File root = new File(baseFolder, "foo"); + checkState(root.mkdir(), "Must be able to create the root directory"); + + assertThat(root.exists(), is(true)); + checkState(root.setWritable(false), "Must be able to set the root directory to unwritable"); + + thrown.expect(IllegalStateException.class); + LocalArtifactStagingLocation.createAt(root); + } + + @Test + public void testCreateAtThenForExisting() throws Exception { + File baseFolder = tmp.newFolder(); + LocalArtifactStagingLocation newLocation = LocalArtifactStagingLocation.createAt(baseFolder); + File newManifest = newLocation.getManifestFile(); + checkState(newManifest.createNewFile(), "Manifest creation failed"); + File newArtifact = newLocation.getArtifactFile("my_artifact"); + checkState(newArtifact.createNewFile(), "Artifact creation failed"); + + LocalArtifactStagingLocation forExisting = + LocalArtifactStagingLocation.forExistingDirectory(baseFolder); + assertThat(forExisting.getManifestFile(), equalTo(newManifest)); + assertThat(forExisting.getArtifactFile("my_artifact"), equalTo(newArtifact)); + } + + @Test + public void testForExistingWithoutRoot() throws Exception { + File baseFolder = tmp.newFolder(); + File root = new File(baseFolder, "bar"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("root"); + LocalArtifactStagingLocation.forExistingDirectory(root); + } + + @Test + public void testForExistingWithoutManifest() throws Exception { + File baseFolder = tmp.newFolder(); + LocalArtifactStagingLocation newLocation = LocalArtifactStagingLocation.createAt(baseFolder); + File newArtifact = newLocation.getArtifactFile("my_artifact"); + checkState(newArtifact.createNewFile(), "Artifact creation failed"); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Manifest"); + LocalArtifactStagingLocation.forExistingDirectory(baseFolder); + } + + @Test + public void testForExistingWithoutArtifacts() throws Exception { + File baseFolder = tmp.newFolder(); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("artifact directory"); + + LocalArtifactStagingLocation.forExistingDirectory(baseFolder); + } +} diff --git a/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactRetrievalServiceTest.java b/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactRetrievalServiceTest.java new file mode 100644 index 000000000000..82c6f54685e0 --- /dev/null +++ b/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactRetrievalServiceTest.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.artifact.local; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.stub.StreamObserver; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactChunk; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactMetadata; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetArtifactRequest; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestRequest; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestResponse; +import org.apache.beam.model.jobmanagement.v1.ArtifactApi.Manifest; +import org.apache.beam.model.jobmanagement.v1.ArtifactRetrievalServiceGrpc; +import org.apache.beam.runners.core.construction.ArtifactServiceStager; +import org.apache.beam.runners.fnexecution.GrpcFnServer; +import org.apache.beam.runners.fnexecution.InProcessServerFactory; +import org.apache.beam.runners.fnexecution.ServerFactory; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link LocalFileSystemArtifactRetrievalService}. + */ +@RunWith(JUnit4.class) +public class LocalFileSystemArtifactRetrievalServiceTest { + @Rule public TemporaryFolder tmp = new TemporaryFolder(); + + private File root; + private ServerFactory serverFactory = InProcessServerFactory.create(); + + private GrpcFnServer stagerServer; + + private GrpcFnServer retrievalServer; + private ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceStub retrievalStub; + + @Before + public void setup() throws Exception { + root = tmp.newFolder(); + stagerServer = + GrpcFnServer.allocatePortAndCreateFor( + LocalFileSystemArtifactStagerService.withRootDirectory(root), serverFactory); + } + + @After + public void teardown() throws Exception { + stagerServer.close(); + retrievalServer.close(); + } + + @Test + public void retrieveManifest() throws Exception { + Map artifacts = new HashMap<>(); + artifacts.put("foo", "bar, baz, quux".getBytes()); + artifacts.put("spam", new byte[] {127, -22, 5}); + stageAndCreateRetrievalService(artifacts); + + final AtomicReference returned = new AtomicReference<>(); + final CountDownLatch completed = new CountDownLatch(1); + retrievalStub.getManifest( + GetManifestRequest.getDefaultInstance(), + new StreamObserver() { + @Override + public void onNext(GetManifestResponse value) { + returned.set(value.getManifest()); + } + + @Override + public void onError(Throwable t) { + completed.countDown(); + } + + @Override + public void onCompleted() { + completed.countDown(); + } + }); + + completed.await(); + assertThat(returned.get(), not(nullValue())); + + List manifestArtifacts = new ArrayList<>(); + for (ArtifactMetadata artifactMetadata : returned.get().getArtifactList()) { + manifestArtifacts.add(artifactMetadata.getName()); + } + assertThat(manifestArtifacts, containsInAnyOrder("foo", "spam")); + } + + @Test + public void retrieveArtifact() throws Exception { + Map artifacts = new HashMap<>(); + byte[] fooContents = "bar, baz, quux".getBytes(); + artifacts.put("foo", fooContents); + byte[] spamContents = {127, -22, 5}; + artifacts.put("spam", spamContents); + stageAndCreateRetrievalService(artifacts); + + final CountDownLatch completed = new CountDownLatch(2); + ByteArrayOutputStream returnedFooBytes = new ByteArrayOutputStream(); + retrievalStub.getArtifact( + GetArtifactRequest.newBuilder().setName("foo").build(), + new MultimapChunkAppender(returnedFooBytes, completed)); + ByteArrayOutputStream returnedSpamBytes = new ByteArrayOutputStream(); + retrievalStub.getArtifact( + GetArtifactRequest.newBuilder().setName("spam").build(), + new MultimapChunkAppender(returnedSpamBytes, completed)); + + completed.await(); + assertArrayEquals(fooContents, returnedFooBytes.toByteArray()); + assertArrayEquals(spamContents, returnedSpamBytes.toByteArray()); + } + + @Test + public void retrieveArtifactNotPresent() throws Exception { + stageAndCreateRetrievalService(Collections.singletonMap("foo", "bar, baz, quux".getBytes())); + + final CountDownLatch completed = new CountDownLatch(1); + final AtomicReference thrown = new AtomicReference<>(); + retrievalStub.getArtifact( + GetArtifactRequest.newBuilder().setName("spam").build(), + new StreamObserver() { + @Override + public void onNext(ArtifactChunk value) { + fail( + "Should never receive an " + + ArtifactChunk.class.getSimpleName() + + " for a nonexistent artifact"); + } + + @Override + public void onError(Throwable t) { + thrown.set(t); + completed.countDown(); + } + + @Override + public void onCompleted() { + completed.countDown(); + } + }); + + completed.await(); + assertThat(thrown.get(), not(nullValue())); + assertThat(thrown.get().getMessage(), containsString("No such artifact")); + assertThat(thrown.get().getMessage(), containsString("spam")); + } + + private void stageAndCreateRetrievalService(Map artifacts) throws Exception { + List artifactFiles = new ArrayList<>(); + for (Map.Entry artifact : artifacts.entrySet()) { + File artifactFile = tmp.newFile(artifact.getKey()); + new FileOutputStream(artifactFile).getChannel().write(ByteBuffer.wrap(artifact.getValue())); + artifactFiles.add(artifactFile); + } + + ArtifactServiceStager stager = + ArtifactServiceStager.overChannel( + InProcessChannelBuilder.forName(stagerServer.getApiServiceDescriptor().getUrl()) + .build()); + stager.stage(artifactFiles); + + retrievalServer = + GrpcFnServer.allocatePortAndCreateFor( + LocalFileSystemArtifactRetrievalService.forRootDirectory(root), serverFactory); + retrievalStub = + ArtifactRetrievalServiceGrpc.newStub( + InProcessChannelBuilder.forName(retrievalServer.getApiServiceDescriptor().getUrl()) + .build()); + } + + private static class MultimapChunkAppender implements StreamObserver { + private final ByteArrayOutputStream target; + private final CountDownLatch completed; + + private MultimapChunkAppender(ByteArrayOutputStream target, CountDownLatch completed) { + this.target = target; + this.completed = completed; + } + + @Override + public void onNext(ArtifactChunk value) { + try { + target.write(value.getData().toByteArray()); + } catch (IOException e) { + // This should never happen + throw new AssertionError(e); + } + } + + @Override + public void onError(Throwable t) { + completed.countDown(); + } + + @Override + public void onCompleted() { + completed.countDown(); + } + } +} diff --git a/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerServiceTest.java b/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerServiceTest.java index 92146a7c722c..0d8509603609 100644 --- a/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerServiceTest.java +++ b/runners/local-artifact-service-java/src/test/java/org/apache/beam/artifact/local/LocalFileSystemArtifactStagerServiceTest.java @@ -99,7 +99,7 @@ public void singleDataPutArtifactSucceeds() throws Exception { responseObserver.awaitTerminalState(); - File staged = stager.getArtifactFile(name); + File staged = stager.getLocation().getArtifactFile(name); assertThat(staged.exists(), is(true)); ByteBuffer buf = ByteBuffer.allocate(data.length); new FileInputStream(staged).getChannel().read(buf); @@ -146,7 +146,7 @@ public void multiPartPutArtifactSucceeds() throws Exception { responseObserver.awaitTerminalState(); - File staged = stager.getArtifactFile(name); + File staged = stager.getLocation().getArtifactFile(name); assertThat(staged.exists(), is(true)); ByteBuffer buf = ByteBuffer.allocate("foo-bar-baz".length()); new FileInputStream(staged).getChannel().read(buf); diff --git a/runners/local-java/pom.xml b/runners/local-java/pom.xml index 6b2959856101..04ead0b5ade0 100644 --- a/runners/local-java/pom.xml +++ b/runners/local-java/pom.xml @@ -68,7 +68,12 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test diff --git a/runners/reference/java/pom.xml b/runners/reference/java/pom.xml index 2db15a0f575a..d07d7c912531 100644 --- a/runners/reference/java/pom.xml +++ b/runners/reference/java/pom.xml @@ -62,10 +62,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test - + + + org.hamcrest + hamcrest-library + test + + junit junit diff --git a/runners/reference/job-server/pom.xml b/runners/reference/job-server/pom.xml index a42491d44cd4..cea1e7de3d90 100644 --- a/runners/reference/job-server/pom.xml +++ b/runners/reference/job-server/pom.xml @@ -137,8 +137,9 @@ org.hamcrest - hamcrest-all + hamcrest-library test + diff --git a/runners/reference/job-server/src/main/java/org/apache/beam/runners/reference/job/ReferenceRunnerJobServer.java b/runners/reference/job-server/src/main/java/org/apache/beam/runners/reference/job/ReferenceRunnerJobServer.java index cbb6f5255eba..494eaf22ff77 100644 --- a/runners/reference/job-server/src/main/java/org/apache/beam/runners/reference/job/ReferenceRunnerJobServer.java +++ b/runners/reference/job-server/src/main/java/org/apache/beam/runners/reference/job/ReferenceRunnerJobServer.java @@ -24,12 +24,9 @@ import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** A program that runs a {@link ReferenceRunnerJobService}. */ public class ReferenceRunnerJobServer { - private static final Logger LOG = LoggerFactory.getLogger(ReferenceRunnerJobService.class); public static void main(String[] args) throws Exception { ServerConfiguration configuration = new ServerConfiguration(); diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle index 47892e383690..e1d2c0abdd25 100644 --- a/runners/spark/build.gradle +++ b/runners/spark/build.gradle @@ -89,3 +89,6 @@ configurations.testRuntimeClasspath { // Testing the Spark runner causes a StackOverflowError if slf4j-jdk14 is on the classpath exclude group: "org.slf4j", module: "slf4j-jdk14" } + +// Generates :runners:spark:runQuickstartJavaSpark +createJavaQuickstartValidationTask(name: 'Spark') diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index fca7b9b228b6..16cec1cd64b4 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -324,12 +324,18 @@ + org.hamcrest - hamcrest-all + hamcrest-core provided + + org.hamcrest + hamcrest-library + provided + org.apache.beam @@ -346,7 +352,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/EmptyCheckpointMark.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/EmptyCheckpointMark.java index a4ab3798424a..6b9e0977ed3d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/EmptyCheckpointMark.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/EmptyCheckpointMark.java @@ -31,7 +31,7 @@ public class EmptyCheckpointMark implements UnboundedSource.CheckpointMark, Seri private static final EmptyCheckpointMark INSTANCE = new EmptyCheckpointMark(); private static final int ID = 2654265; // some constant to serve as identifier. - private EmptyCheckpointMark() {}; + private EmptyCheckpointMark() {} public static EmptyCheckpointMark get() { return INSTANCE; diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index fdf1422d55a3..ce7795bb0dad 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -154,14 +154,12 @@ public void putDataset(PValue pvalue, Dataset dataset, boolean forceCache) { } catch (IllegalStateException e) { // name not set, ignore } - if (forceCache || shouldCache(pvalue)) { + if ((forceCache || shouldCache(pvalue)) && pvalue instanceof PCollection) { // we cache only PCollection - if (pvalue instanceof PCollection) { - Coder coder = ((PCollection) pvalue).getCoder(); - Coder wCoder = - ((PCollection) pvalue).getWindowingStrategy().getWindowFn().windowCoder(); - dataset.cache(storageLevel(), WindowedValue.getFullCoder(coder, wCoder)); - } + Coder coder = ((PCollection) pvalue).getCoder(); + Coder wCoder = + ((PCollection) pvalue).getWindowingStrategy().getWindowFn().windowCoder(); + dataset.cache(storageLevel(), WindowedValue.getFullCoder(coder, wCoder)); } datasets.put(pvalue, dataset); leaves.add(dataset); diff --git a/sdks/go/gogradle.lock b/sdks/go/gogradle.lock new file mode 100644 index 000000000000..fdf7b1901cfb --- /dev/null +++ b/sdks/go/gogradle.lock @@ -0,0 +1,697 @@ +# This file is generated by gogradle automatically, you should NEVER modify it manually. +--- +apiVersion: "0.8.1" +dependencies: + build: + - vcs: "git" + name: "cloud.google.com/go" + commit: "4f6c921ec566a33844f4e7879b31cd8575a6982d" + url: "https://code.googlesource.com/gocloud" + transitive: false + - urls: + - "https://github.com/Shopify/sarama.git" + - "git@github.com:Shopify/sarama.git" + vcs: "git" + name: "github.com/Shopify/sarama" + commit: "541689b9f4212043471eb537fa72da507025d3ea" + transitive: false + - urls: + - "https://github.com/armon/consul-api.git" + - "git@github.com:armon/consul-api.git" + vcs: "git" + name: "github.com/armon/consul-api" + commit: "eb2c6b5be1b66bab83016e0b05f01b8d5496ffbd" + transitive: false + - name: "github.com/beorn7/perks" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/beorn7/perks" + transitive: false + - name: "github.com/bgentry/speakeasy" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/bgentry/speakeasy" + transitive: false + - name: "github.com/coreos/bbolt" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/coreos/bbolt" + transitive: false + - urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + transitive: false + - name: "github.com/coreos/go-semver" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/coreos/go-semver" + transitive: false + - name: "github.com/coreos/go-systemd" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/coreos/go-systemd" + transitive: false + - name: "github.com/coreos/pkg" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/coreos/pkg" + transitive: false + - urls: + - "https://github.com/cpuguy83/go-md2man.git" + - "git@github.com:cpuguy83/go-md2man.git" + vcs: "git" + name: "github.com/cpuguy83/go-md2man" + commit: "dc9f53734905c233adfc09fd4f063dce63ce3daf" + transitive: false + - urls: + - "https://github.com/davecgh/go-spew.git" + - "git@github.com:davecgh/go-spew.git" + vcs: "git" + name: "github.com/davecgh/go-spew" + commit: "87df7c60d5820d0f8ae11afede5aa52325c09717" + transitive: false + - name: "github.com/dgrijalva/jwt-go" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/dgrijalva/jwt-go" + transitive: false + - name: "github.com/dustin/go-humanize" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/dustin/go-humanize" + transitive: false + - urls: + - "https://github.com/eapache/go-resiliency.git" + - "git@github.com:eapache/go-resiliency.git" + vcs: "git" + name: "github.com/eapache/go-resiliency" + commit: "ef9aaa7ea8bd2448429af1a77cf41b2b3b34bdd6" + transitive: false + - urls: + - "https://github.com/eapache/go-xerial-snappy.git" + - "git@github.com:eapache/go-xerial-snappy.git" + vcs: "git" + name: "github.com/eapache/go-xerial-snappy" + commit: "bb955e01b9346ac19dc29eb16586c90ded99a98c" + transitive: false + - urls: + - "https://github.com/eapache/queue.git" + - "git@github.com:eapache/queue.git" + vcs: "git" + name: "github.com/eapache/queue" + commit: "44cc805cf13205b55f69e14bcb69867d1ae92f98" + transitive: false + - urls: + - "https://github.com/fsnotify/fsnotify.git" + - "git@github.com:fsnotify/fsnotify.git" + vcs: "git" + name: "github.com/fsnotify/fsnotify" + commit: "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" + transitive: false + - name: "github.com/ghodss/yaml" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/ghodss/yaml" + transitive: false + - name: "github.com/gogo/protobuf" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/gogo/protobuf" + transitive: false + - urls: + - "https://github.com/golang/glog.git" + - "git@github.com:golang/glog.git" + vcs: "git" + name: "github.com/golang/glog" + commit: "23def4e6c14b4da8ac2ed8007337bc5eb5007998" + transitive: false + - name: "github.com/golang/groupcache" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/golang/groupcache" + transitive: false + - urls: + - "https://github.com/golang/mock.git" + - "git@github.com:golang/mock.git" + vcs: "git" + name: "github.com/golang/mock" + commit: "b3e60bcdc577185fce3cf625fc96b62857ce5574" + transitive: false + - urls: + - "https://github.com/golang/protobuf.git" + - "git@github.com:golang/protobuf.git" + vcs: "git" + name: "github.com/golang/protobuf" + commit: "bbd03ef6da3a115852eaf24c8a1c46aeb39aa175" + transitive: false + - urls: + - "https://github.com/golang/snappy.git" + - "git@github.com:golang/snappy.git" + vcs: "git" + name: "github.com/golang/snappy" + commit: "553a641470496b2327abcac10b36396bd98e45c9" + transitive: false + - name: "github.com/google/btree" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/google/btree" + transitive: false + - urls: + - "https://github.com/google/go-cmp.git" + - "git@github.com:google/go-cmp.git" + vcs: "git" + name: "github.com/google/go-cmp" + commit: "3af367b6b30c263d47e8895973edcca9a49cf029" + transitive: false + - urls: + - "https://github.com/google/pprof.git" + - "git@github.com:google/pprof.git" + vcs: "git" + name: "github.com/google/pprof" + commit: "a8f279b7952b27edbcb72e5a6c69ee9be4c8ad93" + transitive: false + - urls: + - "https://github.com/googleapis/gax-go.git" + - "git@github.com:googleapis/gax-go.git" + vcs: "git" + name: "github.com/googleapis/gax-go" + commit: "317e0006254c44a0ac427cc52a0e083ff0b9622f" + transitive: false + - name: "github.com/gorilla/websocket" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/gorilla/websocket" + transitive: false + - name: "github.com/grpc-ecosystem/go-grpc-prometheus" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/grpc-ecosystem/go-grpc-prometheus" + transitive: false + - name: "github.com/grpc-ecosystem/grpc-gateway" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/grpc-ecosystem/grpc-gateway" + transitive: false + - urls: + - "https://github.com/hashicorp/hcl.git" + - "git@github.com:hashicorp/hcl.git" + vcs: "git" + name: "github.com/hashicorp/hcl" + commit: "23c074d0eceb2b8a5bfdbb271ab780cde70f05a8" + transitive: false + - urls: + - "https://github.com/ianlancetaylor/demangle.git" + - "git@github.com:ianlancetaylor/demangle.git" + vcs: "git" + name: "github.com/ianlancetaylor/demangle" + commit: "4883227f66371e02c4948937d3e2be1664d9be38" + transitive: false + - name: "github.com/inconshreveable/mousetrap" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/inconshreveable/mousetrap" + transitive: false + - name: "github.com/jonboulle/clockwork" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/jonboulle/clockwork" + transitive: false + - urls: + - "https://github.com/kr/fs.git" + - "git@github.com:kr/fs.git" + vcs: "git" + name: "github.com/kr/fs" + commit: "2788f0dbd16903de03cb8186e5c7d97b69ad387b" + transitive: false + - name: "github.com/kr/pty" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/kr/pty" + transitive: false + - urls: + - "https://github.com/magiconair/properties.git" + - "git@github.com:magiconair/properties.git" + vcs: "git" + name: "github.com/magiconair/properties" + commit: "49d762b9817ba1c2e9d0c69183c2b4a8b8f1d934" + transitive: false + - name: "github.com/mattn/go-runewidth" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/mattn/go-runewidth" + transitive: false + - name: "github.com/matttproud/golang_protobuf_extensions" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/matttproud/golang_protobuf_extensions" + transitive: false + - urls: + - "https://github.com/mitchellh/go-homedir.git" + - "git@github.com:mitchellh/go-homedir.git" + vcs: "git" + name: "github.com/mitchellh/go-homedir" + commit: "b8bc1bf767474819792c23f32d8286a45736f1c6" + transitive: false + - urls: + - "https://github.com/mitchellh/mapstructure.git" + - "git@github.com:mitchellh/mapstructure.git" + vcs: "git" + name: "github.com/mitchellh/mapstructure" + commit: "a4e142e9c047c904fa2f1e144d9a84e6133024bc" + transitive: false + - name: "github.com/olekukonko/tablewriter" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/olekukonko/tablewriter" + transitive: false + - urls: + - "https://github.com/openzipkin/zipkin-go.git" + - "git@github.com:openzipkin/zipkin-go.git" + vcs: "git" + name: "github.com/openzipkin/zipkin-go" + commit: "3741243b287094fda649c7f0fa74bd51f37dc122" + transitive: false + - urls: + - "https://github.com/pelletier/go-toml.git" + - "git@github.com:pelletier/go-toml.git" + vcs: "git" + name: "github.com/pelletier/go-toml" + commit: "acdc4509485b587f5e675510c4f2c63e90ff68a8" + transitive: false + - name: "github.com/petar/GoLLRB" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/petar/GoLLRB" + transitive: false + - urls: + - "https://github.com/pierrec/lz4.git" + - "git@github.com:pierrec/lz4.git" + vcs: "git" + name: "github.com/pierrec/lz4" + commit: "ed8d4cc3b461464e69798080a0092bd028910298" + transitive: false + - urls: + - "https://github.com/pierrec/xxHash.git" + - "git@github.com:pierrec/xxHash.git" + vcs: "git" + name: "github.com/pierrec/xxHash" + commit: "a0006b13c722f7f12368c00a3d3c2ae8a999a0c6" + transitive: false + - urls: + - "https://github.com/pkg/errors.git" + - "git@github.com:pkg/errors.git" + vcs: "git" + name: "github.com/pkg/errors" + commit: "30136e27e2ac8d167177e8a583aa4c3fea5be833" + transitive: false + - urls: + - "https://github.com/pkg/sftp.git" + - "git@github.com:pkg/sftp.git" + vcs: "git" + name: "github.com/pkg/sftp" + commit: "22e9c1ccc02fc1b9fa3264572e49109b68a86947" + transitive: false + - urls: + - "https://github.com/prometheus/client_golang.git" + - "git@github.com:prometheus/client_golang.git" + vcs: "git" + name: "github.com/prometheus/client_golang" + commit: "9bb6ab929dcbe1c8393cd9ef70387cb69811bd1c" + transitive: false + - name: "github.com/prometheus/client_model" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/prometheus/client_model" + transitive: false + - name: "github.com/prometheus/common" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/prometheus/common" + transitive: false + - urls: + - "https://github.com/prometheus/procfs.git" + - "git@github.com:prometheus/procfs.git" + vcs: "git" + name: "github.com/prometheus/procfs" + commit: "cb4147076ac75738c9a7d279075a253c0cc5acbd" + transitive: false + - urls: + - "https://github.com/rcrowley/go-metrics.git" + - "git@github.com:rcrowley/go-metrics.git" + vcs: "git" + name: "github.com/rcrowley/go-metrics" + commit: "8732c616f52954686704c8645fe1a9d59e9df7c1" + transitive: false + - name: "github.com/russross/blackfriday" + host: + name: "github.com/cpuguy83/go-md2man" + commit: "dc9f53734905c233adfc09fd4f063dce63ce3daf" + urls: + - "https://github.com/cpuguy83/go-md2man.git" + - "git@github.com:cpuguy83/go-md2man.git" + vcs: "git" + vendorPath: "vendor/github.com/russross/blackfriday" + transitive: false + - name: "github.com/shurcooL/sanitized_anchor_name" + host: + name: "github.com/cpuguy83/go-md2man" + commit: "dc9f53734905c233adfc09fd4f063dce63ce3daf" + urls: + - "https://github.com/cpuguy83/go-md2man.git" + - "git@github.com:cpuguy83/go-md2man.git" + vcs: "git" + vendorPath: "vendor/github.com/shurcooL/sanitized_anchor_name" + transitive: false + - name: "github.com/sirupsen/logrus" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/sirupsen/logrus" + transitive: false + - name: "github.com/soheilhy/cmux" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/soheilhy/cmux" + transitive: false + - urls: + - "https://github.com/spf13/afero.git" + - "git@github.com:spf13/afero.git" + vcs: "git" + name: "github.com/spf13/afero" + commit: "bb8f1927f2a9d3ab41c9340aa034f6b803f4359c" + transitive: false + - urls: + - "https://github.com/spf13/cast.git" + - "git@github.com:spf13/cast.git" + vcs: "git" + name: "github.com/spf13/cast" + commit: "acbeb36b902d72a7a4c18e8f3241075e7ab763e4" + transitive: false + - urls: + - "https://github.com/spf13/cobra.git" + - "git@github.com:spf13/cobra.git" + vcs: "git" + name: "github.com/spf13/cobra" + commit: "93959269ad99e80983c9ba742a7e01203a4c0e4f" + transitive: false + - urls: + - "https://github.com/spf13/jwalterweatherman.git" + - "git@github.com:spf13/jwalterweatherman.git" + vcs: "git" + name: "github.com/spf13/jwalterweatherman" + commit: "7c0cea34c8ece3fbeb2b27ab9b59511d360fb394" + transitive: false + - name: "github.com/spf13/pflag" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/spf13/pflag" + transitive: false + - urls: + - "https://github.com/spf13/viper.git" + - "git@github.com:spf13/viper.git" + vcs: "git" + name: "github.com/spf13/viper" + commit: "aafc9e6bc7b7bb53ddaa75a5ef49a17d6e654be5" + transitive: false + - urls: + - "https://github.com/stathat/go.git" + - "git@github.com:stathat/go.git" + vcs: "git" + name: "github.com/stathat/go" + commit: "74669b9f388d9d788c97399a0824adbfee78400e" + transitive: false + - name: "github.com/tmc/grpc-websocket-proxy" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/tmc/grpc-websocket-proxy" + transitive: false + - name: "github.com/ugorji/go" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/ugorji/go" + transitive: false + - name: "github.com/urfave/cli" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/urfave/cli" + transitive: false + - name: "github.com/xiang90/probing" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/github.com/xiang90/probing" + transitive: false + - urls: + - "https://github.com/xordataexchange/crypt.git" + - "git@github.com:xordataexchange/crypt.git" + vcs: "git" + name: "github.com/xordataexchange/crypt" + commit: "b2862e3d0a775f18c7cfe02273500ae307b61218" + transitive: false + - vcs: "git" + name: "go.opencensus.io" + commit: "aa2b39d1618ef56ba156f27cfcdae9042f68f0bc" + url: "https://github.com/census-instrumentation/opencensus-go" + transitive: false + - vcs: "git" + name: "golang.org/x/crypto" + commit: "d9133f5469342136e669e85192a26056b587f503" + url: "https://go.googlesource.com/crypto" + transitive: false + - vcs: "git" + name: "golang.org/x/debug" + commit: "95515998a8a4bd7448134b2cb5971dbeb12e0b77" + url: "https://go.googlesource.com/debug" + transitive: false + - vcs: "git" + name: "golang.org/x/net" + commit: "2fb46b16b8dda405028c50f7c7f0f9dd1fa6bfb1" + url: "https://go.googlesource.com/net" + transitive: false + - vcs: "git" + name: "golang.org/x/oauth2" + commit: "a032972e28060ca4f5644acffae3dfc268cc09db" + url: "https://go.googlesource.com/oauth2" + transitive: false + - vcs: "git" + name: "golang.org/x/sync" + commit: "fd80eb99c8f653c847d294a001bdf2a3a6f768f5" + url: "https://go.googlesource.com/sync" + transitive: false + - vcs: "git" + name: "golang.org/x/sys" + commit: "37707fdb30a5b38865cfb95e5aab41707daec7fd" + url: "https://go.googlesource.com/sys" + transitive: false + - name: "golang.org/x/text" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/golang.org/x/text" + transitive: false + - name: "golang.org/x/time" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/golang.org/x/time" + transitive: false + - vcs: "git" + name: "google.golang.org/api" + commit: "386d4e5f4f92f86e6aec85985761bba4b938a2d5" + url: "https://code.googlesource.com/google-api-go-client" + transitive: false + - vcs: "git" + name: "google.golang.org/genproto" + commit: "2b5a72b8730b0b16380010cfe5286c42108d88e7" + url: "https://github.com/google/go-genproto" + transitive: false + - vcs: "git" + name: "google.golang.org/grpc" + commit: "7646b5360d049a7ca31e9133315db43456f39e2e" + url: "https://github.com/grpc/grpc-go" + transitive: false + - name: "gopkg.in/cheggaaa/pb.v1" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/gopkg.in/cheggaaa/pb.v1" + transitive: false + - name: "gopkg.in/yaml.v2" + host: + name: "github.com/coreos/etcd" + commit: "11214aa33bf5a47d3d9d8dafe0f6b97237dfe921" + urls: + - "https://github.com/coreos/etcd.git" + - "git@github.com:coreos/etcd.git" + vcs: "git" + vendorPath: "vendor/gopkg.in/yaml.v2" + transitive: false + test: [] diff --git a/sdks/java/build-tools/src/main/resources/beam/checkstyle.xml b/sdks/java/build-tools/src/main/resources/beam/checkstyle.xml index b2a74a775a84..ec51e44112d9 100644 --- a/sdks/java/build-tools/src/main/resources/beam/checkstyle.xml +++ b/sdks/java/build-tools/src/main/resources/beam/checkstyle.xml @@ -24,6 +24,7 @@ what the following rules do, please see the checkstyle configuration page at http://checkstyle.sourceforge.net/config.html --> + @@ -58,19 +59,13 @@ page at http://checkstyle.sourceforge.net/config.html --> - - - - - - + + + - - - @@ -107,6 +102,7 @@ page at http://checkstyle.sourceforge.net/config.html --> + @@ -383,6 +379,15 @@ page at http://checkstyle.sourceforge.net/config.html --> WHITESPACE CHECKS --> + + + + + + + + - - + + + + + + diff --git a/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml b/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml index d9fd2032f68d..03432670b500 100644 --- a/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml +++ b/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml @@ -336,17 +336,14 @@ - - - - - - - + diff --git a/sdks/java/build-tools/src/main/resources/beam/suppressions.xml b/sdks/java/build-tools/src/main/resources/beam/suppressions.xml index 9f60c2549715..575d3f2eaf36 100644 --- a/sdks/java/build-tools/src/main/resources/beam/suppressions.xml +++ b/sdks/java/build-tools/src/main/resources/beam/suppressions.xml @@ -17,13 +17,13 @@ "http://www.puppycrawl.com/dtds/suppressions_1_1.dtd"> - - - + + + - - - + + + diff --git a/sdks/java/container/build.gradle b/sdks/java/container/build.gradle index ca6f8c414a5f..93e57f478de5 100644 --- a/sdks/java/container/build.gradle +++ b/sdks/java/container/build.gradle @@ -36,6 +36,7 @@ dependencies { // TODO(herohde): use "./" prefix to prevent gogradle use base github path, for now. // TODO(herohde): get the pkg subdirectory only, if possible. We spend mins pulling cmd/beamctl deps. build name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir + test name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir } dockerDependency library.java.slf4j_api dockerDependency library.java.slf4j_jdk14 @@ -46,7 +47,7 @@ task copyDockerfileDependencies(type: Copy) { from configurations.dockerDependency rename "slf4j-api.*", "slf4j-api.jar" rename "slf4j-jdk14.*", "slf4j-jdk14.jar" - rename "beam-sdks-java-harness.*", "beam-sdks-java-harness.jar" + rename "harness.*\\.jar", "beam-sdks-java-harness.jar" into "build/target" } @@ -60,8 +61,13 @@ golang { } docker { - // TODO(herohde): make the name easier to generate for releases. - name System.properties['user.name'] + '-docker-apache.bintray.io/beam/java:latest' + String repositoryRoot + if (rootProject.hasProperty(["docker-repository-root"])) { + repositoryRoot = rootProject["docker-repository-root"] + } else { + repositoryRoot = "${System.properties["user.name"]}-docker-apache.bintray.io/beam" + } + name "${repositoryRoot}/java:latest" files "./build/" } // Ensure that we build the required resources and copy and file dependencies from related projects diff --git a/sdks/java/core/pom.xml b/sdks/java/core/pom.xml index 0c49ac7ec46a..bb90f43d980d 100644 --- a/sdks/java/core/pom.xml +++ b/sdks/java/core/pom.xml @@ -49,15 +49,6 @@ - - org.apache.maven.plugins - maven-checkstyle-plugin - - - ${project.basedir}/src/test/ - - - org.apache.maven.plugins maven-shade-plugin @@ -332,7 +323,13 @@ org.hamcrest - hamcrest-all + hamcrest-core + provided + + + + org.hamcrest + hamcrest-library provided @@ -350,10 +347,17 @@ org.mockito - mockito-all + mockito-core test - + + + org.objenesis + objenesis + 1.0 + test + + com.esotericsoftware.kryo kryo diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java similarity index 59% rename from sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java index 70fbf5879171..049bf2f93841 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java @@ -25,72 +25,70 @@ import java.util.Collections; import java.util.List; import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.values.BeamRecord; -import org.apache.beam.sdk.values.BeamRecordType; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** - * A {@link Coder} for {@link BeamRecord}. It wraps the {@link Coder} for each element directly. + * A {@link Coder} for {@link Row}. It wraps the {@link Coder} for each element directly. */ @Experimental -public class BeamRecordCoder extends CustomCoder { +public class RowCoder extends CustomCoder { private static final BitSetCoder nullListCoder = BitSetCoder.of(); - private BeamRecordType recordType; + private RowType rowType; private List coders; - private BeamRecordCoder(BeamRecordType recordType, List coders) { - this.recordType = recordType; + private RowCoder(RowType rowType, List coders) { + this.rowType = rowType; this.coders = coders; } - public static BeamRecordCoder of(BeamRecordType recordType, List coderArray){ - if (recordType.getFieldCount() != coderArray.size()) { + public static RowCoder of(RowType rowType, List coderArray){ + if (rowType.getFieldCount() != coderArray.size()) { throw new IllegalArgumentException("Coder size doesn't match with field size"); } - return new BeamRecordCoder(recordType, coderArray); + return new RowCoder(rowType, coderArray); } - public BeamRecordType getRecordType() { - return recordType; + public RowType getRowType() { + return rowType; } @Override - public void encode(BeamRecord value, OutputStream outStream) + public void encode(Row value, OutputStream outStream) throws CoderException, IOException { nullListCoder.encode(scanNullFields(value), outStream); for (int idx = 0; idx < value.getFieldCount(); ++idx) { - if (value.getFieldValue(idx) == null) { + if (value.getValue(idx) == null) { continue; } - coders.get(idx).encode(value.getFieldValue(idx), outStream); + coders.get(idx).encode(value.getValue(idx), outStream); } } @Override - public BeamRecord decode(InputStream inStream) throws CoderException, IOException { + public Row decode(InputStream inStream) throws CoderException, IOException { BitSet nullFields = nullListCoder.decode(inStream); - List fieldValues = new ArrayList<>(recordType.getFieldCount()); - for (int idx = 0; idx < recordType.getFieldCount(); ++idx) { + List fieldValues = new ArrayList<>(rowType.getFieldCount()); + for (int idx = 0; idx < rowType.getFieldCount(); ++idx) { if (nullFields.get(idx)) { fieldValues.add(null); } else { fieldValues.add(coders.get(idx).decode(inStream)); } } - BeamRecord record = new BeamRecord(recordType, fieldValues); - - return record; + return Row.withRowType(rowType).addValues(fieldValues).build(); } /** - * Scan {@link BeamRecord} to find fields with a NULL value. + * Scan {@link Row} to find fields with a NULL value. */ - private BitSet scanNullFields(BeamRecord record){ - BitSet nullFields = new BitSet(record.getFieldCount()); - for (int idx = 0; idx < record.getFieldCount(); ++idx) { - if (record.getFieldValue(idx) == null) { + private BitSet scanNullFields(Row row){ + BitSet nullFields = new BitSet(row.getFieldCount()); + for (int idx = 0; idx < row.getFieldCount(); ++idx) { + if (row.getValue(idx) == null) { nullFields.set(idx); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SnappyCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SnappyCoder.java new file mode 100644 index 000000000000..b3e4698f5151 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SnappyCoder.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.coders; + +import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; +import org.apache.beam.sdk.util.CoderUtils; +import org.xerial.snappy.Snappy; + +/** + * Wraps an existing coder with Snappy compression. It makes sense to use this coder only when it's + * likely that the encoded value is quite large and compressible. + */ +public class SnappyCoder extends StructuredCoder { + private final Coder innerCoder; + + /** Wraps the given coder into a {@link SnappyCoder}. */ + public static SnappyCoder of(Coder innerCoder) { + return new SnappyCoder<>(innerCoder); + } + + private SnappyCoder(Coder innerCoder) { + this.innerCoder = innerCoder; + } + + @Override + public void encode(T value, OutputStream os) throws IOException { + ByteArrayCoder.of() + .encode(Snappy.compress(CoderUtils.encodeToByteArray(innerCoder, value)), os); + } + + @Override + public T decode(InputStream is) throws IOException { + return CoderUtils.decodeFromByteArray( + innerCoder, Snappy.uncompress(ByteArrayCoder.of().decode(is))); + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(innerCoder); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + innerCoder.verifyDeterministic(); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java index d7dfc395c073..412da158acf2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java @@ -33,7 +33,6 @@ /** A {@link FileBasedSink} for Avro files. */ class AvroSink extends FileBasedSink { - private final DynamicAvroDestinations dynamicDestinations; private final boolean genericRecords; AvroSink( @@ -42,7 +41,6 @@ class AvroSink extends FileBasedSink, ResourceId>> resultsToFinalFilenames) throws IOException { int numFiles = resultsToFinalFilenames.size(); - LOG.debug("Copying {} files.", numFiles); - List srcFiles = new ArrayList<>(); - List dstFiles = new ArrayList<>(); - for (KV, ResourceId> entry : resultsToFinalFilenames) { - srcFiles.add(entry.getKey().getTempFilename()); - dstFiles.add(entry.getValue()); - LOG.info( - "Will copy temporary file {} to final location {}", - entry.getKey(), - entry.getValue()); - } - // During a failure case, files may have been deleted in an earlier step. Thus - // we ignore missing files here. - FileSystems.copy(srcFiles, dstFiles, StandardMoveOptions.IGNORE_MISSING_FILES); + LOG.debug("Copying {} files.", numFiles); + List srcFiles = new ArrayList<>(); + List dstFiles = new ArrayList<>(); + for (KV, ResourceId> entry : resultsToFinalFilenames) { + srcFiles.add(entry.getKey().getTempFilename()); + dstFiles.add(entry.getValue()); + LOG.info( + "Will copy temporary file {} to final location {}", entry.getKey(), entry.getValue()); + } + // During a failure case, files may have been deleted in an earlier step. Thus + // we ignore missing files here. + FileSystems.copy(srcFiles, dstFiles, StandardMoveOptions.IGNORE_MISSING_FILES); removeTemporaryFiles(srcFiles); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java index 156b95011d03..9295981744d5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java @@ -23,6 +23,7 @@ import static org.apache.beam.sdk.transforms.Contextful.fn; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.collect.Lists; import java.io.IOException; @@ -755,7 +756,7 @@ public interface Sink extends Serializable { public abstract static class Write extends PTransform, WriteFilesResult> { /** A policy for generating names for shard files. */ - interface FileNaming extends Serializable { + public interface FileNaming extends Serializable { /** * Generates the filename. MUST use each argument and return different values for * each combination of the arguments. @@ -1147,6 +1148,52 @@ public Write withIgnoreWindowing() { return toBuilder().setIgnoreWindowing(true).build(); } + @VisibleForTesting + Contextful> resolveFileNamingFn() { + if (getDynamic()) { + checkArgument( + getConstantFileNaming() == null, + "when using writeDynamic(), must use versions of .withNaming() " + + "that take functions from DestinationT"); + checkArgument(getFilenamePrefix() == null, ".withPrefix() requires write()"); + checkArgument(getFilenameSuffix() == null, ".withSuffix() requires write()"); + checkArgument( + getFileNamingFn() != null, + "when using writeDynamic(), must specify " + + ".withNaming() taking a function form DestinationT"); + return fn( + (element, c) -> { + FileNaming naming = getFileNamingFn().getClosure().apply(element, c); + return getOutputDirectory() == null + ? naming + : relativeFileNaming(getOutputDirectory(), naming); + }, + getFileNamingFn().getRequirements()); + } else { + checkArgument(getFileNamingFn() == null, + ".withNaming() taking a function from DestinationT requires writeDynamic()"); + FileNaming constantFileNaming; + if (getConstantFileNaming() == null) { + constantFileNaming = defaultNaming( + MoreObjects.firstNonNull( + getFilenamePrefix(), StaticValueProvider.of("output")), + MoreObjects.firstNonNull(getFilenameSuffix(), StaticValueProvider.of(""))); + } else { + checkArgument( + getFilenamePrefix() == null, + ".to(FileNaming) is incompatible with .withSuffix()"); + checkArgument( + getFilenameSuffix() == null, + ".to(FileNaming) is incompatible with .withPrefix()"); + constantFileNaming = getConstantFileNaming(); + } + if (getOutputDirectory() != null) { + constantFileNaming = relativeFileNaming(getOutputDirectory(), constantFileNaming); + } + return fn(SerializableFunctions.constant(constantFileNaming)); + } + } + @Override public WriteFilesResult expand(PCollection input) { Write.Builder resolvedSpec = new AutoValue_FileIO_Write.Builder<>(); @@ -1172,52 +1219,7 @@ public WriteFilesResult expand(PCollection input) { resolvedSpec.setDestinationCoder((Coder) VoidCoder.of()); } - // Resolve fileNamingFn - Contextful> fileNamingFn; - if (getDynamic()) { - checkArgument( - getConstantFileNaming() == null, - "when using writeDynamic(), must use versions of .withNaming() " - + "that take functions from DestinationT"); - checkArgument(getFilenamePrefix() == null, ".withPrefix() requires write()"); - checkArgument(getFilenameSuffix() == null, ".withSuffix() requires write()"); - checkArgument( - getFileNamingFn() != null, - "when using writeDynamic(), must specify " - + ".withNaming() taking a function form DestinationT"); - fileNamingFn = - Contextful.fn( - (element, c) -> { - FileNaming naming = getFileNamingFn().getClosure().apply(element, c); - return getOutputDirectory() == null - ? naming - : relativeFileNaming(getOutputDirectory(), naming); - }, - getFileNamingFn().getRequirements()); - } else { - checkArgument(getFileNamingFn() == null, - ".withNaming() taking a function from DestinationT requires writeDynamic()"); - FileNaming constantFileNaming; - if (getConstantFileNaming() == null) { - constantFileNaming = defaultNaming( - MoreObjects.firstNonNull( - getFilenamePrefix(), StaticValueProvider.of("output")), - MoreObjects.firstNonNull(getFilenameSuffix(), StaticValueProvider.of(""))); - if (getOutputDirectory() != null) { - constantFileNaming = relativeFileNaming(getOutputDirectory(), constantFileNaming); - } - } else { - checkArgument( - getFilenamePrefix() == null, ".to(FileNaming) is incompatible with .withSuffix()"); - checkArgument( - getFilenameSuffix() == null, ".to(FileNaming) is incompatible with .withPrefix()"); - constantFileNaming = getConstantFileNaming(); - } - fileNamingFn = - fn(SerializableFunctions.constant(constantFileNaming)); - } - - resolvedSpec.setFileNamingFn(fileNamingFn); + resolvedSpec.setFileNamingFn(resolveFileNamingFn()); resolvedSpec.setEmptyWindowDestination(getEmptyWindowDestination()); if (getTempDirectory() == null) { checkArgument( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionResult.java index b01ae4646b35..6da721068fb7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionResult.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionResult.java @@ -37,7 +37,9 @@ public double mean() { return (1.0 * sum()) / count(); } - public static final DistributionResult ZERO = create(0, 0, Long.MAX_VALUE, Long.MIN_VALUE); + /** The IDENTITY_ELEMENT is used to start accumulating distributions. */ + public static final DistributionResult IDENTITY_ELEMENT = + create(0, 0, Long.MAX_VALUE, Long.MIN_VALUE); public static DistributionResult create(long sum, long count, long min, long max) { return new AutoValue_DistributionResult(sum, count, min, max); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java index 1dc9d4403f75..4eb461efe8a7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.util.ReleaseInfo; +import org.apache.beam.sdk.util.common.ReflectHelpers; import org.joda.time.DateTimeUtils; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; @@ -285,7 +286,9 @@ public Class> create(PipelineOptions options) { @SuppressWarnings({"unchecked", "rawtypes"}) Class> direct = (Class>) - Class.forName("org.apache.beam.runners.direct.DirectRunner"); + Class.forName( + "org.apache.beam.runners.direct.DirectRunner", true, + ReflectHelpers.findClassLoader()); return direct; } catch (ClassNotFoundException e) { throw new IllegalArgumentException(String.format( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java index 24a0d9d1cd60..50f09e85c98b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static java.util.Locale.ROOT; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.databind.JavaType; @@ -73,10 +74,12 @@ import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.StreamSupport; import javax.annotation.Nonnull; import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.options.Validation.Required; import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; import org.apache.beam.sdk.transforms.display.DisplayData; @@ -122,7 +125,7 @@ public static PipelineOptions create() { * {@link Class#getSimpleName() classes simple name}. * *

Note that {@code } must be composable with every registered interface with this factory. - * See {@link PipelineOptionsFactory#validateWellFormed(Class, Set)} for more details. + * See {@link PipelineOptionsFactory.Cache#validateWellFormed(Class, Set)} for more details. * * @return An object that implements {@code }. */ @@ -278,7 +281,7 @@ public PipelineOptions create() { * this builder during construction. * *

Note that {@code } must be composable with every registered interface with this - * factory. See {@link PipelineOptionsFactory#validateWellFormed(Class, Set)} for more + * factory. See {@link PipelineOptionsFactory.Cache#validateWellFormed(Class)} for more * details. * * @return An object that implements {@code }. @@ -347,7 +350,7 @@ static boolean printHelpUsageAndExitIfNeeded(ListMultimap option // Otherwise attempt to print the specific help option. try { - Class klass = Class.forName(helpOption); + Class klass = Class.forName(helpOption, true, ReflectHelpers.findClassLoader()); if (!PipelineOptions.class.isAssignableFrom(klass)) { throw new ClassNotFoundException("PipelineOptions of type " + klass + " not found."); } @@ -408,7 +411,8 @@ private static String findCallersClassName() { StackTraceElement next = elements.next(); if (!PIPELINE_OPTIONS_FACTORY_CLASSES.contains(next.getClassName())) { try { - return Class.forName(next.getClassName()).getSimpleName(); + return Class.forName( + next.getClassName(), true, ReflectHelpers.findClassLoader()).getSimpleName(); } catch (ClassNotFoundException e) { break; } @@ -463,9 +467,6 @@ Class getProxyClass() { private static final Class[] EMPTY_CLASS_ARRAY = new Class[0]; static final ObjectMapper MAPPER = new ObjectMapper().registerModules( ObjectMapper.findModules(ReflectHelpers.findClassLoader())); - private static final ClassLoader CLASS_LOADER; - - private static final Map>> SUPPORTED_PIPELINE_RUNNERS; /** Classes that are used as the boundary in the stack trace to find the callers class name. */ private static final Set PIPELINE_OPTIONS_FACTORY_CLASSES = @@ -477,17 +478,8 @@ Class getProxyClass() { /** A predicate that checks if a method is synthetic via {@link Method#isSynthetic()}. */ private static final Predicate NOT_SYNTHETIC_PREDICATE = input -> !input.isSynthetic(); - /** The set of options that have been registered and visible to the user. */ - private static final Set> REGISTERED_OPTIONS = - Sets.newConcurrentHashSet(); - - /** A cache storing a mapping from a given interface to its registration record. */ - private static final Map, Registration> INTERFACE_CACHE = - Maps.newConcurrentMap(); - - /** A cache storing a mapping from a set of interfaces to its registration record. */ - private static final Map>, Registration> COMBINED_CACHE = - Maps.newConcurrentMap(); + /** Ensure all classloader or volatile data are contained in a single reference. */ + static final AtomicReference CACHE = new AtomicReference<>(); /** The width at which options should be output. */ private static final int TERMINAL_WIDTH = 80; @@ -507,27 +499,7 @@ Class getProxyClass() { LOG.error("Unable to find expected method", e); throw new ExceptionInInitializerError(e); } - - CLASS_LOADER = ReflectHelpers.findClassLoader(); - - Set pipelineRunnerRegistrars = - Sets.newTreeSet(ReflectHelpers.ObjectsClassComparator.INSTANCE); - pipelineRunnerRegistrars.addAll( - Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class, CLASS_LOADER))); - // Store the list of all available pipeline runners. - ImmutableMap.Builder>> builder = - ImmutableMap.builder(); - for (PipelineRunnerRegistrar registrar : pipelineRunnerRegistrars) { - for (Class> klass : registrar.getPipelineRunners()) { - String runnerName = klass.getSimpleName().toLowerCase(); - builder.put(runnerName, klass); - if (runnerName.endsWith("runner")) { - builder.put(runnerName.substring(0, runnerName.length() - "Runner".length()), klass); - } - } - } - SUPPORTED_PIPELINE_RUNNERS = builder.build(); - initializeRegistry(); + resetCache(); } /** @@ -546,113 +518,24 @@ Class getProxyClass() { * @param iface The interface object to manually register. */ public static synchronized void register(Class iface) { - checkNotNull(iface); - checkArgument(iface.isInterface(), "Only interface types are supported."); - - if (REGISTERED_OPTIONS.contains(iface)) { - return; - } - validateWellFormed(iface, REGISTERED_OPTIONS); - REGISTERED_OPTIONS.add(iface); + CACHE.get().register(iface); } /** * Resets the set of interfaces registered with this factory to the default state. + *

IMPORTANT: this is marked as experimental because the correct usage of this + * method requires appropriate synchronization beyond the scope of this method.

* * @see PipelineOptionsFactory#register(Class) + * @see Cache#Cache() */ - @VisibleForTesting - static synchronized void resetRegistry() { - REGISTERED_OPTIONS.clear(); - initializeRegistry(); - } - - /** - * Load and register the list of all classes that extend PipelineOptions. - */ - private static void initializeRegistry() { - register(PipelineOptions.class); - Set pipelineOptionsRegistrars = - Sets.newTreeSet(ReflectHelpers.ObjectsClassComparator.INSTANCE); - pipelineOptionsRegistrars.addAll( - Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class, CLASS_LOADER))); - for (PipelineOptionsRegistrar registrar : pipelineOptionsRegistrars) { - for (Class klass : registrar.getPipelineOptions()) { - register(klass); - } - } - } - - /** - * Validates that the interface conforms to the following: - *
    - *
  • Every inherited interface of {@code iface} must extend PipelineOptions except for - * PipelineOptions itself. - *
  • Any property with the same name must have the same return type for all derived - * interfaces of {@link PipelineOptions}. - *
  • Every bean property of any interface derived from {@link PipelineOptions} must have a - * getter and setter method. - *
  • Every method must conform to being a getter or setter for a JavaBean. - *
  • The derived interface of {@link PipelineOptions} must be composable with every interface - * part of allPipelineOptionsClasses. - *
  • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. - *
  • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for - * this property must be annotated with {@link JsonIgnore @JsonIgnore}. - *
- * - * @param iface The interface to validate. - * @param validatedPipelineOptionsInterfaces The set of validated pipeline options interfaces to - * validate against. - * @return A registration record containing the proxy class and bean info for iface. - */ - static synchronized Registration validateWellFormed( - Class iface, Set> validatedPipelineOptionsInterfaces) { - checkArgument(iface.isInterface(), "Only interface types are supported."); - - // Validate that every inherited interface must extend PipelineOptions except for - // PipelineOptions itself. - validateInheritedInterfacesExtendPipelineOptions(iface); - - @SuppressWarnings("unchecked") - Set> combinedPipelineOptionsInterfaces = - FluentIterable.from(validatedPipelineOptionsInterfaces).append(iface).toSet(); - // Validate that the view of all currently passed in options classes is well formed. - if (!COMBINED_CACHE.containsKey(combinedPipelineOptionsInterfaces)) { - @SuppressWarnings("unchecked") - Class allProxyClass = - (Class) Proxy.getProxyClass(PipelineOptionsFactory.class.getClassLoader(), - combinedPipelineOptionsInterfaces.toArray(EMPTY_CLASS_ARRAY)); - try { - List propertyDescriptors = - validateClass(iface, validatedPipelineOptionsInterfaces, allProxyClass); - COMBINED_CACHE.put(combinedPipelineOptionsInterfaces, - new Registration<>(allProxyClass, propertyDescriptors)); - } catch (IntrospectionException e) { - throw new RuntimeException(e); - } - } - - // Validate that the local view of the class is well formed. - if (!INTERFACE_CACHE.containsKey(iface)) { - @SuppressWarnings({"rawtypes", "unchecked"}) - Class proxyClass = (Class) Proxy.getProxyClass( - PipelineOptionsFactory.class.getClassLoader(), new Class[] {iface}); - try { - List propertyDescriptors = - validateClass(iface, validatedPipelineOptionsInterfaces, proxyClass); - INTERFACE_CACHE.put(iface, - new Registration<>(proxyClass, propertyDescriptors)); - } catch (IntrospectionException e) { - throw new RuntimeException(e); - } - } - @SuppressWarnings("unchecked") - Registration result = (Registration) INTERFACE_CACHE.get(iface); - return result; + @Experimental(Experimental.Kind.UNSPECIFIED) + public static synchronized void resetCache() { + CACHE.set(new Cache()); } public static Set> getRegisteredOptions() { - return Collections.unmodifiableSet(REGISTERED_OPTIONS); + return Collections.unmodifiableSet(CACHE.get().registeredOptions); } /** @@ -666,7 +549,7 @@ public static void printHelp(PrintStream out) { out.println("The set of registered options are:"); Set> sortedOptions = new TreeSet<>(ClassNameComparator.INSTANCE); - sortedOptions.addAll(REGISTERED_OPTIONS); + sortedOptions.addAll(CACHE.get().registeredOptions); for (Class kls : sortedOptions) { out.format(" %s%n", kls.getName()); } @@ -696,7 +579,7 @@ public static void printHelp(PrintStream out) { public static void printHelp(PrintStream out, Class iface) { checkNotNull(out); checkNotNull(iface); - validateWellFormed(iface, REGISTERED_OPTIONS); + CACHE.get().validateWellFormed(iface); Set properties = PipelineOptionsReflector.getOptionSpecs(iface); @@ -835,12 +718,7 @@ private static Optional getDefaultValueFromAnnotation(Method method) { } static Map>> getRegisteredRunners() { - return SUPPORTED_PIPELINE_RUNNERS; - } - - static List getPropertyDescriptors( - Set> interfaces) { - return COMBINED_CACHE.get(interfaces).getPropertyDescriptors(); + return CACHE.get().supportedPipelineRunners; } /** @@ -1579,10 +1457,11 @@ private static ListMultimap parseCommandLine( private static Map parseObjects( Class klass, ListMultimap options, boolean strictParsing) { Map propertyNamesToGetters = Maps.newHashMap(); - PipelineOptionsFactory.validateWellFormed(klass, REGISTERED_OPTIONS); + Cache cache = CACHE.get(); + cache.validateWellFormed(klass); @SuppressWarnings("unchecked") Iterable propertyDescriptors = - PipelineOptionsFactory.getPropertyDescriptors( + cache.getPropertyDescriptors( FluentIterable.from(getRegisteredOptions()).append(klass).toSet()); for (PropertyDescriptor descriptor : propertyDescriptors) { propertyNamesToGetters.put(descriptor.getName(), descriptor.getReadMethod()); @@ -1619,24 +1498,26 @@ private static Map parseObjects( JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); if ("runner".equals(entry.getKey())) { String runner = Iterables.getOnlyElement(entry.getValue()); - if (SUPPORTED_PIPELINE_RUNNERS.containsKey(runner.toLowerCase())) { - convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner.toLowerCase())); + final Map>> pipelineRunners = cache + .supportedPipelineRunners; + if (pipelineRunners.containsKey(runner.toLowerCase())) { + convertedOptions.put("runner", pipelineRunners.get(runner.toLowerCase(ROOT))); } else { try { - Class runnerClass = Class.forName(runner); + Class runnerClass = Class.forName(runner, true, ReflectHelpers.findClassLoader()); if (!(PipelineRunner.class.isAssignableFrom(runnerClass))) { throw new IllegalArgumentException( String.format( "Class '%s' does not implement PipelineRunner. " + "Supported pipeline runners %s", - runner, getSupportedRunners())); + runner, cache.getSupportedRunners())); } convertedOptions.put("runner", runnerClass); } catch (ClassNotFoundException e) { String msg = String.format( "Unknown 'runner' specified '%s', supported pipeline runners %s", - runner, getSupportedRunners()); + runner, cache.getSupportedRunners()); throw new IllegalArgumentException(msg, e); } } @@ -1753,12 +1634,162 @@ private static void checkEmptyStringAllowed(Class type, JavaType genericType, } } - @VisibleForTesting - static Set getSupportedRunners() { - ImmutableSortedSet.Builder supportedRunners = ImmutableSortedSet.naturalOrder(); - for (Class> runner : SUPPORTED_PIPELINE_RUNNERS.values()) { - supportedRunners.add(runner.getSimpleName()); + /** Hold all data which can change after a classloader change. */ + static final class Cache { + private final Map>> supportedPipelineRunners; + + /** The set of options that have been registered and visible to the user. */ + private final Set> registeredOptions = + Sets.newConcurrentHashSet(); + + /** A cache storing a mapping from a given interface to its registration record. */ + private final Map, Registration> interfaceCache = + Maps.newConcurrentMap(); + + /** A cache storing a mapping from a set of interfaces to its registration record. */ + private final Map>, Registration> combinedCache = + Maps.newConcurrentMap(); + + private Cache() { + final ClassLoader loader = ReflectHelpers.findClassLoader(); + + Set pipelineRunnerRegistrars = + Sets.newTreeSet(ReflectHelpers.ObjectsClassComparator.INSTANCE); + pipelineRunnerRegistrars.addAll( + Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class, loader))); + // Store the list of all available pipeline runners. + ImmutableMap.Builder>> builder = + ImmutableMap.builder(); + for (PipelineRunnerRegistrar registrar : pipelineRunnerRegistrars) { + for (Class> klass : registrar.getPipelineRunners()) { + String runnerName = klass.getSimpleName().toLowerCase(); + builder.put(runnerName, klass); + if (runnerName.endsWith("runner")) { + builder.put(runnerName.substring(0, runnerName.length() - "Runner".length()), klass); + } + } + } + supportedPipelineRunners = builder.build(); + initializeRegistry(loader); + } + + /** + * Load and register the list of all classes that extend PipelineOptions. + */ + private void initializeRegistry(final ClassLoader loader) { + register(PipelineOptions.class); + Set pipelineOptionsRegistrars = + Sets.newTreeSet(ReflectHelpers.ObjectsClassComparator.INSTANCE); + pipelineOptionsRegistrars.addAll( + Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class, loader))); + for (PipelineOptionsRegistrar registrar : pipelineOptionsRegistrars) { + for (Class klass : registrar.getPipelineOptions()) { + register(klass); + } + } + } + + private synchronized void register(Class iface) { + checkNotNull(iface); + checkArgument(iface.isInterface(), "Only interface types are supported."); + + if (registeredOptions.contains(iface)) { + return; + } + validateWellFormed(iface); + registeredOptions.add(iface); + } + + private Registration validateWellFormed(Class iface) { + return validateWellFormed(iface, registeredOptions); + } + + @VisibleForTesting + Set getSupportedRunners() { + ImmutableSortedSet.Builder supportedRunners = ImmutableSortedSet.naturalOrder(); + for (Class> runner : supportedPipelineRunners.values()) { + supportedRunners.add(runner.getSimpleName()); + } + return supportedRunners.build(); + } + + @VisibleForTesting + Map>> getSupportedPipelineRunners() { + return supportedPipelineRunners; + } + + /** + * Validates that the interface conforms to the following: + *
    + *
  • Every inherited interface of {@code iface} must extend PipelineOptions except for + * PipelineOptions itself. + *
  • Any property with the same name must have the same return type for all derived + * interfaces of {@link PipelineOptions}. + *
  • Every bean property of any interface derived from {@link PipelineOptions} must have a + * getter and setter method. + *
  • Every method must conform to being a getter or setter for a JavaBean. + *
  • The derived interface of {@link PipelineOptions} must be composable + * with every interface part of allPipelineOptionsClasses. + *
  • Only getters may be annotated with {@link JsonIgnore @JsonIgnore}. + *
  • If any getter is annotated with {@link JsonIgnore @JsonIgnore}, then all getters for + * this property must be annotated with {@link JsonIgnore @JsonIgnore}. + *
+ * + * @param iface The interface to validate. + * @param validatedPipelineOptionsInterfaces The set of validated pipeline options interfaces to + * validate against. + * @return A registration record containing the proxy class and bean info for iface. + */ + synchronized Registration validateWellFormed( + Class iface, + Set> validatedPipelineOptionsInterfaces) { + checkArgument(iface.isInterface(), "Only interface types are supported."); + + // Validate that every inherited interface must extend PipelineOptions except for + // PipelineOptions itself. + validateInheritedInterfacesExtendPipelineOptions(iface); + + @SuppressWarnings("unchecked") + Set> combinedPipelineOptionsInterfaces = + FluentIterable.from(validatedPipelineOptionsInterfaces).append(iface).toSet(); + // Validate that the view of all currently passed in options classes is well formed. + if (!combinedCache.containsKey(combinedPipelineOptionsInterfaces)) { + @SuppressWarnings("unchecked") + Class allProxyClass = + (Class) Proxy.getProxyClass(ReflectHelpers.findClassLoader(), + combinedPipelineOptionsInterfaces.toArray(EMPTY_CLASS_ARRAY)); + try { + List propertyDescriptors = + validateClass(iface, validatedPipelineOptionsInterfaces, allProxyClass); + combinedCache.put(combinedPipelineOptionsInterfaces, + new Registration<>(allProxyClass, propertyDescriptors)); + } catch (IntrospectionException e) { + throw new RuntimeException(e); + } + } + + // Validate that the local view of the class is well formed. + if (!interfaceCache.containsKey(iface)) { + @SuppressWarnings({"rawtypes", "unchecked"}) + Class proxyClass = (Class) Proxy.getProxyClass( + ReflectHelpers.findClassLoader(), new Class[] {iface}); + try { + List propertyDescriptors = + validateClass(iface, validatedPipelineOptionsInterfaces, proxyClass); + interfaceCache.put(iface, + new Registration<>(proxyClass, propertyDescriptors)); + } catch (IntrospectionException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("unchecked") + Registration result = (Registration) interfaceCache.get(iface); + return result; + } + + List getPropertyDescriptors( + Set> interfaces) { + return combinedCache.get(interfaces).getPropertyDescriptors(); } - return supportedRunners.build(); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java index 454cb23538d8..a127fd15e256 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java @@ -220,7 +220,7 @@ synchronized T as(Class iface) { checkArgument(iface.isInterface(), "Not an interface: %s", iface); if (!interfaceToProxyCache.containsKey(iface)) { Registration registration = - PipelineOptionsFactory.validateWellFormed(iface, knownInterfaces); + PipelineOptionsFactory.CACHE.get().validateWellFormed(iface, knownInterfaces); List propertyDescriptors = registration.getPropertyDescriptors(); Class proxyClass = registration.getProxyClass(); gettersToPropertyNames.putAll(generateGettersToPropertyNames(propertyDescriptors)); @@ -641,8 +641,9 @@ public void serialize(PipelineOptions value, JsonGenerator jgen, SerializerProvi // the last serialization of this PipelineOptions and then verify that // they are all serializable. Map filteredOptions = Maps.newHashMap(handler.options); - removeIgnoredOptions(handler.knownInterfaces, filteredOptions); - ensureSerializable(handler.knownInterfaces, filteredOptions); + PipelineOptionsFactory.Cache cache = PipelineOptionsFactory.CACHE.get(); + removeIgnoredOptions(cache, handler.knownInterfaces, filteredOptions); + ensureSerializable(cache, handler.knownInterfaces, filteredOptions); // Now we create the map of serializable options by taking the original // set of serialized options (if any) and updating them with any properties @@ -676,6 +677,7 @@ public void serialize(PipelineOptions value, JsonGenerator jgen, SerializerProvi * {@link JsonIgnore @JsonIgnore} from the passed in options using the passed in interfaces. */ private void removeIgnoredOptions( + PipelineOptionsFactory.Cache cache, Set> interfaces, Map options) { // Find all the method names that are annotated with JSON ignore. Set jsonIgnoreMethodNames = @@ -685,8 +687,7 @@ private void removeIgnoredOptions( .toSet(); // Remove all options that have the same method name as the descriptor. - for (PropertyDescriptor descriptor - : PipelineOptionsFactory.getPropertyDescriptors(interfaces)) { + for (PropertyDescriptor descriptor : cache.getPropertyDescriptors(interfaces)) { if (jsonIgnoreMethodNames.contains(descriptor.getReadMethod().getName())) { options.remove(descriptor.getName()); } @@ -697,12 +698,13 @@ private void removeIgnoredOptions( * We use an {@link ObjectMapper} to verify that the passed in options are serializable * and deserializable. */ - private void ensureSerializable(Set> interfaces, + private void ensureSerializable( + PipelineOptionsFactory.Cache cache, + Set> interfaces, Map options) throws IOException { // Construct a map from property name to the return type of the getter. Map propertyToReturnType = Maps.newHashMap(); - for (PropertyDescriptor descriptor - : PipelineOptionsFactory.getPropertyDescriptors(interfaces)) { + for (PropertyDescriptor descriptor : cache.getPropertyDescriptors(interfaces)) { if (descriptor.getReadMethod() != null) { propertyToReturnType.put(descriptor.getName(), descriptor.getReadMethod().getGenericReturnType()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 15e6df8ba226..246c45041465 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -495,27 +495,25 @@ private void setOutput(POutput output) { for (PValue outputValue : output.expand().values()) { outputProducers.add(getProducer(outputValue)); } - if (outputProducers.contains(this)) { - if (!parts.isEmpty() || outputProducers.size() > 1) { - Set otherProducerNames = new HashSet<>(); - for (Node outputProducer : outputProducers) { - if (outputProducer != this) { - otherProducerNames.add(outputProducer.getFullName()); - } + if (outputProducers.contains(this) && (!parts.isEmpty() || outputProducers.size() > 1)) { + Set otherProducerNames = new HashSet<>(); + for (Node outputProducer : outputProducers) { + if (outputProducer != this) { + otherProducerNames.add(outputProducer.getFullName()); } - throw new IllegalArgumentException( - String.format( - "Output of composite transform [%s] contains a primitive %s produced by it. " - + "Only primitive transforms are permitted to produce primitive outputs." - + "%n Outputs: %s" - + "%n Other Producers: %s" - + "%n Components: %s", - getFullName(), - POutput.class.getSimpleName(), - output.expand(), - otherProducerNames, - parts)); } + throw new IllegalArgumentException( + String.format( + "Output of composite transform [%s] contains a primitive %s produced by it. " + + "Only primitive transforms are permitted to produce primitive outputs." + + "%n Outputs: %s" + + "%n Other Producers: %s" + + "%n Components: %s", + getFullName(), + POutput.class.getSimpleName(), + output.expand(), + otherProducerNames, + parts)); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 88c64543484d..0cda9c528134 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -286,10 +286,8 @@ public ResultT match(Cases cases) { @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { - if (this.coder == null) { - if (coders[0] != null) { - this.coder = (Coder) coders[0]; - } + if (this.coder == null && coders[0] != null) { + this.coder = (Coder) coders[0]; } } @@ -355,10 +353,8 @@ public ResultT match(Cases cases) { @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { - if (this.accumCoder == null) { - if (coders[1] != null) { - this.accumCoder = (Coder) coders[1]; - } + if (this.accumCoder == null && coders[1] != null) { + this.accumCoder = (Coder) coders[1]; } } @@ -434,10 +430,8 @@ public ResultT match(Cases cases) { @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { - if (this.accumCoder == null) { - if (coders[2] != null) { - this.accumCoder = (Coder) coders[2]; - } + if (this.accumCoder == null && coders[2] != null) { + this.accumCoder = (Coder) coders[2]; } } @@ -506,10 +500,8 @@ public ResultT match(Cases cases) { @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { - if (this.elemCoder == null) { - if (coders[0] != null) { - this.elemCoder = (Coder) coders[0]; - } + if (this.elemCoder == null && coders[0] != null) { + this.elemCoder = (Coder) coders[0]; } } @@ -567,15 +559,11 @@ public ResultT match(Cases cases) { @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { - if (this.keyCoder == null) { - if (coders[0] != null) { - this.keyCoder = (Coder) coders[0]; - } + if (this.keyCoder == null && coders[0] != null) { + this.keyCoder = (Coder) coders[0]; } - if (this.valueCoder == null) { - if (coders[1] != null) { - this.valueCoder = (Coder) coders[1]; - } + if (this.valueCoder == null && coders[1] != null) { + this.valueCoder = (Coder) coders[1]; } } @@ -636,10 +624,8 @@ public ResultT match(Cases cases) { @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { - if (this.elemCoder == null) { - if (coders[0] != null) { - this.elemCoder = (Coder) coders[0]; - } + if (this.elemCoder == null && coders[0] != null) { + this.elemCoder = (Coder) coders[0]; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java index 55e958099f78..dea492f75e19 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java @@ -142,7 +142,7 @@ public final Builder addElements( */ public Builder advanceWatermarkTo(Instant newWatermark) { checkArgument( - newWatermark.isAfter(currentWatermark), "The watermark must monotonically advance"); + !newWatermark.isBefore(currentWatermark), "The watermark must monotonically advance"); checkArgument( newWatermark.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE), "The Watermark cannot progress beyond the maximum. Got: %s. Maximum: %s", diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java index 0e370f11bc35..eeb32a732b54 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java @@ -420,7 +420,8 @@ private QuantileState( this.numQuantiles = numQuantiles; this.numBuffers = numBuffers; this.bufferSize = bufferSize; - this.buffers = new PriorityQueue<>(numBuffers + 1); + this.buffers = new PriorityQueue<>(numBuffers + 1, + (q1, q2) -> Integer.compare(q1.level, q2.level)); this.min = min; this.max = max; this.unbufferedElements.addAll(unbufferedElements); @@ -620,7 +621,7 @@ public List extractOutput() { /** * A single buffer in the sense of the referenced algorithm. */ - private static class QuantileBuffer implements Comparable> { + private static class QuantileBuffer { private int level; private long weight; private List elements; @@ -635,11 +636,6 @@ public QuantileBuffer(int level, long weight, List elements) { this.elements = elements; } - @Override - public int compareTo(QuantileBuffer other) { - return this.level - other.level; - } - @Override public String toString() { return "QuantileBuffer[" diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index 724c98f5cce2..6278d312d46e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -511,7 +511,9 @@ public Coder getDefaultOutputCoder(CoderRegistry registry, Coder inputCode public static class Holder { @Nullable private V value; private boolean present; + private Holder() { } + private Holder(V value) { set(value); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Distinct.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Distinct.java index 0f8207f6aee0..5ab59152ed68 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Distinct.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Distinct.java @@ -27,8 +27,6 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.WindowingStrategy; import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * {@code Distinct} takes a {@code PCollection} and returns a {@code PCollection} that has @@ -64,7 +62,6 @@ * @param the type of the elements of the input and output {@code PCollection}s */ public class Distinct extends PTransform, PCollection> { - private static final Logger LOG = LoggerFactory.getLogger(Distinct.class); /** * Returns a {@code Distinct} {@code PTransform}. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 9f8dd45d1105..a3cdc88d46dc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -33,6 +33,8 @@ import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; @@ -48,24 +50,22 @@ import org.joda.time.Instant; /** - * The argument to {@link ParDo} providing the code to use to process - * elements of the input - * {@link org.apache.beam.sdk.values.PCollection}. + * The argument to {@link ParDo} providing the code to use to process elements of the input {@link + * org.apache.beam.sdk.values.PCollection}. * - *

See {@link ParDo} for more explanation, examples of use, and - * discussion of constraints on {@code DoFn}s, including their - * serializability, lack of access to global shared mutable state, + *

See {@link ParDo} for more explanation, examples of use, and discussion of constraints on + * {@code DoFn}s, including their serializability, lack of access to global shared mutable state, * requirements for failure tolerance, and benefits of optimization. * - *

{@code DoFn}s can be tested in a particular - * {@code Pipeline} by running that {@code Pipeline} on sample input - * and then checking its output. Unit testing of a {@code DoFn}, - * separately from any {@code ParDo} transform or {@code Pipeline}, - * can be done via the {@link DoFnTester} harness. + *

{@link DoFn DoFns} can be tested by using {@link TestPipeline}. You can verify their + * functional correctness in a local test using the {@code DirectRunner} as well as running + * integration tests with your production runner of choice. Typically, you can generate the input + * data using {@link Create#of} or other transforms. However, if you need to test the behavior of + * {@link StartBundle} and {@link FinishBundle} with particular bundle boundaries, you can use + * {@link TestStream}. * - *

Implementations must define a method annotated with {@link ProcessElement} - * that satisfies the requirements described there. See the {@link ProcessElement} - * for details. + *

Implementations must define a method annotated with {@link ProcessElement} that satisfies the + * requirements described there. See the {@link ProcessElement} for details. * *

Example usage: * @@ -89,8 +89,7 @@ public abstract class DoFn implements Serializable, HasDisplayD public abstract class StartBundleContext { /** * Returns the {@code PipelineOptions} specified with the {@link - * org.apache.beam.sdk.PipelineRunner} invoking this {@code DoFn}. The {@code - * PipelineOptions} will be the default running via {@link DoFnTester}. + * org.apache.beam.sdk.PipelineRunner} invoking this {@code DoFn}. */ public abstract PipelineOptions getPipelineOptions(); } @@ -101,8 +100,7 @@ public abstract class StartBundleContext { public abstract class FinishBundleContext { /** * Returns the {@code PipelineOptions} specified with the {@link - * org.apache.beam.sdk.PipelineRunner} invoking this {@code DoFn}. The {@code - * PipelineOptions} will be the default running via {@link DoFnTester}. + * org.apache.beam.sdk.PipelineRunner} invoking this {@code DoFn}. */ public abstract PipelineOptions getPipelineOptions(); @@ -137,8 +135,7 @@ public abstract void output( public abstract class WindowedContext { /** * Returns the {@code PipelineOptions} specified with the {@link - * org.apache.beam.sdk.PipelineRunner} invoking this {@code DoFn}. The {@code - * PipelineOptions} will be the default running via {@link DoFnTester}. + * org.apache.beam.sdk.PipelineRunner} invoking this {@code DoFn}. */ public abstract PipelineOptions getPipelineOptions(); @@ -480,8 +477,19 @@ public interface OutputReceiver { } /** - * Annotation for the method to use to prepare an instance for processing bundles of elements. The - * method annotated with this must satisfy the following constraints + * Annotation for the method to use to prepare an instance for processing bundles of elements. + * + *

This is a good place to initialize transient in-memory resources, such as network + * connections. The resources can then be disposed in {@link Teardown}. + * + *

This is not a good place to perform external side-effects that later need cleanup, + * e.g. creating temporary files on distributed filesystems, starting VMs, or initiating data + * export jobs. Such logic must be instead implemented purely via {@link StartBundle}, + * {@link ProcessElement} and {@link FinishBundle} methods, references to the objects + * requiring cleanup must be passed as {@link PCollection} elements, and they must be cleaned + * up via regular Beam transforms, e.g. see the {@link Wait} transform. + * + *

The method annotated with this must satisfy the following constraints: *

    *
  • It must have zero arguments. *
@@ -603,11 +611,37 @@ public interface OutputReceiver { @Target(ElementType.METHOD) public @interface FinishBundle {} - /** - * Annotation for the method to use to clean up this instance after processing bundles of - * elements. No other method will be called after a call to the annotated method is made. - * The method annotated with this must satisfy the following constraint: + * Annotation for the method to use to clean up this instance before it is discarded. No other + * method will be called after a call to the annotated method is made. + * + *

A runner will do its best to call this method on any given instance to prevent leaks of + * transient resources, however, there may be situations where this is impossible (e.g. process + * crash, hardware failure, etc.) or unnecessary (e.g. the pipeline is shutting down and the + * process is about to be killed anyway, so all transient resources will be released + * automatically by the OS). In these cases, the call may not happen. It will also not be retried, + * because in such situations the DoFn instance no longer exists, so there's no instance to + * retry it on. + * + *

Thus, all work that depends on input elements, and all externally important side effects, + * must be performed in the {@link ProcessElement} or {@link FinishBundle} methods. + * + *

Example things that are a good idea to do in this method: + *

    + *
  • Close a network connection that was opened in {@link Setup} + *
  • Shut down a helper process that was started in {@link Setup} + *
+ * + *

Example things that MUST NOT be done in this method: + *

    + *
  • Flushing a batch of buffered records to a database: this must be done in + * {@link FinishBundle}. + *
  • Deleting temporary files on a distributed filesystem: this must be done + * using the pipeline structure, e.g. using the {@link Wait} transform. + *
+ * + *

The method annotated with this must satisfy the following constraint: + * *

    *
  • It must have zero arguments. *
@@ -615,8 +649,7 @@ public interface OutputReceiver { @Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) - public @interface Teardown { - } + public @interface Teardown {} /** * Annotation for the method that maps an element to an initial restriction for a For example: - * - *
 {@code
- * DoFn fn = ...;
- *
- * DoFnTester fnTester = DoFnTester.of(fn);
- *
- * // Set arguments shared across all bundles:
- * fnTester.setSideInputs(...);      // If fn takes side inputs.
- * fnTester.setOutputTags(...);  // If fn writes to more than one output.
- *
- * // Process a bundle containing a single input element:
- * Input testInput = ...;
- * List testOutputs = fnTester.processBundle(testInput);
- * Assert.assertThat(testOutputs, Matchers.hasItems(...));
- *
- * // Process a bigger bundle:
- * Assert.assertThat(fnTester.processBundle(i1, i2, ...), Matchers.hasItems(...));
- * } 
- * - * @param the type of the {@link DoFn}'s (main) input elements - * @param the type of the {@link DoFn}'s (main) output elements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ +@Deprecated public class DoFnTester implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(DoFnTester.class); + /** - * Returns a {@code DoFnTester} supporting unit-testing of the given - * {@link DoFn}. By default, uses {@link CloningBehavior#CLONE_ONCE}. - * - *

The only supported extra parameter of the {@link DoFn.ProcessElement} method is - * {@link BoundedWindow}. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ @SuppressWarnings("unchecked") + @Deprecated public static DoFnTester of(DoFn fn) { checkNotNull(fn, "fn can't be null"); + LOG.warn( + "Your tests use DoFnTester, which may not exercise DoFns correctly. " + + "Please use TestPipeline instead."); return new DoFnTester<>(fn); } /** - * Registers the tuple of values of the side input {@link PCollectionView}s to - * pass to the {@link DoFn} under test. - * - *

Resets the state of this {@link DoFnTester}. - * - *

If this isn't called, {@code DoFnTester} assumes the - * {@link DoFn} takes no side inputs. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void setSideInputs(Map, Map> sideInputs) { checkState( state == State.UNINITIALIZED, @@ -110,15 +88,9 @@ public void setSideInputs(Map, Map> sideInp } /** - * Registers the values of a side input {@link PCollectionView} to pass to the {@link DoFn} - * under test. - * - *

The provided value is the final value of the side input in the specified window, not - * the value of the input PCollection in that window. - * - *

If this isn't called, {@code DoFnTester} will return the default value for any side input - * that is used. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void setSideInput(PCollectionView sideInput, BoundedWindow window, T value) { checkState( state == State.UNINITIALIZED, @@ -132,14 +104,18 @@ public void setSideInput(PCollectionView sideInput, BoundedWindow window, windowValues.put(window, value); } + /** + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. + */ + @Deprecated public PipelineOptions getPipelineOptions() { return options; } /** - * When a {@link DoFnTester} should clone the {@link DoFn} under test and how it should manage - * the lifecycle of the {@link DoFn}. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public enum CloningBehavior { /** * Clone the {@link DoFn} and call {@link DoFn.Setup} every time a bundle starts; call {@link @@ -159,26 +135,26 @@ public enum CloningBehavior { } /** - * Instruct this {@link DoFnTester} whether or not to clone the {@link DoFn} under test. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void setCloningBehavior(CloningBehavior newValue) { checkState(state == State.UNINITIALIZED, "Wrong state: %s", state); this.cloningBehavior = newValue; } /** - * Indicates whether this {@link DoFnTester} will clone the {@link DoFn} under test. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public CloningBehavior getCloningBehavior() { return cloningBehavior; } /** - * A convenience operation that first calls {@link #startBundle}, - * then calls {@link #processElement} on each of the input elements, then - * calls {@link #finishBundle}, then returns the result of - * {@link #takeOutputElements}. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List processBundle(Iterable inputElements) throws Exception { startBundle(); for (InputT inputElement : inputElements) { @@ -189,27 +165,18 @@ public List processBundle(Iterable inputElements) th } /** - * A convenience method for testing {@link DoFn DoFns} with bundles of elements. - * Logic proceeds as follows: - * - *

    - *
  1. Calls {@link #startBundle}.
  2. - *
  3. Calls {@link #processElement} on each of the arguments.
  4. - *
  5. Calls {@link #finishBundle}.
  6. - *
  7. Returns the result of {@link #takeOutputElements}.
  8. - *
+ * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated @SafeVarargs public final List processBundle(InputT... inputElements) throws Exception { return processBundle(Arrays.asList(inputElements)); } /** - * Calls the {@link DoFn.StartBundle} method on the {@link DoFn} under test. - * - *

If needed, first creates a fresh instance of the {@link DoFn} under test and calls - * {@link DoFn.Setup}. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void startBundle() throws Exception { checkState( state == State.UNINITIALIZED || state == State.BUNDLE_FINISHED, @@ -237,28 +204,17 @@ private static void unwrapUserCodeException(UserCodeException e) throws Exceptio } /** - * Calls the {@link DoFn.ProcessElement} method on the {@link DoFn} under test, in a - * context where {@link DoFn.ProcessContext#element} returns the - * given element and the element is in the global window. - * - *

Will call {@link #startBundle} automatically, if it hasn't - * already been called. - * - * @throws IllegalStateException if the {@code DoFn} under test has already - * been finished + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void processElement(InputT element) throws Exception { processTimestampedElement(TimestampedValue.atMinimumTimestamp(element)); } /** - * Calls {@link DoFn.ProcessElement} on the {@code DoFn} under test, in a - * context where {@link DoFn.ProcessContext#element} returns the - * given element and timestamp and the element is in the global window. - * - *

Will call {@link #startBundle} automatically, if it hasn't - * already been called. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void processTimestampedElement(TimestampedValue element) throws Exception { checkNotNull(element, "Timestamped element cannot be null"); processWindowedElement( @@ -266,13 +222,9 @@ public void processTimestampedElement(TimestampedValue element) throws E } /** - * Calls {@link DoFn.ProcessElement} on the {@code DoFn} under test, in a - * context where {@link DoFn.ProcessContext#element} returns the - * given element and timestamp and the element is in the given window. - * - *

Will call {@link #startBundle} automatically, if it hasn't - * already been called. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void processWindowedElement( InputT element, Instant timestamp, final BoundedWindow window) throws Exception { if (state != State.BUNDLE_STARTED) { @@ -319,7 +271,7 @@ public OnTimerContext onTimerContext(DoFn doFn) { } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException( "Not expected to access RestrictionTracker from a regular DoFn in DoFnTester"); } @@ -340,15 +292,9 @@ public Timer timer(String timerId) { } /** - * Calls the {@link DoFn.FinishBundle} method of the {@link DoFn} under test. - * - *

If {@link #setCloningBehavior} was called with {@link CloningBehavior#CLONE_PER_BUNDLE}, - * then also calls {@link DoFn.Teardown} on the {@link DoFn}, and it will be cloned and - * {@link DoFn.Setup} again when processing the next bundle. - * - * @throws IllegalStateException if {@link DoFn.FinishBundle} has already been called - * for this bundle. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void finishBundle() throws Exception { checkState( state == State.BUNDLE_STARTED, @@ -370,14 +316,9 @@ public void finishBundle() throws Exception { } /** - * Returns the elements output so far to the main output. Does not - * clear them, so subsequent calls will continue to include these - * elements. - * - * @see #takeOutputElements - * @see #clearOutputElements - * + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List peekOutputElements() { return peekOutputElementsWithTimestamp() .stream() @@ -386,14 +327,9 @@ public List peekOutputElements() { } /** - * Returns the elements output so far to the main output with associated timestamps. Does not - * clear them, so subsequent calls will continue to include these. - * elements. - * - * @see #takeOutputElementsWithTimestamp - * @see #clearOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ - @Experimental + @Deprecated public List> peekOutputElementsWithTimestamp() { // TODO: Should we return an unmodifiable list? return getImmutableOutput(mainOutputTag) @@ -403,17 +339,17 @@ public List> peekOutputElementsWithTimestamp() { } /** - * Returns the elements output so far to the main output in the provided window with associated - * timestamps. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List> peekOutputElementsInWindow(BoundedWindow window) { return peekOutputElementsInWindow(mainOutputTag, window); } /** - * Returns the elements output so far to the specified output in the provided window with - * associated timestamps. + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List> peekOutputElementsInWindow( TupleTag tag, BoundedWindow window) { @@ -427,20 +363,17 @@ public List> peekOutputElementsInWindow( } /** - * Clears the record of the elements output so far to the main output. - * - * @see #peekOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void clearOutputElements() { getMutableOutput(mainOutputTag).clear(); } /** - * Returns the elements output so far to the main output. - * Clears the list so these elements don't appear in future calls. - * - * @see #peekOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List takeOutputElements() { List resultElems = new ArrayList<>(peekOutputElements()); clearOutputElements(); @@ -448,14 +381,9 @@ public List takeOutputElements() { } /** - * Returns the elements output so far to the main output with associated timestamps. - * Clears the list so these elements don't appear in future calls. - * - * @see #peekOutputElementsWithTimestamp - * @see #takeOutputElements - * @see #clearOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ - @Experimental + @Deprecated public List> takeOutputElementsWithTimestamp() { List> resultElems = new ArrayList<>(peekOutputElementsWithTimestamp()); @@ -464,13 +392,9 @@ public List> takeOutputElementsWithTimestamp() { } /** - * Returns the elements output so far to the output with the - * given tag. Does not clear them, so subsequent calls will - * continue to include these elements. - * - * @see #takeOutputElements - * @see #clearOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List peekOutputElements(TupleTag tag) { // TODO: Should we return an unmodifiable list? return getImmutableOutput(tag) @@ -480,20 +404,17 @@ public List peekOutputElements(TupleTag tag) { } /** - * Clears the record of the elements output so far to the output with the given tag. - * - * @see #peekOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public void clearOutputElements(TupleTag tag) { getMutableOutput(tag).clear(); } /** - * Returns the elements output so far to the output with the given tag. - * Clears the list so these elements don't appear in future calls. - * - * @see #peekOutputElements + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. */ + @Deprecated public List takeOutputElements(TupleTag tag) { List resultElems = new ArrayList<>(peekOutputElements(tag)); clearOutputElements(tag); @@ -506,6 +427,10 @@ private List> getImmutableOutput(TupleTag tag) { return ImmutableList.copyOf(MoreObjects.firstNonNull(elems, Collections.emptyList())); } + /** + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. + */ + @Deprecated @SuppressWarnings({"unchecked", "rawtypes"}) public List> getMutableOutput(TupleTag tag) { List> outputList = (List) getOutputs().get(tag); @@ -516,6 +441,10 @@ public List> getMutableOutput(TupleTag tag) { return outputList; } + /** + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. + */ + @Deprecated public TupleTag getMainOutputTag() { return mainOutputTag; } @@ -556,6 +485,10 @@ public void output(TupleTag tag, T output, Instant timestamp, BoundedWind } } + /** + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. + */ + @Deprecated public DoFn.ProcessContext createProcessContext( ValueInSingleWindow element) { return new TestProcessContext(element); @@ -644,6 +577,10 @@ public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp } } + /** + * @deprecated Use {@link TestPipeline} with the {@code DirectRunner}. + */ + @Deprecated @Override public void close() throws Exception { if (state == State.BUNDLE_STARTED) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Impulse.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Impulse.java new file mode 100644 index 000000000000..bef2b803afb2 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Impulse.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.transforms; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.WindowingStrategy; + +/** + * For internal use only; no backwards-compatibility guarantees. + * + *

A {@link PTransform} which produces a single empty byte array at the minimum timestamp in the + * {@link GlobalWindow}. + * + *

Users should instead use {@link Create} or another {@link Read} transform to begin consuming + * elements. + */ +@Internal +public class Impulse extends PTransform> { + /** + * Create a new {@link Impulse} {@link PTransform}. + */ + // TODO: Make public and implement the default expansion of Read with Impulse -> ParDo + static Impulse create() { + return new Impulse(); + } + + private Impulse() {} + + @Override + public PCollection expand(PBegin input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + WindowingStrategy.globalDefault(), + IsBounded.BOUNDED, + ByteArrayCoder.of()); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 5e0e3abe698e..c31c4950e4fb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -101,6 +101,11 @@ * provided, will be called on the discarded instance. * * + *

Note also that calls to {@link DoFn.Teardown} are best effort, and may not be called before a + * {@link DoFn} is discarded in the general case. As a result, use of the {@link DoFn.Teardown} + * method to perform side effects is not appropriate, because the elements that produced the side + * effect will not be replayed in case of failure, and those side effects are permanently lost. + * *

Each of the calls to any of the {@link DoFn DoFn's} processing * methods can produce zero or more output elements. All of the * of output elements from all of the {@link DoFn} instances diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Wait.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Wait.java new file mode 100644 index 000000000000..7e05514084e0 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Wait.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.apache.beam.sdk.transforms.Contextful.fn; +import static org.apache.beam.sdk.transforms.Requirements.requiresSideInputs; + +import com.google.common.collect.Lists; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.transforms.windowing.Never; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; + +/** + * Delays processing of each window in a {@link PCollection} until signaled. + * + *

Given a main {@link PCollection} and a signal {@link PCollection}, produces output identical + * to its main input, but all elements for a window are produced only once that window is closed + * in the signal {@link PCollection}. + * + *

To express the pattern "apply T to X after Y is ready", use {@code + * X.apply(Wait.on(Y)).apply(T)}. + * + *

In particular: returns a {@link PCollection} with contents identical to the input, but delays + * producing elements of the output in window W until the signal's window W closes (i.e. signal's + * watermark passes W.end + signal.allowedLateness). + * + *

In other words, an element of the output at timestamp "t" will be produced only after no more + * elements of the signal can appear with a timestamp below "t". + * + *

Example usage: write a {@link PCollection} to one database and then to another database, + * making sure that writing a window of data to the second database starts only after the respective + * window has been fully written to the first database. + * + *

{@code
+ * PCollection firstWriteResults = data.apply(ParDo.of(...write to first database...));
+ * data.apply(Wait.on(firstWriteResults))
+ *     // Windows of this intermediate PCollection will be processed no earlier than when
+ *     // the respective window of firstWriteResults closes.
+ *     .apply(ParDo.of(...write to second database...));
+ * }
+ * + *

Notes: + * + *

    + *
  • If signal is globally windowed, main input must also be. This typically would be useful + * only in a batch pipeline, because the global window of an infinite PCollection never + * closes, so the wait signal will never be ready. + *
  • Beware that if the signal has large allowed lateness, the wait signal will fire only after + * that lateness elapses, i.e. after the watermark of the signal passes end of the window plus + * allowed lateness. In other words: do not use this with signals that specify a large allowed + * lateness. + *
+ */ +@Experimental +public class Wait { + /** Waits on the given signal collections. */ + public static OnSignal on(PCollection... signals) { + return on(Arrays.asList(signals)); + } + + /** Waits on the given signal collections. */ + public static OnSignal on(List> signals) { + return new OnSignal<>(signals); + } + + /** Implementation of {@link #on}. */ + public static class OnSignal extends PTransform, PCollection> { + private final transient List> signals; + + private OnSignal(List> signals) { + this.signals = signals; + } + + @Override + public PCollection expand(PCollection input) { + List> views = Lists.newArrayList(); + for (int i = 0; i < signals.size(); ++i) { + views.add(signals.get(i).apply("To wait view " + i, new ToWaitView())); + } + + return input.apply( + "Wait", + MapElements.into(input.getCoder().getEncodedTypeDescriptor()) + .via(fn((t, c) -> t, requiresSideInputs(views)))); + } + } + + private static class ToWaitView extends PTransform, PCollectionView> { + @Override + public PCollectionView expand(PCollection input) { + return expandTyped(input); + } + + private PCollectionView expandTyped(PCollection input) { + return input + .apply(Window.configure().triggering(Never.ever())) + .apply(Sample.any(1)) + .apply(View.asList()); + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java index a71d9bbe7340..9eca32f1efc4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java @@ -26,7 +26,7 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Function; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Ordering; @@ -39,8 +39,6 @@ import java.io.OutputStream; import java.io.Serializable; import java.util.Arrays; -import java.util.Collections; -import java.util.LinkedList; import java.util.List; import java.util.Map; import javax.annotation.Nullable; @@ -52,9 +50,9 @@ import org.apache.beam.sdk.coders.DurationCoder; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.SnappyCoder; import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.Contextful.Fn; @@ -715,7 +713,7 @@ public ProcessContinuation process( ProcessContext c, final GrowthTracker tracker) throws Exception { if (!tracker.hasPending() && !tracker.currentRestriction().isOutputComplete) { - LOG.debug("{} - polling input", c.element()); + Instant now = Instant.now(); Growth.PollResult res = spec.getPollFn().getClosure().apply(c.element(), wrapProcessContext(c)); // TODO (https://issues.apache.org/jira/browse/BEAM-2680): @@ -726,25 +724,39 @@ public ProcessContinuation process( int numPending = tracker.addNewAsPending(res); if (numPending > 0) { LOG.info( - "{} - polling returned {} results, of which {} were new. The output is {}.", + "{} - current round of polling took {} ms and returned {} results, " + + "of which {} were new. The output is {}.", c.element(), + new Duration(now, Instant.now()).getMillis(), res.getOutputs().size(), numPending, BoundedWindow.TIMESTAMP_MAX_VALUE.equals(res.getWatermark()) - ? "complete" - : "incomplete"); + ? "final" + : "not yet final"); } } - while (tracker.hasPending()) { + int numEmittedInThisRound = 0; + int numTotalPending = tracker.getNumPending(); + int numPreviouslyEmitted = tracker.currentRestriction().completed.size(); + int numTotalKnown = numPreviouslyEmitted + numTotalPending; + while (true) { c.updateWatermark(tracker.getWatermark()); - - TimestampedValue nextPending = tracker.tryClaimNextPending(); - if (nextPending == null) { - return stop(); + Map.Entry> entry = tracker.getNextPending(); + if (entry == null || !tracker.tryClaim(entry.getKey())) { + break; } + TimestampedValue nextPending = entry.getValue(); c.outputWithTimestamp( KV.of(c.element(), nextPending.getValue()), nextPending.getTimestamp()); + ++numEmittedInThisRound; } + LOG.info( + "{} - emitted {} new results (of {} total known: {} emitted so far, {} more to emit).", + c.element(), + numEmittedInThisRound, + numTotalKnown, + numEmittedInThisRound + numPreviouslyEmitted, + numTotalPending - numEmittedInThisRound); Instant watermark = tracker.getWatermark(); if (watermark != null) { // Null means the poll result did not provide a watermark and there were no new elements, @@ -754,14 +766,18 @@ public ProcessContinuation process( // No more pending outputs - future output will come from more polling, // unless output is complete or termination condition is reached. if (tracker.shouldPollMore()) { + LOG.info( + "{} - emitted all {} known results so far; will resume polling in {} ms", + c.element(), + numTotalKnown, + spec.getPollInterval().getMillis()); return resume().withResumeDelay(spec.getPollInterval()); } return stop(); } private Growth.TerminationCondition getTerminationCondition() { - return ((Growth.TerminationCondition) - spec.getTerminationPerInput()); + return (Growth.TerminationCondition) spec.getTerminationPerInput(); } @GetInitialRestriction @@ -779,8 +795,8 @@ public GrowthTracker newTracker( @GetRestrictionCoder @SuppressWarnings({"unchecked", "rawtypes"}) public Coder> getRestrictionCoder() { - return GrowthStateCoder.of( - outputCoder, (Coder) spec.getTerminationPerInput().getStateCoder()); + return SnappyCoder.of(GrowthStateCoder.of( + outputCoder, (Coder) spec.getTerminationPerInput().getStateCoder())); } } @@ -792,10 +808,10 @@ static class GrowthState { // timestamp is more than X behind the watermark. // As of writing, we don't do this, but preserve the information for forward compatibility // in case of pipeline update. TODO: do this. - private final Map completed; + private final ImmutableMap completed; // Outputs that are known to be present in a poll result, but have not yet been returned // from a ProcessElement call, sorted by timestamp to help smooth watermark progress. - private final List> pending; + private final ImmutableMap> pending; // If true, processing of this restriction should only output "pending". Otherwise, it should // also continue polling. private final boolean isOutputComplete; @@ -805,24 +821,24 @@ static class GrowthState { @Nullable private final Instant pollWatermark; GrowthState(TerminationStateT terminationState) { - this.completed = Collections.emptyMap(); - this.pending = Collections.emptyList(); + this.completed = ImmutableMap.of(); + this.pending = ImmutableMap.of(); this.isOutputComplete = false; this.terminationState = checkNotNull(terminationState); this.pollWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; } GrowthState( - Map completed, - List> pending, + ImmutableMap completed, + ImmutableMap> pending, boolean isOutputComplete, @Nullable TerminationStateT terminationState, @Nullable Instant pollWatermark) { if (!isOutputComplete) { checkNotNull(terminationState); } - this.completed = Collections.unmodifiableMap(completed); - this.pending = Collections.unmodifiableList(pending); + this.completed = completed; + this.pending = pending; this.isOutputComplete = isOutputComplete; this.terminationState = terminationState; this.pollWatermark = pollWatermark; @@ -835,7 +851,7 @@ public String toString(Growth.TerminationCondition termina + " elements>, pending=<" + pending.size() + " elements" - + (pending.isEmpty() ? "" : (", earliest " + pending.get(0))) + + (pending.isEmpty() ? "" : (", earliest " + pending.values().iterator().next())) + ">, isOutputComplete=" + isOutputComplete + ", terminationState=" @@ -848,7 +864,7 @@ public String toString(Growth.TerminationCondition termina @VisibleForTesting static class GrowthTracker - implements RestrictionTracker> { + extends RestrictionTracker, HashCode> { private final Funnel coderFunnel; private final Growth.TerminationCondition terminationCondition; @@ -861,9 +877,9 @@ static class GrowthTracker // Remaining pending outputs; initialized from state.pending (if non-empty) or in // addNewAsPending(); drained via tryClaimNextPending(). - private LinkedList> pending; + private Map> pending; // Outputs that have been claimed in the current ProcessElement call. A prefix of "pending". - private List> claimed = Lists.newArrayList(); + private Map> claimed = Maps.newLinkedHashMap(); private boolean isOutputComplete; @Nullable private TerminationStateT terminationState; @Nullable private Instant pollWatermark; @@ -889,7 +905,7 @@ static class GrowthTracker this.isOutputComplete = state.isOutputComplete; this.pollWatermark = state.pollWatermark; this.terminationState = state.terminationState; - this.pending = Lists.newLinkedList(state.pending); + this.pending = Maps.newLinkedHashMap(state.pending); } @Override @@ -899,27 +915,31 @@ public synchronized GrowthState currentRestric @Override public synchronized GrowthState checkpoint() { + checkState( + !claimed.isEmpty(), "Can't checkpoint before any element was successfully claimed"); + // primary should contain exactly the work claimed in the current ProcessElement call - i.e. // claimed outputs become pending, and it shouldn't poll again. GrowthState primary = new GrowthState<>( state.completed /* completed */, - claimed /* pending */, + ImmutableMap.copyOf(claimed) /* pending */, true /* isOutputComplete */, null /* terminationState */, BoundedWindow.TIMESTAMP_MAX_VALUE /* pollWatermark */); // residual should contain exactly the work *not* claimed in the current ProcessElement call - // unclaimed pending outputs plus future polling outputs. - Map newCompleted = Maps.newHashMap(state.completed); - for (TimestampedValue claimedOutput : claimed) { + ImmutableMap.Builder newCompleted = ImmutableMap.builder(); + newCompleted.putAll(state.completed); + for (Map.Entry> claimedOutput : claimed.entrySet()) { newCompleted.put( - hash128(claimedOutput.getValue()), claimedOutput.getTimestamp()); + claimedOutput.getKey(), claimedOutput.getValue().getTimestamp()); } GrowthState residual = new GrowthState<>( - newCompleted /* completed */, - pending /* pending */, + newCompleted.build() /* completed */, + ImmutableMap.copyOf(pending) /* pending */, isOutputComplete /* isOutputComplete */, terminationState, pollWatermark); @@ -930,7 +950,7 @@ public synchronized GrowthState checkpoint() { this.isOutputComplete = primary.isOutputComplete; this.pollWatermark = primary.pollWatermark; this.terminationState = null; - this.pending = Lists.newLinkedList(); + this.pending = Maps.newLinkedHashMap(); this.shouldStop = true; return residual; @@ -954,16 +974,29 @@ synchronized boolean hasPending() { return !pending.isEmpty(); } + private synchronized int getNumPending() { + return pending.size(); + } + @VisibleForTesting @Nullable - synchronized TimestampedValue tryClaimNextPending() { - if (shouldStop) { + synchronized Map.Entry> getNextPending() { + if (pending.isEmpty()) { return null; } + return pending.entrySet().iterator().next(); + } + + @Override + protected synchronized boolean tryClaimImpl(HashCode hash) { + if (shouldStop) { + return false; + } checkState(!pending.isEmpty(), "No more unclaimed pending outputs"); - TimestampedValue value = pending.removeFirst(); - claimed.add(value); - return value; + TimestampedValue value = pending.remove(hash); + checkArgument(value != null, "Attempted to claim unknown hash %s", hash); + claimed.put(hash, value); + return true; } @VisibleForTesting @@ -999,19 +1032,23 @@ synchronized int addNewAsPending(Growth.PollResult pollResult) { if (!newPending.isEmpty()) { terminationState = terminationCondition.onSeenNewOutput(Instant.now(), terminationState); } - this.pending = - Lists.newLinkedList( - Ordering.natural() - .onResultOf( - (Function, Comparable>) - TimestampedValue::getTimestamp) - .sortedCopy(newPending.values())); + + List>> sortedPending = + Ordering.natural() + .onResultOf( + (Map.Entry> entry) -> + entry.getValue().getTimestamp()) + .sortedCopy(newPending.entrySet()); + this.pending = Maps.newLinkedHashMap(); + for (Map.Entry> entry : sortedPending) { + this.pending.put(entry.getKey(), entry.getValue()); + } // If poll result doesn't provide a watermark, assume that future new outputs may // arrive with about the same timestamps as the current new outputs. if (pollResult.getWatermark() != null) { this.pollWatermark = pollResult.getWatermark(); } else if (!pending.isEmpty()) { - this.pollWatermark = pending.getFirst().getTimestamp(); + this.pollWatermark = pending.values().iterator().next().getTimestamp(); } if (BoundedWindow.TIMESTAMP_MAX_VALUE.equals(pollWatermark)) { isOutputComplete = true; @@ -1026,7 +1063,9 @@ synchronized Instant getWatermark() { // min(watermark for future polling, earliest remaining pending element) return Ordering.natural() .nullsLast() - .min(pollWatermark, pending.isEmpty() ? null : pending.getFirst().getTimestamp()); + .min( + pollWatermark, + pending.isEmpty() ? null : pending.values().iterator().next().getTimestamp()); } @Override @@ -1037,7 +1076,7 @@ public synchronized String toString() { + ", pending=<" + pending.size() + " elements" - + (pending.isEmpty() ? "" : (", earliest " + pending.get(0))) + + (pending.isEmpty() ? "" : (", earliest " + pending.values().iterator().next())) + ">, claimed=<" + claimed.size() + " elements>, isOutputComplete=" @@ -1091,7 +1130,7 @@ GrowthStateCoder of( private final Coder outputCoder; private final Coder> completedCoder; - private final Coder>> pendingCoder; + private final Coder> timestampedOutputCoder; private final Coder terminationStateCoder; private GrowthStateCoder( @@ -1099,14 +1138,18 @@ private GrowthStateCoder( this.outputCoder = outputCoder; this.terminationStateCoder = terminationStateCoder; this.completedCoder = MapCoder.of(HASH_CODE_CODER, INSTANT_CODER); - this.pendingCoder = ListCoder.of(TimestampedValue.TimestampedValueCoder.of(outputCoder)); + this.timestampedOutputCoder = TimestampedValue.TimestampedValueCoder.of(outputCoder); } @Override public void encode(GrowthState value, OutputStream os) throws IOException { completedCoder.encode(value.completed, os); - pendingCoder.encode(value.pending, os); + VarIntCoder.of().encode(value.pending.size(), os); + for (Map.Entry> entry : value.pending.entrySet()) { + HASH_CODE_CODER.encode(entry.getKey(), os); + timestampedOutputCoder.encode(entry.getValue(), os); + } BOOLEAN_CODER.encode(value.isOutputComplete, os); terminationStateCoder.encode(value.terminationState, os); INSTANT_CODER.encode(value.pollWatermark, os); @@ -1115,12 +1158,22 @@ public void encode(GrowthState value, OutputSt @Override public GrowthState decode(InputStream is) throws IOException { Map completed = completedCoder.decode(is); - List> pending = pendingCoder.decode(is); + int numPending = VarIntCoder.of().decode(is); + ImmutableMap.Builder> pending = ImmutableMap.builder(); + for (int i = 0; i < numPending; ++i) { + HashCode hash = HASH_CODE_CODER.decode(is); + TimestampedValue output = timestampedOutputCoder.decode(is); + pending.put(hash, output); + } boolean isOutputComplete = BOOLEAN_CODER.decode(is); TerminationStateT terminationState = terminationStateCoder.decode(is); Instant pollWatermark = INSTANT_CODER.decode(is); return new GrowthState<>( - completed, pending, isOutputComplete, terminationState, pollWatermark); + ImmutableMap.copyOf(completed), + pending.build(), + isOutputComplete, + terminationState, + pollWatermark); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index ec2bf342a590..ddd2c3f1d0a9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -78,7 +78,7 @@ void invokeSplitRestriction( DoFn.OutputReceiver restrictionReceiver); /** Invoke the {@link DoFn.NewTracker} method on the bound {@link DoFn}. */ - > TrackerT invokeNewTracker( + > TrackerT invokeNewTracker( RestrictionT restriction); /** Get the bound {@link DoFn}. */ @@ -124,7 +124,7 @@ interface ArgumentProvider { * If this is a splittable {@link DoFn}, returns the {@link RestrictionTracker} associated with * the current {@link ProcessElement} call. */ - RestrictionTracker restrictionTracker(); + RestrictionTracker restrictionTracker(); /** Returns the state cell for the given {@link StateId}. */ State state(String stateId); @@ -203,7 +203,7 @@ public Timer timer(String timerId) { FakeArgumentProvider.class.getSimpleName())); } - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException( String.format( "Should never call non-overridden methods of %s", diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 96cb43b79189..219e0584403f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -366,6 +366,7 @@ public abstract static class ProcessContextParameter extends Parameter { public abstract static class OnTimerContextParameter extends Parameter { OnTimerContextParameter() {} } + /** * Descriptor for a {@link Parameter} of type {@link BoundedWindow}. * @@ -374,6 +375,7 @@ public abstract static class OnTimerContextParameter extends Parameter { @AutoValue public abstract static class WindowParameter extends Parameter { WindowParameter() {} + public abstract TypeDescriptor windowT(); } @@ -386,6 +388,7 @@ public abstract static class WindowParameter extends Parameter { public abstract static class RestrictionTrackerParameter extends Parameter { // Package visible for AutoValue RestrictionTrackerParameter() {} + public abstract TypeDescriptor trackerT(); } @@ -399,6 +402,7 @@ public abstract static class RestrictionTrackerParameter extends Parameter { public abstract static class StateParameter extends Parameter { // Package visible for AutoValue StateParameter() {} + public abstract StateDeclaration referent(); } @@ -410,6 +414,7 @@ public abstract static class StateParameter extends Parameter { public abstract static class TimerParameter extends Parameter { // Package visible for AutoValue TimerParameter() {} + public abstract TimerDeclaration referent(); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 6d6ed8a4186b..de3acfd67003 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -908,7 +908,7 @@ private static Parameter analyzeExtraParameter( List allowedParamTypes = Arrays.asList( formatType(new TypeDescriptor() {}), - formatType(new TypeDescriptor>() {})); + formatType(new TypeDescriptor>() {})); paramErrors.throwIllegalArgument( "%s is not a valid context parameter. Should be one of %s", formatType(paramT), allowedParamTypes); @@ -1131,9 +1131,9 @@ static DoFnSignature.GetRestrictionCoderMethod analyzeGetRestrictionCoderMethod( * RestrictionT}. */ private static - TypeDescriptor> restrictionTrackerTypeOf( + TypeDescriptor> restrictionTrackerTypeOf( TypeDescriptor restrictionT) { - return new TypeDescriptor>() {}.where( + return new TypeDescriptor>() {}.where( new TypeParameter() {}, restrictionT); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java index 3366dfecaacd..8badd5cc2a37 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java @@ -24,7 +24,7 @@ */ public interface HasDefaultTracker< RestrictionT extends HasDefaultTracker, - TrackerT extends RestrictionTracker> { + TrackerT extends RestrictionTracker> { /** Creates a new tracker for {@code this}. */ TrackerT newTracker(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java index 8ec2c6b6a875..f2d9e5cfd819 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java @@ -30,7 +30,7 @@ * A {@link RestrictionTracker} for claiming offsets in an {@link OffsetRange} in a monotonically * increasing fashion. */ -public class OffsetRangeTracker implements RestrictionTracker { +public class OffsetRangeTracker extends RestrictionTracker { private OffsetRange range; @Nullable private Long lastClaimedOffset = null; @Nullable private Long lastAttemptedOffset = null; @@ -46,11 +46,8 @@ public synchronized OffsetRange currentRestriction() { @Override public synchronized OffsetRange checkpoint() { - if (lastClaimedOffset == null) { - OffsetRange res = range; - range = new OffsetRange(range.getFrom(), range.getFrom()); - return res; - } + checkState( + lastClaimedOffset != null, "Can't checkpoint before any offset was successfully claimed"); OffsetRange res = new OffsetRange(lastClaimedOffset + 1, range.getTo()); this.range = new OffsetRange(range.getFrom(), lastClaimedOffset + 1); return res; @@ -64,7 +61,8 @@ public synchronized OffsetRange checkpoint() { * @return {@code true} if the offset was successfully claimed, {@code false} if it is outside the * current {@link OffsetRange} of this tracker (in that case this operation is a no-op). */ - public synchronized boolean tryClaim(long i) { + @Override + protected synchronized boolean tryClaimImpl(Long i) { checkArgument( lastAttemptedOffset == null || i > lastAttemptedOffset, "Trying to claim offset %s while last attempted was %s", diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java index 8cb0a6bd4baa..8b59f054b96e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java @@ -17,18 +17,81 @@ */ package org.apache.beam.sdk.transforms.splittabledofn; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.transforms.DoFn; /** * Manages concurrent access to the restriction and keeps track of its claimed part for a
splittable {@link DoFn}. */ -public interface RestrictionTracker { +public abstract class RestrictionTracker { + /** Internal interface allowing a runner to observe the calls to {@link #tryClaim}. */ + @Internal + public interface ClaimObserver { + /** Called when {@link #tryClaim} returns true. */ + void onClaimed(PositionT position); + + /** Called when {@link #tryClaim} returns false. */ + void onClaimFailed(PositionT position); + } + + @Nullable private ClaimObserver claimObserver; + + /** + * Sets a {@link ClaimObserver} to be invoked on every call to {@link #tryClaim}. Internal: + * intended only for runner authors. + */ + @Internal + public void setClaimObserver(ClaimObserver claimObserver) { + checkNotNull(claimObserver, "claimObserver"); + checkState(this.claimObserver == null, "A claim observer has already been set"); + this.claimObserver = claimObserver; + } + + /** + * Attempts to claim the block of work in the current restriction identified by the given + * position. + * + *

If this succeeds, the DoFn MUST execute the entire block of work. If this fails: + * + *

    + *
  • {@link DoFn.ProcessElement} MUST return {@link DoFn.ProcessContinuation#stop} without + * performing any additional work or emitting output (note that emitting output or + * performing work from {@link DoFn.ProcessElement} is also not allowed before the first + * call to this method). + *
  • {@link RestrictionTracker#checkDone} MUST succeed. + *
+ * + *

Under the hood, calls {@link #tryClaimImpl} and notifies {@link ClaimObserver} of the + * result. + */ + public final boolean tryClaim(PositionT position) { + if (tryClaimImpl(position)) { + if (claimObserver != null) { + claimObserver.onClaimed(position); + } + return true; + } else { + if (claimObserver != null) { + claimObserver.onClaimFailed(position); + } + return false; + } + } + + /** Tracker-specific implementation of {@link #tryClaim}. */ + @Internal + protected abstract boolean tryClaimImpl(PositionT position); + /** * Returns a restriction accurately describing the full range of work the current {@link * DoFn.ProcessElement} call will do, including already completed work. */ - RestrictionT currentRestriction(); + public abstract RestrictionT currentRestriction(); /** * Signals that the current {@link DoFn.ProcessElement} call should terminate as soon as possible: @@ -37,9 +100,12 @@ public interface RestrictionTracker { * *

Modifies {@link #currentRestriction}. Returns a restriction representing the rest of the * work: the old value of {@link #currentRestriction} is equivalent to the new value and the - * return value of this method combined. Must be called at most once on a given object. + * return value of this method combined. + * + *

Must be called at most once on a given object. Must not be called before the first + * successful {@link #tryClaim} call. */ - RestrictionT checkpoint(); + public abstract RestrictionT checkpoint(); /** * Called by the runner after {@link DoFn.ProcessElement} returns. @@ -47,7 +113,7 @@ public interface RestrictionTracker { *

Must throw an exception with an informative error message, if there is still any unclaimed * work remaining in the restriction. */ - void checkDone() throws IllegalStateException; + public abstract void checkDone() throws IllegalStateException; // TODO: Add the more general splitRemainderAfterFraction() and other methods. } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java index 04218682d6c5..ca410b3de5bd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/MergeOverlappingIntervalWindows.java @@ -64,14 +64,17 @@ public static void mergeWindows(WindowFn.MergeContext c) thro private static class MergeCandidate { @Nullable private IntervalWindow union; private final List parts; + public MergeCandidate() { union = null; parts = new ArrayList<>(); } + public MergeCandidate(IntervalWindow window) { union = window; parts = new ArrayList<>(Arrays.asList(window)); } + public boolean intersects(IntervalWindow window) { return union == null || union.intersects(window); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStream.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStream.java index b0784cad2bc0..c8d3911f923c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStream.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStream.java @@ -17,9 +17,11 @@ */ package org.apache.beam.sdk.util; +import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.util.concurrent.ArrayBlockingQueue; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.sdk.coders.Coder.Context; @@ -65,6 +67,9 @@ */ @NotThreadSafe public class BufferedElementCountingOutputStream extends OutputStream { + private static final int MAX_POOLED = 12; + @VisibleForTesting static final ArrayBlockingQueue BUFFER_POOL = + new ArrayBlockingQueue<>(MAX_POOLED); public static final int DEFAULT_BUFFER_SIZE = 64 * 1024; private final ByteBuffer buffer; private final OutputStream os; @@ -84,12 +89,17 @@ public BufferedElementCountingOutputStream(OutputStream os) { * manner with the given {@code bufferSize}. */ BufferedElementCountingOutputStream(OutputStream os, int bufferSize) { - this.buffer = ByteBuffer.allocate(bufferSize); this.os = os; this.finished = false; this.count = 0; + ByteBuffer buffer = BUFFER_POOL.poll(); + if (buffer == null) { + buffer = ByteBuffer.allocate(bufferSize); + } + this.buffer = buffer; } + /** * Finishes the encoding by flushing any buffered data, * and outputting a final count of 0. @@ -101,6 +111,9 @@ public void finish() throws IOException { flush(); // Finish the stream by stating that there are 0 elements that follow. VarInt.encode(0, os); + if (!BUFFER_POOL.offer(buffer)) { + // The pool is full, we can't store the buffer. We just drop the buffer. + } finished = true; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java index 1e1ab286741b..0da99c3d3183 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java @@ -19,7 +19,6 @@ import com.google.common.base.Throwables; import com.google.common.io.BaseEncoding; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -27,7 +26,6 @@ import java.io.OutputStream; import java.lang.ref.SoftReference; import java.lang.reflect.ParameterizedType; - import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.values.TypeDescriptor; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java new file mode 100644 index 000000000000..8275fadca93b --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressWarnings; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import javax.annotation.Nullable; + +/** + * Utilities to do future programming with Java 8. + * + *

Standards for these utilities: + * + *

  • Always allow thrown exceptions, and they should cause futures to complete exceptionally. + *
  • Always return {@link CompletionStage} as a future value.
  • Return {@link CompletableFuture} + * only to the producer of a future value.
+ */ +public class MoreFutures { + + /** + * Gets the result of the given future. + * + *

This utility is provided so consumers of futures need not even convert to {@link + * CompletableFuture}, an interface that is only suitable for producers of futures. + */ + public static T get(CompletionStage future) + throws InterruptedException, ExecutionException { + return future.toCompletableFuture().get(); + } + + /** + * Gets the result of the given future. + * + *

This utility is provided so consumers of futures need not even convert to {@link + * CompletableFuture}, an interface that is only suitable for producers of futures. + */ + public static T get(CompletionStage future, long duration, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return future.toCompletableFuture().get(duration, unit); + } + + /** + * Indicates whether the future is done. + * + *

This utility is provided so consumers of futures need not even convert to {@link + * CompletableFuture}, an interface that is only suitable for producers of futures. + */ + public static boolean isDone(CompletionStage future) { + return future.toCompletableFuture().isDone(); + } + + /** + * Indicates whether the future is cancelled. + * + *

This utility is provided so consumers of futures need not even convert to {@link + * CompletableFuture}, an interface that is only suitable for producers of futures. + */ + public static boolean isCancelled(CompletionStage future) { + return future.toCompletableFuture().isCancelled(); + } + + /** + * Like {@link CompletableFuture#supplyAsync(Supplier)} but for {@link ThrowingSupplier}. + * + *

If the {@link ThrowingSupplier} throws an exception, the future completes exceptionally. + */ + public static CompletionStage supplyAsync( + ThrowingSupplier supplier, ExecutorService executorService) { + CompletableFuture result = new CompletableFuture<>(); + + CompletionStage wrapper = CompletableFuture.runAsync( + () -> { + try { + result.complete(supplier.get()); + } catch (InterruptedException e) { + result.completeExceptionally(e); + Thread.currentThread().interrupt(); + } catch (Throwable t) { + result.completeExceptionally(t); + } + }, + executorService); + return wrapper.thenCompose(nothing -> result); + } + + /** + * Shorthand for {@link #supplyAsync(ThrowingSupplier, ExecutorService)} using {@link + * ForkJoinPool#commonPool()}. + */ + public static CompletionStage supplyAsync(ThrowingSupplier supplier) { + return supplyAsync(supplier, ForkJoinPool.commonPool()); + } + + /** + * Like {@link CompletableFuture#runAsync} but for {@link ThrowingRunnable}. + * + *

If the {@link ThrowingRunnable} throws an exception, the future completes exceptionally. + */ + public static CompletionStage runAsync( + ThrowingRunnable runnable, ExecutorService executorService) { + CompletableFuture result = new CompletableFuture<>(); + + CompletionStage wrapper = + CompletableFuture.runAsync( + () -> { + try { + runnable.run(); + result.complete(null); + } catch (InterruptedException e) { + result.completeExceptionally(e); + Thread.currentThread().interrupt(); + } catch (Throwable t) { + result.completeExceptionally(t); + } + }, + executorService); + return wrapper.thenCompose(nothing -> result); + } + + /** + * Shorthand for {@link #runAsync(ThrowingRunnable, ExecutorService)} using {@link + * ForkJoinPool#commonPool()}. + */ + public static CompletionStage runAsync(ThrowingRunnable runnable) { + return runAsync(runnable, ForkJoinPool.commonPool()); + } + + /** + * Like {@link CompletableFuture#allOf} but returning the result of constituent futures. + */ + public static CompletionStage> allAsList( + Collection> futures) { + + // CompletableFuture.allOf completes exceptionally if any of the futures do. + // We have to gather the results separately. + CompletionStage blockAndDiscard = + CompletableFuture.allOf(futuresToCompletableFutures(futures)); + + return blockAndDiscard.thenApply( + nothing -> + futures + .stream() + .map(future -> future.toCompletableFuture().join()) + .collect(Collectors.toList())); + } + + /** + * An object that represents either a result or an exceptional termination. + * + *

This is used, for example, in aggregating the results of many future values in {@link + * #allAsList(Collection)}. + */ + @SuppressWarnings(value = "NM_CLASS_NOT_EXCEPTION", + justification = "The class does hold an exception; its name is accurate.") + @AutoValue + public abstract static class ExceptionOrResult { + + /** + * Describes whether the result was an exception. + */ + public enum IsException { + EXCEPTION, + RESULT + } + + public abstract IsException isException(); + + public abstract @Nullable + T getResult(); + + public abstract @Nullable + Throwable getException(); + + public static ExceptionOrResult exception(Throwable throwable) { + return new AutoValue_MoreFutures_ExceptionOrResult(IsException.EXCEPTION, null, throwable); + } + + public static ExceptionOrResult result(T result) { + return new AutoValue_MoreFutures_ExceptionOrResult(IsException.EXCEPTION, result, null); + } + } + + /** + * Like {@link #allAsList} but return a list . + */ + public static CompletionStage>> allAsListWithExceptions( + Collection> futures) { + + // CompletableFuture.allOf completes exceptionally if any of the futures do. + // We have to gather the results separately. + CompletionStage blockAndDiscard = + CompletableFuture.allOf(futuresToCompletableFutures(futures)) + .whenComplete((ignoredValues, arbitraryException) -> { + }); + + return blockAndDiscard.thenApply( + nothing -> + futures + .stream() + .map( + future -> { + // The limited scope of the exceptions wrapped allows CancellationException + // to still be thrown. + try { + return ExceptionOrResult.result(future.toCompletableFuture().join()); + } catch (CompletionException exc) { + return ExceptionOrResult.exception(exc); + } + }) + .collect(Collectors.toList())); + } + + /** + * Helper to convert a list of futures into an array for use in {@link CompletableFuture} vararg + * combinators. + */ + private static CompletableFuture[] futuresToCompletableFutures( + Collection> futures) { + CompletableFuture[] completableFutures = new CompletableFuture[futures.size()]; + int i = 0; + for (CompletionStage future : futures) { + completableFutures[i] = future.toCompletableFuture(); + ++i; + } + return completableFutures; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ThrowingRunnable.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ThrowingRunnable.java new file mode 100644 index 000000000000..7b65de3dd238 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ThrowingRunnable.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +/** Like {@link Runnable} but allowed to throw any exception. */ +@FunctionalInterface +public interface ThrowingRunnable { + void run() throws Exception; +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ThrowingSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ThrowingSupplier.java new file mode 100644 index 000000000000..4d8c43521dbc --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ThrowingSupplier.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import java.util.function.Supplier; + +/** Like {@link Supplier} but allowed to throw any exception. */ +@FunctionalInterface +public interface ThrowingSupplier { + T get() throws Exception; +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/WindowedValue.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/WindowedValue.java index 66ffcc2a534f..53776d31a6c6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/WindowedValue.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/WindowedValue.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.util; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; @@ -50,35 +51,34 @@ */ public abstract class WindowedValue { - /** - * Returns a {@code WindowedValue} with the given value, timestamp, - * and windows. - */ + /** Returns a {@code WindowedValue} with the given value, timestamp, and windows. */ public static WindowedValue of( - T value, - Instant timestamp, - Collection windows, - PaneInfo pane) { - checkNotNull(pane); - - if (windows.size() == 0 && BoundedWindow.TIMESTAMP_MIN_VALUE.equals(timestamp)) { - return valueInEmptyWindows(value, pane); - } else if (windows.size() == 1) { + T value, Instant timestamp, Collection windows, PaneInfo pane) { + checkArgument(pane != null, "WindowedValue requires PaneInfo, but it was null"); + checkArgument(windows.size() > 0, "WindowedValue requires windows, but there were none"); + + if (windows.size() == 1) { return of(value, timestamp, windows.iterator().next(), pane); } else { return new TimestampedValueInMultipleWindows<>(value, timestamp, windows, pane); } } - /** - * Returns a {@code WindowedValue} with the given value, timestamp, and window. - */ + /** @deprecated for use only in compatibility with old broken code */ + @Deprecated + static WindowedValue createWithoutValidation( + T value, Instant timestamp, Collection windows, PaneInfo pane) { + if (windows.size() == 1) { + return of(value, timestamp, windows.iterator().next(), pane); + } else { + return new TimestampedValueInMultipleWindows<>(value, timestamp, windows, pane); + } + } + + /** Returns a {@code WindowedValue} with the given value, timestamp, and window. */ public static WindowedValue of( - T value, - Instant timestamp, - BoundedWindow window, - PaneInfo pane) { - checkNotNull(pane); + T value, Instant timestamp, BoundedWindow window, PaneInfo pane) { + checkArgument(pane != null, "WindowedValue requires PaneInfo, but it was null"); boolean isGlobal = GlobalWindow.INSTANCE.equals(window); if (isGlobal && BoundedWindow.TIMESTAMP_MIN_VALUE.equals(timestamp)) { @@ -107,8 +107,8 @@ public static WindowedValue valueInGlobalWindow(T value, PaneInfo pane) { } /** - * Returns a {@code WindowedValue} with the given value and timestamp, - * {@code GlobalWindow} and default pane. + * Returns a {@code WindowedValue} with the given value and timestamp, {@code GlobalWindow} and + * default pane. */ public static WindowedValue timestampedValueInGlobalWindow(T value, Instant timestamp) { if (BoundedWindow.TIMESTAMP_MIN_VALUE.equals(timestamp)) { @@ -118,54 +118,22 @@ public static WindowedValue timestampedValueInGlobalWindow(T value, Insta } } - /** - * Returns a {@code WindowedValue} with the given value in no windows, and the default timestamp - * and pane. - * - * @deprecated a value in no windows technically is not "in" a PCollection. It is allowed to drop - * it at any point, and benign runner implementation details could cause silent data loss. - */ - @Deprecated - public static WindowedValue valueInEmptyWindows(T value) { - return new ValueInEmptyWindows<>(value, PaneInfo.NO_FIRING); - } - - /** - * Returns a {@code WindowedValue} with the given value in no windows, and the default timestamp - * and the specified pane. - * - * @deprecated a value in no windows technically is not "in" a PCollection. It is allowed to drop - * it at any point, and benign runner implementation details could cause silent data loss. - */ - @Deprecated - public static WindowedValue valueInEmptyWindows(T value, PaneInfo pane) { - return new ValueInEmptyWindows<>(value, pane); - } - /** * Returns a new {@code WindowedValue} that is a copy of this one, but with a different value, * which may have a new type {@code NewT}. */ public abstract WindowedValue withValue(NewT value); - /** - * Returns the value of this {@code WindowedValue}. - */ + /** Returns the value of this {@code WindowedValue}. */ public abstract T getValue(); - /** - * Returns the timestamp of this {@code WindowedValue}. - */ + /** Returns the timestamp of this {@code WindowedValue}. */ public abstract Instant getTimestamp(); - /** - * Returns the windows of this {@code WindowedValue}. - */ + /** Returns the windows of this {@code WindowedValue}. */ public abstract Collection getWindows(); - /** - * Returns the pane of this {@code WindowedValue} in its window. - */ + /** Returns the pane of this {@code WindowedValue} in its window. */ public abstract PaneInfo getPane(); /** @@ -210,8 +178,8 @@ public int hashCode() { Collections.singletonList(GlobalWindow.INSTANCE); /** - * An abstract superclass for implementations of {@link WindowedValue} that stores the value - * and pane info. + * An abstract superclass for implementations of {@link WindowedValue} that stores the value and + * pane info. */ private abstract static class SimpleWindowedValue extends WindowedValue { private final T value; @@ -226,18 +194,15 @@ protected SimpleWindowedValue(T value, PaneInfo pane) { public PaneInfo getPane() { return pane; } + @Override public T getValue() { return value; } } - /** - * The abstract superclass of WindowedValue representations where - * timestamp == MIN. - */ - private abstract static class MinTimestampWindowedValue - extends SimpleWindowedValue { + /** The abstract superclass of WindowedValue representations where timestamp == MIN. */ + private abstract static class MinTimestampWindowedValue extends SimpleWindowedValue { public MinTimestampWindowedValue(T value, PaneInfo pane) { super(value, pane); } @@ -248,12 +213,8 @@ public Instant getTimestamp() { } } - /** - * The representation of a WindowedValue where timestamp == MIN and - * windows == {GlobalWindow}. - */ - private static class ValueInGlobalWindow - extends MinTimestampWindowedValue { + /** The representation of a WindowedValue where timestamp == MIN and windows == {GlobalWindow}. */ + private static class ValueInGlobalWindow extends MinTimestampWindowedValue { public ValueInGlobalWindow(T value, PaneInfo pane) { super(value, pane); } @@ -293,64 +254,11 @@ public String toString() { } } - /** - * The representation of a WindowedValue where timestamp == MIN and windows == {}. - * - * @deprecated a value in no windows technically is not "in" a PCollection. It is allowed to drop - * it at any point, and benign runner implementation details could cause silent data loss. - */ - @Deprecated - private static class ValueInEmptyWindows extends MinTimestampWindowedValue { - public ValueInEmptyWindows(T value, PaneInfo pane) { - super(value, pane); - } - - @Override - public WindowedValue withValue(NewT newValue) { - return new ValueInEmptyWindows<>(newValue, getPane()); - } - - @Override - public Collection getWindows() { - return Collections.emptyList(); - } - - @Override - public boolean equals(Object o) { - if (o instanceof ValueInEmptyWindows) { - ValueInEmptyWindows that = (ValueInEmptyWindows) o; - return Objects.equals(that.getPane(), this.getPane()) - && Objects.equals(that.getValue(), this.getValue()); - } else { - return super.equals(o); - } - } - - @Override - public int hashCode() { - return Objects.hash(getValue(), getPane()); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(getClass()) - .add("value", getValue()) - .add("pane", getPane()) - .toString(); - } - } - - /** - * The abstract superclass of WindowedValue representations where - * timestamp is arbitrary. - */ - private abstract static class TimestampedWindowedValue - extends SimpleWindowedValue { + /** The abstract superclass of WindowedValue representations where timestamp is arbitrary. */ + private abstract static class TimestampedWindowedValue extends SimpleWindowedValue { private final Instant timestamp; - public TimestampedWindowedValue(T value, - Instant timestamp, - PaneInfo pane) { + public TimestampedWindowedValue(T value, Instant timestamp, PaneInfo pane) { super(value, pane); this.timestamp = checkNotNull(timestamp); } @@ -362,14 +270,11 @@ public Instant getTimestamp() { } /** - * The representation of a WindowedValue where timestamp {@code >} - * MIN and windows == {GlobalWindow}. + * The representation of a WindowedValue where timestamp {@code >} MIN and windows == + * {GlobalWindow}. */ - private static class TimestampedValueInGlobalWindow - extends TimestampedWindowedValue { - public TimestampedValueInGlobalWindow(T value, - Instant timestamp, - PaneInfo pane) { + private static class TimestampedValueInGlobalWindow extends TimestampedWindowedValue { + public TimestampedValueInGlobalWindow(T value, Instant timestamp, PaneInfo pane) { super(value, timestamp, pane); } @@ -386,8 +291,7 @@ public Collection getWindows() { @Override public boolean equals(Object o) { if (o instanceof TimestampedValueInGlobalWindow) { - TimestampedValueInGlobalWindow that = - (TimestampedValueInGlobalWindow) o; + TimestampedValueInGlobalWindow that = (TimestampedValueInGlobalWindow) o; // Compare timestamps first as they are most likely to differ. // Also compare timestamps according to millis-since-epoch because otherwise expensive // comparisons are made on their Chronology objects. @@ -416,17 +320,14 @@ public String toString() { } /** - * The representation of a WindowedValue where timestamp is arbitrary and - * windows == a single non-Global window. + * The representation of a WindowedValue where timestamp is arbitrary and windows == a single + * non-Global window. */ - private static class TimestampedValueInSingleWindow - extends TimestampedWindowedValue { + private static class TimestampedValueInSingleWindow extends TimestampedWindowedValue { private final BoundedWindow window; - public TimestampedValueInSingleWindow(T value, - Instant timestamp, - BoundedWindow window, - PaneInfo pane) { + public TimestampedValueInSingleWindow( + T value, Instant timestamp, BoundedWindow window, PaneInfo pane) { super(value, timestamp, pane); this.window = checkNotNull(window); } @@ -444,8 +345,7 @@ public Collection getWindows() { @Override public boolean equals(Object o) { if (o instanceof TimestampedValueInSingleWindow) { - TimestampedValueInSingleWindow that = - (TimestampedValueInSingleWindow) o; + TimestampedValueInSingleWindow that = (TimestampedValueInSingleWindow) o; // Compare timestamps first as they are most likely to differ. // Also compare timestamps according to millis-since-epoch because otherwise expensive // comparisons are made on their Chronology objects. @@ -475,19 +375,12 @@ public String toString() { } } - /** - * The representation of a WindowedValue, excluding the special - * cases captured above. - */ - private static class TimestampedValueInMultipleWindows - extends TimestampedWindowedValue { + /** The representation of a WindowedValue, excluding the special cases captured above. */ + private static class TimestampedValueInMultipleWindows extends TimestampedWindowedValue { private Collection windows; public TimestampedValueInMultipleWindows( - T value, - Instant timestamp, - Collection windows, - PaneInfo pane) { + T value, Instant timestamp, Collection windows, PaneInfo pane) { super(value, timestamp, pane); this.windows = checkNotNull(windows); } @@ -505,8 +398,7 @@ public Collection getWindows() { @Override public boolean equals(Object o) { if (o instanceof TimestampedValueInMultipleWindows) { - TimestampedValueInMultipleWindows that = - (TimestampedValueInMultipleWindows) o; + TimestampedValueInMultipleWindows that = (TimestampedValueInMultipleWindows) o; // Compare timestamps first as they are most likely to differ. // Also compare timestamps according to millis-since-epoch because otherwise expensive // comparisons are made on their Chronology objects. @@ -548,67 +440,54 @@ private void ensureWindowsAreASet() { } } - ///////////////////////////////////////////////////////////////////////////// /** - * Returns the {@code Coder} to use for a {@code WindowedValue}, - * using the given valueCoder and windowCoder. + * Returns the {@code Coder} to use for a {@code WindowedValue}, using the given valueCoder and + * windowCoder. */ public static FullWindowedValueCoder getFullCoder( - Coder valueCoder, - Coder windowCoder) { + Coder valueCoder, Coder windowCoder) { return FullWindowedValueCoder.of(valueCoder, windowCoder); } - /** - * Returns the {@code ValueOnlyCoder} from the given valueCoder. - */ + /** Returns the {@code ValueOnlyCoder} from the given valueCoder. */ public static ValueOnlyWindowedValueCoder getValueOnlyCoder(Coder valueCoder) { return ValueOnlyWindowedValueCoder.of(valueCoder); } - /** - * Abstract class for {@code WindowedValue} coder. - */ - public abstract static class WindowedValueCoder - extends StructuredCoder> { + /** Abstract class for {@code WindowedValue} coder. */ + public abstract static class WindowedValueCoder extends StructuredCoder> { final Coder valueCoder; WindowedValueCoder(Coder valueCoder) { this.valueCoder = checkNotNull(valueCoder); } - /** - * Returns the value coder. - */ + /** Returns the value coder. */ public Coder getValueCoder() { return valueCoder; } /** - * Returns a new {@code WindowedValueCoder} that is a copy of this one, - * but with a different value coder. + * Returns a new {@code WindowedValueCoder} that is a copy of this one, but with a different + * value coder. */ public abstract WindowedValueCoder withValueCoder(Coder valueCoder); } - /** - * Coder for {@code WindowedValue}. - */ + /** Coder for {@code WindowedValue}. */ public static class FullWindowedValueCoder extends WindowedValueCoder { private final Coder windowCoder; // Precompute and cache the coder for a list of windows. private final Coder> windowsCoder; public static FullWindowedValueCoder of( - Coder valueCoder, - Coder windowCoder) { + Coder valueCoder, Coder windowCoder) { return new FullWindowedValueCoder<>(valueCoder, windowCoder); } - FullWindowedValueCoder(Coder valueCoder, - Coder windowCoder) { + FullWindowedValueCoder(Coder valueCoder, Coder windowCoder) { super(valueCoder); this.windowCoder = checkNotNull(windowCoder); // It's not possible to statically type-check correct use of the @@ -642,9 +521,7 @@ public void encode(WindowedValue windowedElem, OutputStream outStream) } @Override - public void encode(WindowedValue windowedElem, - OutputStream outStream, - Context context) + public void encode(WindowedValue windowedElem, OutputStream outStream, Context context) throws CoderException, IOException { InstantCoder.of().encode(windowedElem.getTimestamp(), outStream); windowsCoder.encode(windowedElem.getWindows(), outStream); @@ -661,11 +538,13 @@ public WindowedValue decode(InputStream inStream) throws CoderException, IOEx public WindowedValue decode(InputStream inStream, Context context) throws CoderException, IOException { Instant timestamp = InstantCoder.of().decode(inStream); - Collection windows = - windowsCoder.decode(inStream); + Collection windows = windowsCoder.decode(inStream); PaneInfo pane = PaneInfoCoder.INSTANCE.decode(inStream); T value = valueCoder.decode(inStream, context); - return WindowedValue.of(value, timestamp, windows, pane); + + // Because there are some remaining (incorrect) uses of WindowedValue with no windows, + // we call this deprecated no-validation path when decoding + return WindowedValue.createWithoutValidation(value, timestamp, windows, pane); } @Override @@ -677,8 +556,8 @@ public void verifyDeterministic() throws NonDeterministicException { } @Override - public void registerByteSizeObserver(WindowedValue value, - ElementByteSizeObserver observer) throws Exception { + public void registerByteSizeObserver(WindowedValue value, ElementByteSizeObserver observer) + throws Exception { InstantCoder.of().registerByteSizeObserver(value.getTimestamp(), observer); windowsCoder.registerByteSizeObserver(value.getWindows(), observer); PaneInfoCoder.INSTANCE.registerByteSizeObserver(value.getPane(), observer); @@ -688,8 +567,8 @@ public void registerByteSizeObserver(WindowedValue value, /** * {@inheritDoc}. * - * @return a singleton list containing the {@code valueCoder} of this - * {@link FullWindowedValueCoder}. + * @return a singleton list containing the {@code valueCoder} of this {@link + * FullWindowedValueCoder}. */ @Override public List> getCoderArguments() { @@ -707,12 +586,11 @@ public List> getComponents() { /** * Coder for {@code WindowedValue}. * - *

A {@code ValueOnlyWindowedValueCoder} only encodes and decodes the value. It drops - * timestamp and windows for encoding, and uses defaults timestamp, and windows for decoding. + *

A {@code ValueOnlyWindowedValueCoder} only encodes and decodes the value. It drops timestamp + * and windows for encoding, and uses defaults timestamp, and windows for decoding. */ public static class ValueOnlyWindowedValueCoder extends WindowedValueCoder { - public static ValueOnlyWindowedValueCoder of( - Coder valueCoder) { + public static ValueOnlyWindowedValueCoder of(Coder valueCoder) { return new ValueOnlyWindowedValueCoder<>(valueCoder); } @@ -752,14 +630,11 @@ public WindowedValue decode(InputStream inStream, Context context) @Override public void verifyDeterministic() throws NonDeterministicException { verifyDeterministic( - this, - "ValueOnlyWindowedValueCoder requires a deterministic valueCoder", - valueCoder); + this, "ValueOnlyWindowedValueCoder requires a deterministic valueCoder", valueCoder); } @Override - public void registerByteSizeObserver( - WindowedValue value, ElementByteSizeObserver observer) + public void registerByteSizeObserver(WindowedValue value, ElementByteSizeObserver observer) throws Exception { valueCoder.registerByteSizeObserver(value.getValue(), observer); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ZipFiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ZipFiles.java index 6836a4bbecbb..e6215e29c555 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ZipFiles.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ZipFiles.java @@ -136,12 +136,10 @@ static void unzipFile( } } else { File parentFile = targetFile.getParentFile(); - if (!parentFile.isDirectory()) { - if (!parentFile.mkdirs()) { - throw new IOException( - "Failed to create directory: " - + parentFile.getAbsolutePath()); - } + if (!parentFile.isDirectory() && !parentFile.mkdirs()) { + throw new IOException( + "Failed to create directory: " + + parentFile.getAbsolutePath()); } // Write the file to the destination. asByteSource(zipFileObj, entry).copyTo(Files.asByteSink(targetFile)); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/BeamRecordType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/BeamRecordType.java deleted file mode 100644 index 620361c52ab6..000000000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/BeamRecordType.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.values; - -import com.google.common.collect.ImmutableList; -import java.io.Serializable; -import java.util.List; -import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.coders.BeamRecordCoder; -import org.apache.beam.sdk.coders.Coder; - -/** - * {@link BeamRecordType} describes the fields in {@link BeamRecord}, extra checking can be added - * by overwriting {@link BeamRecordType#validateValueType(int, Object)}. - */ -@Experimental -public class BeamRecordType implements Serializable{ - private List fieldNames; - private List fieldCoders; - - /** - * Create a {@link BeamRecordType} with a name and Coder for each field. - */ - public BeamRecordType(List fieldNames, List fieldCoders) { - if (fieldNames.size() != fieldCoders.size()) { - throw new IllegalStateException( - "the size of fieldNames and fieldCoders need to be the same."); - } - this.fieldNames = fieldNames; - this.fieldCoders = fieldCoders; - } - - /** - * Validate input fieldValue for a field. - * @throws IllegalArgumentException throw exception when the validation fails. - */ - public void validateValueType(int index, Object fieldValue) - throws IllegalArgumentException{ - //do nothing by default. - } - - /** - * Return the coder for {@link BeamRecord}, which wraps {@link #fieldCoders} for each field. - */ - public BeamRecordCoder getRecordCoder(){ - return BeamRecordCoder.of(this, fieldCoders); - } - - /** - * Returns an immutable list of field names. - */ - public List getFieldNames(){ - return ImmutableList.copyOf(fieldNames); - } - - /** - * Return the name of field by index. - */ - public String getFieldNameByIndex(int index){ - return fieldNames.get(index); - } - - /** - * Find the index of a given field. - */ - public int findIndexOfField(String fieldName){ - return fieldNames.indexOf(fieldName); - } - - /** - * Return the count of fields. - */ - public int getFieldCount(){ - return fieldNames.size(); - } - - @Override - public String toString() { - return "BeamRecordType [fieldsName=" + fieldNames + "]"; - } -} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java index 3e86935ab07a..b8b08771948f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java @@ -156,7 +156,7 @@ public Map, PCollection> getAll() { * @return the output of the applied {@link PTransform} */ public OutputT apply( - PTransform t) { + PTransform t) { return Pipeline.applyTransform(this, t); } @@ -169,7 +169,7 @@ public OutputT apply( * @return the output of the applied {@link PTransform} */ public OutputT apply( - String name, PTransform t) { + String name, PTransform t) { return Pipeline.applyTransform(name, this, t); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/BeamRecord.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java similarity index 57% rename from sdks/java/core/src/main/java/org/apache/beam/sdk/values/BeamRecord.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index c79d1f8c1d19..c638eb1fd463 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/BeamRecord.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -17,6 +17,9 @@ */ package org.apache.beam.sdk.values; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; import java.io.Serializable; import java.math.BigDecimal; import java.util.ArrayList; @@ -25,69 +28,63 @@ import java.util.Date; import java.util.GregorianCalendar; import java.util.List; +import java.util.stream.Collector; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.coders.BeamRecordCoder; /** - * {@link BeamRecord} is an immutable tuple-like type to represent one element in a - * {@link PCollection}. The fields are described with a {@link BeamRecordType}. - * - *

By default, {@link BeamRecordType} only contains the name for each field. It - * can be extended to support more sophisticated validation by overwriting - * {@link BeamRecordType#validateValueType(int, Object)}. + * {@link Row} is an immutable tuple-like type to represent one element in a + * {@link PCollection}. The fields are described with a {@link RowType}. * - *

A Coder {@link BeamRecordCoder} is provided, which wraps the Coder for each data field. + *

{@link RowType} contains the names for each field and the coder for the whole + * record, {see @link RowType#getRowCoder()}. */ @Experimental -public class BeamRecord implements Serializable { - //immutable list of field values. - private List dataValues; - private BeamRecordType dataType; +@AutoValue +public abstract class Row implements Serializable { /** - * Creates a BeamRecord. - * @param dataType type of the record - * @param rawDataValues values of the record, record's size must match size of - * the {@code BeamRecordType}, or can be null, if it is null - * then every field is null. + * Creates a {@link Row} from the list of values and {@link #getRowType()}. */ - public BeamRecord(BeamRecordType dataType, List rawDataValues) { - if (dataType.getFieldNames().size() != rawDataValues.size()) { - throw new IllegalArgumentException( - "Field count in BeamRecordType(" + dataType.getFieldNames().size() - + ") and rawDataValues(" + rawDataValues.size() + ") must match!"); - } - - this.dataType = dataType; - this.dataValues = new ArrayList<>(dataType.getFieldCount()); + public static Collector, Row> toRow( + RowType rowType) { - for (int idx = 0; idx < dataType.getFieldCount(); ++idx) { - dataValues.add(null); - } - - for (int idx = 0; idx < dataType.getFieldCount(); ++idx) { - addField(idx, rawDataValues.get(idx)); - } + return Collector.of( + () -> new ArrayList<>(rowType.getFieldCount()), + List::add, + (left, right) -> { + left.addAll(right); + return left; + }, + values -> Row.withRowType(rowType).addValues(values).build()); } /** - * see {@link #BeamRecord(BeamRecordType, List)}. + * Creates a new record filled with nulls. */ - public BeamRecord(BeamRecordType dataType, Object... rawdataValues) { - this(dataType, Arrays.asList(rawdataValues)); + public static Row nullRow(RowType rowType) { + return + Row + .withRowType(rowType) + .addValues(Collections.nCopies(rowType.getFieldCount(), null)) + .build(); } - private void addField(int index, Object fieldValue) { - dataType.validateValueType(index, fieldValue); - dataValues.set(index, fieldValue); + /** + * Get value by field name, {@link ClassCastException} is thrown + * if type doesn't match. + */ + public T getValue(String fieldName) { + return getValue(getRowType().indexOf(fieldName)); } /** - * Get value by field name. + * Get value by field index, {@link ClassCastException} is thrown + * if type doesn't match. */ - public Object getFieldValue(String fieldName) { - return getFieldValue(dataType.getFieldNames().indexOf(fieldName)); + @Nullable + public T getValue(int fieldIdx) { + return (T) getValues().get(fieldIdx); } /** @@ -95,7 +92,7 @@ public Object getFieldValue(String fieldName) { * if type doesn't match. */ public Byte getByte(String fieldName) { - return (Byte) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -103,7 +100,7 @@ public Byte getByte(String fieldName) { * if type doesn't match. */ public Short getShort(String fieldName) { - return (Short) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -111,7 +108,7 @@ public Short getShort(String fieldName) { * if type doesn't match. */ public Integer getInteger(String fieldName) { - return (Integer) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -119,7 +116,7 @@ public Integer getInteger(String fieldName) { * if type doesn't match. */ public Float getFloat(String fieldName) { - return (Float) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -127,7 +124,7 @@ public Float getFloat(String fieldName) { * if type doesn't match. */ public Double getDouble(String fieldName) { - return (Double) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -135,7 +132,7 @@ public Double getDouble(String fieldName) { * if type doesn't match. */ public Long getLong(String fieldName) { - return (Long) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -143,7 +140,7 @@ public Long getLong(String fieldName) { * if type doesn't match. */ public String getString(String fieldName) { - return (String) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -151,7 +148,7 @@ public String getString(String fieldName) { * if type doesn't match. */ public Date getDate(String fieldName) { - return (Date) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -159,7 +156,7 @@ public Date getDate(String fieldName) { * if type doesn't match. */ public GregorianCalendar getGregorianCalendar(String fieldName) { - return (GregorianCalendar) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -167,7 +164,7 @@ public GregorianCalendar getGregorianCalendar(String fieldName) { * if type doesn't match. */ public BigDecimal getBigDecimal(String fieldName) { - return (BigDecimal) getFieldValue(fieldName); + return getValue(fieldName); } /** @@ -175,13 +172,7 @@ public BigDecimal getBigDecimal(String fieldName) { * if type doesn't match. */ public Boolean getBoolean(String fieldName) { - return (Boolean) getFieldValue(fieldName); - } - - /** Get value by field index. */ - @Nullable - public Object getFieldValue(int fieldIdx) { - return dataValues.get(fieldIdx); + return getValue(fieldName); } /** @@ -189,7 +180,7 @@ public Object getFieldValue(int fieldIdx) { * if type doesn't match. */ public Byte getByte(int idx) { - return (Byte) getFieldValue(idx); + return getValue(idx); } /** @@ -197,7 +188,7 @@ public Byte getByte(int idx) { * if type doesn't match. */ public Short getShort(int idx) { - return (Short) getFieldValue(idx); + return getValue(idx); } /** @@ -205,7 +196,7 @@ public Short getShort(int idx) { * if type doesn't match. */ public Integer getInteger(int idx) { - return (Integer) getFieldValue(idx); + return getValue(idx); } /** @@ -213,7 +204,7 @@ public Integer getInteger(int idx) { * if type doesn't match. */ public Float getFloat(int idx) { - return (Float) getFieldValue(idx); + return getValue(idx); } /** @@ -221,7 +212,7 @@ public Float getFloat(int idx) { * if type doesn't match. */ public Double getDouble(int idx) { - return (Double) getFieldValue(idx); + return getValue(idx); } /** @@ -229,7 +220,7 @@ public Double getDouble(int idx) { * if type doesn't match. */ public Long getLong(int idx) { - return (Long) getFieldValue(idx); + return getValue(idx); } /** @@ -237,7 +228,7 @@ public Long getLong(int idx) { * if type doesn't match. */ public String getString(int idx) { - return (String) getFieldValue(idx); + return getValue(idx); } /** @@ -245,7 +236,7 @@ public String getString(int idx) { * if type doesn't match. */ public Date getDate(int idx) { - return (Date) getFieldValue(idx); + return getValue(idx); } /** @@ -253,7 +244,7 @@ public Date getDate(int idx) { * if type doesn't match. */ public GregorianCalendar getGregorianCalendar(int idx) { - return (GregorianCalendar) getFieldValue(idx); + return getValue(idx); } /** @@ -261,7 +252,7 @@ public GregorianCalendar getGregorianCalendar(int idx) { * if type doesn't match. */ public BigDecimal getBigDecimal(int idx) { - return (BigDecimal) getFieldValue(idx); + return getValue(idx); } /** @@ -269,51 +260,66 @@ public BigDecimal getBigDecimal(int idx) { * if type doesn't match. */ public Boolean getBoolean(int idx) { - return (Boolean) getFieldValue(idx); + return getValue(idx); } /** * Return the size of data fields. */ public int getFieldCount() { - return dataValues.size(); + return getValues().size(); } /** * Return the list of data values. */ - public List getDataValues() { - return Collections.unmodifiableList(dataValues); - } + public abstract List getValues(); /** - * Return {@link BeamRecordType} which describes the fields. + * Return {@link RowType} which describes the fields. */ - public BeamRecordType getDataType() { - return dataType; - } + public abstract RowType getRowType(); - @Override - public String toString() { - return "BeamRecord [dataValues=" + dataValues + ", dataType=" + dataType + "]"; + /** + * Creates a record builder with specified {@link #getRowType()}. + * {@link Builder#build()} will throw an {@link IllegalArgumentException} if number of fields + * in {@link #getRowType()} does not match the number of fields specified. + */ + public static Builder withRowType(RowType rowType) { + return + new AutoValue_Row.Builder(rowType); } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; + /** + * Builder for {@link Row}. + */ + public static class Builder { + private List values = new ArrayList<>(); + private RowType type; + + Builder(RowType type) { + this.type = type; } - if (obj == null) { - return false; + + public Builder addValues(List values) { + this.values.addAll(values); + return this; } - if (getClass() != obj.getClass()) { - return false; + + public Builder addValues(Object ... values) { + return addValues(Arrays.asList(values)); } - BeamRecord other = (BeamRecord) obj; - return toString().equals(other.toString()); - } - @Override public int hashCode() { - return 31 * getDataType().hashCode() + getDataValues().hashCode(); + public Row build() { + checkNotNull(type); + + if (type.getFieldCount() != values.size()) { + throw new IllegalArgumentException( + String.format( + "Field count in RowType (%s) and values (%s) must match", + type.fieldNames(), values)); + } + return new AutoValue_Row(values, type); + } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowType.java new file mode 100644 index 000000000000..6189b05ef13c --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowType.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collector; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.RowCoder; + +/** + * {@link RowType} describes the fields in {@link Row}. + */ +@Experimental +@AutoValue +public abstract class RowType implements Serializable{ + abstract List fieldNames(); + abstract List fieldCoders(); + + /** + * Field of a row. + * + *

Contains field name and its coder. + */ + @AutoValue + public abstract static class Field { + abstract String name(); + abstract Coder coder(); + + public static Field of(String name, Coder coder) { + return new AutoValue_RowType_Field(name, coder); + } + } + + /** + * Collects a stream of {@link Field}s into a {@link RowType}. + */ + public static Collector, RowType> toRowType() { + return Collector.of( + ArrayList::new, + List::add, + (left, right) -> { + left.addAll(right); + return left; + }, + RowType::fromFields); + } + + private static RowType fromFields(List fields) { + ImmutableList.Builder names = ImmutableList.builder(); + ImmutableList.Builder coders = ImmutableList.builder(); + + for (Field field : fields) { + names.add(field.name()); + coders.add(field.coder()); + } + + return fromNamesAndCoders(names.build(), coders.build()); + } + + /** + * Creates a new {@link Field} with specified name and coder. + */ + public static Field newField(String name, Coder coder) { + return Field.of(name, coder); + } + + public static RowType fromNamesAndCoders( + List fieldNames, + List fieldCoders) { + + if (fieldNames.size() != fieldCoders.size()) { + throw new IllegalStateException( + "the size of fieldNames and fieldCoders need to be the same."); + } + + return new AutoValue_RowType(fieldNames, fieldCoders); + } + + /** + * Return the coder for {@link Row}, which wraps {@link #fieldCoders} for each field. + */ + public RowCoder getRowCoder(){ + return RowCoder.of(this, fieldCoders()); + } + + /** + * Return the field coder for {@code index}. + */ + public Coder getFieldCoder(int index){ + return fieldCoders().get(index); + } + + /** + * Returns an immutable list of field names. + */ + public List getFieldNames(){ + return ImmutableList.copyOf(fieldNames()); + } + + /** + * Return the name of field by index. + */ + public String getFieldName(int index){ + return fieldNames().get(index); + } + + /** + * Find the index of a given field. + */ + public int indexOf(String fieldName){ + return fieldNames().indexOf(fieldName); + } + + /** + * Return the count of fields. + */ + public int getFieldCount(){ + return fieldNames().size(); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ByteBuddyUtils.java new file mode 100644 index 000000000000..8bc75282319d --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ByteBuddyUtils.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static net.bytebuddy.matcher.ElementMatchers.named; + +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.DynamicType; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.FixedValue; +import net.bytebuddy.implementation.Implementation; + +/** + * Utilities to help with code generation for implementing {@link FieldValueGetter}s. + */ +class ByteBuddyUtils { + + /** + * Creates an instance of the {@link DynamicType.Builder} + * to start implementation of the {@link FieldValueGetter}. + */ + static DynamicType.Builder subclassGetterInterface( + ByteBuddy byteBuddy, Class clazz) { + + TypeDescription.Generic getterGenericType = + TypeDescription.Generic.Builder.parameterizedType(FieldValueGetter.class, clazz).build(); + + return (DynamicType.Builder) byteBuddy.subclass(getterGenericType); + } + + /** + * Implements {@link FieldValueGetter#name()}. + */ + static DynamicType.Builder implementNameGetter( + DynamicType.Builder getterClassBuilder, + String fieldName) { + + return getterClassBuilder + .method(named("name")) + .intercept(FixedValue.reference(fieldName)); + } + + /** + * Implements {@link FieldValueGetter#type()}. + */ + static DynamicType.Builder implementTypeGetter( + DynamicType.Builder getterClassBuilder, + Class fieldType) { + + return getterClassBuilder + .method(named("type")) + .intercept(FixedValue.reference(fieldType)); + } + + /** + * Implements {@link FieldValueGetter#get(Object)} for getting public fields from pojos. + */ + static DynamicType.Builder implementValueGetter( + DynamicType.Builder getterClassBuilder, + Implementation fieldAccessImplementation) { + + return getterClassBuilder + .method(named("get")) + .intercept(fieldAccessImplementation); + } + + /** + * Finish the {@link FieldValueGetter} implementation and return its new instance. + * + *

Wraps underlying {@link InstantiationException} and {@link IllegalAccessException} + * into {@link RuntimeException}. + * + *

Does no validations of whether everything has been implemented correctly. + */ + static FieldValueGetter makeNewGetterInstance( + String fieldName, + DynamicType.Builder getterBuilder) { + + try { + return getterBuilder + .make() + .load( + ByteBuddyUtils.class.getClassLoader(), + ClassLoadingStrategy.Default.INJECTION) + .getLoaded() + .newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new RuntimeException( + "Unable to generate a getter for field '" + fieldName + "'.", e); + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/DefaultRowTypeFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/DefaultRowTypeFactory.java new file mode 100644 index 000000000000..355500d64618 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/DefaultRowTypeFactory.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.values.RowType; + +/** + * A default implementation of the {@link RowTypeFactory} interface. The purpose of + * the factory is to create a row types given a list of getters. + * + *

Row type is represented by {@link RowType} which essentially is a + * {@code List>}. + * + *

Getters (e.g. pojo field getters) are represented by {@link FieldValueGetter} interface, + * which exposes the field's name (see {@link FieldValueGetter#name()}) + * and java type (see {@link FieldValueGetter#type()}). + * + *

This factory then uses the default {@link CoderRegistry} to map java types of + * the getters to coders, and then creates an instance of {@link RowType} using those coders. + * + *

If there is no coder in the default {@link CoderRegistry} for the java type of the getter, + * then the factory throws {@link UnsupportedOperationException}. + * + *

This is the default factory implementation used in {@link RowFactory}. + * It should work for regular java pipelines where coder mapping via default {@link CoderRegistry} + * is enough. + * + *

In other cases, when mapping requires extra logic, another implentation of the + * {@link RowTypeFactory} should be used instead of this class. + * + *

For example, Beam SQL uses {@link java.sql.Types} as an intermediate type representation + * instead of using plain java types. The mapping between {@link java.sql.Types} and coders + * is not available in the default {@link CoderRegistry}, so custom SQL-specific implementation of + * {@link RowTypeFactory} is used with SQL infrastructure instead of this class. + * See {@code SqlRecordTypeFactory}. + */ +class DefaultRowTypeFactory implements RowTypeFactory { + + private static final CoderRegistry CODER_REGISTRY = CoderRegistry.createDefault(); + + /** + * Uses {@link FieldValueGetter#name()} as field names. + * Uses {@link CoderRegistry#createDefault()} to get coders for {@link FieldValueGetter#type()}. + */ + @Override + public RowType createRowType(Iterable fieldValueGetters) { + return + RowType + .fromNamesAndCoders( + getFieldNames(fieldValueGetters), + getFieldCoders(fieldValueGetters)); + } + + private static List getFieldNames(Iterable fieldValueGetters) { + ImmutableList.Builder names = ImmutableList.builder(); + + for (FieldValueGetter fieldValueGetter : fieldValueGetters) { + names.add(fieldValueGetter.name()); + } + + return names.build(); + } + + private static List getFieldCoders(Iterable fieldValueGetters) { + ImmutableList.Builder coders = ImmutableList.builder(); + + for (FieldValueGetter fieldValueGetter : fieldValueGetters) { + try { + coders.add(CODER_REGISTRY.getCoder(fieldValueGetter.type())); + } catch (CannotProvideCoderException e) { + throw new UnsupportedOperationException( + "Fields of type " + + fieldValueGetter.type().getSimpleName() + " are not supported yet", e); + } + } + + return coders.build(); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/FieldValueGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/FieldValueGetter.java new file mode 100644 index 000000000000..4337012d279a --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/FieldValueGetter.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import org.apache.beam.sdk.annotations.Internal; + +/** + * For internal use only; no backwards-compatibility guarantees. + * + *

An interface to access a field of a class. + * + *

Implementations of this interface are generated at runtime by {@link RowFactory} + * to map pojo fields to BeamRecord fields. + */ +@Internal +public interface FieldValueGetter { + Object get(T object); + String name(); + Class type(); +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/GeneratedGetterFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/GeneratedGetterFactory.java new file mode 100644 index 000000000000..4295dcaf34cf --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/GeneratedGetterFactory.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static net.bytebuddy.implementation.MethodCall.invoke; +import static org.apache.beam.sdk.values.reflect.ByteBuddyUtils.implementNameGetter; +import static org.apache.beam.sdk.values.reflect.ByteBuddyUtils.implementTypeGetter; +import static org.apache.beam.sdk.values.reflect.ByteBuddyUtils.implementValueGetter; +import static org.apache.beam.sdk.values.reflect.ByteBuddyUtils.makeNewGetterInstance; +import static org.apache.beam.sdk.values.reflect.ByteBuddyUtils.subclassGetterInterface; +import static org.apache.beam.sdk.values.reflect.ReflectionUtils.getPublicGetters; +import static org.apache.beam.sdk.values.reflect.ReflectionUtils.tryStripGetPrefix; + +import com.google.common.collect.ImmutableList; +import java.lang.reflect.Method; +import java.util.List; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.dynamic.DynamicType; + +/** + * Implements and creates an instance of the {@link FieldValueGetter} for each public + * getter method of the pojo class. + * + *

Generated {@link FieldValueGetter#get(Object)} calls the corresponding + * getter method of the pojo. + * + *

Generated {@link FieldValueGetter#name()} strips the 'get' from the getter method name. + * + *

For example if pojo looks like + *

{@code
+ * public class PojoClass {
+ *   public String getPojoNameField() { ... }
+ * }
+ * }
+ * + *

Then, class name aside, generated {@link FieldValueGetter} will look like: + *

{@code
+ * public class FieldValueGetterGenerated implements FieldValueGetter {
+ *   public String name() {
+ *     return "pojoNameField";
+ *   }
+ *
+ *   public Class type() {
+ *     return String.class;
+ *   }
+ *
+ *   public get(PojoType pojo) {
+ *     return pojo.getPojoNameField();
+ *   }
+ * }
+ * }
+ * + *

ByteBuddy is used to generate the code. Class naming is left to ByteBuddy's defaults. + * + *

Class is injected into ByteBuddyUtils.class.getClassLoader(). + * See {@link ByteBuddyUtils#makeNewGetterInstance(String, DynamicType.Builder)} + * and ByteBuddy documentation for details. + */ +class GeneratedGetterFactory implements GetterFactory { + + private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); + + /** + * Returns the list of the getters, one for each public getter of the pojoClass. + */ + @Override + public List generateGetters(Class pojoClass) { + ImmutableList.Builder getters = ImmutableList.builder(); + + List getterMethods = getPublicGetters(pojoClass); + + for (Method getterMethod : getterMethods) { + getters.add(createFieldGetterInstance(pojoClass, getterMethod)); + } + + return getters.build(); + } + + private static FieldValueGetter createFieldGetterInstance(Class clazz, Method getterMethod) { + + DynamicType.Builder getterBuilder = + subclassGetterInterface(BYTE_BUDDY, clazz); + + getterBuilder = implementNameGetter(getterBuilder, tryStripGetPrefix(getterMethod)); + getterBuilder = implementTypeGetter(getterBuilder, getterMethod.getReturnType()); + getterBuilder = implementValueGetter(getterBuilder, invoke(getterMethod).onArgument(0)); + + return makeNewGetterInstance(getterMethod.getName(), getterBuilder); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/GetterFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/GetterFactory.java new file mode 100644 index 000000000000..f4e104304164 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/GetterFactory.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values.reflect; + +import java.util.List; + +/** + * Interface for factories generating getter wrappers. + * See {@link GeneratedGetterFactory} or {@link ReflectionGetterFactory}. + */ +interface GetterFactory { + + /** + * Generates getters for {@code clazz}. + */ + List generateGetters(Class clazz); +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionGetter.java new file mode 100644 index 000000000000..374a460a6c33 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionGetter.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.apache.beam.sdk.values.reflect.ReflectionUtils.tryStripGetPrefix; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** + * Implementation of {@link FieldValueGetter} backed by relfection-based getter invocation, + * as opposed to a code-generated version produced by {@link GeneratedGetterFactory}. + */ +class ReflectionGetter implements FieldValueGetter { + private String name; + private Class type; + private Method getter; + + ReflectionGetter(Method getter) { + this.getter = getter; + this.name = tryStripGetPrefix(getter); + this.type = getter.getReturnType(); + } + + @Override + public Object get(Object object) { + try { + return getter.invoke(object); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalArgumentException("Unable to invoke " + getter, e); + } + } + + @Override + public String name() { + return name; + } + + @Override + public Class type() { + return type; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionGetterFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionGetterFactory.java new file mode 100644 index 000000000000..20cf6740d244 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionGetterFactory.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.apache.beam.sdk.values.reflect.ReflectionUtils.getPublicGetters; + +import com.google.common.collect.ImmutableList; +import java.lang.reflect.Method; +import java.util.List; +import org.apache.beam.sdk.values.RowType; + +/** + * Factory to wrap calls to pojo getters into instances of {@link FieldValueGetter} + * using reflection. + * + *

Returns instances of {@link FieldValueGetter}s backed getter methods of a pojo class. + * Getters are invoked using {@link java.lang.reflect.Method#invoke(Object, Object...)} + * from {@link FieldValueGetter#get(Object)}. + * + *

Caching is not handled at this level, {@link RowFactory} should cache getters + * for each {@link RowType}. + */ +class ReflectionGetterFactory implements GetterFactory { + + /** + * Returns a list of {@link FieldValueGetter}s. + * One for each public getter of the {@code pojoClass}. + */ + @Override + public List generateGetters(Class pojoClass) { + ImmutableList.Builder getters = ImmutableList.builder(); + + for (Method getterMethod : getPublicGetters(pojoClass)) { + getters.add(new ReflectionGetter(getterMethod)); + } + + return getters.build(); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionUtils.java new file mode 100644 index 000000000000..0107f9a75df1 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/ReflectionUtils.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.List; + +/** + * Helpers to get the information about a class. + */ +class ReflectionUtils { + + /** + * Returns a list of non-void public methods with names prefixed with 'get'. + */ + static List getPublicGetters(Class clazz) { + List getters = new ArrayList<>(); + for (Method method : clazz.getDeclaredMethods()) { + if (isGetter(method) && isPublic(method)) { + getters.add(method); + } + } + + return getters; + } + + /** + * Tries to remove a 'get' prefix from a method name. + * + *

Converts method names like 'getSomeField' into 'someField' if they start with 'get'. + * Returns names unchanged if they don't start with 'get'. + */ + static String tryStripGetPrefix(Method method) { + String name = method.getName(); + + if (name.length() <= 3 || !name.startsWith("get")) { + return name; + } + + String firstLetter = name.substring(3, 4).toLowerCase(); + + return (name.length() == 4) + ? firstLetter + : (firstLetter + name.substring(4, name.length())); + } + + private static boolean isGetter(Method method) { + return method.getName().startsWith("get") + && !Void.TYPE.equals(method.getReturnType()); + } + + private static boolean isPublic(Method method) { + return Modifier.isPublic(method.getModifiers()); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowFactory.java new file mode 100644 index 000000000000..b684d6182447 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowFactory.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; + +/** + * For internal use only; no backwards-compatibility guarantees. + * + *

Generates the code to create {@link RowType}s and {@link Row}s based on pojos. + * + *

Generated record types are cached in the instance of this factory. + * + *

At the moment single pojo class corresponds to single {@link RowType}. + * + *

Supported pojo getter types depend on types supported by the {@link RowTypeFactory}. + * See {@link DefaultRowTypeFactory} for default implementation. + */ +@Internal +public class RowFactory { + + private RowTypeFactory rowTypeFactory; + private final Map rowTypesCache = new HashMap<>(); + private final List getterFactories; + + /** + * Creates an instance of {@link RowFactory} using {@link DefaultRowTypeFactory} + * and {@link GeneratedGetterFactory}. + */ + public static RowFactory createDefault() { + return new RowFactory(); + } + + /** + * Create new instance based on default record type factory. + * + *

Use this to create instances of {@link RowType}. + */ + private RowFactory() { + this(new DefaultRowTypeFactory(), new GeneratedGetterFactory()); + } + + /** + * Create new instance with custom record type factory. + * + *

For example this can be used to create BeamRecordSqlTypes instead of {@link RowType}. + */ + RowFactory(RowTypeFactory rowTypeFactory, GetterFactory ... getterFactories) { + this.rowTypeFactory = rowTypeFactory; + this.getterFactories = Arrays.asList(getterFactories); + } + + /** + * Create a {@link Row} of the pojo. + * + *

This implementation copies the return values of the pojo getters into + * the record fields on creation. + * + *

Currently all public getters are used to populate the record type and instance. + * + *

Field names for getters are stripped of the 'get' prefix. + * For example record field 'name' will be generated for 'getName()' pojo method. + */ + public Row create(Object pojo) { + RowTypeGetters getters = getRecordType(pojo.getClass()); + List fieldValues = getFieldValues(getters.valueGetters(), pojo); + return Row.withRowType(getters.rowType()).addValues(fieldValues).build(); + } + + private synchronized RowTypeGetters getRecordType(Class pojoClass) { + if (rowTypesCache.containsKey(pojoClass)) { + return rowTypesCache.get(pojoClass); + } + + List fieldValueGetters = createGetters(pojoClass); + RowType rowType = rowTypeFactory.createRowType(fieldValueGetters); + rowTypesCache.put(pojoClass, new RowTypeGetters(rowType, fieldValueGetters)); + + return rowTypesCache.get(pojoClass); + } + + private List createGetters(Class pojoClass) { + ImmutableList.Builder getters = ImmutableList.builder(); + + for (GetterFactory getterFactory : getterFactories) { + getters.addAll(getterFactory.generateGetters(pojoClass)); + } + + return getters.build(); + } + + private List getFieldValues(List fieldValueGetters, Object pojo) { + ImmutableList.Builder builder = ImmutableList.builder(); + + for (FieldValueGetter fieldValueGetter : fieldValueGetters) { + builder.add(fieldValueGetter.get(pojo)); + } + + return builder.build(); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowTypeFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowTypeFactory.java new file mode 100644 index 000000000000..f91f6a018baf --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowTypeFactory.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.values.RowType; + +/** + * For internal use only; no backwards-compatibility guarantees. + * + *

Interface for factories used to create record types based on getters. + * + *

Different implementations can have different ways of mapping getter types to coders. + * For example Beam SQL uses custom mapping via java.sql.Types. + * + *

Default implementation is {@link DefaultRowTypeFactory}. + * It returns instances of {@link RowType}, mapping {@link FieldValueGetter#type()} + * to known coders. + */ +@Internal +public interface RowTypeFactory { + + /** + * Create a {@link RowType} for the list of the pojo field getters. + */ + RowType createRowType(Iterable getters); +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowTypeGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowTypeGetters.java new file mode 100644 index 000000000000..e6f8998d9c7e --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/RowTypeGetters.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import java.util.List; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; + +/** + * Helper class to hold {@link RowType} and {@link FieldValueGetter}s which were used to + * create it. + * + *

This is used in {@link RowFactory} to create instances of {@link Row}s. + */ +class RowTypeGetters { + private RowType rowType; + private List fieldValueGetters; + + RowTypeGetters(RowType rowType, List fieldValueGetters) { + this.rowType = rowType; + this.fieldValueGetters = fieldValueGetters; + } + + /** + * Returns a {@link RowType}. + */ + RowType rowType() { + return rowType; + } + + /** + * Returns the list of {@link FieldValueGetter}s which + * were used to create {@link RowTypeGetters#rowType()}. + */ + List valueGetters() { + return fieldValueGetters; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/package-info.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/package-info.java new file mode 100644 index 000000000000..c7c549221476 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/reflect/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Classes to generate BeamRecords from pojos. + */ +package org.apache.beam.sdk.values.reflect; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java index 5c0c0e9e857b..3d29af5bfde0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java @@ -436,12 +436,14 @@ public PCollection expand(PCollection input) { return input.apply("custom_name", Sum.integersGlobally()); } } + class ReplacementTransform extends PTransform, PCollection> { @Override public PCollection expand(PCollection input) { return input.apply("custom_name", Max.integersGlobally()); } } + class ReplacementOverrideFactory implements PTransformOverrideFactory< PCollection, PCollection, OriginalTransform> { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java index fdcef3d4e74a..9c8c5bf67fee 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java @@ -442,6 +442,7 @@ private static class NonDeterministicArray { @SuppressWarnings("unused") private UnorderedMapClass[] arrayField; } + @Test public void testDeterministicNonDeterministicArray() { assertNonDeterministic(AvroCoder.of(NonDeterministicArray.class), @@ -714,7 +715,9 @@ private abstract static class DeterministicUnionBase {} @Union({ UnionCase1.class, UnionCase2.class, UnionCase3.class }) private abstract static class NonDeterministicUnionBase {} + private static class UnionCase1 extends DeterministicUnionBase {} + private static class UnionCase2 extends DeterministicUnionBase { @SuppressWarnings("unused") String field; @@ -903,6 +906,7 @@ private static class SomeGeneric { @SuppressWarnings("unused") private T foo; } + private static class Foo { @SuppressWarnings("unused") String id; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/DelegateCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/DelegateCoderTest.java index 5a2add480d17..c7dfba21a8cb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/DelegateCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/DelegateCoderTest.java @@ -104,9 +104,6 @@ public void testSerializable() throws Exception { CoderProperties.coderSerializable(TEST_CODER); } - private static final String TEST_ENCODING_ID = "test-encoding-id"; - private static final String TEST_ALLOWED_ENCODING = "test-allowed-encoding"; - @Test public void testCoderEquals() throws Exception { DelegateCoder.CodingFunction identityFn = input -> input; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/StructuredCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/StructuredCoderTest.java index 7aa2080cf101..2980a1456f09 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/StructuredCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/StructuredCoderTest.java @@ -96,6 +96,7 @@ private static class ObjectIdentityBoolean { public ObjectIdentityBoolean(boolean value) { this.value = value; } + public boolean getValue() { return value; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java index dece483cf03d..44f3db4d9c0a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java @@ -467,9 +467,9 @@ public void testCompressionUNCOMPRESSED() throws FileNotFoundException, IOExcept } private void assertReadValues(final BufferedReader br, String... values) throws IOException { - try (final BufferedReader _br = br) { + try (final BufferedReader lbr = br) { for (String value : values) { - assertEquals(String.format("Line should read '%s'", value), value, _br.readLine()); + assertEquals(String.format("Line should read '%s'", value), value, lbr.readLine()); } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java index 608fd0a8b1c6..36d0928b3e88 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io; +import static org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE; import static org.hamcrest.Matchers.isA; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -42,6 +43,8 @@ import org.apache.beam.sdk.testing.UsesSplittableParDo; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Watch; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; import org.junit.Rule; @@ -301,4 +304,76 @@ private static MatchResult.Metadata metadata(Path path, int size) { .setSizeBytes(size) .build(); } + + private static FileIO.Write.FileNaming resolveFileNaming(FileIO.Write write) + throws Exception { + return write.resolveFileNamingFn().getClosure().apply(null, null); + } + + private static String getDefaultFileName(FileIO.Write write) throws Exception { + return resolveFileNaming(write).getFilename(null, null, 0, 0, null); + } + + @Test + public void testFilenameFnResolution() throws Exception { + FileIO.Write.FileNaming foo = (window, pane, numShards, shardIndex, compression) -> "foo"; + + String expected = + FileSystems.matchNewResource("test", true).resolve("foo", RESOLVE_FILE).toString(); + assertEquals( + "Filenames should be resolved within a relative directory if '.to' is invoked", + expected, + getDefaultFileName(FileIO.writeDynamic().to("test").withNaming(o -> foo))); + assertEquals( + "Filenames should be resolved within a relative directory if '.to' is invoked", + expected, + getDefaultFileName(FileIO.write().to("test").withNaming(foo))); + + assertEquals( + "Filenames should be resolved as the direct result of the filenaming function if '.to' " + + "is not invoked", + "foo", + getDefaultFileName(FileIO.writeDynamic().withNaming(o -> foo))); + assertEquals( + "Filenames should be resolved as the direct result of the filenaming function if '.to' " + + "is not invoked", + "foo", + getDefaultFileName(FileIO.write().withNaming(foo))); + + assertEquals( + "Default to the defaultNaming if a filenaming isn't provided for a non-dynamic write", + "output-00000-of-00000", + resolveFileNaming(FileIO.write()) + .getFilename( + GlobalWindow.INSTANCE, + PaneInfo.ON_TIME_AND_ONLY_FIRING, + 0, + 0, + Compression.UNCOMPRESSED)); + + assertEquals( + "Default Naming should take prefix and suffix into account if provided", + "foo-00000-of-00000.bar", + resolveFileNaming(FileIO.write().withPrefix("foo").withSuffix(".bar")) + .getFilename( + GlobalWindow.INSTANCE, + PaneInfo.ON_TIME_AND_ONLY_FIRING, + 0, + 0, + Compression.UNCOMPRESSED)); + + assertEquals( + "Filenames should be resolved within a relative directory if '.to' is invoked, " + + "even with default naming", + FileSystems.matchNewResource("test", true) + .resolve("output-00000-of-00000", RESOLVE_FILE) + .toString(), + resolveFileNaming(FileIO.write().to("test")) + .getFilename( + GlobalWindow.INSTANCE, + PaneInfo.ON_TIME_AND_ONLY_FIRING, + 0, + 0, + Compression.UNCOMPRESSED)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java index 8658782440ca..e4e7790906fc 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java @@ -73,7 +73,6 @@ import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.DisplayData.Builder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -430,10 +429,6 @@ public FilenamePolicy getFilenamePolicy(Integer destination) { "simple"); } - @Override - public void populateDisplayData(Builder builder) { - super.populateDisplayData(builder); - } } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java index b1b08fea2e96..e4c4102d6b64 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java @@ -17,13 +17,17 @@ */ package org.apache.beam.sdk.options; +import static java.util.Locale.ROOT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -58,6 +62,7 @@ import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; import org.apache.beam.sdk.testing.CrashingRunner; import org.apache.beam.sdk.testing.ExpectedLogs; +import org.apache.beam.sdk.testing.InterceptingUrlClassLoader; import org.apache.beam.sdk.testing.RestoreSystemProperties; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.hamcrest.Matchers; @@ -98,7 +103,7 @@ public void testAutomaticRegistrationInculdesWithoutRunnerSuffix() { REGISTERED_RUNNER.getSimpleName() .substring(0, REGISTERED_RUNNER.getSimpleName().length() - "Runner".length())); Map>> registered = - PipelineOptionsFactory.getRegisteredRunners(); + PipelineOptionsFactory.CACHE.get().getSupportedPipelineRunners(); assertEquals(REGISTERED_RUNNER, registered.get(REGISTERED_RUNNER.getSimpleName() .toLowerCase() @@ -1365,7 +1370,9 @@ public void testSettingUnknownRunner() { "Unknown 'runner' specified 'UnknownRunner', supported " + "pipeline runners"); Set registeredRunners = PipelineOptionsFactory.getRegisteredRunners().keySet(); assertThat(registeredRunners, hasItem(REGISTERED_RUNNER.getSimpleName().toLowerCase())); - expectedException.expectMessage(PipelineOptionsFactory.getSupportedRunners().toString()); + + expectedException.expectMessage(PipelineOptionsFactory.CACHE.get() + .getSupportedRunners().toString()); PipelineOptionsFactory.fromArgs(args).create(); } @@ -1790,4 +1797,67 @@ public void serialize(JacksonIncompatible jacksonIncompatible, JsonGenerator jso } } + /** Used to test that the thread context class loader is used when creating proxies. */ + public interface ClassLoaderTestOptions extends PipelineOptions { + @Default.Boolean(true) + @Description("A test option.") + boolean isOption(); + void setOption(boolean b); + } + + @Test + public void testPipelineOptionsFactoryUsesTccl() throws Exception { + final Thread thread = Thread.currentThread(); + final ClassLoader testClassLoader = thread.getContextClassLoader(); + final ClassLoader caseLoader = new InterceptingUrlClassLoader( + testClassLoader, + name -> name.toLowerCase(ROOT).contains("test")); + thread.setContextClassLoader(caseLoader); + PipelineOptionsFactory.resetCache(); + try { + final PipelineOptions pipelineOptions = PipelineOptionsFactory.create(); + final Class optionType = caseLoader.loadClass( + "org.apache.beam.sdk.options.PipelineOptionsFactoryTest$ClassLoaderTestOptions"); + final Object options = pipelineOptions.as(optionType); + assertSame(caseLoader, options.getClass().getClassLoader()); + assertSame(optionType.getClassLoader(), options.getClass().getClassLoader()); + assertSame(testClassLoader, optionType.getInterfaces()[0].getClassLoader()); + assertTrue(Boolean.class.cast(optionType.getMethod("isOption").invoke(options))); + } finally { + thread.setContextClassLoader(testClassLoader); + PipelineOptionsFactory.resetCache(); + } + } + + @Test + public void testDefaultMethodIgnoresDefaultImplementation() { + OptionsWithDefaultMethod optsWithDefault = + PipelineOptionsFactory.as(OptionsWithDefaultMethod.class); + assertThat(optsWithDefault.getValue(), nullValue()); + + optsWithDefault.setValue(12.25); + assertThat(optsWithDefault.getValue(), equalTo(12.25)); + } + + private interface ExtendedOptionsWithDefault extends OptionsWithDefaultMethod {} + + @Test + public void testDefaultMethodInExtendedClassIgnoresDefaultImplementation() { + OptionsWithDefaultMethod extendedOptsWithDefault = + PipelineOptionsFactory.as(ExtendedOptionsWithDefault.class); + assertThat(extendedOptsWithDefault.getValue(), nullValue()); + + extendedOptsWithDefault.setValue(Double.NEGATIVE_INFINITY); + assertThat(extendedOptsWithDefault.getValue(), equalTo(Double.NEGATIVE_INFINITY)); + } + + private interface OptionsWithDefaultMethod extends PipelineOptions { + default Number getValue() { + return 1024; + } + + void setValue(Number value); + } + + } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java index 9e6dffc077b2..0cd372a1ea43 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java @@ -74,7 +74,7 @@ public class ProxyInvocationHandlerTest { @Rule public TestRule resetPipelineOptionsRegistry = new ExternalResource() { @Override protected void before() { - PipelineOptionsFactory.resetRegistry(); + PipelineOptionsFactory.resetCache(); } }; @@ -468,7 +468,7 @@ public void testResetRegistry() { PipelineOptionsFactory.register(FooOptions.class); assertThat(PipelineOptionsFactory.getRegisteredOptions(), hasItem(FooOptions.class)); - PipelineOptionsFactory.resetRegistry(); + PipelineOptionsFactory.resetCache(); assertEquals(defaultRegistry, PipelineOptionsFactory.getRegisteredOptions()); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java index b5adcb577d68..6a79f7ab5b73 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java @@ -17,10 +17,11 @@ */ package org.apache.beam.sdk.testing; +import com.google.common.base.Predicates; import com.google.common.collect.Sets; import com.google.common.io.ByteStreams; import java.io.IOException; -import java.util.Set; +import java.util.function.Predicate; /** * A classloader that intercepts loading of specifically named classes. This classloader copies @@ -28,11 +29,15 @@ * with multiple classloaders.. */ public class InterceptingUrlClassLoader extends ClassLoader { - private final Set ownedClasses; + private final Predicate test; public InterceptingUrlClassLoader(final ClassLoader parent, final String... ownedClasses) { + this(parent, Predicates.in(Sets.newHashSet(ownedClasses))::apply); + } + + public InterceptingUrlClassLoader(final ClassLoader parent, final Predicate test) { super(parent); - this.ownedClasses = Sets.newHashSet(ownedClasses); + this.test = test; } @Override @@ -42,7 +47,7 @@ public Class loadClass(final String name) throws ClassNotFoundException { return alreadyLoaded; } - if (name != null && ownedClasses.contains(name)) { + if (name != null && test.test(name)) { try { final String classAsResource = name.replace('.', '/') + ".class"; final byte[] classBytes = diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java index d3c1004de17e..a5b4c1b00715 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java @@ -90,6 +90,7 @@ public int hashCode() { private static class NotSerializableObjectCoder extends AtomicCoder { private NotSerializableObjectCoder() { } + private static final NotSerializableObjectCoder INSTANCE = new NotSerializableObjectCoder(); @JsonCreator diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java index 11e40d987a72..0200b0887987 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java @@ -23,14 +23,17 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasNamespace; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -72,15 +75,18 @@ import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.Window.ClosingBehavior; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.TimestampedValue; +import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.Assume; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -88,7 +94,7 @@ import org.junit.runners.JUnit4; /** - * Tests for Combine transforms. + * Tests for {@link Combine} transforms. */ @RunWith(JUnit4.class) public class CombineTest implements Serializable { @@ -1235,4 +1241,128 @@ public void processElement(ProcessContext c) throws Exception { } })); } + + /** + * Class for use in testing use of Java 8 method references. + */ + private static class Summer implements Serializable { + public int sum(Iterable integers) { + int sum = 0; + for (int i : integers) { + sum += i; + } + return sum; + } + } + + /** + * Tests creation of a global {@link Combine} via Java 8 lambda. + */ + @Test + @Category(ValidatesRunner.class) + public void testCombineGloballyLambda() { + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4)) + .apply(Combine.globally(integers -> { + int sum = 0; + for (int i : integers) { + sum += i; + } + return sum; + })); + + PAssert.that(output).containsInAnyOrder(10); + pipeline.run(); + } + + /** + * Tests creation of a global {@link Combine} via a Java 8 method reference. + */ + @Test + @Category(ValidatesRunner.class) + public void testCombineGloballyInstanceMethodReference() { + + PCollection output = pipeline + .apply(Create.of(1, 2, 3, 4)) + .apply(Combine.globally(new Summer()::sum)); + + PAssert.that(output).containsInAnyOrder(10); + pipeline.run(); + } + + /** + * Tests creation of a per-key {@link Combine} via a Java 8 lambda. + */ + @Test + @Category(ValidatesRunner.class) + public void testCombinePerKeyLambda() { + + PCollection> output = pipeline + .apply(Create.of(KV.of("a", 1), KV.of("b", 2), KV.of("a", 3), KV.of("c", 4))) + .apply(Combine.perKey(integers -> { + int sum = 0; + for (int i : integers) { + sum += i; + } + return sum; + })); + + PAssert.that(output).containsInAnyOrder( + KV.of("a", 4), + KV.of("b", 2), + KV.of("c", 4)); + pipeline.run(); + } + + /** + * Tests creation of a per-key {@link Combine} via a Java 8 method reference. + */ + @Test + @Category(ValidatesRunner.class) + public void testCombinePerKeyInstanceMethodReference() { + + PCollection> output = pipeline + .apply(Create.of(KV.of("a", 1), KV.of("b", 2), KV.of("a", 3), KV.of("c", 4))) + .apply(Combine.perKey(new Summer()::sum)); + + PAssert.that(output).containsInAnyOrder( + KV.of("a", 4), + KV.of("b", 2), + KV.of("c", 4)); + pipeline.run(); + } + + /** + * Tests that we can serialize {@link Combine.CombineFn CombineFns} constructed from a lambda. + * Lambdas can be problematic because the {@link Class} object is synthetic and cannot be + * deserialized. + */ + @Test + public void testLambdaSerialization() { + SerializableFunction, Object> combiner = xs -> Iterables.getFirst(xs, 0); + + boolean lambdaClassSerializationThrows; + try { + SerializableUtils.clone(combiner.getClass()); + lambdaClassSerializationThrows = false; + } catch (IllegalArgumentException e) { + // Expected + lambdaClassSerializationThrows = true; + } + Assume.assumeTrue("Expected lambda class serialization to fail. " + + "If it's fixed, we can remove special behavior in Combine.", + lambdaClassSerializationThrows); + + + Combine.Globally combine = Combine.globally(combiner); + SerializableUtils.clone(combine); // should not throw. + } + + @Test + public void testLambdaDisplayData() { + Combine.Globally combine = Combine.globally(xs -> Iterables.getFirst(xs, 0)); + DisplayData displayData = DisplayData.from(combine); + MatcherAssert.assertThat(displayData.items(), not(empty())); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DistinctTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DistinctTest.java index 802b937cd6b1..5835c1e50503 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DistinctTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DistinctTest.java @@ -17,13 +17,21 @@ */ package org.apache.beam.sdk.transforms; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; @@ -48,6 +56,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -57,6 +66,9 @@ public class DistinctTest { @Rule public final TestPipeline p = TestPipeline.create(); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test @Category(ValidatesRunner.class) public void testDistinct() { @@ -273,4 +285,53 @@ public void testTriggeredDistinctRepresentativeValuesEmpty() { PAssert.that(distinctValues).containsInAnyOrder(KV.of(1, "k1")); triggeredDistinctRepresentativePipeline.run(); } + + @Test + public void withLambdaRepresentativeValuesFnNoTypeDescriptorShouldThrow() { + + Multimap predupedContents = HashMultimap.create(); + predupedContents.put(3, "foo"); + predupedContents.put(4, "foos"); + predupedContents.put(6, "barbaz"); + predupedContents.put(6, "bazbar"); + PCollection dupes = + p.apply(Create.of("foo", "foos", "barbaz", "barbaz", "bazbar", "foo")); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Unable to return a default Coder for RemoveRepresentativeDupes"); + + // Thrown when applying a transform to the internal WithKeys that withRepresentativeValueFn is + // implemented with + dupes.apply("RemoveRepresentativeDupes", Distinct.withRepresentativeValueFn(String::length)); + } + + @Test + @Category(NeedsRunner.class) + public void withLambdaRepresentativeValuesFnAndTypeDescriptorShouldApplyFn() { + + PCollection dupes = + p.apply(Create.of("foo", "foos", "barbaz", "barbaz", "bazbar", "foo")); + PCollection deduped = + dupes.apply( + Distinct.withRepresentativeValueFn(String::length) + .withRepresentativeType(TypeDescriptor.of(Integer.class))); + + PAssert.that(deduped).satisfies((Iterable strs) -> { + Multimap predupedContents = HashMultimap.create(); + predupedContents.put(3, "foo"); + predupedContents.put(4, "foos"); + predupedContents.put(6, "barbaz"); + predupedContents.put(6, "bazbar"); + + Set seenLengths = new HashSet<>(); + for (String s : strs) { + assertThat(predupedContents.values(), hasItem(s)); + assertThat(seenLengths, not(contains(s.length()))); + seenLengths.add(s.length()); + } + return null; + }); + + p.run(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java index 8e70dcba555d..88ef81f49b18 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java @@ -416,6 +416,7 @@ private enum LifecycleState { INSIDE_BUNDLE, TORN_DOWN } + private LifecycleState state = LifecycleState.UNINITIALIZED; @Setup diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java index a2c5ad532609..afe5b7cfeee1 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import java.io.Serializable; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; @@ -29,6 +30,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -61,6 +63,9 @@ public Boolean apply(Integer elem) { @Rule public final TestPipeline p = TestPipeline.create(); + @Rule + public transient ExpectedException thrown = ExpectedException.none(); + @Test @Category(ValidatesRunner.class) public void testIdentityFilterByPredicate() { @@ -161,4 +166,74 @@ public void testDisplayData() { assertThat(DisplayData.from(Filter.equal(567)), hasDisplayItem("predicate", "x == 567")); } + + @Test + @Category(ValidatesRunner.class) + public void testIdentityFilterByPredicateWithLambda() { + + PCollection output = p + .apply(Create.of(591, 11789, 1257, 24578, 24799, 307)) + .apply(Filter.by(i -> true)); + + PAssert.that(output).containsInAnyOrder(591, 11789, 1257, 24578, 24799, 307); + p.run(); + } + + @Test + @Category(ValidatesRunner.class) + public void testNoFilterByPredicateWithLambda() { + + PCollection output = p + .apply(Create.of(1, 2, 4, 5)) + .apply(Filter.by(i -> false)); + + PAssert.that(output).empty(); + p.run(); + } + + @Test + @Category(ValidatesRunner.class) + public void testFilterByPredicateWithLambda() { + + PCollection output = p + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.by(i -> i % 2 == 0)); + + PAssert.that(output).containsInAnyOrder(2, 4, 6); + p.run(); + } + + /** + * Confirms that in Java 8 style, where a lambda results in a rawtype, the output type token is + * not useful. If this test ever fails there may be simplifications available to us. + */ + @Test + public void testFilterParDoOutputTypeDescriptorRawWithLambda() throws Exception { + + @SuppressWarnings({"unchecked", "rawtypes"}) + PCollection output = p + .apply(Create.of("hello")) + .apply(Filter.by(s -> true)); + + thrown.expect(CannotProvideCoderException.class); + p.getCoderRegistry().getCoder(output.getTypeDescriptor()); + } + + @Test + @Category(ValidatesRunner.class) + public void testFilterByMethodReferenceWithLambda() { + + PCollection output = p + .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) + .apply(Filter.by(new EvenFilter()::isEven)); + + PAssert.that(output).containsInAnyOrder(2, 4, 6); + p.run(); + } + + private static class EvenFilter implements Serializable { + public boolean isEven(int i) { + return i % 2 == 0; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java index 18bd8413d94b..7c15c0965ab0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java @@ -39,6 +39,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -224,4 +225,46 @@ public Iterable> apply(KV input) { })); } } + + /** + * Basic test of {@link FlatMapElements} with a lambda (which is instantiated as a + * {@link SerializableFunction}). + */ + @Test + @Category(NeedsRunner.class) + public void testFlatMapBasicWithLambda() throws Exception { + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(FlatMapElements + // Note that the input type annotation is required. + .into(TypeDescriptors.integers()) + .via((Integer i) -> ImmutableList.of(i, -i))); + + PAssert.that(output).containsInAnyOrder(1, 3, -1, -3, 2, -2); + pipeline.run(); + } + + /** + * Basic test of {@link FlatMapElements} with a method reference. + */ + @Test + @Category(NeedsRunner.class) + public void testFlatMapMethodReference() throws Exception { + + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(FlatMapElements + // Note that the input type annotation is required. + .into(TypeDescriptors.integers()) + .via(new Negater()::numAndNegation)); + + PAssert.that(output).containsInAnyOrder(1, 3, -1, -3, 2, -2); + pipeline.run(); + } + + private static class Negater implements Serializable { + public List numAndNegation(int input) { + return ImmutableList.of(input, -input); + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java index 39a65d15366e..96a4cc8b8580 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -296,4 +297,66 @@ public KV apply(KV input) { })); } } + + /** + * Basic test of {@link MapElements} with a lambda (which is instantiated as a {@link + * SerializableFunction}). + */ + @Test + @Category(NeedsRunner.class) + public void testMapLambda() throws Exception { + + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(MapElements + // Note that the type annotation is required. + .into(TypeDescriptors.integers()) + .via((Integer i) -> i * 2)); + + PAssert.that(output).containsInAnyOrder(6, 2, 4); + pipeline.run(); + } + + /** + * Basic test of {@link MapElements} with a lambda wrapped into a {@link SimpleFunction} to + * remember its type. + */ + @Test + @Category(NeedsRunner.class) + public void testMapWrappedLambda() throws Exception { + + PCollection output = + pipeline + .apply(Create.of(1, 2, 3)) + .apply( + MapElements + .via(new SimpleFunction((Integer i) -> i * 2) {})); + + PAssert.that(output).containsInAnyOrder(6, 2, 4); + pipeline.run(); + } + + /** + * Basic test of {@link MapElements} with a method reference. + */ + @Test + @Category(NeedsRunner.class) + public void testMapMethodReference() throws Exception { + + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + .apply(MapElements + // Note that the type annotation is required. + .into(TypeDescriptors.integers()) + .via(new Doubler()::doubleIt)); + + PAssert.that(output).containsInAnyOrder(6, 2, 4); + pipeline.run(); + } + + private static class Doubler implements Serializable { + public int doubleIt(int val) { + return val * 2; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 4d9776ec0ee0..464ab187d14a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -144,6 +144,7 @@ public void processElement(ProcessContext c, BoundedWindow window) { + ":" + window.maxTimestamp().getMillis()); } } + static class TestNoOutputDoFn extends DoFn { @ProcessElement public void processElement(DoFn.ProcessContext c) throws Exception {} @@ -964,6 +965,7 @@ private static class TestDummy { } private static class TestDummyCoder extends AtomicCoder { private TestDummyCoder() { } + private static final TestDummyCoder INSTANCE = new TestDummyCoder(); @JsonCreator @@ -1016,6 +1018,7 @@ private static class MainOutputDummyFn extends DoFn { public MainOutputDummyFn(TupleTag intOutputTag) { this.intOutputTag = intOutputTag; } + @ProcessElement public void processElement(ProcessContext c) { c.output(new TestDummy()); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java index f43c162c7232..4977d0e8b2c8 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java @@ -25,6 +25,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -148,4 +149,34 @@ public void testDisplayData() { assertThat(displayData, hasDisplayItem("numPartitions", 123)); assertThat(displayData, hasDisplayItem("partitionFn", IdentityFn.class)); } + + @Test + @Category(NeedsRunner.class) + public void testModPartitionWithLambda() { + + PCollectionList outputs = pipeline + .apply(Create.of(1, 2, 4, 5)) + .apply(Partition.of(3, (element, numPartitions) -> element % numPartitions)); + assertEquals(3, outputs.size()); + PAssert.that(outputs.get(0)).empty(); + PAssert.that(outputs.get(1)).containsInAnyOrder(1, 4); + PAssert.that(outputs.get(2)).containsInAnyOrder(2, 5); + pipeline.run(); + } + + /** + * Confirms that in Java 8 style, where a lambda results in a rawtype, the output type token is + * not useful. If this test ever fails there may be simplifications available to us. + */ + @Test + @Category(NeedsRunner.class) + public void testPartitionFnOutputTypeDescriptorRaw() throws Exception { + + PCollectionList output = pipeline + .apply(Create.of("hello")) + .apply(Partition.of(1, (element, numPartitions) -> 0)); + + thrown.expect(CannotProvideCoderException.class); + pipeline.getCoderRegistry().getCoder(output.get(0).getTypeDescriptor()); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionTest.java index bcfb5588396a..4fcaea303f1e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionTest.java @@ -17,6 +17,10 @@ */ package org.apache.beam.sdk.transforms; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.values.TypeDescriptors; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -40,4 +44,33 @@ public void testFailureIfNotOverridden() { SimpleFunction broken = new SimpleFunction() {}; } + + /** + * Basic test of {@link MapElements} with a lambda (which is instantiated as a {@link + * SerializableFunction}). + */ + @Test + public void testGoodTypeForLambda() throws Exception { + SimpleFunction fn = new SimpleFunction(Object::toString) {}; + + assertThat(fn.getInputTypeDescriptor(), equalTo(TypeDescriptors.integers())); + assertThat(fn.getOutputTypeDescriptor(), equalTo(TypeDescriptors.strings())); + } + + /** + * Basic test of {@link MapElements} with a lambda wrapped into a {@link SimpleFunction} to + * remember its type. + */ + @Test + public void testGoodTypeForMethodRef() throws Exception { + SimpleFunction fn = + new SimpleFunction(SimpleFunctionTest::toStringThisThing) {}; + + assertThat(fn.getInputTypeDescriptor(), equalTo(TypeDescriptors.integers())); + assertThat(fn.getOutputTypeDescriptor(), equalTo(TypeDescriptors.strings())); + } + + private static String toStringThisThing(Integer i) { + return i.toString(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index 596a335cb484..5c2281891df4 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -223,7 +223,7 @@ public ProcessContinuation processElement(ProcessContext c, OffsetRangeTracker t int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); for (int i = trueStart, numIterations = 1; - tracker.tryClaim(blockStarts[i]); + tracker.tryClaim((long) blockStarts[i]); ++i, ++numIterations) { for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { c.output(index); @@ -351,7 +351,7 @@ public ProcessContinuation processElement(ProcessContext c, OffsetRangeTracker t int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); for (int i = trueStart, numIterations = 1; - tracker.tryClaim(blockStarts[i]); + tracker.tryClaim((long) blockStarts[i]); ++i, ++numIterations) { for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { c.output(KV.of(c.sideInput(sideInput) + ":" + c.element(), index)); @@ -516,7 +516,7 @@ private enum State { @ProcessElement public void processElement(ProcessContext c, OffsetRangeTracker tracker) { assertEquals(State.INSIDE_BUNDLE, state); - assertTrue(tracker.tryClaim(0)); + assertTrue(tracker.tryClaim(0L)); c.output(c.element()); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WaitTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WaitTest.java new file mode 100644 index 000000000000..a34b667b0cfd --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WaitTest.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.junit.Assert.assertFalse; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.Lists; +import java.io.Serializable; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testing.UsesTestStream; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link Wait}. */ +@RunWith(JUnit4.class) +public class WaitTest implements Serializable { + @Rule public transient TestPipeline p = TestPipeline.create(); + + private static class Event { + private final Instant processingTime; + private final TimestampedValue element; + private final Instant watermarkUpdate; + + private Event(Instant processingTime, TimestampedValue element) { + this.processingTime = processingTime; + this.element = element; + this.watermarkUpdate = null; + } + + private Event(Instant processingTime, Instant watermarkUpdate) { + this.processingTime = processingTime; + this.element = null; + this.watermarkUpdate = watermarkUpdate; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("processingTime", processingTime) + .add("element", element) + .add("watermarkUpdate", watermarkUpdate) + .toString(); + } + } + + /** + * Generates a {@link TestStream} of the given duration containing the values [0, numElements) and + * the same number of random but monotonic watermark updates, with each element within + * allowedLateness of the respective watermark update. + * + *

TODO: Consider moving this into TestStream if it's useful enough. + */ + private PCollection generateStreamWithBoundedDisorder( + String name, + Instant base, + Duration totalDuration, + int numElements, + Duration allowedLateness) { + TestStream.Builder stream = TestStream.create(VarLongCoder.of()); + + // Generate numElements random watermark updates. After each one also generate an element within + // allowedLateness of it. + List watermarks = Lists.newArrayList(); + for (int i = 0; i < numElements; ++i) { + watermarks.add(base.plus(new Duration((long) (totalDuration.getMillis() * Math.random())))); + } + Collections.sort(watermarks); + + List> events = Lists.newArrayList(); + for (int i = 0; i < numElements; ++i) { + Instant processingTimestamp = + base.plus((long) (1.0 * i * totalDuration.getMillis() / (numElements + 1))); + Instant watermark = watermarks.get(i); + Instant elementTimestamp = + watermark.minus((long) (Math.random() * allowedLateness.getMillis())); + events.add(new Event<>(processingTimestamp, watermark)); + events.add(new Event<>(processingTimestamp, TimestampedValue.of((long) i, elementTimestamp))); + } + + Instant lastProcessingTime = base; + for (Event event : events) { + Duration processingTimeDelta = new Duration(lastProcessingTime, event.processingTime); + if (processingTimeDelta.getMillis() > 0) { + stream = stream.advanceProcessingTime(processingTimeDelta); + } + lastProcessingTime = event.processingTime; + + if (event.element != null) { + stream = stream.addElements(event.element); + } else { + stream = stream.advanceWatermarkTo(event.watermarkUpdate); + } + } + return p.apply(name, stream.advanceWatermarkToInfinity()); + } + + private static final AtomicReference TEST_WAIT_MAX_MAIN_TIMESTAMP = + new AtomicReference<>(); + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWaitWithSameFixedWindows() { + testWaitWithParameters( + Duration.standardMinutes(1) /* duration */, + Duration.standardSeconds(15) /* lateness */, + 20 /* numMainElements */, + FixedWindows.of(Duration.standardSeconds(15)), + 20 /* numSignalElements */, + FixedWindows.of(Duration.standardSeconds(15))); + } + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWaitWithDifferentFixedWindows() { + testWaitWithParameters( + Duration.standardMinutes(1) /* duration */, + Duration.standardSeconds(15) /* lateness */, + 20 /* numMainElements */, + FixedWindows.of(Duration.standardSeconds(15)), + 20 /* numSignalElements */, + FixedWindows.of(Duration.standardSeconds(7))); + } + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWaitWithSignalInSlidingWindows() { + testWaitWithParameters( + Duration.standardMinutes(1) /* duration */, + Duration.standardSeconds(15) /* lateness */, + 20 /* numMainElements */, + FixedWindows.of(Duration.standardSeconds(15)), + 20 /* numSignalElements */, + SlidingWindows.of(Duration.standardSeconds(7)).every(Duration.standardSeconds(1))); + } + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWaitInGlobalWindow() { + testWaitWithParameters( + Duration.standardMinutes(1) /* duration */, + Duration.standardSeconds(15) /* lateness */, + 20 /* numMainElements */, + new GlobalWindows(), + 20 /* numSignalElements */, + new GlobalWindows()); + } + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWaitWithSomeSignalWindowsEmpty() { + testWaitWithParameters( + Duration.standardMinutes(1) /* duration */, + Duration.standardSeconds(0) /* lateness */, + 20 /* numMainElements */, + FixedWindows.of(Duration.standardSeconds(1)), + 10 /* numSignalElements */, + FixedWindows.of(Duration.standardSeconds(1))); + } + + /** + * Tests the {@link Wait} transform with a given configuration of the main input and the signal + * input. Specifically, generates random streams with bounded lateness for main and signal inputs + * and tests the property that, after observing a main input element with timestamp Tmain, no + * signal elements are observed with timestamp Tsig < Tmain. + * + * @param duration event-time duration of both inputs + * @param lateness bound on the lateness of elements in both inputs + * @param numMainElements number of elements in the main input + * @param mainWindowFn windowing function of the main input + * @param numSignalElements number of elements in the signal input + * @param signalWindowFn windowing function of the signal input. + */ + private void testWaitWithParameters( + Duration duration, + Duration lateness, + int numMainElements, + WindowFn mainWindowFn, + int numSignalElements, + WindowFn signalWindowFn) { + TEST_WAIT_MAX_MAIN_TIMESTAMP.set(null); + + Instant base = Instant.now(); + + PCollection input = + generateStreamWithBoundedDisorder("main", base, duration, numMainElements, lateness) + .apply( + "Window main", + Window.into(mainWindowFn) + .discardingFiredPanes() + // Use an aggressive trigger for main input and signal to get more + // frequent / aggressive verification. + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .withAllowedLateness(lateness)) + .apply("Fire main", new Fire<>()); + + PCollection signal = + generateStreamWithBoundedDisorder("signal", base, duration, numSignalElements, lateness) + .apply( + "Window signal", + Window.into(signalWindowFn) + .discardingFiredPanes() + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .withAllowedLateness(lateness)) + .apply("Fire signal", new Fire<>()) + .apply( + "Check sequencing", + ParDo.of( + new DoFn() { + @ProcessElement + public void process(ProcessContext c) { + Instant maxMainTimestamp = TEST_WAIT_MAX_MAIN_TIMESTAMP.get(); + if (maxMainTimestamp != null) { + assertFalse( + "Signal at timestamp " + + c.timestamp() + + " generated after main timestamp progressed to " + + maxMainTimestamp, + c.timestamp().isBefore(maxMainTimestamp)); + } + c.output(c.element()); + } + })); + + PCollection output = input.apply(Wait.on(signal)); + + output.apply( + "Update main timestamp", + ParDo.of( + new DoFn() { + @ProcessElement + public void process(ProcessContext c, BoundedWindow w) { + while (true) { + Instant maxMainTimestamp = TEST_WAIT_MAX_MAIN_TIMESTAMP.get(); + Instant newMaxTimestamp = + (maxMainTimestamp == null || c.timestamp().isAfter(maxMainTimestamp)) + ? c.timestamp() + : maxMainTimestamp; + if (TEST_WAIT_MAX_MAIN_TIMESTAMP.compareAndSet( + maxMainTimestamp, newMaxTimestamp)) { + break; + } + } + c.output(c.element()); + } + })); + + List expectedOutput = Lists.newArrayList(); + for (int i = 0; i < numMainElements; ++i) { + expectedOutput.add((long) i); + } + PAssert.that(output).containsInAnyOrder(expectedOutput); + + p.run(); + } + + private static class Fire extends PTransform, PCollection> { + @Override + public PCollection expand(PCollection input) { + return input + .apply(WithKeys.of("")) + .apply(GroupByKey.create()) + .apply(Values.create()) + .apply(Flatten.iterables()); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java index 2d0e6e32f098..fcece90dece3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java @@ -27,7 +27,6 @@ import static org.joda.time.Duration.standardSeconds; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -36,10 +35,12 @@ import com.google.common.collect.Lists; import com.google.common.collect.Ordering; import com.google.common.collect.Sets; +import com.google.common.hash.HashCode; import java.io.Serializable; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -496,31 +497,11 @@ private static GrowthTracker newTracker() { return newTracker(new GrowthState<>(never().forNewInput(Instant.now(), null))); } - @Test - public void testGrowthTrackerCheckpointEmpty() { - // Checkpoint an empty tracker. - GrowthTracker tracker = newTracker(); - GrowthState residual = tracker.checkpoint(); - GrowthState primary = tracker.currentRestriction(); - Watch.Growth.Never condition = never(); - assertEquals( - primary.toString(condition), - new GrowthState<>( - Collections.emptyMap() /* completed */, - Collections.>emptyList() /* pending */, - true /* isOutputFinal */, - (Integer) null /* terminationState */, - BoundedWindow.TIMESTAMP_MAX_VALUE /* pollWatermark */) - .toString(condition)); - assertEquals( - residual.toString(condition), - new GrowthState<>( - Collections.emptyMap() /* completed */, - Collections.>emptyList() /* pending */, - false /* isOutputFinal */, - 0 /* terminationState */, - BoundedWindow.TIMESTAMP_MIN_VALUE /* pollWatermark */) - .toString(condition)); + private String tryClaimNextPending(GrowthTracker tracker) { + assertTrue(tracker.hasPending()); + Map.Entry> entry = tracker.getNextPending(); + tracker.tryClaim(entry.getKey()); + return entry.getValue().getValue(); } @Test @@ -537,10 +518,8 @@ public void testGrowthTrackerCheckpointNonEmpty() { .withWatermark(now.plus(standardSeconds(7)))); assertEquals(now.plus(standardSeconds(1)), tracker.getWatermark()); - assertTrue(tracker.hasPending()); - assertEquals("a", tracker.tryClaimNextPending().getValue()); - assertTrue(tracker.hasPending()); - assertEquals("b", tracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(tracker)); + assertEquals("b", tryClaimNextPending(tracker)); assertTrue(tracker.hasPending()); assertEquals(now.plus(standardSeconds(3)), tracker.getWatermark()); @@ -550,10 +529,8 @@ public void testGrowthTrackerCheckpointNonEmpty() { // Verify primary: should contain what the current tracker claimed, and nothing else. assertEquals(now.plus(standardSeconds(1)), primaryTracker.getWatermark()); - assertTrue(primaryTracker.hasPending()); - assertEquals("a", primaryTracker.tryClaimNextPending().getValue()); - assertTrue(primaryTracker.hasPending()); - assertEquals("b", primaryTracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(primaryTracker)); + assertEquals("b", tryClaimNextPending(primaryTracker)); assertFalse(primaryTracker.hasPending()); assertFalse(primaryTracker.shouldPollMore()); // No more pending elements in primary restriction, and no polling. @@ -562,19 +539,16 @@ public void testGrowthTrackerCheckpointNonEmpty() { // Verify residual: should contain what the current tracker didn't claim. assertEquals(now.plus(standardSeconds(3)), residualTracker.getWatermark()); - assertTrue(residualTracker.hasPending()); - assertEquals("c", residualTracker.tryClaimNextPending().getValue()); - assertTrue(residualTracker.hasPending()); - assertEquals("d", residualTracker.tryClaimNextPending().getValue()); + assertEquals("c", tryClaimNextPending(residualTracker)); + assertEquals("d", tryClaimNextPending(residualTracker)); assertFalse(residualTracker.hasPending()); assertTrue(residualTracker.shouldPollMore()); // No more pending elements in residual restriction, but poll watermark still holds. assertEquals(now.plus(standardSeconds(7)), residualTracker.getWatermark()); // Verify current tracker: it was checkpointed, so should contain nothing else. - assertNull(tracker.tryClaimNextPending()); - tracker.checkDone(); assertFalse(tracker.hasPending()); + tracker.checkDone(); assertFalse(tracker.shouldPollMore()); assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark()); } @@ -592,10 +566,10 @@ public void testGrowthTrackerOutputFullyBeforeCheckpointIncomplete() { TimestampedValue.of("b", now.plus(standardSeconds(2))))) .withWatermark(now.plus(standardSeconds(7)))); - assertEquals("a", tracker.tryClaimNextPending().getValue()); - assertEquals("b", tracker.tryClaimNextPending().getValue()); - assertEquals("c", tracker.tryClaimNextPending().getValue()); - assertEquals("d", tracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(tracker)); + assertEquals("b", tryClaimNextPending(tracker)); + assertEquals("c", tryClaimNextPending(tracker)); + assertEquals("d", tryClaimNextPending(tracker)); assertFalse(tracker.hasPending()); assertEquals(now.plus(standardSeconds(7)), tracker.getWatermark()); @@ -605,14 +579,10 @@ public void testGrowthTrackerOutputFullyBeforeCheckpointIncomplete() { // Verify primary: should contain what the current tracker claimed, and nothing else. assertEquals(now.plus(standardSeconds(1)), primaryTracker.getWatermark()); - assertTrue(primaryTracker.hasPending()); - assertEquals("a", primaryTracker.tryClaimNextPending().getValue()); - assertTrue(primaryTracker.hasPending()); - assertEquals("b", primaryTracker.tryClaimNextPending().getValue()); - assertTrue(primaryTracker.hasPending()); - assertEquals("c", primaryTracker.tryClaimNextPending().getValue()); - assertTrue(primaryTracker.hasPending()); - assertEquals("d", primaryTracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(primaryTracker)); + assertEquals("b", tryClaimNextPending(primaryTracker)); + assertEquals("c", tryClaimNextPending(primaryTracker)); + assertEquals("d", tryClaimNextPending(primaryTracker)); assertFalse(primaryTracker.hasPending()); assertFalse(primaryTracker.shouldPollMore()); // No more pending elements in primary restriction, and no polling. @@ -645,10 +615,10 @@ public void testGrowthTrackerPollAfterCheckpointIncompleteWithNewOutputs() { TimestampedValue.of("b", now.plus(standardSeconds(2))))) .withWatermark(now.plus(standardSeconds(7)))); - assertEquals("a", tracker.tryClaimNextPending().getValue()); - assertEquals("b", tracker.tryClaimNextPending().getValue()); - assertEquals("c", tracker.tryClaimNextPending().getValue()); - assertEquals("d", tracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(tracker)); + assertEquals("b", tryClaimNextPending(tracker)); + assertEquals("c", tryClaimNextPending(tracker)); + assertEquals("d", tryClaimNextPending(tracker)); GrowthState checkpoint = tracker.checkpoint(); // Simulate resuming from the checkpoint and adding more elements. @@ -666,9 +636,9 @@ public void testGrowthTrackerPollAfterCheckpointIncompleteWithNewOutputs() { .withWatermark(now.plus(standardSeconds(12)))); assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark()); - assertEquals("e", residualTracker.tryClaimNextPending().getValue()); + assertEquals("e", tryClaimNextPending(residualTracker)); assertEquals(now.plus(standardSeconds(8)), residualTracker.getWatermark()); - assertEquals("f", residualTracker.tryClaimNextPending().getValue()); + assertEquals("f", tryClaimNextPending(residualTracker)); assertFalse(residualTracker.hasPending()); assertTrue(residualTracker.shouldPollMore()); @@ -688,9 +658,9 @@ public void testGrowthTrackerPollAfterCheckpointIncompleteWithNewOutputs() { TimestampedValue.of("f", now.plus(standardSeconds(8)))))); assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark()); - assertEquals("e", residualTracker.tryClaimNextPending().getValue()); + assertEquals("e", tryClaimNextPending(residualTracker)); assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark()); - assertEquals("f", residualTracker.tryClaimNextPending().getValue()); + assertEquals("f", tryClaimNextPending(residualTracker)); assertFalse(residualTracker.hasPending()); assertTrue(residualTracker.shouldPollMore()); @@ -711,10 +681,10 @@ public void testGrowthTrackerPollAfterCheckpointWithoutNewOutputs() { TimestampedValue.of("b", now.plus(standardSeconds(2))))) .withWatermark(now.plus(standardSeconds(7)))); - assertEquals("a", tracker.tryClaimNextPending().getValue()); - assertEquals("b", tracker.tryClaimNextPending().getValue()); - assertEquals("c", tracker.tryClaimNextPending().getValue()); - assertEquals("d", tracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(tracker)); + assertEquals("b", tryClaimNextPending(tracker)); + assertEquals("c", tryClaimNextPending(tracker)); + assertEquals("d", tryClaimNextPending(tracker)); // Simulate resuming from the checkpoint but there are no new elements. GrowthState checkpoint = tracker.checkpoint(); @@ -759,10 +729,10 @@ public void testGrowthTrackerPollAfterCheckpointWithoutNewOutputsNoWatermark() { TimestampedValue.of("c", now.plus(standardSeconds(3))), TimestampedValue.of("a", now.plus(standardSeconds(1))), TimestampedValue.of("b", now.plus(standardSeconds(2)))))); - assertEquals("a", tracker.tryClaimNextPending().getValue()); - assertEquals("b", tracker.tryClaimNextPending().getValue()); - assertEquals("c", tracker.tryClaimNextPending().getValue()); - assertEquals("d", tracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(tracker)); + assertEquals("b", tryClaimNextPending(tracker)); + assertEquals("c", tryClaimNextPending(tracker)); + assertEquals("d", tryClaimNextPending(tracker)); assertEquals(now.plus(standardSeconds(1)), tracker.getWatermark()); // Simulate resuming from the checkpoint but there are no new elements. @@ -786,12 +756,6 @@ public void testGrowthTrackerRepeatedEmptyPollWatermark() { GrowthTracker tracker = newTracker(); tracker.addNewAsPending(PollResult.incomplete(Collections.emptyList())); assertEquals(BoundedWindow.TIMESTAMP_MIN_VALUE, tracker.getWatermark()); - - // Simulate resuming from the checkpoint but there are still no new elements. - GrowthTracker residualTracker = newTracker(tracker.checkpoint()); - tracker.addNewAsPending(PollResult.incomplete(Collections.emptyList())); - // No new elements and no explicit watermark supplied - still no watermark. - assertEquals(BoundedWindow.TIMESTAMP_MIN_VALUE, residualTracker.getWatermark()); } // Empty poll result with watermark { @@ -801,12 +765,6 @@ public void testGrowthTrackerRepeatedEmptyPollWatermark() { PollResult.incomplete(Collections.>emptyList()) .withWatermark(now)); assertEquals(now, tracker.getWatermark()); - - // Simulate resuming from the checkpoint but there are still no new elements. - GrowthTracker residualTracker = newTracker(tracker.checkpoint()); - tracker.addNewAsPending(PollResult.incomplete(Collections.emptyList())); - // No new elements and no explicit watermark supplied - should keep old watermark. - assertEquals(now, residualTracker.getWatermark()); } } @@ -822,10 +780,10 @@ public void testGrowthTrackerOutputFullyBeforeCheckpointComplete() { TimestampedValue.of("a", now.plus(standardSeconds(1))), TimestampedValue.of("b", now.plus(standardSeconds(2)))))); - assertEquals("a", tracker.tryClaimNextPending().getValue()); - assertEquals("b", tracker.tryClaimNextPending().getValue()); - assertEquals("c", tracker.tryClaimNextPending().getValue()); - assertEquals("d", tracker.tryClaimNextPending().getValue()); + assertEquals("a", tryClaimNextPending(tracker)); + assertEquals("b", tryClaimNextPending(tracker)); + assertEquals("c", tryClaimNextPending(tracker)); + assertEquals("d", tryClaimNextPending(tracker)); assertFalse(tracker.hasPending()); assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark()); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java index 57b887ff32a3..97614d68633f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java @@ -25,12 +25,14 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -74,6 +76,9 @@ public class WithKeysTest { @Rule public final TestPipeline p = TestPipeline.create(); + @Rule + public ExpectedException thrown = ExpectedException.none(); + @Test @Category(NeedsRunner.class) public void testExtractKeys() { @@ -150,4 +155,33 @@ public Integer apply(String value) { return value.length(); } } + + @Test + @Category(ValidatesRunner.class) + public void withLambdaAndTypeDescriptorShouldSucceed() { + + PCollection values = p.apply(Create.of("1234", "3210", "0", "-12")); + PCollection> kvs = values.apply( + WithKeys.of((SerializableFunction) Integer::valueOf) + .withKeyType(TypeDescriptor.of(Integer.class))); + + PAssert.that(kvs).containsInAnyOrder( + KV.of(1234, "1234"), KV.of(0, "0"), KV.of(-12, "-12"), KV.of(3210, "3210")); + + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void withLambdaAndNoTypeDescriptorShouldThrow() { + + PCollection values = p.apply(Create.of("1234", "3210", "0", "-12")); + + values.apply("ApplyKeysWithWithKeys", WithKeys.of(Integer::valueOf)); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Unable to return a default Coder for ApplyKeysWithWithKeys"); + + p.run(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsTest.java index 02ce55968dd5..77cac64fde7a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsTest.java @@ -173,4 +173,34 @@ public void withTimestampsWithNullFnShouldThrowOnConstruction() { p.run(); } + + @Test + @Category(ValidatesRunner.class) + public void withTimestampsLambdaShouldApplyTimestamps() { + + final String yearTwoThousand = "946684800000"; + PCollection timestamped = + p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) + .apply(WithTimestamps.of((String input) -> new Instant(Long.valueOf(input)))); + + PCollection> timestampedVals = + timestamped.apply(ParDo.of(new DoFn>() { + @ProcessElement + public void processElement(ProcessContext c) + throws Exception { + c.output(KV.of(c.element(), c.timestamp())); + } + })); + + PAssert.that(timestamped) + .containsInAnyOrder(yearTwoThousand, "0", "1234", Integer.toString(Integer.MAX_VALUE)); + PAssert.that(timestampedVals) + .containsInAnyOrder( + KV.of("0", new Instant(0)), + KV.of("1234", new Instant(Long.valueOf("1234"))), + KV.of(Integer.toString(Integer.MAX_VALUE), new Instant(Integer.MAX_VALUE)), + KV.of(yearTwoThousand, new Instant(Long.valueOf(yearTwoThousand)))); + + p.run(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java index 9b24b698549a..939c29b41a25 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java @@ -366,8 +366,8 @@ public void populateDisplayData(Builder builder) { .addIfNotDefault(DisplayData.item("Double", Double.valueOf(1.23)), Double.valueOf(1.23)) .addIfNotDefault(DisplayData.item("boolean", true), true) .addIfNotDefault( - DisplayData.item("Boolean", Boolean.valueOf(true)), - Boolean.valueOf(true)); + DisplayData.item("Boolean", Boolean.TRUE), + Boolean.TRUE); } }); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGroupByKeyTest.java index 2ec43c5497ea..2b34c03b5dfe 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGroupByKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGroupByKeyTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.io.Serializable; import java.util.ArrayList; @@ -30,12 +31,13 @@ import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; @@ -398,6 +400,7 @@ public void processElement(ProcessContext c) { */ @SuppressWarnings("unchecked") @Test + @Category(NeedsRunner.class) public void testConsumingDoFn() throws Exception { TupleTag purchasesTag = new TupleTag<>(); TupleTag addressesTag = new TupleTag<>(); @@ -424,19 +427,30 @@ public void testConsumingDoFn() throws Exception { .and(addressesTag, Arrays.asList("8a")) .and(namesTag, new ArrayList<>()); - List> results = - DoFnTester.of( - new CorrelatePurchaseCountForAddressesWithoutNamesFn( - purchasesTag, - addressesTag, - namesTag)) - .processBundle( - KV.of(1, result1), - KV.of(2, result2), - KV.of(3, result3), - KV.of(4, result4)); - - assertThat(results, containsInAnyOrder(KV.of("4a", 2), KV.of("8a", 0))); + KvCoder coder = KvCoder.of( + VarIntCoder.of(), + CoGbkResult.CoGbkResultCoder.of( + CoGbkResultSchema.of( + ImmutableList.of(purchasesTag, addressesTag, namesTag)), + UnionCoder.of( + ImmutableList.of( + StringUtf8Coder.of(), + StringUtf8Coder.of(), + StringUtf8Coder.of())))); + + PCollection> results = + p.apply( + Create.of( + KV.of(1, result1), KV.of(2, result2), KV.of(3, result3), KV.of(4, result4)) + .withCoder(coder)) + .apply( + ParDo.of( + new CorrelatePurchaseCountForAddressesWithoutNamesFn( + purchasesTag, addressesTag, namesTag))); + + PAssert.that(results).containsInAnyOrder(KV.of("4a", 2), KV.of("8a", 0)); + + p.run(); } /** diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index cdd7f6007ff5..2c1575ab64c5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -120,6 +120,7 @@ class MockFn extends DoFn { @ProcessElement public void processElement(ProcessContext c) throws Exception {} } + MockFn mockFn = mock(MockFn.class); assertEquals(stop(), invokeProcessElement(mockFn)); verify(mockFn).processElement(mockProcessContext); @@ -180,6 +181,7 @@ class MockFn extends DoFn { @DoFn.ProcessElement public void processElement(ProcessContext c, IntervalWindow w) throws Exception {} } + MockFn fn = mock(MockFn.class); assertEquals(stop(), invokeProcessElement(fn)); verify(fn).processElement(mockProcessContext, mockWindow); @@ -204,6 +206,7 @@ class MockFn extends DoFn { public void processElement(ProcessContext c, @StateId(stateId) ValueState valueState) throws Exception {} } + MockFn fn = mock(MockFn.class); assertEquals(stop(), invokeProcessElement(fn)); verify(fn).processElement(mockProcessContext, mockState); @@ -230,6 +233,7 @@ public void processElement(ProcessContext c, @TimerId(timerId) Timer timer) @OnTimer(timerId) public void onTimer() {} } + MockFn fn = mock(MockFn.class); assertEquals(stop(), invokeProcessElement(fn)); verify(fn).processElement(mockProcessContext, mockTimer); @@ -254,6 +258,7 @@ public SomeRestrictionTracker newTracker(SomeRestriction restriction) { return null; } } + MockFn fn = mock(MockFn.class); when(fn.processElement(mockProcessContext, null)).thenReturn(resume()); assertEquals(resume(), invokeProcessElement(fn)); @@ -277,6 +282,7 @@ public void before() {} @Teardown public void after() {} } + MockFn fn = mock(MockFn.class); DoFnInvoker invoker = DoFnInvokers.invokerFor(fn); invoker.invokeSetup(); @@ -295,7 +301,7 @@ public void after() {} private static class SomeRestriction {} private abstract static class SomeRestrictionTracker - implements RestrictionTracker {} + extends RestrictionTracker {} private static class SomeRestrictionCoder extends AtomicCoder { public static SomeRestrictionCoder of() { @@ -385,7 +391,7 @@ public DoFn.ProcessContext processContext(DoFn f } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { return tracker; } })); @@ -399,7 +405,13 @@ public DefaultTracker newTracker() { } } - private static class DefaultTracker implements RestrictionTracker { + private static class DefaultTracker + extends RestrictionTracker { + @Override + protected boolean tryClaimImpl(Void position) { + throw new UnsupportedOperationException(); + } + @Override public RestrictionWithDefaultTracker currentRestriction() { throw new UnsupportedOperationException(); @@ -440,6 +452,7 @@ public RestrictionWithDefaultTracker getInitialRestriction(String element) { return null; } } + MockFn fn = mock(MockFn.class); DoFnInvoker invoker = DoFnInvokers.invokerFor(fn); @@ -653,7 +666,7 @@ public DoFn.ProcessContext processContext(DoFn doFn) { } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { return null; // will not be touched } }); @@ -754,7 +767,7 @@ public void onMyTimer(IntervalWindow w) { static class StableNameTestDoFn extends DoFn { @ProcessElement public void process() {} - }; + } /** * This is a change-detector test that the generated name is stable across runs. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java index 44ae5c4f2425..239cb2e2c257 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java @@ -39,7 +39,7 @@ public void testBadExtraProcessContextType() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage( "Integer is not a valid context parameter. " - + "Should be one of [BoundedWindow, RestrictionTracker]"); + + "Should be one of [BoundedWindow, RestrictionTracker]"); analyzeProcessElementMethod( new AnonymousMethod() { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 50621c1bca83..a76201bea13b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -55,7 +55,7 @@ private abstract static class SomeRestriction implements HasDefaultTracker {} private abstract static class SomeRestrictionTracker - implements RestrictionTracker {} + extends RestrictionTracker {} private abstract static class SomeRestrictionCoder extends StructuredCoder {} @@ -199,6 +199,7 @@ class UnsplittableFn extends DoFn { @ProcessElement public void process(ProcessContext context) {} } + assertEquals( PCollection.IsBounded.BOUNDED, DoFnSignatures @@ -332,7 +333,9 @@ public void process(ProcessContext context, SomeRestrictionTracker tracker) {} DoFnSignatures.getSignature(BadFn.class); } - abstract class SomeDefaultTracker implements RestrictionTracker {} + abstract class SomeDefaultTracker + extends RestrictionTracker {} + abstract class RestrictionWithDefaultTracker implements HasDefaultTracker {} @@ -390,7 +393,7 @@ public SomeRestriction getInitialRestriction(Integer element) { } thrown.expectMessage( - "Returns void, but must return a subtype of RestrictionTracker"); + "Returns void, but must return a subtype of RestrictionTracker"); DoFnSignatures.getSignature(BadFn.class); } @@ -578,7 +581,8 @@ private SomeRestrictionTracker method(SomeRestriction restriction, Object extra) @Test public void testNewTrackerInconsistent() throws Exception { thrown.expectMessage( - "Returns SomeRestrictionTracker, but must return a subtype of RestrictionTracker"); + "Returns SomeRestrictionTracker, " + + "but must return a subtype of RestrictionTracker"); DoFnSignatures.analyzeNewTrackerMethod( errors(), TypeDescriptor.of(FakeDoFn.class), diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/OnTimerInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/OnTimerInvokersTest.java index 0cc67c63b662..af6b6ce2e4f6 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/OnTimerInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/OnTimerInvokersTest.java @@ -117,7 +117,7 @@ public void process() {} @OnTimer(TIMER_ID) public void onMyTimer() {} - }; + } /** * This is a change-detector test that the generated name is stable across runs. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java index 8aed6b9c01ca..b723dd886603 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java @@ -38,24 +38,31 @@ public void testTryClaim() throws Exception { OffsetRange range = new OffsetRange(100, 200); OffsetRangeTracker tracker = new OffsetRangeTracker(range); assertEquals(range, tracker.currentRestriction()); - assertTrue(tracker.tryClaim(100)); - assertTrue(tracker.tryClaim(150)); - assertTrue(tracker.tryClaim(199)); - assertFalse(tracker.tryClaim(200)); + assertTrue(tracker.tryClaim(100L)); + assertTrue(tracker.tryClaim(150L)); + assertTrue(tracker.tryClaim(199L)); + assertFalse(tracker.tryClaim(200L)); } @Test public void testCheckpointUnstarted() throws Exception { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + expected.expect(IllegalStateException.class); + tracker.checkpoint(); + } + + @Test + public void testCheckpointOnlyFailedClaim() throws Exception { + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + assertFalse(tracker.tryClaim(250L)); + expected.expect(IllegalStateException.class); OffsetRange checkpoint = tracker.checkpoint(); - assertEquals(new OffsetRange(100, 100), tracker.currentRestriction()); - assertEquals(new OffsetRange(100, 200), checkpoint); } @Test public void testCheckpointJustStarted() throws Exception { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(100)); + assertTrue(tracker.tryClaim(100L)); OffsetRange checkpoint = tracker.checkpoint(); assertEquals(new OffsetRange(100, 101), tracker.currentRestriction()); assertEquals(new OffsetRange(101, 200), checkpoint); @@ -64,8 +71,8 @@ public void testCheckpointJustStarted() throws Exception { @Test public void testCheckpointRegular() throws Exception { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(105)); - assertTrue(tracker.tryClaim(110)); + assertTrue(tracker.tryClaim(105L)); + assertTrue(tracker.tryClaim(110L)); OffsetRange checkpoint = tracker.checkpoint(); assertEquals(new OffsetRange(100, 111), tracker.currentRestriction()); assertEquals(new OffsetRange(111, 200), checkpoint); @@ -74,9 +81,9 @@ public void testCheckpointRegular() throws Exception { @Test public void testCheckpointClaimedLast() throws Exception { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(105)); - assertTrue(tracker.tryClaim(110)); - assertTrue(tracker.tryClaim(199)); + assertTrue(tracker.tryClaim(105L)); + assertTrue(tracker.tryClaim(110L)); + assertTrue(tracker.tryClaim(199L)); OffsetRange checkpoint = tracker.checkpoint(); assertEquals(new OffsetRange(100, 200), tracker.currentRestriction()); assertEquals(new OffsetRange(200, 200), checkpoint); @@ -85,10 +92,10 @@ public void testCheckpointClaimedLast() throws Exception { @Test public void testCheckpointAfterFailedClaim() throws Exception { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(105)); - assertTrue(tracker.tryClaim(110)); - assertTrue(tracker.tryClaim(160)); - assertFalse(tracker.tryClaim(240)); + assertTrue(tracker.tryClaim(105L)); + assertTrue(tracker.tryClaim(110L)); + assertTrue(tracker.tryClaim(160L)); + assertFalse(tracker.tryClaim(240L)); OffsetRange checkpoint = tracker.checkpoint(); assertEquals(new OffsetRange(100, 161), tracker.currentRestriction()); assertEquals(new OffsetRange(161, 200), checkpoint); @@ -98,50 +105,50 @@ public void testCheckpointAfterFailedClaim() throws Exception { public void testNonMonotonicClaim() throws Exception { expected.expectMessage("Trying to claim offset 103 while last attempted was 110"); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(105)); - assertTrue(tracker.tryClaim(110)); - tracker.tryClaim(103); + assertTrue(tracker.tryClaim(105L)); + assertTrue(tracker.tryClaim(110L)); + tracker.tryClaim(103L); } @Test public void testClaimBeforeStartOfRange() throws Exception { expected.expectMessage("Trying to claim offset 90 before start of the range [100, 200)"); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - tracker.tryClaim(90); + tracker.tryClaim(90L); } @Test public void testCheckDoneAfterTryClaimPastEndOfRange() { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(150)); - assertTrue(tracker.tryClaim(175)); - assertFalse(tracker.tryClaim(220)); + assertTrue(tracker.tryClaim(150L)); + assertTrue(tracker.tryClaim(175L)); + assertFalse(tracker.tryClaim(220L)); tracker.checkDone(); } @Test public void testCheckDoneAfterTryClaimAtEndOfRange() { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(150)); - assertTrue(tracker.tryClaim(175)); - assertFalse(tracker.tryClaim(200)); + assertTrue(tracker.tryClaim(150L)); + assertTrue(tracker.tryClaim(175L)); + assertFalse(tracker.tryClaim(200L)); tracker.checkDone(); } @Test public void testCheckDoneAfterTryClaimRightBeforeEndOfRange() { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(150)); - assertTrue(tracker.tryClaim(175)); - assertTrue(tracker.tryClaim(199)); + assertTrue(tracker.tryClaim(150L)); + assertTrue(tracker.tryClaim(175L)); + assertTrue(tracker.tryClaim(199L)); tracker.checkDone(); } @Test public void testCheckDoneWhenNotDone() { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(150)); - assertTrue(tracker.tryClaim(175)); + assertTrue(tracker.tryClaim(150L)); + assertTrue(tracker.tryClaim(175L)); expected.expectMessage( "Last attempted offset was 175 in range [100, 200), " + "claiming work in [176, 200) was not attempted"); @@ -151,8 +158,8 @@ public void testCheckDoneWhenNotDone() { @Test public void testCheckDoneWhenExplicitlyMarkedDone() { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); - assertTrue(tracker.tryClaim(150)); - assertTrue(tracker.tryClaim(175)); + assertTrue(tracker.tryClaim(150L)); + assertTrue(tracker.tryClaim(175L)); tracker.markDone(); tracker.checkDone(); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowingTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowingTest.java index 09cedc49bd95..33c25e305506 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowingTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowingTest.java @@ -68,10 +68,12 @@ public void processElement(ProcessContext c, BoundedWindow window) { + ":" + window); } } + private WindowFn windowFn; public WindowedCount(WindowFn windowFn) { this.windowFn = windowFn; } + @Override public PCollection expand(PCollection in) { return in.apply( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStreamTest.java index 53798da1e864..1e4ebb9898b1 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStreamTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/BufferedElementCountingOutputStreamTest.java @@ -17,7 +17,9 @@ */ package org.apache.beam.sdk.util; +import static org.apache.beam.sdk.util.BufferedElementCountingOutputStream.BUFFER_POOL; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -27,6 +29,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -156,6 +159,39 @@ public void testWritingBytesWhenFinishedThrows() throws Exception { testValues(toBytes("a")).write("b".getBytes()); } + @Test + public void testBuffersAreTakenAndReturned() throws Exception { + BUFFER_POOL.clear(); + BUFFER_POOL.offer(ByteBuffer.allocate(256)); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + BufferedElementCountingOutputStream os = createAndWriteValues(toBytes("abcdefghij"), baos); + assertEquals(0, BUFFER_POOL.size()); + os.finish(); + assertEquals(1, BUFFER_POOL.size()); + + } + + @Test + public void testBehaviorWhenBufferPoolFull() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + while (BUFFER_POOL.remainingCapacity() > 0) { + BUFFER_POOL.offer(ByteBuffer.allocate(256)); + } + BufferedElementCountingOutputStream os = createAndWriteValues(toBytes("abcdefghij"), baos); + os.finish(); + assertEquals(0, BUFFER_POOL.remainingCapacity()); + } + + @Test + public void testBehaviorWhenBufferPoolEmpty() throws Exception { + BUFFER_POOL.clear(); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + BufferedElementCountingOutputStream os = createAndWriteValues(toBytes("abcdefghij"), baos); + assertEquals(0, BUFFER_POOL.size()); + os.finish(); + assertEquals(1, BUFFER_POOL.size()); + } + private List toBytes(String ... values) { ImmutableList.Builder builder = ImmutableList.builder(); for (String value : values) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MoreFuturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MoreFuturesTest.java new file mode 100644 index 000000000000..22ab4c092768 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MoreFuturesTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertThat; + +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link MoreFutures}. */ +@RunWith(JUnit4.class) +public class MoreFuturesTest { + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void supplyAsyncSuccess() throws Exception { + CompletionStage future = MoreFutures.supplyAsync(() -> 42); + assertThat(MoreFutures.get(future), equalTo(42)); + } + + @Test + public void supplyAsyncFailure() throws Exception { + final String testMessage = "this is just a test"; + CompletionStage future = MoreFutures.supplyAsync(() -> { + throw new IllegalStateException(testMessage); + }); + + thrown.expect(ExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + thrown.expectMessage(testMessage); + MoreFutures.get(future); + } + + @Test + public void runAsyncSuccess() throws Exception { + AtomicInteger result = new AtomicInteger(0); + CompletionStage sideEffectFuture = MoreFutures.runAsync(() -> { + result.set(42); + }); + + MoreFutures.get(sideEffectFuture); + assertThat(result.get(), equalTo(42)); + } + + @Test + public void runAsyncFailure() throws Exception { + final String testMessage = "this is just a test"; + CompletionStage sideEffectFuture = MoreFutures.runAsync(() -> { + throw new IllegalStateException(testMessage); + }); + + thrown.expect(ExecutionException.class); + thrown.expectCause(isA(IllegalStateException.class)); + thrown.expectMessage(testMessage); + MoreFutures.get(sideEffectFuture); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MovingFunctionTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MovingFunctionTest.java index 8d57bf4d24b6..b337a1c9984f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MovingFunctionTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MovingFunctionTest.java @@ -95,10 +95,8 @@ public void movingSum() { int lost = 0; for (int i = 0; i < SAMPLE_PERIOD * 2; i++) { f.add(i , 1); - if (i >= SAMPLE_PERIOD) { - if (i % SAMPLE_UPDATE == 0) { - lost += SAMPLE_UPDATE; - } + if (i >= SAMPLE_PERIOD && i % SAMPLE_UPDATE == 0) { + lost += SAMPLE_UPDATE; } assertEquals(i + 1 - lost, f.get(i)); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java index c3b01711c4b9..d119a25038a0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertNotSame; import com.google.common.collect.ImmutableList; - import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/WindowedValueTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/WindowedValueTest.java index b2bb818907df..6c2333863cdd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/WindowedValueTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/WindowedValueTest.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.util; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; @@ -36,22 +35,30 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing; import org.joda.time.Instant; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Test case for {@link WindowedValue}. */ @RunWith(JUnit4.class) public class WindowedValueTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + @Test public void testWindowedValueCoder() throws CoderException { Instant timestamp = new Instant(1234); - WindowedValue value = WindowedValue.of( - "abc", - new Instant(1234), - Arrays.asList(new IntervalWindow(timestamp, timestamp.plus(1000)), - new IntervalWindow(timestamp.plus(1000), timestamp.plus(2000))), - PaneInfo.NO_FIRING); + WindowedValue value = + WindowedValue.of( + "abc", + new Instant(1234), + Arrays.asList( + new IntervalWindow(timestamp, timestamp.plus(1000)), + new IntervalWindow(timestamp.plus(1000), timestamp.plus(2000))), + PaneInfo.NO_FIRING); Coder> windowedValueCoder = WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder()); @@ -67,8 +74,8 @@ public void testWindowedValueCoder() throws CoderException { @Test public void testFullWindowedValueCoderIsSerializableWithWellKnownCoderType() { - CoderProperties.coderSerializable(WindowedValue.getFullCoder( - GlobalWindow.Coder.INSTANCE, GlobalWindow.Coder.INSTANCE)); + CoderProperties.coderSerializable( + WindowedValue.getFullCoder(GlobalWindow.Coder.INSTANCE, GlobalWindow.Coder.INSTANCE)); } @Test @@ -77,11 +84,9 @@ public void testValueOnlyWindowedValueCoderIsSerializableWithWellKnownCoderType( } @Test - public void testExplodeWindowsInNoWindowsEmptyIterable() { - WindowedValue value = - WindowedValue.of("foo", Instant.now(), ImmutableList.of(), PaneInfo.NO_FIRING); - - assertThat(value.explodeWindows(), emptyIterable()); + public void testExplodeWindowsInNoWindowsCrash() { + thrown.expect(IllegalArgumentException.class); + WindowedValue.of("foo", Instant.now(), ImmutableList.of(), PaneInfo.NO_FIRING); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java new file mode 100644 index 000000000000..d94d005b2a71 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values; + +import static org.apache.beam.sdk.values.Row.toRow; +import static org.apache.beam.sdk.values.RowType.toRowType; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.util.stream.Stream; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * Unit tests for {@link Row}. + */ +public class RowTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testCreatesNullRecord() { + RowType type = + Stream + .of( + RowType.newField("f_int", VarIntCoder.of()), + RowType.newField("f_str", StringUtf8Coder.of()), + RowType.newField("f_double", DoubleCoder.of())) + .collect(toRowType()); + + Row row = Row.nullRow(type); + + assertNull(row.getValue("f_int")); + assertNull(row.getValue("f_str")); + assertNull(row.getValue("f_double")); + } + + @Test + public void testCreatesRecord() { + RowType type = + Stream + .of( + RowType.newField("f_int", VarIntCoder.of()), + RowType.newField("f_str", StringUtf8Coder.of()), + RowType.newField("f_double", DoubleCoder.of())) + .collect(toRowType()); + + Row row = + Row + .withRowType(type) + .addValues(1, "2", 3.0d) + .build(); + + assertEquals(1, row. getValue("f_int")); + assertEquals("2", row.getValue("f_str")); + assertEquals(3.0d, row. getValue("f_double")); + } + + @Test + public void testCollector() { + RowType type = + Stream + .of( + RowType.newField("f_int", VarIntCoder.of()), + RowType.newField("f_str", StringUtf8Coder.of()), + RowType.newField("f_double", DoubleCoder.of())) + .collect(toRowType()); + + Row row = + Stream + .of(1, "2", 3.0d) + .collect(toRow(type)); + + assertEquals(1, row.getValue("f_int")); + assertEquals("2", row.getValue("f_str")); + assertEquals(3.0d, row.getValue("f_double")); + } + + @Test + public void testThrowsForIncorrectNumberOfFields() { + RowType type = + Stream + .of( + RowType.newField("f_int", VarIntCoder.of()), + RowType.newField("f_str", StringUtf8Coder.of()), + RowType.newField("f_double", DoubleCoder.of())) + .collect(toRowType()); + + thrown.expect(IllegalArgumentException.class); + Row.withRowType(type).addValues(1, "2").build(); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTypeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTypeTest.java new file mode 100644 index 000000000000..35af16a3542b --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTypeTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values; + +import static org.apache.beam.sdk.values.RowType.toRowType; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * Unit tests for {@link RowType}. + */ +public class RowTypeTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testCreatesFromNamesAndCoders() { + List names = Arrays.asList("f_int", "f_string"); + List coders = Arrays.asList(VarIntCoder.of(), StringUtf8Coder.of()); + + RowType rowType = RowType.fromNamesAndCoders(names, coders); + + assertEquals(2, rowType.getFieldCount()); + + assertEquals("f_int", rowType.getFieldName(0)); + assertEquals("f_string", rowType.getFieldName(1)); + + assertEquals(VarIntCoder.of(), rowType.getFieldCoder(0)); + assertEquals(StringUtf8Coder.of(), rowType.getFieldCoder(1)); + } + + @Test + public void testThrowsForWrongFieldCount() { + List names = Arrays.asList("f_int", "f_string"); + List coders = Arrays.asList(VarIntCoder.of(), StringUtf8Coder.of(), VarLongCoder.of()); + + thrown.expect(IllegalStateException.class); + RowType.fromNamesAndCoders(names, coders); + } + + @Test + public void testCollector() { + RowType rowType = + Stream + .of( + RowType.newField("f_int", VarIntCoder.of()), + RowType.newField("f_string", StringUtf8Coder.of())) + .collect(toRowType()); + + assertEquals(2, rowType.getFieldCount()); + + assertEquals("f_int", rowType.getFieldName(0)); + assertEquals("f_string", rowType.getFieldName(1)); + + assertEquals(VarIntCoder.of(), rowType.getFieldCoder(0)); + assertEquals(StringUtf8Coder.of(), rowType.getFieldCoder(1)); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/DefaultRowTypeFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/DefaultRowTypeFactoryTest.java new file mode 100644 index 000000000000..ad3d728c5d7f --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/DefaultRowTypeFactoryTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.values.RowType; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * Unit tests for {@link DefaultRowTypeFactory}. + */ +public class DefaultRowTypeFactoryTest { + + /** + * Test class without supported coder. + */ + private static class UnsupportedClass { + } + + private static final List GETTERS = ImmutableList + .builder() + .add(getter("byteGetter", Byte.class)) + .add(getter("integerGetter", Integer.class)) + .add(getter("longGetter", Long.class)) + .add(getter("doubleGetter", Double.class)) + .add(getter("booleanGetter", Boolean.class)) + .add(getter("stringGetter", String.class)) + .build(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testContainsCorrectFields() throws Exception { + DefaultRowTypeFactory factory = new DefaultRowTypeFactory(); + + RowType rowType = factory.createRowType(GETTERS); + + assertEquals(GETTERS.size(), rowType.getFieldCount()); + assertEquals( + Arrays.asList( + "byteGetter", + "integerGetter", + "longGetter", + "doubleGetter", + "booleanGetter", + "stringGetter"), + rowType.getFieldNames()); + } + + @Test + public void testContainsCorrectCoders() throws Exception { + DefaultRowTypeFactory factory = new DefaultRowTypeFactory(); + + RowType recordType = factory.createRowType(GETTERS); + + assertEquals(GETTERS.size(), recordType.getFieldCount()); + assertEquals( + Arrays.asList( + ByteCoder.of(), + VarIntCoder.of(), + VarLongCoder.of(), + DoubleCoder.of(), + BooleanCoder.of(), + StringUtf8Coder.of()), + recordType.getRowCoder().getCoders()); + } + + @Test + public void testThrowsForUnsupportedTypes() throws Exception { + thrown.expect(UnsupportedOperationException.class); + + DefaultRowTypeFactory factory = new DefaultRowTypeFactory(); + + factory.createRowType( + Arrays.asList(getter("unsupportedGetter", UnsupportedClass.class))); + } + + private static FieldValueGetter getter(final String fieldName, final Class fieldType) { + return new FieldValueGetter() { + @Override + public Object get(Object object) { + return null; + } + + @Override + public String name() { + return fieldName; + } + + @Override + public Class type() { + return fieldType; + } + }; + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/GeneratedGetterFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/GeneratedGetterFactoryTest.java new file mode 100644 index 000000000000..e5f55af42789 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/GeneratedGetterFactoryTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableSet; +import java.util.List; +import java.util.Set; +import org.junit.Test; + +/** + * Unit tests for {@link GeneratedGetterFactory}. + */ +public class GeneratedGetterFactoryTest { + + /** + * Test pojo. + */ + private static class Pojo { + private String privateStringField = "privateStringValue"; + private Integer privateIntegerField = 15; + + public String publicStringField = "publicStringField"; + + public String getPrivateStringField() { + return privateStringField; + } + + public Integer getPrivateIntegerField() { + return privateIntegerField; + } + } + + @Test + public void testGettersHaveCorrectNames() throws Exception { + List getters = new GeneratedGetterFactory().generateGetters(Pojo.class); + + assertEquals( + ImmutableSet.of("privateStringField", "privateIntegerField"), + getNames(getters)); + } + + @Test + public void testGettersHaveCorrectTypes() throws Exception { + List getters = new GeneratedGetterFactory().generateGetters(Pojo.class); + + assertEquals( + ImmutableSet.of(String.class, Integer.class), + getTypes(getters)); + } + + @Test + public void testGettersReturnCorrectValues() throws Exception { + List getters = new GeneratedGetterFactory().generateGetters(Pojo.class); + + assertEquals( + ImmutableSet.of("privateStringValue", 15), + getValues(getters, new Pojo())); + } + + private Set getNames(List getters) { + ImmutableSet.Builder names = ImmutableSet.builder(); + + for (FieldValueGetter getter : getters) { + names.add(getter.name()); + } + + return names.build(); + } + + private Set getTypes(List getters) { + ImmutableSet.Builder types = ImmutableSet.builder(); + + for (FieldValueGetter getter : getters) { + types.add(getter.type()); + } + + return types.build(); + } + + private ImmutableSet getValues(List getters, Pojo pojo) { + ImmutableSet.Builder values = ImmutableSet.builder(); + + for (FieldValueGetter getter : getters) { + values.add(getter.get(pojo)); + } + + return values.build(); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/ReflectionGetterFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/ReflectionGetterFactoryTest.java new file mode 100644 index 000000000000..a6cd41161376 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/ReflectionGetterFactoryTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableSet; +import java.util.List; +import java.util.Set; +import org.junit.Test; + +/** + * Unit tests for {@link ReflectionGetterFactory}. + */ +public class ReflectionGetterFactoryTest { + + /** + * Test pojo. + */ + private static class Pojo { + private String privateStringField = "privateStringValue"; + private Integer privateIntegerField = 15; + + public String publicStringField = "publicStringField"; + + public String getPrivateStringField() { + return privateStringField; + } + + public Integer getPrivateIntegerField() { + return privateIntegerField; + } + } + + @Test + public void testGettersHaveCorrectNames() throws Exception { + List getters = new ReflectionGetterFactory().generateGetters(Pojo.class); + + assertEquals( + ImmutableSet.of("privateStringField", "privateIntegerField"), + getNames(getters)); + } + + @Test + public void testGettersHaveCorrectTypes() throws Exception { + List getters = new ReflectionGetterFactory().generateGetters(Pojo.class); + + assertEquals( + ImmutableSet.of(String.class, Integer.class), + getTypes(getters)); + } + + @Test + public void testGettersReturnCorrectValues() throws Exception { + List getters = new ReflectionGetterFactory().generateGetters(Pojo.class); + + assertEquals( + ImmutableSet.of("privateStringValue", 15), + getValues(getters, new Pojo())); + } + + private Set getNames(List getters) { + ImmutableSet.Builder names = ImmutableSet.builder(); + + for (FieldValueGetter getter : getters) { + names.add(getter.name()); + } + + return names.build(); + } + + private Set getTypes(List getters) { + ImmutableSet.Builder types = ImmutableSet.builder(); + + for (FieldValueGetter getter : getters) { + types.add(getter.type()); + } + + return types.build(); + } + + private ImmutableSet getValues(List getters, Pojo pojo) { + ImmutableSet.Builder values = ImmutableSet.builder(); + + for (FieldValueGetter getter : getters) { + values.add(getter.get(pojo)); + } + + return values.build(); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/ReflectionGetterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/ReflectionGetterTest.java new file mode 100644 index 000000000000..9d5de00f4792 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/ReflectionGetterTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.junit.Assert.assertEquals; + +import java.lang.reflect.Method; +import org.junit.Test; + +/** + * Unit tests for {@link ReflectionGetter}. + */ +public class ReflectionGetterTest { + + /** + * Test pojo. + */ + private static class Pojo { + public String getStringField() { + return "test"; + } + + public Integer getIntField() { + return 3421; + } + + public Integer notGetter() { + return 542; + } + } + + private static final Method STRING_GETTER = method("getStringField"); + private static final Method INT_GETTER = method("getIntField"); + private static final Method NOT_GETTER = method("notGetter"); + + private static Method method(String methodName) { + try { + return Pojo.class.getDeclaredMethod(methodName); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Unable to find method '" + methodName + "'"); + } + } + + @Test + public void testInitializedWithCorrectNames() { + ReflectionGetter stringGetter = new ReflectionGetter(STRING_GETTER); + ReflectionGetter intGetter = new ReflectionGetter(INT_GETTER); + ReflectionGetter notGetter = new ReflectionGetter(NOT_GETTER); + + assertEquals("stringField", stringGetter.name()); + assertEquals("intField", intGetter.name()); + assertEquals("notGetter", notGetter.name()); + } + + + @Test + public void testInitializedWithCorrectTypes() { + ReflectionGetter stringGetter = new ReflectionGetter(STRING_GETTER); + ReflectionGetter intGetter = new ReflectionGetter(INT_GETTER); + ReflectionGetter notGetter = new ReflectionGetter(NOT_GETTER); + + assertEquals(String.class, stringGetter.type()); + assertEquals(Integer.class, intGetter.type()); + assertEquals(Integer.class, notGetter.type()); + } + + @Test + public void testInvokesCorrectGetter() { + Pojo pojo = new Pojo(); + + ReflectionGetter stringGetter = new ReflectionGetter(STRING_GETTER); + ReflectionGetter intGetter = new ReflectionGetter(INT_GETTER); + ReflectionGetter notGetter = new ReflectionGetter(NOT_GETTER); + + assertEquals("test", stringGetter.get(pojo)); + assertEquals(3421, intGetter.get(pojo)); + assertEquals(542, notGetter.get(pojo)); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/RowFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/RowFactoryTest.java new file mode 100644 index 000000000000..8e125b0417ea --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/RowFactoryTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; + +import com.google.common.collect.ImmutableList; +import org.apache.beam.sdk.values.Row; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Unit tests for {@link RowFactory}. + */ +@RunWith(Parameterized.class) +public class RowFactoryTest { + + /** + * Test pojo. + */ + public static final class SomePojo { + private String someStringField; + private Integer someIntegerField; + + public SomePojo(String someStringField, Integer someIntegerField) { + this.someStringField = someStringField; + this.someIntegerField = someIntegerField; + } + + public String getSomeStringField() { + return someStringField; + } + + public Integer getSomeIntegerField() { + return someIntegerField; + } + } + + /** + * Getters factories to test the record factory with. + */ + @Parameterized.Parameters + public static Iterable gettersFactories() { + return ImmutableList.of(new GeneratedGetterFactory(), new ReflectionGetterFactory()); + } + + private GetterFactory getterFactory; + + public RowFactoryTest(GetterFactory getterFactory) { + this.getterFactory = getterFactory; + } + + @Test + public void testNewRecordFieldValues() throws Exception { + SomePojo pojo = new SomePojo("someString", 42); + RowFactory factory = newFactory(); + + Row row = factory.create(pojo); + + assertEquals(2, row.getFieldCount()); + assertThat( + row.getValues(), + containsInAnyOrder((Object) "someString", Integer.valueOf(42))); + } + + @Test + public void testNewRecordFieldNames() throws Exception { + SomePojo pojo = new SomePojo("someString", 42); + RowFactory factory = newFactory(); + + Row row = factory.create(pojo); + + assertThat(row.getRowType().getFieldNames(), + containsInAnyOrder("someStringField", "someIntegerField")); + } + + @Test + public void testCreatesNewInstanceEachTime() throws Exception { + SomePojo pojo = new SomePojo("someString", 42); + RowFactory factory = newFactory(); + + Row row1 = factory.create(pojo); + Row row2 = factory.create(pojo); + + assertNotSame(row1, row2); + } + + @Test + public void testCachesRecordType() throws Exception { + SomePojo pojo = new SomePojo("someString", 42); + RowFactory factory = newFactory(); + + Row row1 = factory.create(pojo); + Row row2 = factory.create(pojo); + + assertSame(row1.getRowType(), row2.getRowType()); + } + + @Test + public void testCopiesValues() throws Exception { + SomePojo pojo = new SomePojo("someString", 42); + RowFactory factory = newFactory(); + + Row row = factory.create(pojo); + + assertThat( + row.getValues(), + containsInAnyOrder((Object) "someString", Integer.valueOf(42))); + + pojo.someIntegerField = 23; + pojo.someStringField = "hello"; + + assertThat( + row.getValues(), + containsInAnyOrder((Object) "someString", Integer.valueOf(42))); + } + + private RowFactory newFactory() { + return new RowFactory(new DefaultRowTypeFactory(), getterFactory); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/RowTypeGettersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/RowTypeGettersTest.java new file mode 100644 index 000000000000..fc9ba417ceaf --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/reflect/RowTypeGettersTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.values.reflect; + +import static java.util.Collections.emptyList; +import static org.junit.Assert.assertSame; + +import java.util.List; +import org.apache.beam.sdk.values.RowType; +import org.junit.Test; + +/** + * Unit tests for {@link RowTypeGetters}. + */ +public class RowTypeGettersTest { + + @Test + public void testGetters() { + RowType rowType = RowType.fromNamesAndCoders(emptyList(), emptyList()); + List fieldValueGetters = emptyList(); + + RowTypeGetters getters = new RowTypeGetters(rowType, fieldValueGetters); + + assertSame(rowType, getters.rowType()); + assertSame(fieldValueGetters, getters.valueGetters()); + } +} diff --git a/sdks/java/extensions/google-cloud-platform-core/pom.xml b/sdks/java/extensions/google-cloud-platform-core/pom.xml index d6632f089437..29ab5352d568 100644 --- a/sdks/java/extensions/google-cloud-platform-core/pom.xml +++ b/sdks/java/extensions/google-cloud-platform-core/pom.xml @@ -184,13 +184,19 @@ org.hamcrest - hamcrest-all + hamcrest-core + provided + + + + org.hamcrest + hamcrest-library provided org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java index 3d035aab40dd..cd35374dba50 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java @@ -44,9 +44,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import java.io.FileNotFoundException; import java.io.IOException; @@ -59,6 +56,7 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; @@ -578,7 +576,7 @@ public boolean shouldRetry(IOException e) { } private static void executeBatches(List batches) throws IOException { - ListeningExecutorService executor = + ExecutorService executor = MoreExecutors.listeningDecorator( MoreExecutors.getExitingExecutorService( new ThreadPoolExecutor( @@ -588,18 +586,17 @@ private static void executeBatches(List batches) throws IOExceptio TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>()))); - List> futures = new LinkedList<>(); + List> futures = new LinkedList<>(); for (final BatchRequest batch : batches) { - futures.add( - executor.submit( - () -> { - batch.execute(); - return null; - })); + futures.add(MoreFutures.runAsync( + () -> { + batch.execute(); + }, + executor)); } try { - Futures.allAsList(futures).get(); + MoreFutures.get(MoreFutures.allAsList(futures)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while executing batch GCS request", e); diff --git a/sdks/java/extensions/jackson/pom.xml b/sdks/java/extensions/jackson/pom.xml index 23de0af4e3e9..1b3fcc09ba12 100644 --- a/sdks/java/extensions/jackson/pom.xml +++ b/sdks/java/extensions/jackson/pom.xml @@ -57,10 +57,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + junit junit diff --git a/sdks/java/extensions/join-library/pom.xml b/sdks/java/extensions/join-library/pom.xml index 838233bfcd91..4f3892e3f898 100644 --- a/sdks/java/extensions/join-library/pom.xml +++ b/sdks/java/extensions/join-library/pom.xml @@ -49,10 +49,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + junit junit diff --git a/sdks/java/extensions/join-library/src/test/java/org/apache/beam/sdk/extensions/joinlibrary/OuterFullJoinTest.java b/sdks/java/extensions/join-library/src/test/java/org/apache/beam/sdk/extensions/joinlibrary/OuterFullJoinTest.java index cdf4f4f77936..3076f68558e6 100644 --- a/sdks/java/extensions/join-library/src/test/java/org/apache/beam/sdk/extensions/joinlibrary/OuterFullJoinTest.java +++ b/sdks/java/extensions/join-library/src/test/java/org/apache/beam/sdk/extensions/joinlibrary/OuterFullJoinTest.java @@ -19,7 +19,6 @@ import java.util.ArrayList; import java.util.List; - import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; diff --git a/sdks/java/extensions/protobuf/pom.xml b/sdks/java/extensions/protobuf/pom.xml index 394fd573e966..099a7f0e3b2e 100644 --- a/sdks/java/extensions/protobuf/pom.xml +++ b/sdks/java/extensions/protobuf/pom.xml @@ -96,13 +96,19 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/extensions/sketching/build.gradle b/sdks/java/extensions/sketching/build.gradle new file mode 100644 index 000000000000..cf8e57c0604d --- /dev/null +++ b/sdks/java/extensions/sketching/build.gradle @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +apply from: project(":").file("build_rules.gradle") +applyJavaNature() + +description = "Apache Beam :: SDKs :: Java :: Extensions :: Sketching" + +def streamlib_version = "2.9.5" +def tdigest_version = "3.2" + +dependencies { + compile library.java.guava + shadow project(path: ":sdks:java:core", configuration: "shadow") + shadow "com.clearspring.analytics:stream:$streamlib_version" + shadow "com.tdunning:t-digest:$tdigest_version" + shadow library.java.slf4j_api + shadowTest library.java.avro + shadowTest library.java.commons_lang3 + shadowTest project(path: ":sdks:java:core", configuration: "shadowTest") + shadowTest project(path: ":runners:direct-java", configuration: "shadow") + shadowTest library.java.hamcrest_core + shadowTest library.java.junit +} diff --git a/sdks/java/extensions/sketching/pom.xml b/sdks/java/extensions/sketching/pom.xml index b3856528b209..d5ca7afd7b01 100755 --- a/sdks/java/extensions/sketching/pom.xml +++ b/sdks/java/extensions/sketching/pom.xml @@ -31,6 +31,7 @@ 2.9.5 + 3.2 @@ -39,16 +40,18 @@ beam-sdks-java-core - + com.clearspring.analytics stream ${streamlib.version} + - org.slf4j - slf4j-api + com.tdunning + t-digest + ${t-digest.version} @@ -91,10 +94,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + junit junit diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinct.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinct.java index 3fea951f4718..e8d85f7c0578 100644 --- a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinct.java +++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinct.java @@ -258,10 +258,32 @@ abstract static class Builder { abstract GloballyDistinct build(); } + /** + * Sets the precision {@code p}. + * + *

Keep in mind that {@code p} cannot be lower than 4, because the estimation would be too + * inaccurate. + * + *

See {@link ApproximateDistinct#precisionForRelativeError(double)} and {@link + * ApproximateDistinct#relativeErrorForPrecision(int)} to have more information about the + * relationship between precision and relative error. + * + * @param p the precision value for the normal representation + */ public GloballyDistinct withPrecision(int p) { return toBuilder().setPrecision(p).build(); } + /** + * Sets the sparse representation's precision {@code sp}. + * + *

Values above 32 are not yet supported by the AddThis version of HyperLogLog+. + * + *

Fore more information about the sparse representation, read Google's paper available here. + * + * @param sp the precision of HyperLogLog+' sparse representation + */ public GloballyDistinct withSparsePrecision(int sp) { return toBuilder().setSparsePrecision(sp).build(); } @@ -310,10 +332,32 @@ abstract static class Builder { abstract PerKeyDistinct build(); } + /** + * Sets the precision {@code p}. + * + *

Keep in mind that {@code p} cannot be lower than 4, because the estimation would be too + * inaccurate. + * + *

See {@link ApproximateDistinct#precisionForRelativeError(double)} and {@link + * ApproximateDistinct#relativeErrorForPrecision(int)} to have more information about the + * relationship between precision and relative error. + * + * @param p the precision value for the normal representation + */ public PerKeyDistinct withPrecision(int p) { return toBuilder().setPrecision(p).build(); } + /** + * Sets the sparse representation's precision {@code sp}. + * + *

Values above 32 are not yet supported by the AddThis version of HyperLogLog+. + * + *

Fore more information about the sparse representation, read Google's paper available here. + * + * @param sp the precision of HyperLogLog+' sparse representation + */ public PerKeyDistinct withSparsePrecision(int sp) { return toBuilder().setSparsePrecision(sp).build(); } @@ -367,7 +411,7 @@ public static ApproximateDistinctFn create(Coder coder) } /** - * Returns a new {@link ApproximateDistinctFn} combiner with a new precision {@code p}. + * Returns an {@link ApproximateDistinctFn} combiner with a new precision {@code p}. * *

Keep in mind that {@code p} cannot be lower than 4, because the estimation would be too * inaccurate. @@ -384,8 +428,8 @@ public ApproximateDistinctFn withPrecision(int p) { } /** - * Returns a new {@link ApproximateDistinctFn} combiner with a sparse representation of - * precision {@code sp}. + * Returns an {@link ApproximateDistinctFn} combiner with a new + * sparse representation's precision {@code sp}. * *

Values above 32 are not yet supported by the AddThis version of HyperLogLog+. * diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java index b1aa8aec3674..5d06ec00ac5c 100644 --- a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java +++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java @@ -21,13 +21,11 @@ import com.clearspring.analytics.stream.frequency.FrequencyMergeException; import com.google.auto.value.AutoValue; import com.google.common.hash.Hashing; - import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; import java.util.Iterator; - import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; @@ -100,16 +98,15 @@ * advanced processing involving the Count-Min sketch. * * - *

Example 1: simple default use

+ *

Example 1: default use

* - *

The simplest use is simply to call the {@link #globally()} or {@link #perKey()} method in + *

The simplest use is to call the {@link #globally()} or {@link #perKey()} method in * order to retrieve the sketch with an estimate number of hits for each element in the stream. * *


  * {@literal PCollection} pc = ...;
  * {@literal PCollection} countMinSketch = pc.apply(SketchFrequencies
  * {@literal        .}globally()); //{@literal .}perKey();
- * }
  * 
* *

Example 2: tune accuracy parameters

@@ -126,7 +123,6 @@ * {@literal .}globally() //{@literal .}perKey() * .withRelativeError(eps) * .withConfidence(conf)); - * } * * *

Example 3: query the resulting sketch

@@ -155,9 +151,8 @@ * public void procesElement(ProcessContext c) { * Long elem = c.element(); * CountMinSketch sketch = c.sideInput(sketchView); - * sketch.estimateCount(elem, coder); + * c.output(sketch.estimateCount(elem, coder)); * }}).withSideInputs(sketchView)); - * } * * *

Example 4: Using the CombineFn

@@ -177,7 +172,6 @@ * {@literal PCollection} output = input.apply(Combine.globally(CountMinSketchFn * {@literal .}create(new MyObjectCoder()) * .withAccuracy(eps, conf))); - * } * * *

Warning: this class is experimental.
@@ -241,10 +235,25 @@ abstract static class Builder { abstract GlobalSketch build(); } + /** + * Sets the relative error {@code epsilon}. + * + *

Keep in mind that the lower the {@code epsilon} value, the greater the width. + * + * @param eps the error relative to the total number of distinct elements + */ public GlobalSketch withRelativeError(double eps) { return toBuilder().setRelativeError(eps).build(); } + /** + * Sets the {@code confidence} value, i.e. + * the probability that the relative error is lower or equal to {@code epsilon}. + * + *

Keep in mind that the greater the confidence, the greater the depth. + * + * @param conf the confidence in the result to not exceed the relative error + */ public GlobalSketch withConfidence(double conf) { return toBuilder().setConfidence(conf).build(); } @@ -289,10 +298,25 @@ abstract static class Builder { abstract PerKeySketch build(); } + /** + * Sets the relative error {@code epsilon}. + * + *

Keep in mind that the lower the {@code epsilon} value, the greater the width. + * + * @param eps the error relative to the total number of distinct elements + */ public PerKeySketch withRelativeError(double eps) { return toBuilder().setRelativeError(eps).build(); } + /** + * Sets the {@code confidence} value, i.e. + * the probability that the relative error is lower or equal to {@code epsilon}. + * + *

Keep in mind that the greater the confidence, the greater the depth. + * + * @param conf the confidence in the result to not exceed the relative error + */ public PerKeySketch withConfidence(double conf) { return toBuilder().setConfidence(conf).build(); } @@ -330,7 +354,7 @@ private CountMinSketchFn(final Coder coder, double eps, double confidenc } /** - * Returns an {@link CountMinSketchFn} combiner with the given input coder.
+ * Returns a {@link CountMinSketchFn} combiner with the given input coder.
* Warning : the coder must be deterministic. * * @param coder the coder that encodes the elements' type diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/TDigestQuantiles.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/TDigestQuantiles.java new file mode 100644 index 000000000000..28e95ac4a845 --- /dev/null +++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/TDigestQuantiles.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sketching; + +import com.google.auto.value.AutoValue; +import com.tdunning.math.stats.MergingDigest; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.Iterator; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + + +/** + * {@code PTransform}s for getting information about quantiles in a stream. + * + *

This class uses the T-Digest structure introduced by Ted Dunning, and more precisely + * the {@link MergingDigest} implementation. + * + *

References

+ * + *

The paper and implementation are available on Ted Dunning's + * Github profile + * + *

Parameters

+ * + *

Only one parameter can be tuned in order to control the tradeoff between + * the estimation accuracy and the memory use.
+ * + *

Stream elements are compressed into a linked list of centroids. + * The compression factor {@code cf} is used to limit the number of elements represented by + * each centroid as well as the total number of centroids.
+ * The relative error will always be a small fraction of 1% for values at extreme quantiles + * and always be less than 3/cf at middle quantiles.
+ * + *

By default the compression factor is set to 100, + * which guarantees a relative error less than 3%. + * + *

Examples

+ * + *

There are 2 ways of using this class: + * + *

    + *
  • Use the {@link PTransform}s that return a {@link PCollection} which contains + * a {@link MergingDigest} for querying the value at a given quantile or + * the approximate quantile position of an element. + *
  • Use the {@link TDigestQuantilesFn} {@code CombineFn} that is exposed in order + * to make advanced processing involving the {@link MergingDigest}. + *
+ * + *

Example 1: Default use

+ * + *

The simplest use is to call the {@link #globally()} or {@link #perKey()} method in + * order to retrieve the digest, and then to query the structure. + * + *


+ * {@literal PCollection} pc = ...;
+ * {@literal PCollection} countMinSketch = pc.apply(TDigestQuantiles
+ *         .globally()); // .perKey()
+ * 
+ * + *

Example 2: tune accuracy parameters

+ * + *

One can tune the compression factor {@code cf} in order to control accuracy and memory.
+ * This tuning works exactly the same for {@link #globally()} and {@link #perKey()}. + * + *


+ *  double cf = 500;
+ * {@literal PCollection} pc = ...;
+ * {@literal PCollection} countMinSketch = pc.apply(TDigestQuantiles
+ *         .globally() // .perKey()
+ *         .withCompression(cf);
+ * 
+ * + *

Example 3 : Query the resulting structure

+ * + *

This example shows how to query the resulting structure, for example to + * build {@code PCollection} of {@link KV}s with each pair corresponding to + * a couple (quantile, value). + * + *


+ * {@literal PCollection} pc = ...;
+ * {@literal PCollection>} quantiles = pc.apply(ParDo.of(
+ *        {@literal new DoFn>()} {
+ *          {@literal @ProcessElement}
+ *           public void procesElement(ProcessContext c) {
+ *             double[] quantiles = {0.01, 0.25, 0.5, 0.75, 0.99}
+ *             for (double q : quantiles) {
+ *                c.output(KV.of(q, c.element().quantile(q));
+ *             }
+ *           }}));
+ * 
+ * + *

One can also retrieve the approximate quantile position of a given element in the stream + * using {@code cdf(double)} method instead of {@code quantile(double)}. + * + *

Example 4: Using the CombineFn

+ * + *

The {@code CombineFn} does the same thing as the {@code PTransform}s but + * it can be used for doing stateful processing or in + * {@link org.apache.beam.sdk.transforms.CombineFns.ComposedCombineFn}. + * + *

This example is not really interesting but it shows how one can properly + * create a {@link TDigestQuantilesFn}. + * + *


+ *  double cf = 250;
+ * {@literal PCollection} input = ...;
+ * {@literal PCollection} output = input.apply(Combine
+ *         .globally(TDigestQuantilesFn.create(cf)));
+ * 
+ * + *

Warning: this class is experimental.
+ * Its API is subject to change in future versions of Beam. + * */ +@Experimental +public final class TDigestQuantiles { + + /** + * Compute the stream in order to build a T-Digest structure (MergingDigest) + * for keeping track of the stream distribution and returns a {@code PCollection}. + *
The resulting structure can be queried in order to retrieve the approximate value + * at a given quantile or the approximate quantile position of a given element. + */ + public static GlobalDigest globally() { + return GlobalDigest.builder().build(); + } + + /** + * Like {@link #globally()}, but builds a digest for each key in the stream. + * + * @param the type of the keys + */ + public static PerKeyDigest perKey() { + return PerKeyDigest.builder().build(); + } + + /** Implementation of {@link #globally()}. */ + @AutoValue + public abstract static class GlobalDigest + extends PTransform, PCollection> { + + abstract double compression(); + + abstract Builder toBuilder(); + + static Builder builder() { + return new AutoValue_TDigestQuantiles_GlobalDigest.Builder() + .setCompression(100); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setCompression(double cf); + + abstract GlobalDigest build(); + } + + /** + * Sets the compression factor {@code cf}. + * + *

Keep in mind that a compression factor {@code cf} of c guarantees + * a relative error less than 3/c at mid quantiles.
+ * The accuracy will always be significantly less than 1% at extreme quantiles. + * + * @param cf the bound value for centroid and digest sizes. + */ + public GlobalDigest withCompression(double cf) { + return toBuilder().setCompression(cf).build(); + } + + @Override + public PCollection expand(PCollection input) { + return input.apply( + "Compute T-Digest Structure", + Combine.globally(TDigestQuantilesFn.create(this.compression()))); + } + } + + /** Implementation of {@link #perKey()}. */ + @AutoValue + public abstract static class PerKeyDigest + extends PTransform>, PCollection>> { + + abstract double compression(); + abstract Builder toBuilder(); + + static Builder builder() { + return new AutoValue_TDigestQuantiles_PerKeyDigest.Builder() + .setCompression(100); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setCompression(double cf); + + abstract PerKeyDigest build(); + } + + /** + * Sets the compression factor {@code cf}. + * + *

Keep in mind that a compression factor {@code cf} of c guarantees + * a relative error less than 3/c at mid quantiles.
+ * The accuracy will always be significantly less than 1% at extreme quantiles. + * + * @param cf the bound value for centroid and digest sizes. + */ + public PerKeyDigest withCompression(double cf) { + return toBuilder().setCompression(cf).build(); + } + + @Override + public PCollection> expand(PCollection> input) { + return input.apply( + "Compute T-Digest Structure", + Combine.perKey(TDigestQuantilesFn.create(this.compression()))); + } + } + + /** Implements the {@link Combine.CombineFn} of {@link TDigestQuantiles} transforms. */ + public static class TDigestQuantilesFn + extends Combine.CombineFn { + + private final double compression; + + private TDigestQuantilesFn(double compression) { + this.compression = compression; + } + + /** + * Returns {@link TDigestQuantilesFn} combiner with the given compression factor. + * + *

Keep in mind that a compression factor {@code cf} of c guarantees + * a relative error less than 3/c at mid quantiles.
+ * The accuracy will always be significantly less than 1% at extreme quantiles. + * + * @param compression the bound value for centroid and digest sizes. + */ + public static TDigestQuantilesFn create(double compression) { + if (compression > 0) { + return new TDigestQuantilesFn(compression); + } + throw new IllegalArgumentException("Compression factor should be greater than 0."); + } + + @Override public MergingDigest createAccumulator() { + return new MergingDigest(compression); + } + + @Override public MergingDigest addInput(MergingDigest accum, Double value) { + accum.add(value); + return accum; + } + + /** Output the whole structure so it can be queried, reused or stored easily. */ + @Override public MergingDigest extractOutput(MergingDigest accum) { + return accum; + } + + @Override public MergingDigest mergeAccumulators( + Iterable accumulators) { + Iterator it = accumulators.iterator(); + MergingDigest merged = it.next(); + while (it.hasNext()) { + merged.add(it.next()); + } + return merged; + } + + @Override public Coder getAccumulatorCoder(CoderRegistry registry, + Coder inputCoder) { + return new MergingDigestCoder(); + } + + @Override public Coder getDefaultOutputCoder(CoderRegistry registry, + Coder inputCoder) { + return new MergingDigestCoder(); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(DisplayData + .item("compression", compression) + .withLabel("Compression factor")); + } + } + + /** Coder for {@link MergingDigest} class. */ + static class MergingDigestCoder extends CustomCoder { + + private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of(); + + @Override public void encode(MergingDigest value, OutputStream outStream) + throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null T-Digest sketch"); + } + ByteBuffer buf = ByteBuffer.allocate(value.byteSize()); + value.asBytes(buf); + BYTE_ARRAY_CODER.encode(buf.array(), outStream); + } + + @Override public MergingDigest decode(InputStream inStream) throws IOException { + byte[] bytes = BYTE_ARRAY_CODER.decode(inStream); + ByteBuffer buf = ByteBuffer.wrap(bytes); + return MergingDigest.fromBytes(buf); + } + + @Override public boolean isRegisterByteSizeObserverCheap(MergingDigest value) { + return true; + } + + @Override protected long getEncodedElementByteSize(MergingDigest value) + throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null T-Digest sketch"); + } + return value.byteSize(); + } + } +} diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinctTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinctTest.java index 11221343f8ac..27655f8fcda5 100644 --- a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinctTest.java +++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/ApproximateDistinctTest.java @@ -46,15 +46,11 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** Tests for {@link ApproximateDistinct}. */ @RunWith(JUnit4.class) public class ApproximateDistinctTest implements Serializable { - private static final Logger LOG = LoggerFactory.getLogger(ApproximateDistinctTest.class); - @Rule public final transient TestPipeline tp = TestPipeline.create(); @Test diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java index fc1260ea9635..bb3389b9c75f 100644 --- a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java +++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java @@ -25,7 +25,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; - import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericData; diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/TDigestQuantilesTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/TDigestQuantilesTest.java new file mode 100644 index 000000000000..c01bd987ca34 --- /dev/null +++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/TDigestQuantilesTest.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sketching; + +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.junit.Assert.assertThat; + +import com.tdunning.math.stats.Centroid; +import com.tdunning.math.stats.MergingDigest; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; +import org.apache.beam.sdk.extensions.sketching.TDigestQuantiles.MergingDigestCoder; +import org.apache.beam.sdk.extensions.sketching.TDigestQuantiles.TDigestQuantilesFn; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.Values; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** Tests for {@link TDigestQuantiles}. */ +public class TDigestQuantilesTest { + + @Rule public final transient TestPipeline tp = TestPipeline.create(); + + private static final List stream = generateStream(); + + private static final int size = 999; + + private static final int compression = 100; + + private static final double[] quantiles = {0.25, 0.5, 0.75, 0.99}; + + private static List generateStream() { + List li = new ArrayList<>(); + for (double i = 1D; i <= size; i++) { + li.add(i); + } + Collections.shuffle(li); + return li; + } + + @Test + public void globally() { + PCollection> col = tp.apply(Create.of(stream)) + .apply(TDigestQuantiles.globally().withCompression(compression)) + .apply(ParDo.of(new RetrieveQuantiles(quantiles))); + + PAssert.that("Verify Accuracy", col).satisfies(new VerifyAccuracy()); + tp.run(); + } + + @Test + public void perKey() { + PCollection> col = tp.apply(Create.of(stream)) + .apply(WithKeys.of(1)) + .apply(TDigestQuantiles.perKey().withCompression(compression)) + .apply(Values.create()) + .apply(ParDo.of(new RetrieveQuantiles(quantiles))); + + PAssert.that("Verify Accuracy", col).satisfies(new VerifyAccuracy()); + + tp.run(); + } + + @Test + public void testCoder() throws Exception { + MergingDigest tDigest = new MergingDigest(1000); + for (int i = 0; i < 10; i++) { + tDigest.add(2.4 + i); + } + + Assert.assertTrue("Encode and Decode", encodeDecodeEquals(tDigest)); + } + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testMergeAccum() { + Random rd = new Random(1234); + List accums = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + MergingDigest std = new MergingDigest(100); + for (int j = 0; j < 1000; j++) { + std.add(rd.nextDouble()); + } + accums.add(std); + } + TDigestQuantilesFn fn = TDigestQuantilesFn.create(100); + MergingDigest res = fn.mergeAccumulators(accums); + } + + private boolean encodeDecodeEquals(MergingDigest tDigest) throws IOException { + MergingDigest decoded = CoderUtils.clone(new MergingDigestCoder(), tDigest); + + boolean equal = true; + // the only way to compare the two sketches is to compare them centroid by centroid. + // Indeed, the means are doubles but are encoded as float and cast during decoding. + // This entails a small approximation that makes the centroids different after decoding. + Iterator it1 = decoded.centroids().iterator(); + Iterator it2 = tDigest.centroids().iterator(); + + for (int i = 0; i < decoded.centroids().size(); i++) { + Centroid c1 = it1.next(); + Centroid c2 = it2.next(); + if ((float) c1.mean() != (float) c2.mean() || c1.count() != c2.count()) { + equal = false; + break; + } + } + return equal; + } + + @Test + public void testDisplayData() { + final TDigestQuantilesFn fn = TDigestQuantilesFn.create(155D); + assertThat(DisplayData.from(fn), hasDisplayItem("compression", 155D)); + } + + static class RetrieveQuantiles extends DoFn> { + private final double[] quantiles; + + public RetrieveQuantiles(double[] quantiles) { + this.quantiles = quantiles; + } + + @ProcessElement public void processElement(ProcessContext c) { + for (double q : quantiles) { + c.output(KV.of(q, c.element().quantile(q))); + } + } + } + + static class VerifyAccuracy implements SerializableFunction>, Void> { + + double expectedError = 3D / compression; + + public Void apply(Iterable> input) { + for (KV pair : input) { + double expectedValue = pair.getKey() * (size + 1); + boolean isAccurate = Math.abs(pair.getValue() - expectedValue) + / size <= expectedError; + Assert.assertTrue("not accurate enough : \nQuantile " + pair.getKey() + + " is " + pair.getValue() + " and not " + expectedValue, + isAccurate); + } + return null; + } + } +} diff --git a/sdks/java/extensions/sorter/pom.xml b/sdks/java/extensions/sorter/pom.xml index 977b30342d06..86601c3d10a4 100644 --- a/sdks/java/extensions/sorter/pom.xml +++ b/sdks/java/extensions/sorter/pom.xml @@ -61,13 +61,19 @@ org.hamcrest - hamcrest-all + hamcrest-core test - + + + org.hamcrest + hamcrest-library + test + + org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle index bcca0cf66ba5..402ec9530067 100644 --- a/sdks/java/extensions/sql/build.gradle +++ b/sdks/java/extensions/sql/build.gradle @@ -24,6 +24,14 @@ description = "Apache Beam :: SDKs :: Java :: Extensions :: SQL" test { jvmArgs "-da" + + // charset that calcite will use for the tables in the tests + // need to setup as system property prior to running any tests + // or some of the tests may fail as calcite will init itself + // with it's default of ISO-8859-1 + systemProperty 'saffron.default.charset', 'UTF-16LE' + systemProperty 'saffron.default.nationalcharset', 'UTF-16LE' + systemProperty 'saffron.default.collation.name', 'UTF-16LE$en_US' } configurations { @@ -35,7 +43,7 @@ configurations { } -def calcite_version = "1.13.0" +def calcite_version = "1.15.0" def avatica_version = "1.10.0" dependencies { @@ -80,7 +88,7 @@ task copyFmppTemplatesFromCalciteCore(type: Copy) { } // Generate the FMPP sources from the FMPP templates. -def generateFmppOutputDir = "${project.buildDir}/generated-fmpp" +def generateFmppOutputDir = "${project.buildDir}/generated/fmpp" task generateFmppSources { dependsOn configurations.fmppTask dependsOn copyFmppTemplatesFromSrc @@ -91,9 +99,14 @@ task generateFmppSources { } } +// Match the output directory for generated code with the package, to be more tool-friendly +def generateFmppJavaccRoot = "${generateFmppOutputDir}/javacc" +def generatedJavaccSourceDir = "${project.buildDir}/generated/javacc" +def generatedJavaccPackageDir = "${generatedJavaccSourceDir}/org/apache/beam/sdk/extensions/sql/impl/parser/impl" compileJavacc { dependsOn generateFmppSources - inputDirectory = file(generateFmppOutputDir) + inputDirectory = file(generateFmppJavaccRoot) + outputDirectory = file(generatedJavaccPackageDir) arguments = [grammar_encoding: "UTF-8", static: "false", lookahead: "2"] } @@ -116,3 +129,14 @@ shadowJar { // module. relocate "org.codehaus", "org.apache.beam.sdks.java.extensions.sql.repackaged.org.codehaus" } + +// Help IntelliJ find the fmpp bits +idea { + module { + sourceDirs += file(generateFmppOutputDir) + generatedSourceDirs += file(generateFmppOutputDir) + + sourceDirs += file(generatedJavaccSourceDir) + generatedSourceDirs += file(generatedJavaccSourceDir) + } +} diff --git a/sdks/java/extensions/sql/pom.xml b/sdks/java/extensions/sql/pom.xml index 28c7b5f10f56..2de6f0ee380a 100644 --- a/sdks/java/extensions/sql/pom.xml +++ b/sdks/java/extensions/sql/pom.xml @@ -35,9 +35,15 @@ ${maven.build.timestamp} yyyy-MM-dd HH:mm - 1.13.0 + 1.15.0 1.10.0 1.9.5 + + + + UTF-16LE @@ -68,27 +74,11 @@ - - - - org.apache.maven.plugins - maven-checkstyle-plugin - - - ${project.basedir}/src/test/ - ${project.build.sourceDirectory} - - - - - org.apache.maven.plugins maven-compiler-plugin - 1.8 - 1.8 false @@ -219,7 +209,12 @@ org.apache.maven.plugins maven-surefire-plugin - -da + -da + + ${calcite.charset} + ${calcite.charset} + ${calcite.charset}$en_US + @@ -276,18 +271,21 @@ com.google.protobuf + org.apache.${renderedArtifactId}.repackaged.com.google.protobuf org.apache.calcite + org.apache.${renderedArtifactId}.repackaged.org.apache.calcite org.codehaus + org.apache.${renderedArtifactId}.repackaged.org.codehaus @@ -407,12 +405,17 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test org.mockito - mockito-all + mockito-core ${mockito.version} test diff --git a/sdks/java/extensions/sql/src/main/codegen/data/Parser.tdd b/sdks/java/extensions/sql/src/main/codegen/data/Parser.tdd index 09a53799042f..1afa73d255bb 100644 --- a/sdks/java/extensions/sql/src/main/codegen/data/Parser.tdd +++ b/sdks/java/extensions/sql/src/main/codegen/data/Parser.tdd @@ -36,7 +36,8 @@ # List of methods for parsing custom SQL statements. statementParserMethods: [ - "SqlCreateTable()" + "SqlCreateTable()", + "SqlDropTable()" ] # List of methods for parsing custom literals. diff --git a/sdks/java/extensions/sql/src/main/codegen/includes/parserImpls.ftl b/sdks/java/extensions/sql/src/main/codegen/includes/parserImpls.ftl index 136c7283845c..ce1d2ae7ce20 100644 --- a/sdks/java/extensions/sql/src/main/codegen/includes/parserImpls.ftl +++ b/sdks/java/extensions/sql/src/main/codegen/includes/parserImpls.ftl @@ -87,3 +87,20 @@ SqlNode SqlCreateTable() : location, tbl_properties, select); } } + +/** + * DROP TABLE table_name + */ +SqlNode SqlDropTable() : +{ + SqlParserPos pos; + SqlIdentifier tblName; +} +{ + { pos = getPos(); } + + tblName = SimpleIdentifier() { + return new SqlDropTable(pos, tblName); + } +} + diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamRecordSqlType.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamRecordSqlType.java deleted file mode 100644 index 9cb60c9c72a7..000000000000 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamRecordSqlType.java +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.extensions.sql; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import java.math.BigDecimal; -import java.sql.Types; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Date; -import java.util.GregorianCalendar; -import java.util.List; -import java.util.Map; -import org.apache.beam.sdk.coders.BigDecimalCoder; -import org.apache.beam.sdk.coders.BigEndianIntegerCoder; -import org.apache.beam.sdk.coders.BigEndianLongCoder; -import org.apache.beam.sdk.coders.ByteCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper.BooleanCoder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper.DateCoder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper.DoubleCoder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper.FloatCoder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper.ShortCoder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper.TimeCoder; -import org.apache.beam.sdk.values.BeamRecord; -import org.apache.beam.sdk.values.BeamRecordType; - -/** - * Type provider for {@link BeamRecord} with SQL types. - * - *

Limited SQL types are supported now, visit - * data types - * for more details. - * - */ -public class BeamRecordSqlType extends BeamRecordType { - private static final Map JAVA_CLASSES = ImmutableMap - .builder() - .put(Types.TINYINT, Byte.class) - .put(Types.SMALLINT, Short.class) - .put(Types.INTEGER, Integer.class) - .put(Types.BIGINT, Long.class) - .put(Types.FLOAT, Float.class) - .put(Types.DOUBLE, Double.class) - .put(Types.DECIMAL, BigDecimal.class) - .put(Types.BOOLEAN, Boolean.class) - .put(Types.CHAR, String.class) - .put(Types.VARCHAR, String.class) - .put(Types.TIME, GregorianCalendar.class) - .put(Types.DATE, Date.class) - .put(Types.TIMESTAMP, Date.class) - .build(); - - private static final Map CODERS = ImmutableMap - .builder() - .put(Types.TINYINT, ByteCoder.of()) - .put(Types.SMALLINT, ShortCoder.of()) - .put(Types.INTEGER, BigEndianIntegerCoder.of()) - .put(Types.BIGINT, BigEndianLongCoder.of()) - .put(Types.FLOAT, FloatCoder.of()) - .put(Types.DOUBLE, DoubleCoder.of()) - .put(Types.DECIMAL, BigDecimalCoder.of()) - .put(Types.BOOLEAN, BooleanCoder.of()) - .put(Types.CHAR, StringUtf8Coder.of()) - .put(Types.VARCHAR, StringUtf8Coder.of()) - .put(Types.TIME, TimeCoder.of()) - .put(Types.DATE, DateCoder.of()) - .put(Types.TIMESTAMP, DateCoder.of()) - .build(); - - public List fieldTypes; - - protected BeamRecordSqlType(List fieldsName, List fieldsCoder) { - super(fieldsName, fieldsCoder); - } - - private BeamRecordSqlType(List fieldsName, List fieldTypes - , List fieldsCoder) { - super(fieldsName, fieldsCoder); - this.fieldTypes = fieldTypes; - } - - public static BeamRecordSqlType create(List fieldNames, - List fieldTypes) { - if (fieldNames.size() != fieldTypes.size()) { - throw new IllegalStateException("the sizes of 'dataType' and 'fieldTypes' must match."); - } - - List fieldCoders = new ArrayList<>(fieldTypes.size()); - - for (Integer fieldType : fieldTypes) { - if (!CODERS.containsKey(fieldType)) { - throw new UnsupportedOperationException( - "Data type: " + fieldType + " not supported yet!"); - } - - fieldCoders.add(CODERS.get(fieldType)); - } - - return new BeamRecordSqlType(fieldNames, fieldTypes, fieldCoders); - } - - @Override - public void validateValueType(int index, Object fieldValue) throws IllegalArgumentException { - if (null == fieldValue) {// no need to do type check for NULL value - return; - } - - int fieldType = fieldTypes.get(index); - Class javaClazz = JAVA_CLASSES.get(fieldType); - if (javaClazz == null) { - throw new IllegalArgumentException("Data type: " + fieldType + " not supported yet!"); - } - - if (!fieldValue.getClass().equals(javaClazz)) { - throw new IllegalArgumentException( - String.format("[%s](%s) doesn't match type [%s]", - fieldValue, fieldValue.getClass(), fieldType) - ); - } - } - - public List getFieldTypes() { - return Collections.unmodifiableList(fieldTypes); - } - - public Integer getFieldTypeByIndex(int index) { - return fieldTypes.get(index); - } - - @Override - public boolean equals(Object obj) { - if (obj != null && obj instanceof BeamRecordSqlType) { - BeamRecordSqlType ins = (BeamRecordSqlType) obj; - return fieldTypes.equals(ins.getFieldTypes()) && getFieldNames().equals(ins.getFieldNames()); - } else { - return false; - } - } - - @Override - public int hashCode() { - return 31 * getFieldNames().hashCode() + getFieldTypes().hashCode(); - } - - @Override - public String toString() { - return "BeamRecordSqlType [fieldNames=" + getFieldNames() - + ", fieldTypes=" + fieldTypes + "]"; - } - - public static Builder builder() { - return new Builder(); - } - - /** - * Builder class to construct {@link BeamRecordSqlType}. - */ - public static class Builder { - - private ImmutableList.Builder fieldNames; - private ImmutableList.Builder fieldTypes; - - public Builder withField(String fieldName, Integer fieldType) { - fieldNames.add(fieldName); - fieldTypes.add(fieldType); - return this; - } - - public Builder withTinyIntField(String fieldName) { - return withField(fieldName, Types.TINYINT); - } - - public Builder withSmallIntField(String fieldName) { - return withField(fieldName, Types.SMALLINT); - } - - public Builder withIntegerField(String fieldName) { - return withField(fieldName, Types.INTEGER); - } - - public Builder withBigIntField(String fieldName) { - return withField(fieldName, Types.BIGINT); - } - - public Builder withFloatField(String fieldName) { - return withField(fieldName, Types.FLOAT); - } - - public Builder withDoubleField(String fieldName) { - return withField(fieldName, Types.DOUBLE); - } - - public Builder withDecimalField(String fieldName) { - return withField(fieldName, Types.DECIMAL); - } - - public Builder withBooleanField(String fieldName) { - return withField(fieldName, Types.BOOLEAN); - } - - public Builder withCharField(String fieldName) { - return withField(fieldName, Types.CHAR); - } - - public Builder withVarcharField(String fieldName) { - return withField(fieldName, Types.VARCHAR); - } - - public Builder withTimeField(String fieldName) { - return withField(fieldName, Types.TIME); - } - - public Builder withDateField(String fieldName) { - return withField(fieldName, Types.DATE); - } - - public Builder withTimestampField(String fieldName) { - return withField(fieldName, Types.TIMESTAMP); - } - - private Builder() { - this.fieldNames = ImmutableList.builder(); - this.fieldTypes = ImmutableList.builder(); - } - - public BeamRecordSqlType build() { - return create(fieldNames.build(), fieldTypes.build()); - } - } -} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java index 78b914d179ac..a12673693f4a 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java @@ -18,22 +18,11 @@ package org.apache.beam.sdk.extensions.sql; import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.coders.BeamRecordCoder; -import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; -import org.apache.beam.sdk.extensions.sql.impl.schema.BeamPCollectionTable; -import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlSelect; -import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.tools.RelConversionException; -import org.apache.calcite.tools.ValidationException; /** * {@code BeamSql} is the DSL interface of BeamSQL. It translates a SQL query as a @@ -47,22 +36,22 @@ Pipeline p = Pipeline.create(options); //create table from TextIO; -PCollection inputTableA = p.apply(TextIO.read().from("/my/input/patha")) - .apply(...); -PCollection inputTableB = p.apply(TextIO.read().from("/my/input/pathb")) - .apply(...); +PCollection inputTableA = p.apply(TextIO.read().from("/my/input/patha")).apply(...); +PCollection inputTableB = p.apply(TextIO.read().from("/my/input/pathb")).apply(...); //run a simple query, and register the output as a table in BeamSql; String sql1 = "select MY_FUNC(c1), c2 from PCOLLECTION"; -PCollection outputTableA = inputTableA.apply( - BeamSql.query(sql1) - .withUdf("MY_FUNC", MY_FUNC.class, "FUNC")); +PCollection outputTableA = inputTableA.apply( + BeamSql + .query(sql1) + .registerUdf("MY_FUNC", MY_FUNC.class, "FUNC"); //run a JOIN with one table from TextIO, and one table from another query -PCollection outputTableB = PCollectionTuple.of( - new TupleTag("TABLE_O_A"), outputTableA) - .and(new TupleTag("TABLE_B"), inputTableB) - .apply(BeamSql.queryMulti("select * from TABLE_O_A JOIN TABLE_B where ...")); +PCollection outputTableB = + PCollectionTuple + .of(new TupleTag<>("TABLE_O_A"), outputTableA) + .and(new TupleTag<>("TABLE_B"), inputTableB) + .apply(BeamSql.query("select * from TABLE_O_A JOIN TABLE_B where ...")); //output the final result with TextIO outputTableB.apply(...).apply(TextIO.write().to("/my/output/path")); @@ -74,13 +63,20 @@ */ @Experimental public class BeamSql { + /** - * Transforms a SQL query into a {@link PTransform} representing an equivalent execution plan. + * Returns a {@link QueryTransform} representing an equivalent execution plan. + * + *

The {@link QueryTransform} can be applied to a {@link PCollection} + * or {@link PCollectionTuple} representing all the input tables. + * + *

The {@link PTransform} outputs a {@link PCollection} of {@link Row}. * - *

The returned {@link PTransform} can be applied to a {@link PCollectionTuple} representing - * all the input tables and results in a {@code PCollection} representing the output - * table. The {@link PCollectionTuple} contains the mapping from {@code table names} to - * {@code PCollection}, each representing an input table. + *

If the {@link PTransform} is applied to {@link PCollection} then it gets registered with + * name PCOLLECTION. + * + *

If the {@link PTransform} is applied to {@link PCollectionTuple} then + * {@link TupleTag#getId()} is used as the corresponding {@link PCollection}s name. * *

    *
  • If the sql query only uses a subset of tables from the upstream {@link PCollectionTuple}, @@ -91,159 +87,7 @@ public class BeamSql { * of the current query call.
  • *
*/ - public static QueryTransform queryMulti(String sqlQuery) { - return new QueryTransform(sqlQuery); - } - - /** - * Transforms a SQL query into a {@link PTransform} representing an equivalent execution plan. - * - *

This is a simplified form of {@link #queryMulti(String)} where the query must reference - * a single input table. - * - *

Make sure to query it from a static table name PCOLLECTION. - */ - public static SimpleQueryTransform query(String sqlQuery) { - return new SimpleQueryTransform(sqlQuery); - } - - /** - * A {@link PTransform} representing an execution plan for a SQL query. - * - *

The table names in the input {@code PCollectionTuple} are only valid during the current - * query. - */ - public static class QueryTransform extends - PTransform> { - private BeamSqlEnv beamSqlEnv = new BeamSqlEnv(); - private String sqlQuery; - - public QueryTransform(String sqlQuery) { - this.sqlQuery = sqlQuery; - } - - /** - * register a UDF function used in this query. - * - *

Refer to {@link BeamSqlUdf} for more about how to implement a UDF in BeamSql. - */ - public QueryTransform withUdf(String functionName, Class clazz){ - beamSqlEnv.registerUdf(functionName, clazz); - return this; - } - /** - * register {@link SerializableFunction} as a UDF function used in this query. - * Note, {@link SerializableFunction} must have a constructor without arguments. - */ - public QueryTransform withUdf(String functionName, SerializableFunction sfn){ - beamSqlEnv.registerUdf(functionName, sfn); - return this; - } - - /** - * register a {@link CombineFn} as UDAF function used in this query. - */ - public QueryTransform withUdaf(String functionName, CombineFn combineFn){ - beamSqlEnv.registerUdaf(functionName, combineFn); - return this; - } - - @Override - public PCollection expand(PCollectionTuple input) { - registerTables(input); - - BeamRelNode beamRelNode = null; - try { - beamRelNode = beamSqlEnv.getPlanner().convertToBeamRel(sqlQuery); - } catch (ValidationException | RelConversionException | SqlParseException e) { - throw new IllegalStateException(e); - } - - try { - return beamRelNode.buildBeamPipeline(input, beamSqlEnv); - } catch (Exception e) { - throw new IllegalStateException(e); - } - } - - //register tables, related with input PCollections. - private void registerTables(PCollectionTuple input){ - for (TupleTag sourceTag : input.getAll().keySet()) { - PCollection sourceStream = (PCollection) input.get(sourceTag); - BeamRecordCoder sourceCoder = (BeamRecordCoder) sourceStream.getCoder(); - - beamSqlEnv.registerTable(sourceTag.getId(), - new BeamPCollectionTable(sourceStream, - (BeamRecordSqlType) sourceCoder.getRecordType())); - } - } - } - - /** - * A {@link PTransform} representing an execution plan for a SQL query referencing - * a single table. - */ - public static class SimpleQueryTransform - extends PTransform, PCollection> { - private static final String PCOLLECTION_TABLE_NAME = "PCOLLECTION"; - private QueryTransform delegate; - - public SimpleQueryTransform(String sqlQuery) { - this.delegate = new QueryTransform(sqlQuery); - } - - /** - * register a UDF function used in this query. - * - *

Refer to {@link BeamSqlUdf} for more about how to implement a UDAF in BeamSql. - */ - public SimpleQueryTransform withUdf(String functionName, Class clazz){ - delegate.withUdf(functionName, clazz); - return this; - } - - /** - * register {@link SerializableFunction} as a UDF function used in this query. - * Note, {@link SerializableFunction} must have a constructor without arguments. - */ - public SimpleQueryTransform withUdf(String functionName, SerializableFunction sfn){ - delegate.withUdf(functionName, sfn); - return this; - } - - /** - * register a {@link CombineFn} as UDAF function used in this query. - */ - public SimpleQueryTransform withUdaf(String functionName, CombineFn combineFn){ - delegate.withUdaf(functionName, combineFn); - return this; - } - - private void validateQuery() { - SqlNode sqlNode; - try { - sqlNode = delegate.beamSqlEnv.getPlanner().parseQuery(delegate.sqlQuery); - delegate.beamSqlEnv.getPlanner().getPlanner().close(); - } catch (SqlParseException e) { - throw new IllegalStateException(e); - } - - if (sqlNode instanceof SqlSelect) { - SqlSelect select = (SqlSelect) sqlNode; - String tableName = select.getFrom().toString(); - if (!tableName.equalsIgnoreCase(PCOLLECTION_TABLE_NAME)) { - throw new IllegalStateException("Use fixed table name " + PCOLLECTION_TABLE_NAME); - } - } else { - throw new UnsupportedOperationException( - "Sql operation: " + sqlNode.toString() + " is not supported!"); - } - } - - @Override - public PCollection expand(PCollection input) { - validateQuery(); - return PCollectionTuple.of(new TupleTag<>(PCOLLECTION_TABLE_NAME), input).apply(delegate); - } + public static QueryTransform query(String sqlQuery) { + return QueryTransform.withQueryString(sqlQuery); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlCli.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlCli.java index 2cadb0eca984..eadda35fd5f6 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlCli.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlCli.java @@ -24,13 +24,14 @@ import org.apache.beam.sdk.extensions.sql.impl.parser.BeamSqlParser; import org.apache.beam.sdk.extensions.sql.impl.parser.ParserUtils; import org.apache.beam.sdk.extensions.sql.impl.parser.SqlCreateTable; +import org.apache.beam.sdk.extensions.sql.impl.parser.SqlDropTable; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; import org.apache.beam.sdk.extensions.sql.meta.Table; import org.apache.beam.sdk.extensions.sql.meta.store.MetaStore; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.sql.SqlNode; @@ -80,6 +81,8 @@ public void execute(String sqlString) throws Exception { if (sqlNode instanceof SqlCreateTable) { handleCreateTable((SqlCreateTable) sqlNode, metaStore); + } else if (sqlNode instanceof SqlDropTable) { + handleDropTable((SqlDropTable) sqlNode); } else { PipelineOptions options = PipelineOptionsFactory.fromArgs(new String[] {}).withValidation() .as(PipelineOptions.class); @@ -103,12 +106,17 @@ private void handleCreateTable(SqlCreateTable stmt, MetaStore store) { env.registerTable(table.getName(), metaStore.buildBeamSqlTable(table.getName())); } + private void handleDropTable(SqlDropTable stmt) { + metaStore.dropTable(stmt.tableName()); + env.deregisterTable(stmt.tableName()); + } + /** * compile SQL, and return a {@link Pipeline}. */ - private static PCollection compilePipeline(String sqlStatement, Pipeline basePipeline, - BeamSqlEnv sqlEnv) throws Exception { - PCollection resultStream = + private static PCollection compilePipeline(String sqlStatement, Pipeline basePipeline, + BeamSqlEnv sqlEnv) throws Exception { + PCollection resultStream = sqlEnv.getPlanner().compileBeamPipeline(sqlStatement, basePipeline, sqlEnv); return resultStream; } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlSeekableTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlSeekableTable.java index dbfe119ccc5d..d274dd98dd9c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlSeekableTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlSeekableTable.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.List; import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** * A seekable table converts a JOIN operator to an inline lookup. @@ -29,7 +29,7 @@ @Experimental public interface BeamSqlSeekableTable extends Serializable{ /** - * return a list of {@code BeamRecord} with given key set. + * return a list of {@code Row} with given key set. */ - List seekRecord(BeamRecord lookupSubRecord); + List seekRow(Row lookupSubRow); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlTable.java index df1e1627f526..efcff21d16f3 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlTable.java @@ -21,9 +21,10 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.sql.impl.schema.BeamIOType; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** * This interface defines a Beam Sql Table. @@ -39,16 +40,16 @@ public interface BeamSqlTable { * create a {@code PCollection} from source. * */ - PCollection buildIOReader(Pipeline pipeline); + PCollection buildIOReader(Pipeline pipeline); /** * create a {@code IO.write()} instance to write to target. * */ - PTransform, PDone> buildIOWriter(); + PTransform, PDone> buildIOWriter(); /** * Get the schema info of the table. */ - BeamRecordSqlType getRowType(); + RowType getRowType(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/QueryTransform.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/QueryTransform.java new file mode 100644 index 000000000000..8a7335f7d37f --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/QueryTransform.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql; + + +import static org.apache.beam.sdk.extensions.sql.QueryValidationHelper.validateQuery; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; + +/** + * A {@link PTransform} representing an execution plan for a SQL query. + * + *

The table names in the input {@code PCollectionTuple} are only valid during the current + * query. + */ +@AutoValue +public abstract class QueryTransform extends PTransform> { + static final String PCOLLECTION_NAME = "PCOLLECTION"; + + abstract String queryString(); + abstract List udfDefinitions(); + abstract List udafDefinitions(); + + @Override + public PCollection expand(PInput input) { + PCollectionTuple inputTuple = toPCollectionTuple(input); + + BeamSqlEnv sqlEnv = new BeamSqlEnv(); + + if (input instanceof PCollection) { + validateQuery(sqlEnv, queryString()); + } + + sqlEnv.registerPCollectionTuple(inputTuple); + registerFunctions(sqlEnv); + + try { + return + sqlEnv + .getPlanner() + .convertToBeamRel(queryString()) + .buildBeamPipeline(inputTuple, sqlEnv); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + + private PCollectionTuple toPCollectionTuple(PInput inputs) { + return (inputs instanceof PCollection) + ? PCollectionTuple.of(new TupleTag<>(PCOLLECTION_NAME), (PCollection) inputs) + : tupleOfAllInputs(inputs.getPipeline(), inputs.expand()); + } + + private PCollectionTuple tupleOfAllInputs( + Pipeline pipeline, + Map, PValue> taggedInputs) { + + PCollectionTuple tuple = PCollectionTuple.empty(pipeline); + + for (Map.Entry, PValue> input : taggedInputs.entrySet()) { + tuple = tuple.and( + new TupleTag<>(input.getKey().getId()), + (PCollection) input.getValue()); + } + + return tuple; + } + + private void registerFunctions(BeamSqlEnv sqlEnv) { + udfDefinitions() + .forEach(udf -> sqlEnv.registerUdf(udf.udfName(), udf.clazz(), udf.methodName())); + + udafDefinitions() + .forEach(udaf -> sqlEnv.registerUdaf(udaf.udafName(), udaf.combineFn())); + } + + /** + * Creates a {@link QueryTransform} with SQL {@code queryString}. + */ + public static QueryTransform withQueryString(String queryString) { + return + builder() + .setQueryString(queryString) + .setUdafDefinitions(Collections.emptyList()) + .setUdfDefinitions(Collections.emptyList()) + .build(); + } + + /** + * register a UDF function used in this query. + * + *

Refer to {@link BeamSqlUdf} for more about how to implement a UDF in BeamSql. + */ + public QueryTransform registerUdf(String functionName, Class clazz) { + return registerUdf(functionName, clazz, BeamSqlUdf.UDF_METHOD); + } + + /** + * Register {@link SerializableFunction} as a UDF function used in this query. + * Note, {@link SerializableFunction} must have a constructor without arguments. + */ + public QueryTransform registerUdf(String functionName, SerializableFunction sfn) { + return registerUdf(functionName, sfn.getClass(), "apply"); + } + + private QueryTransform registerUdf(String functionName, Class clazz, String method) { + ImmutableList newUdfDefinitions = + ImmutableList + .builder() + .addAll(udfDefinitions()) + .add(UdfDefinition.of(functionName, clazz, method)) + .build(); + + return toBuilder().setUdfDefinitions(newUdfDefinitions).build(); + } + + /** + * register a {@link Combine.CombineFn} as UDAF function used in this query. + */ + public QueryTransform registerUdaf(String functionName, Combine.CombineFn combineFn) { + ImmutableList newUdafs = + ImmutableList + .builder() + .addAll(udafDefinitions()) + .add(UdafDefinition.of(functionName, combineFn)) + .build(); + + return toBuilder().setUdafDefinitions(newUdafs).build(); + } + + abstract Builder toBuilder(); + + static Builder builder() { + return new AutoValue_QueryTransform.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setQueryString(String queryString); + abstract Builder setUdfDefinitions(List udfDefinitions); + abstract Builder setUdafDefinitions(List udafDefinitions); + + abstract QueryTransform build(); + } + + @AutoValue + abstract static class UdfDefinition { + abstract String udfName(); + abstract Class clazz(); + abstract String methodName(); + + static UdfDefinition of(String udfName, Class clazz, String methodName) { + return new AutoValue_QueryTransform_UdfDefinition(udfName, clazz, methodName); + } + } + + @AutoValue + abstract static class UdafDefinition { + abstract String udafName(); + abstract Combine.CombineFn combineFn(); + + static UdafDefinition of(String udafName, Combine.CombineFn combineFn) { + return new AutoValue_QueryTransform_UdafDefinition(udafName, combineFn); + } + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/QueryValidationHelper.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/QueryValidationHelper.java new file mode 100644 index 000000000000..ed893daa4127 --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/QueryValidationHelper.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql; + +import static org.apache.beam.sdk.extensions.sql.QueryTransform.PCOLLECTION_NAME; + +import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.parser.SqlParseException; + +/** + * QueryValidationHelper. + */ +class QueryValidationHelper { + + static void validateQuery(BeamSqlEnv sqlEnv, String queryString) { + SqlNode sqlNode; + + try { + sqlNode = sqlEnv.getPlanner().parseQuery(queryString); + sqlEnv.getPlanner().getPlanner().close(); + } catch (SqlParseException e) { + throw new IllegalStateException(e); + } + + if (!(sqlNode instanceof SqlSelect)) { + throw new UnsupportedOperationException( + "Sql operation " + sqlNode.toString() + " is not supported"); + } + + if (!PCOLLECTION_NAME.equalsIgnoreCase(((SqlSelect) sqlNode).getFrom().toString())) { + throw new IllegalStateException("Use " + PCOLLECTION_NAME + " as table name" + + " when selecting from single PCollection." + + " Use PCollectionTuple to explicitly " + + "name the input PCollections"); + } + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlRecordHelper.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/RowHelper.java similarity index 81% rename from sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlRecordHelper.java rename to sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/RowHelper.java index 870165d70fa2..d4be511bf97e 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlRecordHelper.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/RowHelper.java @@ -26,27 +26,23 @@ import java.util.Date; import java.util.GregorianCalendar; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BigDecimalCoder; import org.apache.beam.sdk.coders.BigEndianLongCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.CustomCoder; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** - * A {@link Coder} encodes {@link BeamRecord}. + * Atomic {@link Coder}s for {@link Row} fields for SQL types. */ @Experimental -public class BeamSqlRecordHelper { - - public static BeamRecordSqlType getSqlRecordType(BeamRecord record) { - return (BeamRecordSqlType) record.getDataType(); - } +public class RowHelper { /** * {@link Coder} for Java type {@link Short}. */ - public static class ShortCoder extends CustomCoder { + public static class ShortCoder extends AtomicCoder { private static final ShortCoder INSTANCE = new ShortCoder(); public static ShortCoder of() { @@ -65,15 +61,12 @@ public void encode(Short value, OutputStream outStream) throws CoderException, I public Short decode(InputStream inStream) throws CoderException, IOException { return new DataInputStream(inStream).readShort(); } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - } } + /** * {@link Coder} for Java type {@link Float}, it's stored as {@link BigDecimal}. */ - public static class FloatCoder extends CustomCoder { + public static class FloatCoder extends AtomicCoder { private static final FloatCoder INSTANCE = new FloatCoder(); private static final BigDecimalCoder CODER = BigDecimalCoder.of(); @@ -93,15 +86,12 @@ public void encode(Float value, OutputStream outStream) throws CoderException, I public Float decode(InputStream inStream) throws CoderException, IOException { return CODER.decode(inStream).floatValue(); } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - } } + /** * {@link Coder} for Java type {@link Double}, it's stored as {@link BigDecimal}. */ - public static class DoubleCoder extends CustomCoder { + public static class DoubleCoder extends AtomicCoder { private static final DoubleCoder INSTANCE = new DoubleCoder(); private static final BigDecimalCoder CODER = BigDecimalCoder.of(); @@ -121,16 +111,12 @@ public void encode(Double value, OutputStream outStream) throws CoderException, public Double decode(InputStream inStream) throws CoderException, IOException { return CODER.decode(inStream).doubleValue(); } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - } } /** * {@link Coder} for Java type {@link GregorianCalendar}, it's stored as {@link Long}. */ - public static class TimeCoder extends CustomCoder { + public static class TimeCoder extends AtomicCoder { private static final BigEndianLongCoder longCoder = BigEndianLongCoder.of(); private static final TimeCoder INSTANCE = new TimeCoder(); @@ -153,15 +139,12 @@ public GregorianCalendar decode(InputStream inStream) throws CoderException, IOE calendar.setTime(new Date(longCoder.decode(inStream))); return calendar; } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - } } + /** * {@link Coder} for Java type {@link Date}, it's stored as {@link Long}. */ - public static class DateCoder extends CustomCoder { + public static class DateCoder extends AtomicCoder { private static final BigEndianLongCoder longCoder = BigEndianLongCoder.of(); private static final DateCoder INSTANCE = new DateCoder(); @@ -181,16 +164,12 @@ public void encode(Date value, OutputStream outStream) throws CoderException, IO public Date decode(InputStream inStream) throws CoderException, IOException { return new Date(longCoder.decode(inStream)); } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - } } /** * {@link Coder} for Java type {@link Boolean}. */ - public static class BooleanCoder extends CustomCoder { + public static class BooleanCoder extends AtomicCoder { private static final BooleanCoder INSTANCE = new BooleanCoder(); public static BooleanCoder of() { @@ -209,9 +188,5 @@ public void encode(Boolean value, OutputStream outStream) throws CoderException, public Boolean decode(InputStream inStream) throws CoderException, IOException { return new DataInputStream(inStream).readBoolean(); } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - } } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/RowSqlType.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/RowSqlType.java new file mode 100644 index 000000000000..1208c42f6acf --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/RowSqlType.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql; + +import static org.apache.beam.sdk.values.RowType.toRowType; + +import com.google.common.collect.ImmutableList; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; + + +/** + * Type builder for {@link Row} with SQL types. + * + *

Limited SQL types are supported now, visit + * data types + * for more details. + * + *

SQL types are represented by instances of {@link SqlTypeCoder}, see {@link SqlTypeCoders}. + */ +public class RowSqlType { + public static Builder builder() { + return new Builder(); + } + + /** + * Builder class to construct {@link RowType}. + */ + public static class Builder { + + private ImmutableList.Builder fields; + + public Builder withField(String fieldName, SqlTypeCoder fieldCoder) { + fields.add(RowType.newField(fieldName, fieldCoder)); + return this; + } + + public Builder withTinyIntField(String fieldName) { + return withField(fieldName, SqlTypeCoders.TINYINT); + } + + public Builder withSmallIntField(String fieldName) { + return withField(fieldName, SqlTypeCoders.SMALLINT); + } + + public Builder withIntegerField(String fieldName) { + return withField(fieldName, SqlTypeCoders.INTEGER); + } + + public Builder withBigIntField(String fieldName) { + return withField(fieldName, SqlTypeCoders.BIGINT); + } + + public Builder withFloatField(String fieldName) { + return withField(fieldName, SqlTypeCoders.FLOAT); + } + + public Builder withDoubleField(String fieldName) { + return withField(fieldName, SqlTypeCoders.DOUBLE); + } + + public Builder withDecimalField(String fieldName) { + return withField(fieldName, SqlTypeCoders.DECIMAL); + } + + public Builder withBooleanField(String fieldName) { + return withField(fieldName, SqlTypeCoders.BOOLEAN); + } + + public Builder withCharField(String fieldName) { + return withField(fieldName, SqlTypeCoders.CHAR); + } + + public Builder withVarcharField(String fieldName) { + return withField(fieldName, SqlTypeCoders.VARCHAR); + } + + public Builder withTimeField(String fieldName) { + return withField(fieldName, SqlTypeCoders.TIME); + } + + public Builder withDateField(String fieldName) { + return withField(fieldName, SqlTypeCoders.DATE); + } + + public Builder withTimestampField(String fieldName) { + return withField(fieldName, SqlTypeCoders.TIMESTAMP); + } + + private Builder() { + this.fields = ImmutableList.builder(); + } + + public RowType build() { + return fields.build().stream().collect(toRowType()); + } + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlRowTypeFactory.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlRowTypeFactory.java new file mode 100644 index 000000000000..d4d93087efd1 --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlRowTypeFactory.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.math.BigDecimal; +import java.util.Date; +import java.util.GregorianCalendar; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.values.RowType; +import org.apache.beam.sdk.values.reflect.FieldValueGetter; +import org.apache.beam.sdk.values.reflect.RowTypeFactory; + +/** + * For internal use only; no backwards-compatibility guarantees. + * + *

Implementation of the {@link RowTypeFactory} to return instances + * of {@link RowType} with coders specific for SQL types, see {@link SqlTypeCoders}. + */ +@Internal +public class SqlRowTypeFactory implements RowTypeFactory { + + static final ImmutableMap SQL_CODERS = ImmutableMap + .builder() + .put(Byte.class, SqlTypeCoders.TINYINT) + .put(Short.class, SqlTypeCoders.SMALLINT) + .put(Integer.class, SqlTypeCoders.INTEGER) + .put(Long.class, SqlTypeCoders.BIGINT) + .put(Float.class, SqlTypeCoders.FLOAT) + .put(Double.class, SqlTypeCoders.DOUBLE) + .put(BigDecimal.class, SqlTypeCoders.DECIMAL) + .put(Boolean.class, SqlTypeCoders.BOOLEAN) + .put(String.class, SqlTypeCoders.VARCHAR) + .put(GregorianCalendar.class, SqlTypeCoders.TIME) + .put(Date.class, SqlTypeCoders.TIMESTAMP) + .build(); + + @Override + public RowType createRowType(Iterable getters) { + return + RowType + .fromNamesAndCoders( + fieldNames(getters), + sqlCoders(getters)); + } + + private List fieldNames(Iterable getters) { + ImmutableList.Builder names = ImmutableList.builder(); + + for (FieldValueGetter fieldValueGetter : getters) { + names.add(fieldValueGetter.name()); + } + + return names.build(); + } + + private List sqlCoders(Iterable getters) { + ImmutableList.Builder sqlCoders = ImmutableList.builder(); + + for (FieldValueGetter fieldValueGetter : getters) { + if (!SQL_CODERS.containsKey(fieldValueGetter.type())) { + throw new UnsupportedOperationException( + "Field type " + fieldValueGetter.type().getSimpleName() + " is not supported yet"); + } + + sqlCoders.add(SQL_CODERS.get(fieldValueGetter.type())); + } + + return sqlCoders.build(); + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlTypeCoder.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlTypeCoder.java new file mode 100644 index 000000000000..8c311f548282 --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlTypeCoder.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; + +/** + * Base class for coders for supported SQL types. + */ +public abstract class SqlTypeCoder extends CustomCoder { + + @Override + public void encode(Object value, OutputStream outStream) throws CoderException, IOException { + delegateCoder().encode(value, outStream); + } + + @Override + public Object decode(InputStream inStream) throws CoderException, IOException { + return delegateCoder().decode(inStream); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + delegateCoder().verifyDeterministic(); + } + + protected abstract Coder delegateCoder(); + + @Override + public boolean equals(Object other) { + return other != null && this.getClass().equals(other.getClass()); + + } + + @Override + public int hashCode() { + return this.getClass().hashCode(); + } + + static class SqlTinyIntCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return ByteCoder.of(); + } + } + + static class SqlSmallIntCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.ShortCoder.of(); + } + } + + static class SqlIntegerCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return BigEndianIntegerCoder.of(); + } + } + + static class SqlBigIntCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return BigEndianLongCoder.of(); + } + } + + static class SqlFloatCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.FloatCoder.of(); + } + } + + static class SqlDoubleCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.DoubleCoder.of(); + } + } + + static class SqlDecimalCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return BigDecimalCoder.of(); + } + } + + static class SqlBooleanCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.BooleanCoder.of(); + } + } + + static class SqlCharCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return StringUtf8Coder.of(); + } + } + + static class SqlVarCharCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return StringUtf8Coder.of(); + } + } + + static class SqlTimeCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.TimeCoder.of(); + } + } + + static class SqlDateCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.DateCoder.of(); + } + } + + static class SqlTimestampCoder extends SqlTypeCoder { + @Override + protected Coder delegateCoder() { + return RowHelper.DateCoder.of(); + } + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlTypeCoders.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlTypeCoders.java new file mode 100644 index 000000000000..b18d42a5b93e --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/SqlTypeCoders.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql; + +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlBigIntCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlBooleanCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlCharCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlDateCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlDecimalCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlDoubleCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlFloatCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlSmallIntCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlTimeCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlTimestampCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlTinyIntCoder; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlVarCharCoder; + +import com.google.common.collect.ImmutableSet; +import java.util.Set; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder.SqlIntegerCoder; + +/** + * Coders for SQL types supported in Beam. + * + *

Currently SQL coders are subclasses of {@link SqlTypeCoder}. + */ +public class SqlTypeCoders { + public static final SqlTypeCoder TINYINT = new SqlTinyIntCoder(); + public static final SqlTypeCoder SMALLINT = new SqlSmallIntCoder(); + public static final SqlTypeCoder INTEGER = new SqlIntegerCoder(); + public static final SqlTypeCoder BIGINT = new SqlBigIntCoder(); + public static final SqlTypeCoder FLOAT = new SqlFloatCoder(); + public static final SqlTypeCoder DOUBLE = new SqlDoubleCoder(); + public static final SqlTypeCoder DECIMAL = new SqlDecimalCoder(); + public static final SqlTypeCoder BOOLEAN = new SqlBooleanCoder(); + public static final SqlTypeCoder CHAR = new SqlCharCoder(); + public static final SqlTypeCoder VARCHAR = new SqlVarCharCoder(); + public static final SqlTypeCoder TIME = new SqlTimeCoder(); + public static final SqlTypeCoder DATE = new SqlDateCoder(); + public static final SqlTypeCoder TIMESTAMP = new SqlTimestampCoder(); + + public static final Set NUMERIC_TYPES = + ImmutableSet.of( + SqlTypeCoders.TINYINT, + SqlTypeCoders.SMALLINT, + SqlTypeCoders.INTEGER, + SqlTypeCoders.BIGINT, + SqlTypeCoders.FLOAT, + SqlTypeCoders.DOUBLE, + SqlTypeCoders.DECIMAL); +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlExample.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlExample.java index 8c6ad98d2375..078ed7e2526f 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlExample.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlExample.java @@ -17,22 +17,20 @@ */ package org.apache.beam.sdk.extensions.sql.example; -import java.sql.Types; -import java.util.Arrays; -import java.util.List; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSql; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.beam.sdk.values.TupleTag; /** @@ -51,50 +49,56 @@ public static void main(String[] args) throws Exception { Pipeline p = Pipeline.create(options); //define the input row format - List fieldNames = Arrays.asList("c1", "c2", "c3"); - List fieldTypes = Arrays.asList(Types.INTEGER, Types.VARCHAR, Types.DOUBLE); - BeamRecordSqlType type = BeamRecordSqlType.create(fieldNames, fieldTypes); - BeamRecord row1 = new BeamRecord(type, 1, "row", 1.0); - BeamRecord row2 = new BeamRecord(type, 2, "row", 2.0); - BeamRecord row3 = new BeamRecord(type, 3, "row", 3.0); + RowType type = RowSqlType + .builder() + .withIntegerField("c1") + .withVarcharField("c2") + .withDoubleField("c3") + .build(); + + Row row1 = Row.withRowType(type).addValues(1, "row", 1.0).build(); + Row row2 = Row.withRowType(type).addValues(2, "row", 2.0).build(); + Row row3 = Row.withRowType(type).addValues(3, "row", 3.0).build(); //create a source PCollection with Create.of(); - PCollection inputTable = PBegin.in(p).apply(Create.of(row1, row2, row3) - .withCoder(type.getRecordCoder())); + PCollection inputTable = PBegin.in(p).apply(Create.of(row1, row2, row3) + .withCoder(type.getRowCoder())); //Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery; - PCollection outputStream = inputTable.apply( + PCollection outputStream = inputTable.apply( BeamSql.query("select c1, c2, c3 from PCOLLECTION where c1 > 1")); // print the output record of case 1; outputStream.apply( "log_result", MapElements.via( - new SimpleFunction() { - public @Nullable Void apply(BeamRecord input) { + new SimpleFunction() { + public @Nullable + Void apply(Row input) { // expect output: // PCOLLECTION: [3, row, 3.0] // PCOLLECTION: [2, row, 2.0] - System.out.println("PCOLLECTION: " + input.getDataValues()); + System.out.println("PCOLLECTION: " + input.getValues()); return null; } })); // Case 2. run the query with BeamSql.query over result PCollection of case 1. - PCollection outputStream2 = + PCollection outputStream2 = PCollectionTuple.of(new TupleTag<>("CASE1_RESULT"), outputStream) - .apply(BeamSql.queryMulti("select c2, sum(c3) from CASE1_RESULT group by c2")); + .apply(BeamSql.query("select c2, sum(c3) from CASE1_RESULT group by c2")); // print the output record of case 2; outputStream2.apply( "log_result", MapElements.via( - new SimpleFunction() { + new SimpleFunction() { @Override - public @Nullable Void apply(BeamRecord input) { + public @Nullable + Void apply(Row input) { // expect output: // CASE1_RESULT: [row, 5.0] - System.out.println("CASE1_RESULT: " + input.getDataValues()); + System.out.println("CASE1_RESULT: " + input.getValues()); return null; } })); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/BeamSqlEnv.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/BeamSqlEnv.java index 405bedffe3f8..87548bdc38f3 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/BeamSqlEnv.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/BeamSqlEnv.java @@ -17,8 +17,12 @@ */ package org.apache.beam.sdk.extensions.sql.impl; +import java.io.IOException; +import java.io.ObjectInputStream; import java.io.Serializable; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.BeamSql; import org.apache.beam.sdk.extensions.sql.BeamSqlCli; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; @@ -26,10 +30,17 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; +import org.apache.beam.sdk.extensions.sql.impl.schema.BeamPCollectionTable; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.apache.beam.sdk.values.TupleTag; import org.apache.calcite.DataContext; +import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.linq4j.Enumerable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -39,6 +50,8 @@ import org.apache.calcite.schema.Statistic; import org.apache.calcite.schema.Statistics; import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.tools.Frameworks; /** @@ -48,20 +61,29 @@ *

It contains a {@link SchemaPlus} which holds the metadata of tables/UDF functions, * and a {@link BeamQueryPlanner} which parse/validate/optimize/translate input SQL queries. */ -public class BeamSqlEnv implements Serializable{ +public class BeamSqlEnv implements Serializable { transient SchemaPlus schema; transient BeamQueryPlanner planner; + transient Map tables; public BeamSqlEnv() { + tables = new HashMap<>(16); schema = Frameworks.createRootSchema(true); planner = new BeamQueryPlanner(schema); } + /** + * Register a UDF function which can be used in SQL expression. + */ + public void registerUdf(String functionName, Class clazz, String method) { + schema.add(functionName, ScalarFunctionImpl.create(clazz, method)); + } + /** * Register a UDF function which can be used in SQL expression. */ public void registerUdf(String functionName, Class clazz) { - schema.add(functionName, ScalarFunctionImpl.create(clazz, BeamSqlUdf.UDF_METHOD)); + registerUdf(functionName, clazz, BeamSqlUdf.UDF_METHOD); } /** @@ -69,7 +91,7 @@ public void registerUdf(String functionName, Class clazz) * Note, {@link SerializableFunction} must have a constructor without arguments. */ public void registerUdf(String functionName, SerializableFunction sfn) { - schema.add(functionName, ScalarFunctionImpl.create(sfn.getClass(), "apply")); + registerUdf(functionName, sfn.getClass(), "apply"); } /** @@ -81,14 +103,57 @@ public void registerUdaf(String functionName, Combine.CombineFn combineFn) { } /** - * Registers a {@link BaseBeamTable} which can be used for all subsequent queries. + * Registers {@link PCollection}s in {@link PCollectionTuple} as a tables. + * + *

Assumes that {@link PCollection} elements are {@link Row}s. * + *

{@link TupleTag#getId()}s are used as table names. + */ + public void registerPCollectionTuple(PCollectionTuple pCollectionTuple) { + pCollectionTuple + .getAll() + .forEach((tag, pCollection) -> + registerPCollection(tag.getId(), (PCollection) pCollection)); + } + + /** + * Registers {@link PCollection} of {@link Row}s as a table. + * + *

Assumes that {@link PCollection#getCoder()} returns an instance of {@link RowCoder}. + */ + public void registerPCollection(String name, PCollection pCollection) { + registerTable(name, pCollection, ((RowCoder) pCollection.getCoder()).getRowType()); + } + + /** + * Registers {@link PCollection} as a table. + */ + public void registerTable(String tableName, PCollection pCollection, RowType rowType) { + registerTable(tableName, new BeamPCollectionTable(pCollection, rowType)); + } + + /** + * Registers a {@link BaseBeamTable} which can be used for all subsequent queries. */ public void registerTable(String tableName, BeamSqlTable table) { + tables.put(tableName, table); schema.add(tableName, new BeamCalciteTable(table.getRowType())); planner.getSourceTables().put(tableName, table); } + public void deregisterTable(String targetTableName) { + // reconstruct the schema + schema = Frameworks.createRootSchema(true); + for (Map.Entry entry : tables.entrySet()) { + String tableName = entry.getKey(); + BeamSqlTable table = entry.getValue(); + if (!tableName.equals(targetTableName)) { + schema.add(tableName, new BeamCalciteTable(table.getRowType())); + } + } + planner = new BeamQueryPlanner(schema); + } + /** * Find {@link BaseBeamTable} by table name. */ @@ -97,13 +162,15 @@ public BeamSqlTable findTable(String tableName){ } private static class BeamCalciteTable implements ScannableTable, Serializable { - private BeamRecordSqlType beamSqlRowType; - public BeamCalciteTable(BeamRecordSqlType beamSqlRowType) { - this.beamSqlRowType = beamSqlRowType; + private RowType beamRowType; + + public BeamCalciteTable(RowType beamRowType) { + this.beamRowType = beamRowType; } + @Override public RelDataType getRowType(RelDataTypeFactory typeFactory) { - return CalciteUtils.toCalciteRowType(this.beamSqlRowType) + return CalciteUtils.toCalciteRowType(this.beamRowType) .apply(BeamQueryPlanner.TYPE_FACTORY); } @@ -128,9 +195,27 @@ public Statistic getStatistic() { public Schema.TableType getJdbcTableType() { return Schema.TableType.TABLE; } + + @Override public boolean isRolledUp(String column) { + return false; + } + + @Override public boolean rolledUpColumnValidInsideAgg(String column, + SqlCall call, SqlNode parent, + CalciteConnectionConfig config) { + return false; + } } public BeamQueryPlanner getPlanner() { return planner; } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + + tables = new HashMap(16); + schema = Frameworks.createRootSchema(true); + planner = new BeamQueryPlanner(schema); + } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlExpressionExecutor.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlExpressionExecutor.java index 3aaf50506acf..19928bb80101 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlExpressionExecutor.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlExpressionExecutor.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.List; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** * {@code BeamSqlExpressionExecutor} fills the gap between relational @@ -35,10 +35,10 @@ public interface BeamSqlExpressionExecutor extends Serializable { void prepare(); /** - * apply transformation to input record {@link BeamRecord} with {@link BoundedWindow}. + * apply transformation to input record {@link Row} with {@link BoundedWindow}. * */ - List execute(BeamRecord inputRow, BoundedWindow window); + List execute(Row inputRow, BoundedWindow window); void close(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutor.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutor.java index 31d5022dd372..ae65c2b0c256 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutor.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutor.java @@ -21,7 +21,6 @@ import java.util.ArrayList; import java.util.Calendar; import java.util.List; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlCaseExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlCastExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; @@ -93,7 +92,7 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.BeamProjectRel; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; @@ -107,7 +106,7 @@ /** * Executor based on {@link BeamSqlExpression} and {@link BeamSqlPrimitive}. * {@code BeamSqlFnExecutor} converts a {@link BeamRelNode} to a {@link BeamSqlExpression}, - * which can be evaluated against the {@link BeamRecord}. + * which can be evaluated against the {@link Row}. * */ public class BeamSqlFnExecutor implements BeamSqlExpressionExecutor { @@ -448,7 +447,7 @@ public void prepare() { } @Override - public List execute(BeamRecord inputRow, BoundedWindow window) { + public List execute(Row inputRow, BoundedWindow window) { List results = new ArrayList<>(); for (BeamSqlExpression exp : exps) { results.add(exp.evaluate(inputRow, window).getValue()); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpression.java index c7eb1568e567..e3c6190d45d0 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpression.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -49,7 +49,7 @@ public BeamSqlCaseExpression(List operands) { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { for (int i = 0; i < operands.size() - 1; i += 2) { Boolean wasOpEvaluated = opValueEvaluated(i, inputRow, window); if (wasOpEvaluated != null && wasOpEvaluated) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpression.java index 5ca7a6940e50..19cd04d3372b 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpression.java @@ -22,7 +22,7 @@ import java.sql.Timestamp; import java.util.List; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.runtime.SqlFunctions; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.format.DateTimeFormat; @@ -72,7 +72,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { SqlTypeName castOutputType = getOutputType(); switch (castOutputType) { case INTEGER: diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlExpression.java index d18b1415fac4..6686e1a6cd8b 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlExpression.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.List; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.type.SqlTypeName; @@ -50,7 +50,7 @@ public SqlTypeName opType(int idx) { return op(idx).getOutputType(); } - public T opValueEvaluated(int idx, BeamRecord row, BoundedWindow window) { + public T opValueEvaluated(int idx, Row row, BoundedWindow window) { return (T) op(idx).evaluate(row, window).getValue(); } @@ -60,10 +60,10 @@ public T opValueEvaluated(int idx, BeamRecord row, BoundedWindow window) { public abstract boolean accept(); /** - * Apply input record {@link BeamRecord} with {@link BoundedWindow} to this expression, + * Apply input record {@link Row} with {@link BoundedWindow} to this expression, * the output value is wrapped with {@link BeamSqlPrimitive}. */ - public abstract BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window); + public abstract BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window); public List getOperands() { return operands; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java index 2c321f7f5201..80c694a26bbc 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java @@ -18,7 +18,7 @@ package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -38,8 +38,8 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { - return BeamSqlPrimitive.of(outputType, inputRow.getFieldValue(inputRef)); + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { + return BeamSqlPrimitive.of(outputType, inputRow.getValue(inputRef)); } public int getInputRef() { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitive.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitive.java index 21cbc809afc5..96b56fc1da9c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitive.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitive.java @@ -21,15 +21,14 @@ import java.util.Date; import java.util.GregorianCalendar; import java.util.List; - import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.NlsString; /** * {@link BeamSqlPrimitive} is a special, self-reference {@link BeamSqlExpression}. - * It holds the value, and return it directly during {@link #evaluate(BeamRecord, BoundedWindow)}. + * It holds the value, and return it directly during {@link #evaluate(Row, BoundedWindow)}. * */ public class BeamSqlPrimitive extends BeamSqlExpression { @@ -150,7 +149,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { return this; } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpression.java index 625de2c0e7f9..3b421710062c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpression.java @@ -22,7 +22,7 @@ import java.util.ArrayList; import java.util.List; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -54,7 +54,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { if (method == null) { reConstructMethod(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowEndExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowEndExpression.java index 919612eab546..d5543c681dcb 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowEndExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowEndExpression.java @@ -20,7 +20,7 @@ import java.util.Date; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -36,7 +36,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { if (window instanceof IntervalWindow) { return BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, ((IntervalWindow) window).end().toDate()); } else { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowExpression.java index 0298f2671bf1..5fd518d1d0f4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowExpression.java @@ -20,7 +20,7 @@ import java.util.Date; import java.util.List; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -43,7 +43,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { return BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, (Date) operands.get(0).evaluate(inputRow, window).getValue()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowStartExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowStartExpression.java index 4b250a5a1d41..905a6e4df2d7 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowStartExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlWindowStartExpression.java @@ -20,7 +20,7 @@ import java.util.Date; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -37,7 +37,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { if (window instanceof IntervalWindow) { return BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, ((IntervalWindow) window).start().toDate()); } else { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpression.java index cc15ff5c26d1..1d56be9189d4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpression.java @@ -24,7 +24,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -51,7 +51,7 @@ protected BeamSqlArithmeticExpression(List operands, SqlTypeN super(operands, outputType); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { BigDecimal left = BigDecimal.valueOf( Double.valueOf(opValueEvaluated(0, inputRow, window).toString())); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlCompareExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlCompareExpression.java index df8bd618dbb5..8a24235b70b5 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlCompareExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlCompareExpression.java @@ -21,7 +21,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -52,7 +52,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { Object leftValue = operands.get(0).evaluate(inputRow, window).getValue(); Object rightValue = operands.get(1).evaluate(inputRow, window).getValue(); switch (operands.get(0).getOutputType()) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNotNullExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNotNullExpression.java index 9a9739eb8ec6..43b9af11a80d 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNotNullExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNotNullExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -47,7 +47,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { Object leftValue = operands.get(0).evaluate(inputRow, window).getValue(); return BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, leftValue != null); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNullExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNullExpression.java index 6034344fd1c8..93b029ee049f 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNullExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/comparison/BeamSqlIsNullExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -47,7 +47,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { Object leftValue = operands.get(0).evaluate(inputRow, window).getValue(); return BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, leftValue == null); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpression.java index fe2dd98ee9f4..1b08dc02f4e4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpression.java @@ -23,7 +23,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -35,11 +35,12 @@ public class BeamSqlCurrentDateExpression extends BeamSqlExpression { public BeamSqlCurrentDateExpression() { super(Collections.emptyList(), SqlTypeName.DATE); } + @Override public boolean accept() { return getOperands().size() == 0; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { return BeamSqlPrimitive.of(outputType, new Date()); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpression.java index fe3feb895115..5824b876dcc6 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpression.java @@ -25,7 +25,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -40,12 +40,13 @@ public class BeamSqlCurrentTimeExpression extends BeamSqlExpression { public BeamSqlCurrentTimeExpression(List operands) { super(operands, SqlTypeName.TIME); } + @Override public boolean accept() { int opCount = getOperands().size(); return opCount <= 1; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { GregorianCalendar ret = new GregorianCalendar(TimeZone.getDefault()); ret.setTime(new Date()); return BeamSqlPrimitive.of(outputType, ret); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpression.java index ca4b3ced628b..13d0fd813d92 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpression.java @@ -23,7 +23,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -38,12 +38,13 @@ public class BeamSqlCurrentTimestampExpression extends BeamSqlExpression { public BeamSqlCurrentTimestampExpression(List operands) { super(operands, SqlTypeName.TIMESTAMP); } + @Override public boolean accept() { int opCount = getOperands().size(); return opCount <= 1; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { return BeamSqlPrimitive.of(outputType, new Date()); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpression.java index 0e1d3db17c78..c0acb20e30f8 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpression.java @@ -23,7 +23,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.DateTimeUtils; import org.apache.calcite.avatica.util.TimeUnitRange; import org.apache.calcite.sql.type.SqlTypeName; @@ -37,12 +37,13 @@ public class BeamSqlDateCeilExpression extends BeamSqlExpression { public BeamSqlDateCeilExpression(List operands) { super(operands, SqlTypeName.TIMESTAMP); } + @Override public boolean accept() { return operands.size() == 2 && opType(1) == SqlTypeName.SYMBOL; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { Date date = opValueEvaluated(0, inputRow, window); long time = date.getTime(); TimeUnitRange unit = ((BeamSqlPrimitive) op(1)).getValue(); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpression.java index 2593629bf95a..d142a56833d4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpression.java @@ -23,7 +23,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.DateTimeUtils; import org.apache.calcite.avatica.util.TimeUnitRange; import org.apache.calcite.sql.type.SqlTypeName; @@ -37,12 +37,13 @@ public class BeamSqlDateFloorExpression extends BeamSqlExpression { public BeamSqlDateFloorExpression(List operands) { super(operands, SqlTypeName.DATE); } + @Override public boolean accept() { return operands.size() == 2 && opType(1) == SqlTypeName.SYMBOL; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { Date date = opValueEvaluated(0, inputRow, window); long time = date.getTime(); TimeUnitRange unit = ((BeamSqlPrimitive) op(1)).getValue(); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpression.java index 6948ba188d55..7920ef9eb9a4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpression.java @@ -19,14 +19,12 @@ package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.date; import com.google.common.collect.ImmutableMap; - import java.util.List; import java.util.Map; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DurationFieldType; @@ -96,7 +94,7 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { if (delegateExpression == null) { throw new IllegalStateException("Unable to execute unsupported 'datetime minus' expression"); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpression.java index 426cda006e4e..bc76dc46371f 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpression.java @@ -22,15 +22,13 @@ import static org.apache.beam.sdk.extensions.sql.impl.utils.SqlTypeUtils.findExpressionOfType; import com.google.common.collect.ImmutableSet; - import java.math.BigDecimal; import java.util.List; import java.util.Set; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; @@ -76,7 +74,7 @@ public boolean accept() { * the same way. */ @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { DateTime timestamp = getTimestampOperand(inputRow, window); BeamSqlPrimitive intervalOperandPrimitive = getIntervalOperand(inputRow, window); SqlTypeName intervalOperandType = intervalOperandPrimitive.getOutputType(); @@ -94,12 +92,12 @@ private int getIntervalMultiplier(BeamSqlPrimitive intervalOperandPrimitive) { return multiplier.intValueExact(); } - private BeamSqlPrimitive getIntervalOperand(BeamRecord inputRow, BoundedWindow window) { + private BeamSqlPrimitive getIntervalOperand(Row inputRow, BoundedWindow window) { return findExpressionOfType(operands, SUPPORTED_INTERVAL_TYPES).get() .evaluate(inputRow, window); } - private DateTime getTimestampOperand(BeamRecord inputRow, BoundedWindow window) { + private DateTime getTimestampOperand(Row inputRow, BoundedWindow window) { BeamSqlPrimitive timestampOperandPrimitive = findExpressionOfType(operands, SqlTypeName.TIMESTAMP).get().evaluate(inputRow, window); return new DateTime(timestampOperandPrimitive.getDate()); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpression.java index 38afd0ad1f3a..28a292101b47 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpression.java @@ -18,15 +18,12 @@ package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.date; -import java.util.Calendar; import java.util.Date; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.DateTimeUtils; import org.apache.calcite.avatica.util.TimeUnitRange; import org.apache.calcite.sql.type.SqlTypeName; @@ -44,55 +41,52 @@ *
  • DAYOFYEAR(date) => EXTRACT(DOY FROM date)
  • *
  • DAYOFMONTH(date) => EXTRACT(DAY FROM date)
  • *
  • DAYOFWEEK(date) => EXTRACT(DOW FROM date)
  • + *
  • HOUR(date) => EXTRACT(HOUR FROM date)
  • + *
  • MINUTE(date) => EXTRACT(MINUTE FROM date)
  • + *
  • SECOND(date) => EXTRACT(SECOND FROM date)
  • * */ public class BeamSqlExtractExpression extends BeamSqlExpression { - private static final Map typeMapping = new HashMap<>(); - static { - typeMapping.put(TimeUnitRange.DOW, Calendar.DAY_OF_WEEK); - typeMapping.put(TimeUnitRange.DOY, Calendar.DAY_OF_YEAR); - typeMapping.put(TimeUnitRange.WEEK, Calendar.WEEK_OF_YEAR); - } - public BeamSqlExtractExpression(List operands) { super(operands, SqlTypeName.BIGINT); } + @Override public boolean accept() { return operands.size() == 2 - && opType(1) == SqlTypeName.BIGINT; + && opType(1) == SqlTypeName.TIMESTAMP; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { - Long time = opValueEvaluated(1, inputRow, window); + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { + Date time = opValueEvaluated(1, inputRow, window); TimeUnitRange unit = ((BeamSqlPrimitive) op(0)).getValue(); switch (unit) { case YEAR: + case QUARTER: case MONTH: case DAY: - Long timeByDay = time / 1000 / 3600 / 24; + case DOW: + case WEEK: + case DOY: + case CENTURY: + case MILLENNIUM: + Long timeByDay = time.getTime() / DateTimeUtils.MILLIS_PER_DAY; Long extracted = DateTimeUtils.unixDateExtract( unit, timeByDay ); return BeamSqlPrimitive.of(outputType, extracted); - case DOY: - case DOW: - case WEEK: - Calendar calendar = Calendar.getInstance(); - calendar.setTime(new Date(time)); - return BeamSqlPrimitive.of(outputType, (long) calendar.get(typeMapping.get(unit))); - - case QUARTER: - calendar = Calendar.getInstance(); - calendar.setTime(new Date(time)); - long ret = calendar.get(Calendar.MONTH) / 3; - if (ret * 3 < calendar.get(Calendar.MONTH)) { - ret += 1; - } - return BeamSqlPrimitive.of(outputType, ret); + case HOUR: + case MINUTE: + case SECOND: + int timeInDay = (int) (time.getTime() % DateTimeUtils.MILLIS_PER_DAY); + extracted = (long) DateTimeUtils.unixTimeExtract( + unit, + timeInDay + ); + return BeamSqlPrimitive.of(outputType, extracted); default: throw new UnsupportedOperationException( diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpression.java index f4ddf710e9d3..166fe9839fcf 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpression.java @@ -22,14 +22,12 @@ import static org.apache.beam.sdk.extensions.sql.impl.utils.SqlTypeUtils.findExpressionOfType; import com.google.common.base.Optional; - import java.math.BigDecimal; import java.util.List; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -85,7 +83,7 @@ && findExpressionOfType(operands, SqlTypeName.INTEGER).isPresent() * "TIMESTAMPADD(YEAR, 2, TIMESTAMP '1984-04-19 01:02:03')" */ @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { BeamSqlPrimitive intervalOperandPrimitive = findExpressionOfType(operands, SqlTypeName.INTERVAL_TYPES).get().evaluate(inputRow, window); SqlTypeName intervalOperandType = intervalOperandPrimitive.getOutputType(); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpression.java index 43b2d5ae22d1..4d812a9d9d9b 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpression.java @@ -23,11 +23,10 @@ import java.math.BigDecimal; import java.util.Date; import java.util.List; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; import org.joda.time.DurationFieldType; @@ -58,7 +57,7 @@ static boolean accept(List operands, SqlTypeName outputType) } @Override - public BeamSqlPrimitive evaluate(BeamRecord row, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row row, BoundedWindow window) { DateTime date = new DateTime((Object) opValueEvaluated(0, row, window)); Period period = intervalToPeriod(op(1).evaluate(row, window)); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpression.java index bcdfa925e915..eb0b59e18fbe 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpression.java @@ -21,11 +21,10 @@ import static org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.date.BeamSqlDatetimeMinusExpression.INTERVALS_DURATIONS_TYPES; import java.util.List; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; import org.joda.time.DurationFieldType; @@ -70,7 +69,7 @@ static boolean accept(List operands, SqlTypeName intervalType * Calcite deals with all intervals this way. Whenever there is an interval, its value is always * multiplied by the corresponding TimeUnit.multiplier */ - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { DateTime timestampStart = new DateTime((Object) opValueEvaluated(1, inputRow, window)); DateTime timestampEnd = new DateTime((Object) opValueEvaluated(0, inputRow, window)); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtils.java index b432d2022988..6874befe3aa6 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtils.java @@ -19,7 +19,6 @@ package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.date; import java.math.BigDecimal; - import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.sql.type.SqlTypeName; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlAndExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlAndExpression.java index 2cae22bc9e97..5045cbb9d9d6 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlAndExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlAndExpression.java @@ -21,7 +21,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -33,7 +33,7 @@ public BeamSqlAndExpression(List operands) { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { boolean result = true; for (BeamSqlExpression exp : operands) { BeamSqlPrimitive expOut = exp.evaluate(inputRow, window); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlLogicalExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlLogicalExpression.java index 5691e3336ea1..9904ff8f2afa 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlLogicalExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlLogicalExpression.java @@ -29,6 +29,7 @@ public abstract class BeamSqlLogicalExpression extends BeamSqlExpression { private BeamSqlLogicalExpression(List operands, SqlTypeName outputType) { super(operands, outputType); } + public BeamSqlLogicalExpression(List operands) { this(operands, SqlTypeName.BOOLEAN); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpression.java index 72a698281405..6f533fb3deee 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -43,7 +43,7 @@ public boolean accept() { return super.accept(); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { Boolean value = opValueEvaluated(0, inputRow, window); if (value == null) { return BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, window); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlOrExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlOrExpression.java index 74dde7a405b3..bfb0aa66ede0 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlOrExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlOrExpression.java @@ -21,7 +21,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -33,7 +33,7 @@ public BeamSqlOrExpression(List operands) { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { boolean result = false; for (BeamSqlExpression exp : operands) { BeamSqlPrimitive expOut = exp.evaluate(inputRow, window); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpression.java index ed0aac018baf..df042ca23abf 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -39,7 +39,7 @@ public BeamSqlMathBinaryExpression(List operands, SqlTypeName return numberOfOperands() == 2 && isOperandNumeric(opType(0)) && isOperandNumeric(opType(1)); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { BeamSqlExpression leftOp = op(0); BeamSqlExpression rightOp = op(1); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpression.java index b1a210ed352e..e002ac008378 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; @@ -46,7 +46,7 @@ public BeamSqlMathUnaryExpression(List operands, SqlTypeName return acceptance; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { BeamSqlExpression operand = op(0); return calculate(operand.evaluate(inputRow, window)); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlPiExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlPiExpression.java index 3072ea0267a7..a9d5a357d6d6 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlPiExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlPiExpression.java @@ -21,7 +21,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -37,7 +37,7 @@ public BeamSqlPiExpression() { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { return BeamSqlPrimitive.of(SqlTypeName.DOUBLE, Math.PI); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandExpression.java index 00f2693ec6ae..918cccc5ff18 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandExpression.java @@ -23,7 +23,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -43,9 +43,9 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRecord, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { if (operands.size() == 1) { - int rowSeed = opValueEvaluated(0, inputRecord, window); + int rowSeed = opValueEvaluated(0, inputRow, window); if (seed == null || seed != rowSeed) { rand.setSeed(rowSeed); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandIntegerExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandIntegerExpression.java index d055de66acfa..dabb9a1f31d1 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandIntegerExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlRandIntegerExpression.java @@ -23,7 +23,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -44,16 +44,16 @@ public boolean accept() { } @Override - public BeamSqlPrimitive evaluate(BeamRecord inputRecord, BoundedWindow window) { + public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { int numericIdx = 0; if (operands.size() == 2) { - int rowSeed = opValueEvaluated(0, inputRecord, window); + int rowSeed = opValueEvaluated(0, inputRow, window); if (seed == null || seed != rowSeed) { rand.setSeed(rowSeed); } numericIdx = 1; } return BeamSqlPrimitive.of(SqlTypeName.INTEGER, - rand.nextInt((int) opValueEvaluated(numericIdx, inputRecord, window))); + rand.nextInt((int) opValueEvaluated(numericIdx, inputRow, window))); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/BeamSqlReinterpretExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/BeamSqlReinterpretExpression.java index b22fd09aa2fb..0d94fa30be4b 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/BeamSqlReinterpretExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/BeamSqlReinterpretExpression.java @@ -19,11 +19,10 @@ package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.reinterpret; import java.util.List; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -50,7 +49,7 @@ public BeamSqlReinterpretExpression(List operands, SqlTypeNam && REINTERPRETER.canConvert(opType(0), SqlTypeName.BIGINT); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { return REINTERPRETER.convert( SqlTypeName.BIGINT, operands.get(0).evaluate(inputRow, window)); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversion.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversion.java index df2996272b8a..b8a681516d85 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversion.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversion.java @@ -20,12 +20,10 @@ import com.google.common.base.Function; import com.google.common.collect.ImmutableSet; - import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.Set; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.calcite.sql.type.SqlTypeName; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpression.java index 5146b14b5db1..91828239e332 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -33,7 +33,7 @@ public BeamSqlCharLengthExpression(List operands) { super(operands, SqlTypeName.INTEGER); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String str = opValueEvaluated(0, inputRow, window); return BeamSqlPrimitive.of(SqlTypeName.INTEGER, str.length()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpression.java index c2f317fa043f..ffa7ad2409e1 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -52,7 +52,7 @@ public BeamSqlConcatExpression(List operands) { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String left = opValueEvaluated(0, inputRow, window); String right = opValueEvaluated(1, inputRow, window); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpression.java index bf0b8f572f81..2468dec99cb1 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -33,7 +33,7 @@ public BeamSqlInitCapExpression(List operands) { super(operands, SqlTypeName.VARCHAR); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String str = opValueEvaluated(0, inputRow, window); StringBuilder ret = new StringBuilder(str); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpression.java index 55f8d6de7865..3c2d139db80c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -33,7 +33,7 @@ public BeamSqlLowerExpression(List operands) { super(operands, SqlTypeName.VARCHAR); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String str = opValueEvaluated(0, inputRow, window); return BeamSqlPrimitive.of(SqlTypeName.VARCHAR, str.toLowerCase()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpression.java index 62d5a64fbeb6..63128f973e69 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -55,7 +55,7 @@ public BeamSqlOverlayExpression(List operands) { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String str = opValueEvaluated(0, inputRow, window); String replaceStr = opValueEvaluated(1, inputRow, window); int idx = opValueEvaluated(2, inputRow, window); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpression.java index f97547eb6f42..2b0452c033e3 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -57,7 +57,7 @@ public BeamSqlPositionExpression(List operands) { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String targetStr = opValueEvaluated(0, inputRow, window); String containingStr = opValueEvaluated(1, inputRow, window); int from = -1; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpression.java index a521ef089750..b2194b2ad2d2 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -55,7 +55,7 @@ public BeamSqlSubstringExpression(List operands) { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String str = opValueEvaluated(0, inputRow, window); int idx = opValueEvaluated(1, inputRow, window); int startIdx = idx; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpression.java index 3c3083c0008e..2244dd4a7f94 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.fun.SqlTrimFunction; import org.apache.calcite.sql.type.SqlTypeName; @@ -59,7 +59,7 @@ public BeamSqlTrimExpression(List operands) { return true; } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { if (operands.size() == 1) { return BeamSqlPrimitive.of(SqlTypeName.VARCHAR, opValueEvaluated(0, inputRow, window).toString().trim()); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpression.java index bc29ec841cfe..b38ff9661c97 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpression.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpression.java @@ -22,7 +22,7 @@ import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; /** @@ -33,7 +33,7 @@ public BeamSqlUpperExpression(List operands) { super(operands, SqlTypeName.VARCHAR); } - @Override public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) { + @Override public BeamSqlPrimitive evaluate(Row inputRow, BoundedWindow window) { String str = opValueEvaluated(0, inputRow, window); return BeamSqlPrimitive.of(SqlTypeName.VARCHAR, str.toUpperCase()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/ParserUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/ParserUtils.java index dae82a666a31..09796096468c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/ParserUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/ParserUtils.java @@ -39,8 +39,8 @@ public static Table convertCreateTableStmtToTable(SqlCreateTable stmt) { for (ColumnDefinition columnDef : stmt.fieldList()) { Column column = Column.builder() .name(columnDef.name().toLowerCase()) - .type( - CalciteUtils.toJavaType( + .coder( + CalciteUtils.toCoder( columnDef.type().deriveType(BeamQueryPlanner.TYPE_FACTORY).getSqlTypeName() ) ) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/SqlDropTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/SqlDropTable.java new file mode 100644 index 000000000000..6f703c98fd6a --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/parser/SqlDropTable.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.impl.parser; + +import java.util.List; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.util.ImmutableNullableList; + +/** + * A Calcite {@code SqlCall} which represents a drop table statement. + */ +public class SqlDropTable extends SqlCall { + private final SqlIdentifier tableName; + + public static final SqlSpecialOperator OPERATOR = new SqlSpecialOperator( + "DROP_TABLE", SqlKind.OTHER) { + @Override + public SqlCall createCall( + SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... o) { + assert functionQualifier == null; + return new SqlDropTable(pos, (SqlIdentifier) o[0]); + } + + @Override + public void unparse( + SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + SqlDropTable t = (SqlDropTable) call; + UnparseUtil u = new UnparseUtil(writer, leftPrec, rightPrec); + u.keyword("DROP", "TABLE").node(t.tableName); + } + }; + + public SqlDropTable(SqlParserPos pos, SqlIdentifier tableName) { + super(pos); + this.tableName = tableName; + } + + @Override + public SqlOperator getOperator() { + return OPERATOR; + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + getOperator().unparse(writer, this, leftPrec, rightPrec); + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(tableName); + } + + public String tableName() { + return tableName.toString().toLowerCase(); + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamQueryPlanner.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamQueryPlanner.java index ebfeffa860d3..b9b0fdbd9692 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamQueryPlanner.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/planner/BeamQueryPlanner.java @@ -28,9 +28,9 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.config.Lex; import org.apache.calcite.jdbc.CalciteSchema; @@ -92,7 +92,7 @@ public BeamQueryPlanner(SchemaPlus schema) { sqlOperatorTables.add(SqlStdOperatorTable.instance()); sqlOperatorTables.add( new CalciteCatalogReader( - CalciteSchema.from(schema), false, Collections.emptyList(), TYPE_FACTORY)); + CalciteSchema.from(schema), Collections.emptyList(), TYPE_FACTORY, null)); FrameworkConfig config = Frameworks.newConfigBuilder() .parserConfig(SqlParser.configBuilder().setLex(Lex.MYSQL).build()).defaultSchema(schema) @@ -119,7 +119,7 @@ public SqlNode parseQuery(String sqlQuery) throws SqlParseException{ * which is linked with the given {@code pipeline}. The final output stream is returned as * {@code PCollection} so more operations can be applied. */ - public PCollection compileBeamPipeline(String sqlStatement, Pipeline basePipeline + public PCollection compileBeamPipeline(String sqlStatement, Pipeline basePipeline , BeamSqlEnv sqlEnv) throws Exception { BeamRelNode relNode = convertToBeamRel(sqlStatement); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java index 092ef2b56f85..f1fb12dec2b5 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java @@ -17,26 +17,31 @@ */ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.util.ArrayList; +import static org.apache.beam.sdk.values.PCollection.IsBounded.BOUNDED; +import static org.apache.beam.sdk.values.RowType.toRowType; + import java.util.List; -import org.apache.beam.sdk.coders.BeamRecordCoder; +import java.util.Optional; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; +import org.apache.beam.sdk.extensions.sql.impl.rule.AggregateWindowField; import org.apache.beam.sdk.extensions.sql.impl.transform.BeamAggregationTransforms; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.WithTimestamps; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.Trigger; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.calcite.linq4j.Ord; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; @@ -51,58 +56,61 @@ /** * {@link BeamRelNode} to replace a {@link Aggregate} node. - * */ public class BeamAggregationRel extends Aggregate implements BeamRelNode { - private int windowFieldIdx = -1; - private WindowFn windowFn; - private Trigger trigger; - private Duration allowedLatence = Duration.ZERO; - - public BeamAggregationRel(RelOptCluster cluster, RelTraitSet traits - , RelNode child, boolean indicator, - ImmutableBitSet groupSet, List groupSets, List aggCalls - , WindowFn windowFn, Trigger trigger, int windowFieldIdx, Duration allowedLatence) { + private final int windowFieldIndex; + private Optional windowField; + + public BeamAggregationRel( + RelOptCluster cluster, + RelTraitSet traits, + RelNode child, + boolean indicator, + ImmutableBitSet groupSet, + List groupSets, + List aggCalls, + Optional windowField) { + super(cluster, traits, child, indicator, groupSet, groupSets, aggCalls); - this.windowFn = windowFn; - this.trigger = trigger; - this.windowFieldIdx = windowFieldIdx; - this.allowedLatence = allowedLatence; + this.windowField = windowField; + this.windowFieldIndex = windowField.map(AggregateWindowField::fieldIndex).orElse(-1); } @Override - public PCollection buildBeamPipeline(PCollectionTuple inputPCollections - , BeamSqlEnv sqlEnv) throws Exception { + public PCollection buildBeamPipeline( + PCollectionTuple inputPCollections, + BeamSqlEnv sqlEnv) throws Exception { + RelNode input = getInput(); String stageName = BeamSqlRelUtils.getStageName(this) + "_"; - PCollection upstream = + PCollection upstream = BeamSqlRelUtils.getBeamRelInput(input).buildBeamPipeline(inputPCollections, sqlEnv); - if (windowFieldIdx != -1) { + if (windowField.isPresent()) { upstream = upstream.apply(stageName + "assignEventTimestamp", WithTimestamps - .of(new BeamAggregationTransforms.WindowTimestampFn(windowFieldIdx)) + .of(new BeamAggregationTransforms.WindowTimestampFn(windowFieldIndex)) .withAllowedTimestampSkew(new Duration(Long.MAX_VALUE))) .setCoder(upstream.getCoder()); } - PCollection windowStream = upstream.apply(stageName + "window", - Window.into(windowFn) - .triggering(trigger) - .withAllowedLateness(allowedLatence) - .accumulatingFiredPanes()); + PCollection windowedStream = + windowField.isPresent() + ? upstream.apply(stageName + "window", Window.into(windowField.get().windowFn())) + : upstream; + + validateWindowIsSupported(windowedStream); - BeamRecordCoder keyCoder = exKeyFieldsSchema(input.getRowType()).getRecordCoder(); - PCollection> exCombineByStream = windowStream.apply( + RowCoder keyCoder = exKeyFieldsSchema(input.getRowType()).getRowCoder(); + PCollection> exCombineByStream = windowedStream.apply( stageName + "exCombineBy", WithKeys - .of(new BeamAggregationTransforms.AggregationGroupByKeyFn( - windowFieldIdx, groupSet))) + .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(windowFieldIndex, groupSet))) .setCoder(KvCoder.of(keyCoder, upstream.getCoder())); - BeamRecordCoder aggCoder = exAggFieldsSchema().getRecordCoder(); + RowCoder aggCoder = exAggFieldsSchema().getRowCoder(); - PCollection> aggregatedStream = + PCollection> aggregatedStream = exCombineByStream .apply( stageName + "combineBy", @@ -111,66 +119,89 @@ public PCollection buildBeamPipeline(PCollectionTuple inputPCollecti getAggCallList(), CalciteUtils.toBeamRowType(input.getRowType())))) .setCoder(KvCoder.of(keyCoder, aggCoder)); - PCollection mergedStream = aggregatedStream.apply(stageName + "mergeRecord", - ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord( - CalciteUtils.toBeamRowType(getRowType()), getAggCallList(), windowFieldIdx))); - mergedStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + PCollection mergedStream = aggregatedStream.apply( + stageName + "mergeRecord", + ParDo.of( + new BeamAggregationTransforms.MergeAggregationRecord( + CalciteUtils.toBeamRowType(getRowType()), + getAggCallList(), + windowFieldIndex))); + mergedStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); return mergedStream; } + /** - * Type of sub-rowrecord used as Group-By keys. + * Performs the same check as {@link GroupByKey}, provides more context in exception. + * + *

    Verifies that the input PCollection is bounded, or that there is windowing/triggering being + * used. Without this, the watermark (at end of global window) will never be reached. + * + *

    Throws {@link UnsupportedOperationException} if validation fails. */ - private BeamRecordSqlType exKeyFieldsSchema(RelDataType relDataType) { - BeamRecordSqlType inputRowType = CalciteUtils.toBeamRowType(relDataType); - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (int i : groupSet.asList()) { - if (i != windowFieldIdx) { - fieldNames.add(inputRowType.getFieldNameByIndex(i)); - fieldTypes.add(inputRowType.getFieldTypeByIndex(i)); - } + private void validateWindowIsSupported(PCollection upstream) { + WindowingStrategy windowingStrategy = upstream.getWindowingStrategy(); + if (windowingStrategy.getWindowFn() instanceof GlobalWindows + && windowingStrategy.getTrigger() instanceof DefaultTrigger + && upstream.isBounded() != BOUNDED) { + + throw new UnsupportedOperationException( + "Please explicitly specify windowing in SQL query using HOP/TUMBLE/SESSION functions " + + "(default trigger will be used in this case). " + + "Unbounded input with global windowing and default trigger is not supported " + + "in Beam SQL aggregations. " + + "See GroupByKey section in Beam Programming Guide"); } - return BeamRecordSqlType.create(fieldNames, fieldTypes); + } + + /** + * Type of sub-rowrecord used as Group-By keys. + */ + private RowType exKeyFieldsSchema(RelDataType relDataType) { + RowType inputRowType = CalciteUtils.toBeamRowType(relDataType); + return groupSet + .asList() + .stream() + .filter(i -> i != windowFieldIndex) + .map(i -> newRowField(inputRowType, i)) + .collect(toRowType()); + } + + private RowType.Field newRowField(RowType rowType, int i) { + return RowType.newField(rowType.getFieldName(i), rowType.getFieldCoder(i)); } /** * Type of sub-rowrecord, that represents the list of aggregation fields. */ - private BeamRecordSqlType exAggFieldsSchema() { - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (AggregateCall ac : getAggCallList()) { - fieldNames.add(ac.name); - fieldTypes.add(CalciteUtils.toJavaType(ac.type.getSqlTypeName())); - } + private RowType exAggFieldsSchema() { + return + getAggCallList() + .stream() + .map(this::newRowField) + .collect(toRowType()); + } - return BeamRecordSqlType.create(fieldNames, fieldTypes); + private RowType.Field newRowField(AggregateCall aggCall) { + return + RowType + .newField(aggCall.name, CalciteUtils.toCoder(aggCall.type.getSqlTypeName())); } @Override - public Aggregate copy(RelTraitSet traitSet, RelNode input, boolean indicator + public Aggregate copy( + RelTraitSet traitSet, RelNode input, boolean indicator , ImmutableBitSet groupSet, List groupSets, List aggCalls) { return new BeamAggregationRel(getCluster(), traitSet, input, indicator - , groupSet, groupSets, aggCalls, windowFn, trigger, windowFieldIdx, allowedLatence); - } - - public void setWindowFn(WindowFn windowFn) { - this.windowFn = windowFn; - } - - public void setTrigger(Trigger trigger) { - this.trigger = trigger; + , groupSet, groupSets, aggCalls, windowField); } public RelWriter explainTerms(RelWriter pw) { // We skip the "groups" element if it is a singleton of "group". pw.item("group", groupSet) - .itemIf("window", windowFn, windowFn != null) - .itemIf("trigger", trigger, trigger != null) - .itemIf("event_time", windowFieldIdx, windowFieldIdx != -1) + .itemIf("window", windowField.orElse(null), windowField.isPresent()) .itemIf("groups", groupSets, getGroupType() != Group.SIMPLE) .itemIf("indicator", indicator, indicator) .itemIf("aggs", aggCalls, pw.nest()); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamFilterRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamFilterRel.java index 9d36a47ea601..d68463606b0b 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamFilterRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamFilterRel.java @@ -23,9 +23,9 @@ import org.apache.beam.sdk.extensions.sql.impl.transform.BeamSqlFilterFn; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -49,19 +49,21 @@ public Filter copy(RelTraitSet traitSet, RelNode input, RexNode condition) { } @Override - public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { RelNode input = getInput(); String stageName = BeamSqlRelUtils.getStageName(this); - PCollection upstream = + PCollection upstream = BeamSqlRelUtils.getBeamRelInput(input).buildBeamPipeline(inputPCollections, sqlEnv); BeamSqlExpressionExecutor executor = new BeamSqlFnExecutor(this); - PCollection filterStream = upstream.apply(stageName, - ParDo.of(new BeamSqlFilterFn(getRelTypeName(), executor))); - filterStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + PCollection filterStream = upstream + .apply( + stageName, + ParDo.of(new BeamSqlFilterFn(getRelTypeName(), executor))); + filterStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); return filterStream; } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSinkRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSinkRel.java index d3de0fbc0fb3..d38b2acf2860 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSinkRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSinkRel.java @@ -21,9 +21,9 @@ import java.util.List; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitSet; @@ -55,12 +55,12 @@ public RelNode copy(RelTraitSet traitSet, List inputs) { * which is the persisted PCollection. */ @Override - public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { RelNode input = getInput(); String stageName = BeamSqlRelUtils.getStageName(this); - PCollection upstream = + PCollection upstream = BeamSqlRelUtils.getBeamRelInput(input).buildBeamPipeline(inputPCollections, sqlEnv); String sourceName = Joiner.on('.').join(getTable().getQualifiedName()); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSourceRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSourceRel.java index 2d6e46f40305..23f6a4fb9078 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSourceRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIOSourceRel.java @@ -21,9 +21,9 @@ import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptTable; @@ -41,21 +41,21 @@ public BeamIOSourceRel(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable } @Override - public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { String sourceName = Joiner.on('.').join(getTable().getQualifiedName()); - TupleTag sourceTupleTag = new TupleTag<>(sourceName); + TupleTag sourceTupleTag = new TupleTag<>(sourceName); if (inputPCollections.has(sourceTupleTag)) { //choose PCollection from input PCollectionTuple if exists there. - PCollection sourceStream = inputPCollections - .get(new TupleTag(sourceName)); + PCollection sourceStream = inputPCollections + .get(new TupleTag(sourceName)); return sourceStream; } else { //If not, the source PColection is provided with BaseBeamTable.buildIOReader(). BeamSqlTable sourceTable = sqlEnv.findTable(sourceName); return sourceTable.buildIOReader(inputPCollections.getPipeline()) - .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRel.java index 1ffb6366259c..7c28ea7612f4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRel.java @@ -20,9 +20,9 @@ import java.util.List; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -51,7 +51,7 @@ public BeamIntersectRel( return new BeamIntersectRel(getCluster(), traitSet, inputs, all); } - @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { return delegate.buildBeamPipeline(inputPCollections, sqlEnv); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java index d0c141b6b742..89196efe162c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRel.java @@ -18,15 +18,17 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; +import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED; +import static org.apache.beam.sdk.values.RowType.toRowType; +import static org.joda.time.Duration.ZERO; + import com.google.common.base.Joiner; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; @@ -35,13 +37,18 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.IncompatibleWindowException; +import org.apache.beam.sdk.transforms.windowing.Trigger; import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -96,20 +103,23 @@ public BeamJoinRel(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelN joinType); } - @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections, - BeamSqlEnv sqlEnv) + @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections, + BeamSqlEnv sqlEnv) throws Exception { BeamRelNode leftRelNode = BeamSqlRelUtils.getBeamRelInput(left); - BeamRecordSqlType leftRowType = CalciteUtils.toBeamRowType(left.getRowType()); + RowType leftRowType = CalciteUtils.toBeamRowType(left.getRowType()); final BeamRelNode rightRelNode = BeamSqlRelUtils.getBeamRelInput(right); if (!seekable(leftRelNode, sqlEnv) && seekable(rightRelNode, sqlEnv)) { return joinAsLookup(leftRelNode, rightRelNode, inputPCollections, sqlEnv) - .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); } - PCollection leftRows = leftRelNode.buildBeamPipeline(inputPCollections, sqlEnv); - PCollection rightRows = rightRelNode.buildBeamPipeline(inputPCollections, sqlEnv); + PCollection leftRows = leftRelNode.buildBeamPipeline(inputPCollections, sqlEnv); + PCollection rightRows = rightRelNode.buildBeamPipeline(inputPCollections, sqlEnv); + + verifySupportedTrigger(leftRows); + verifySupportedTrigger(rightRows); String stageName = BeamSqlRelUtils.getStageName(this); WindowFn leftWinFn = leftRows.getWindowingStrategy().getWindowFn(); @@ -121,36 +131,37 @@ public BeamJoinRel(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelN // build the extract key type // the name of the join field is not important - List names = new ArrayList<>(pairs.size()); - List types = new ArrayList<>(pairs.size()); - for (int i = 0; i < pairs.size(); i++) { - names.add("c" + i); - types.add(leftRowType.getFieldTypeByIndex(pairs.get(i).getKey())); - } - BeamRecordSqlType extractKeyRowType = BeamRecordSqlType.create(names, types); + RowType extractKeyRowType = + pairs + .stream() + .map(pair -> + RowType.newField( + leftRowType.getFieldName(pair.getKey()), + leftRowType.getFieldCoder(pair.getKey()))) + .collect(toRowType()); - Coder extractKeyRowCoder = extractKeyRowType.getRecordCoder(); + Coder extractKeyRowCoder = extractKeyRowType.getRowCoder(); // BeamSqlRow -> KV - PCollection> extractedLeftRows = leftRows + PCollection> extractedLeftRows = leftRows .apply(stageName + "_left_ExtractJoinFields", MapElements.via(new BeamJoinTransforms.ExtractJoinFields(true, pairs))) .setCoder(KvCoder.of(extractKeyRowCoder, leftRows.getCoder())); - PCollection> extractedRightRows = rightRows + PCollection> extractedRightRows = rightRows .apply(stageName + "_right_ExtractJoinFields", MapElements.via(new BeamJoinTransforms.ExtractJoinFields(false, pairs))) .setCoder(KvCoder.of(extractKeyRowCoder, rightRows.getCoder())); // prepare the NullRows - BeamRecord leftNullRow = buildNullRow(leftRelNode); - BeamRecord rightNullRow = buildNullRow(rightRelNode); + Row leftNullRow = buildNullRow(leftRelNode); + Row rightNullRow = buildNullRow(rightRelNode); // a regular join if ((leftRows.isBounded() == PCollection.IsBounded.BOUNDED && rightRows.isBounded() == PCollection.IsBounded.BOUNDED) - || (leftRows.isBounded() == PCollection.IsBounded.UNBOUNDED - && rightRows.isBounded() == PCollection.IsBounded.UNBOUNDED)) { + || (leftRows.isBounded() == UNBOUNDED + && rightRows.isBounded() == UNBOUNDED)) { try { leftWinFn.verifyCompatibility(rightWinFn); } catch (IncompatibleWindowException e) { @@ -162,8 +173,8 @@ public BeamJoinRel(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelN leftNullRow, rightNullRow, stageName); } else if ( (leftRows.isBounded() == PCollection.IsBounded.BOUNDED - && rightRows.isBounded() == PCollection.IsBounded.UNBOUNDED) - || (leftRows.isBounded() == PCollection.IsBounded.UNBOUNDED + && rightRows.isBounded() == UNBOUNDED) + || (leftRows.isBounded() == UNBOUNDED && rightRows.isBounded() == PCollection.IsBounded.BOUNDED) ) { // if one of the sides is Bounded & the other is Unbounded @@ -192,11 +203,33 @@ public BeamJoinRel(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelN } } - private PCollection standardJoin( - PCollection> extractedLeftRows, - PCollection> extractedRightRows, - BeamRecord leftNullRow, BeamRecord rightNullRow, String stageName) { - PCollection>> joinedRows = null; + private void verifySupportedTrigger(PCollection pCollection) { + WindowingStrategy windowingStrategy = pCollection.getWindowingStrategy(); + + if (UNBOUNDED.equals(pCollection.isBounded()) + && !triggersOncePerWindow(windowingStrategy)) { + throw new UnsupportedOperationException( + "Joining unbounded PCollections is currently only supported for " + + "non-global windows with triggers that are known to produce output once per window," + + "such as the default trigger with zero allowed lateness. " + + "In these cases Beam can guarantee it joins all input elements once per window. " + + windowingStrategy + " is not supported"); + } + } + + private boolean triggersOncePerWindow(WindowingStrategy windowingStrategy) { + Trigger trigger = windowingStrategy.getTrigger(); + + return !(windowingStrategy.getWindowFn() instanceof GlobalWindows) + && trigger instanceof DefaultTrigger + && ZERO.equals(windowingStrategy.getAllowedLateness()); + } + + private PCollection standardJoin( + PCollection> extractedLeftRows, + PCollection> extractedRightRows, + Row leftNullRow, Row rightNullRow, String stageName) { + PCollection>> joinedRows = null; switch (joinType) { case LEFT: joinedRows = org.apache.beam.sdk.extensions.joinlibrary.Join @@ -218,53 +251,53 @@ private PCollection standardJoin( break; } - PCollection ret = joinedRows + PCollection ret = joinedRows .apply(stageName + "_JoinParts2WholeRow", MapElements.via(new BeamJoinTransforms.JoinParts2WholeRow())) - .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); return ret; } - public PCollection sideInputJoin( - PCollection> extractedLeftRows, - PCollection> extractedRightRows, - BeamRecord leftNullRow, BeamRecord rightNullRow) { + public PCollection sideInputJoin( + PCollection> extractedLeftRows, + PCollection> extractedRightRows, + Row leftNullRow, Row rightNullRow) { // we always make the Unbounded table on the left to do the sideInput join // (will convert the result accordingly before return) boolean swapped = (extractedLeftRows.isBounded() == PCollection.IsBounded.BOUNDED); JoinRelType realJoinType = (swapped && joinType != JoinRelType.INNER) ? JoinRelType.LEFT : joinType; - PCollection> realLeftRows = + PCollection> realLeftRows = swapped ? extractedRightRows : extractedLeftRows; - PCollection> realRightRows = + PCollection> realRightRows = swapped ? extractedLeftRows : extractedRightRows; - BeamRecord realRightNullRow = swapped ? leftNullRow : rightNullRow; + Row realRightNullRow = swapped ? leftNullRow : rightNullRow; // swapped still need to pass down because, we need to swap the result back. return sideInputJoinHelper(realJoinType, realLeftRows, realRightRows, realRightNullRow, swapped); } - private PCollection sideInputJoinHelper( + private PCollection sideInputJoinHelper( JoinRelType joinType, - PCollection> leftRows, - PCollection> rightRows, - BeamRecord rightNullRow, boolean swapped) { - final PCollectionView>> rowsView = + PCollection> leftRows, + PCollection> rightRows, + Row rightNullRow, boolean swapped) { + final PCollectionView>> rowsView = rightRows.apply(View.asMultimap()); - PCollection ret = leftRows + PCollection ret = leftRows .apply(ParDo.of(new BeamJoinTransforms.SideInputJoinDoFn( joinType, rightNullRow, rowsView, swapped)).withSideInputs(rowsView)) - .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + .setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); return ret; } - private BeamRecord buildNullRow(BeamRelNode relNode) { - BeamRecordSqlType leftType = CalciteUtils.toBeamRowType(relNode.getRowType()); - return new BeamRecord(leftType, Collections.nCopies(leftType.getFieldCount(), null)); + private Row buildNullRow(BeamRelNode relNode) { + RowType leftType = CalciteUtils.toBeamRowType(relNode.getRowType()); + return Row.nullRow(leftType); } private List> extractJoinColumns(int leftRowColumnCount) { @@ -304,9 +337,11 @@ private Pair extractOneJoinColumn(RexCall oneCondition, return new Pair<>(leftIndex, rightIndex); } - private PCollection joinAsLookup(BeamRelNode leftRelNode, BeamRelNode rightRelNode, - PCollectionTuple inputPCollections, BeamSqlEnv sqlEnv) throws Exception { - PCollection factStream = leftRelNode.buildBeamPipeline(inputPCollections, sqlEnv); + private PCollection joinAsLookup(BeamRelNode leftRelNode, + BeamRelNode rightRelNode, + PCollectionTuple inputPCollections, + BeamSqlEnv sqlEnv) throws Exception { + PCollection factStream = leftRelNode.buildBeamPipeline(inputPCollections, sqlEnv); BeamSqlSeekableTable seekableTable = getSeekableTableFromRelNode(rightRelNode, sqlEnv); return factStream.apply("join_as_lookup", diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRel.java index 6f5dff2c8847..9fdafda37204 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRel.java @@ -20,9 +20,9 @@ import java.util.List; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -49,7 +49,7 @@ public BeamMinusRel(RelOptCluster cluster, RelTraitSet traits, List inp return new BeamMinusRel(getCluster(), traitSet, inputs, all); } - @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { return delegate.buildBeamPipeline(inputPCollections, sqlEnv); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamProjectRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamProjectRel.java index 501feb304916..ea8987478f4f 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamProjectRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamProjectRel.java @@ -24,9 +24,9 @@ import org.apache.beam.sdk.extensions.sql.impl.transform.BeamSqlProjectFn; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -59,20 +59,20 @@ public Project copy(RelTraitSet traitSet, RelNode input, List projects, } @Override - public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { RelNode input = getInput(); String stageName = BeamSqlRelUtils.getStageName(this); - PCollection upstream = + PCollection upstream = BeamSqlRelUtils.getBeamRelInput(input).buildBeamPipeline(inputPCollections, sqlEnv); BeamSqlExpressionExecutor executor = new BeamSqlFnExecutor(this); - PCollection projectStream = upstream.apply(stageName, ParDo + PCollection projectStream = upstream.apply(stageName, ParDo .of(new BeamSqlProjectFn(getRelTypeName(), executor, CalciteUtils.toBeamRowType(rowType)))); - projectStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + projectStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); return projectStream; } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamRelNode.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamRelNode.java index 9e8d46de4983..aa56745d64b5 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamRelNode.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamRelNode.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.rel.RelNode; /** @@ -33,7 +33,7 @@ public interface BeamRelNode extends RelNode { * {@code BeamQueryPlanner} visits it with a DFS(Depth-First-Search) * algorithm. */ - PCollection buildBeamPipeline( + PCollection buildBeamPipeline( PCollectionTuple inputPCollections, BeamSqlEnv sqlEnv) throws Exception; } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java index 8b47d8e57af5..0a9af42b97f9 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBase.java @@ -28,10 +28,10 @@ import org.apache.beam.sdk.transforms.join.CoGroupByKey; import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.calcite.rel.RelNode; @@ -62,11 +62,11 @@ public BeamSetOperatorRelBase(BeamRelNode beamRelNode, OpType opType, this.all = all; } - public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { - PCollection leftRows = BeamSqlRelUtils.getBeamRelInput(inputs.get(0)) + PCollection leftRows = BeamSqlRelUtils.getBeamRelInput(inputs.get(0)) .buildBeamPipeline(inputPCollections, sqlEnv); - PCollection rightRows = BeamSqlRelUtils.getBeamRelInput(inputs.get(1)) + PCollection rightRows = BeamSqlRelUtils.getBeamRelInput(inputs.get(1)) .buildBeamPipeline(inputPCollections, sqlEnv); WindowFn leftWindow = leftRows.getWindowingStrategy().getWindowFn(); @@ -77,12 +77,12 @@ public PCollection buildBeamPipeline(PCollectionTuple inputPCollecti + leftWindow + " VS " + rightWindow); } - final TupleTag leftTag = new TupleTag<>(); - final TupleTag rightTag = new TupleTag<>(); + final TupleTag leftTag = new TupleTag<>(); + final TupleTag rightTag = new TupleTag<>(); // co-group String stageName = BeamSqlRelUtils.getStageName(beamRelNode); - PCollection> coGbkResultCollection = + PCollection> coGbkResultCollection = KeyedPCollectionTuple.of( leftTag, leftRows.apply( @@ -94,7 +94,7 @@ public PCollection buildBeamPipeline(PCollectionTuple inputPCollecti stageName + "_CreateRightIndex", MapElements.via(new BeamSetOperatorsTransforms.BeamSqlRow2KvFn()))) .apply(CoGroupByKey.create()); - PCollection ret = coGbkResultCollection + PCollection ret = coGbkResultCollection .apply(ParDo.of(new BeamSetOperatorsTransforms.SetOperatorFilteringDoFn(leftTag, rightTag, opType, all))); return ret; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java index 2a57267f2a1b..16cdc7e10208 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java @@ -25,7 +25,6 @@ import java.util.Comparator; import java.util.List; import org.apache.beam.sdk.coders.ListCoder; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.DoFn; @@ -33,9 +32,9 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelCollation; @@ -120,10 +119,10 @@ public BeamSortRel( } } - @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { RelNode input = getInput(); - PCollection upstream = BeamSqlRelUtils.getBeamRelInput(input) + PCollection upstream = BeamSqlRelUtils.getBeamRelInput(input) .buildBeamPipeline(inputPCollections, sqlEnv); Type windowType = upstream.getWindowingStrategy().getWindowFn() .getWindowTypeDescriptor().getType(); @@ -135,7 +134,7 @@ public BeamSortRel( BeamSqlRowComparator comparator = new BeamSqlRowComparator(fieldIndices, orientation, nullsFirst); // first find the top (offset + count) - PCollection> rawStream = + PCollection> rawStream = upstream .apply( "extractTopOffsetAndFetch", @@ -151,8 +150,8 @@ public BeamSortRel( .setCoder(ListCoder.of(upstream.getCoder())); } - PCollection orderedStream = rawStream.apply("flatten", Flatten.iterables()); - orderedStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRecordCoder()); + PCollection orderedStream = rawStream.apply("flatten", Flatten.iterables()); + orderedStream.setCoder(CalciteUtils.toBeamRowType(getRowType()).getRowCoder()); return orderedStream; } @@ -177,7 +176,7 @@ public void processElement(ProcessContext ctx) { return new BeamSortRel(getCluster(), traitSet, newInput, newCollation, offset, fetch); } - private static class BeamSqlRowComparator implements Comparator, Serializable { + private static class BeamSqlRowComparator implements Comparator, Serializable { private List fieldsIndices; private List orientation; private List nullsFirst; @@ -190,16 +189,15 @@ public BeamSqlRowComparator(List fieldsIndices, this.nullsFirst = nullsFirst; } - @Override public int compare(BeamRecord row1, BeamRecord row2) { + @Override public int compare(Row row1, Row row2) { for (int i = 0; i < fieldsIndices.size(); i++) { int fieldIndex = fieldsIndices.get(i); int fieldRet = 0; - SqlTypeName fieldType = CalciteUtils.getFieldType( - BeamSqlRecordHelper.getSqlRecordType(row1), fieldIndex); + SqlTypeName fieldType = CalciteUtils.getFieldCalciteType(row1.getRowType(), fieldIndex); // whether NULL should be ordered first or last(compared to non-null values) depends on // what user specified in SQL(NULLS FIRST/NULLS LAST) - boolean isValue1Null = (row1.getFieldValue(fieldIndex) == null); - boolean isValue2Null = (row2.getFieldValue(fieldIndex) == null); + boolean isValue1Null = (row1.getValue(fieldIndex) == null); + boolean isValue2Null = (row2.getValue(fieldIndex) == null); if (isValue1Null && isValue2Null) { continue; } else if (isValue1Null && !isValue2Null) { @@ -217,8 +215,8 @@ public BeamSqlRowComparator(List fieldsIndices, case VARCHAR: case DATE: case TIMESTAMP: - Comparable v1 = (Comparable) row1.getFieldValue(fieldIndex); - Comparable v2 = (Comparable) row2.getFieldValue(fieldIndex); + Comparable v1 = (Comparable) row1.getValue(fieldIndex); + Comparable v2 = (Comparable) row2.getValue(fieldIndex); fieldRet = v1.compareTo(v2); break; default: diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRel.java index 85d676e90ba9..f8d34c2d6986 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRel.java @@ -21,12 +21,11 @@ import java.util.List; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; -import org.apache.calcite.rel.RelInput; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.SetOp; import org.apache.calcite.rel.core.Union; @@ -73,15 +72,11 @@ public BeamUnionRel(RelOptCluster cluster, inputs, all); } - public BeamUnionRel(RelInput input) { - super(input); - } - @Override public SetOp copy(RelTraitSet traitSet, List inputs, boolean all) { return new BeamUnionRel(getCluster(), traitSet, inputs, all); } - @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections + @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections , BeamSqlEnv sqlEnv) throws Exception { return delegate.buildBeamPipeline(inputPCollections, sqlEnv); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java index d684294b6500..1e98968f0836 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRel.java @@ -18,17 +18,20 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; +import static java.util.stream.Collectors.toList; +import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.autoCastField; +import static org.apache.beam.sdk.values.Row.toRow; + import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import java.util.stream.IntStream; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.core.Values; @@ -55,25 +58,35 @@ public BeamValuesRel( } - @Override public PCollection buildBeamPipeline(PCollectionTuple inputPCollections - , BeamSqlEnv sqlEnv) throws Exception { - List rows = new ArrayList<>(tuples.size()); + @Override public PCollection buildBeamPipeline( + PCollectionTuple inputPCollections, + BeamSqlEnv sqlEnv) throws Exception { + String stageName = BeamSqlRelUtils.getStageName(this); if (tuples.isEmpty()) { throw new IllegalStateException("Values with empty tuples!"); } - BeamRecordSqlType beamSQLRowType = CalciteUtils.toBeamRowType(this.getRowType()); - for (ImmutableList tuple : tuples) { - List fieldsValue = new ArrayList<>(beamSQLRowType.getFieldCount()); - for (int i = 0; i < tuple.size(); i++) { - fieldsValue.add(BeamTableUtils.autoCastField( - beamSQLRowType.getFieldTypeByIndex(i), tuple.get(i).getValue())); - } - rows.add(new BeamRecord(beamSQLRowType, fieldsValue)); - } + RowType rowType = CalciteUtils.toBeamRowType(this.getRowType()); + + List rows = + tuples + .stream() + .map(tuple -> tupleToRow(rowType, tuple)) + .collect(toList()); + + return + inputPCollections + .getPipeline() + .apply(stageName, Create.of(rows)) + .setCoder(rowType.getRowCoder()); + } - return inputPCollections.getPipeline().apply(stageName, Create.of(rows)) - .setCoder(beamSQLRowType.getRecordCoder()); + private Row tupleToRow(RowType rowType, ImmutableList tuple) { + return + IntStream + .range(0, tuple.size()) + .mapToObj(i -> autoCastField(rowType.getFieldCoder(i), tuple.get(i).getValue())) + .collect(toRow(rowType)); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/package-info.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/package-info.java index 76b335dda8fa..43f9194763a5 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/package-info.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/package-info.java @@ -20,4 +20,8 @@ * BeamSQL specified nodes, to replace {@link org.apache.calcite.rel.RelNode}. * */ +@DefaultAnnotation(NonNull.class) package org.apache.beam.sdk.extensions.sql.impl.rel; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import edu.umd.cs.findbugs.annotations.NonNull; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/AggregateWindowFactory.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/AggregateWindowFactory.java new file mode 100644 index 000000000000..8448fc80f375 --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/AggregateWindowFactory.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql.impl.rule; + +import java.util.List; +import java.util.Optional; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.joda.time.Duration; + +/** + * Creates {@link WindowFn} wrapper based on HOP/TUMBLE/SESSION call in a query. + */ +class AggregateWindowFactory { + + /** + * Returns optional of {@link AggregateWindowField} which represents a + * windowing function specified by HOP/TUMBLE/SESSION in the SQL query. + * + *

    If no known windowing function is specified in the query, then {@link Optional#empty()} + * is returned. + * + *

    Throws {@link UnsupportedOperationException} if it cannot convert SQL windowing function + * call to Beam model, see {@link #getWindowFieldAt(RexCall, int)} for details. + */ + static Optional getWindowFieldAt(RexCall call, int groupField) { + + Optional windowFnOptional = createWindowFn(call.operands, call.op.kind); + + return + windowFnOptional + .map(windowFn -> + AggregateWindowField + .builder() + .setFieldIndex(groupField) + .setWindowFn(windowFn) + .build()); + } + + /** + * Returns a {@link WindowFn} based on the SQL windowing function defined by {#code operatorKind}. + * Supported {@link SqlKind}s: + *

      + *
    • {@link SqlKind#TUMBLE}, mapped to {@link FixedWindows};
    • + *
    • {@link SqlKind#HOP}, mapped to {@link SlidingWindows};
    • + *
    • {@link SqlKind#SESSION}, mapped to {@link Sessions};
    • + *
    + * + *

    For example: + *

    {@code
    +   *   SELECT event_timestamp, COUNT(*)
    +   *   FROM PCOLLECTION
    +   *   GROUP BY TUMBLE(event_timestamp, INTERVAL '1' HOUR)
    +   * }
    + * + *

    SQL window functions support optional window_offset parameter which indicates a + * how window definition is offset from the event time. Offset is zero if not specified. + * + *

    Beam model does not support offset for session windows, so this method will throw + * {@link UnsupportedOperationException} if offset is specified + * in SQL query for {@link SqlKind#SESSION}. + */ + private static Optional createWindowFn(List parameters, SqlKind operatorKind) { + switch (operatorKind) { + case TUMBLE: + + // Fixed-size, non-intersecting time-based windows, for example: + // every hour aggregate elements from the previous hour; + // + // SQL Syntax: + // TUMBLE(monotonic_field, window_size [, window_offset]) + // + // Example: + // TUMBLE(event_timestamp_field, INTERVAL '1' HOUR) + + FixedWindows fixedWindows = FixedWindows.of(durationParameter(parameters, 1)); + if (parameters.size() == 3) { + fixedWindows = fixedWindows.withOffset(durationParameter(parameters, 2)); + } + + return Optional.of(fixedWindows); + case HOP: + + // Sliding, fixed-size, intersecting time-based windows, for example: + // every minute aggregate elements from the previous hour; + // + // SQL Syntax: + // HOP(monotonic_field, emit_frequency, window_size [, window_offset]) + // + // Example: + // HOP(event_timestamp_field, INTERVAL '1' MINUTE, INTERVAL '1' HOUR) + + SlidingWindows slidingWindows = SlidingWindows + .of(durationParameter(parameters, 2)) + .every(durationParameter(parameters, 1)); + + if (parameters.size() == 4) { + slidingWindows = slidingWindows.withOffset(durationParameter(parameters, 3)); + } + + return Optional.of(slidingWindows); + case SESSION: + + // Session windows, for example: + // aggregate events after a gap of 1 minute of no events; + // + // SQL Syntax: + // SESSION(monotonic_field, session_gap) + // + // Example: + // SESSION(event_timestamp_field, INTERVAL '1' MINUTE) + + Sessions sessions = Sessions.withGapDuration(durationParameter(parameters, 1)); + if (parameters.size() == 3) { + throw new UnsupportedOperationException( + "Specifying alignment (offset) is not supported for session windows"); + } + + return Optional.of(sessions); + default: + return Optional.empty(); + } + } + + private static Duration durationParameter(List parameters, int parameterIndex) { + return Duration.millis(intValue(parameters.get(parameterIndex))); + } + + private static long intValue(RexNode operand) { + if (operand instanceof RexLiteral) { + return RexLiteral.intValue(operand); + } else { + throw new IllegalArgumentException(String.format("[%s] is not valid.", operand)); + } + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/AggregateWindowField.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/AggregateWindowField.java new file mode 100644 index 000000000000..40157945dbab --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/AggregateWindowField.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql.impl.rule; + +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.Row; + +/** + * For internal use only; no backwards compatibility guarantees. + * + *

    Represents a field with a window function call in a SQL expression. + */ +@Internal +@AutoValue +public abstract class AggregateWindowField { + public abstract int fieldIndex(); + public abstract WindowFn windowFn(); + + static Builder builder() { + return new AutoValue_AggregateWindowField.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setFieldIndex(int fieldIndex); + abstract Builder setWindowFn(WindowFn window); + abstract AggregateWindowField build(); + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamAggregationRule.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamAggregationRule.java index cdf6712524df..e622d3feb56b 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamAggregationRule.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamAggregationRule.java @@ -17,20 +17,10 @@ */ package org.apache.beam.sdk.extensions.sql.impl.rule; -import com.google.common.collect.ImmutableList; -import java.util.GregorianCalendar; import java.util.List; +import java.util.Optional; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamAggregationRel; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention; -import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; -import org.apache.beam.sdk.transforms.windowing.AfterWatermark; -import org.apache.beam.sdk.transforms.windowing.FixedWindows; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; -import org.apache.beam.sdk.transforms.windowing.Repeatedly; -import org.apache.beam.sdk.transforms.windowing.Sessions; -import org.apache.beam.sdk.transforms.windowing.SlidingWindows; -import org.apache.beam.sdk.transforms.windowing.Trigger; -import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptRuleOperand; @@ -38,12 +28,9 @@ import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; -import org.joda.time.Duration; /** * Rule to detect the window/trigger settings. @@ -71,64 +58,23 @@ public BeamAggregationRule(RelOptRuleOperand operand, String description) { public void onMatch(RelOptRuleCall call) { final Aggregate aggregate = call.rel(0); final Project project = call.rel(1); - updateWindowTrigger(call, aggregate, project); + updateWindow(call, aggregate, project); } - private void updateWindowTrigger(RelOptRuleCall call, Aggregate aggregate, - Project project) { + private void updateWindow(RelOptRuleCall call, Aggregate aggregate, + Project project) { ImmutableBitSet groupByFields = aggregate.getGroupSet(); List projectMapping = project.getProjects(); - WindowFn windowFn = new GlobalWindows(); - Trigger triggerFn = Repeatedly.forever(AfterWatermark.pastEndOfWindow()); - int windowFieldIdx = -1; - Duration allowedLatence = Duration.ZERO; + Optional windowField = Optional.empty(); - for (int groupField : groupByFields.asList()) { - RexNode projNode = projectMapping.get(groupField); - if (projNode instanceof RexCall) { - SqlOperator op = ((RexCall) projNode).op; - ImmutableList parameters = ((RexCall) projNode).operands; - String functionName = op.getName(); - switch (functionName) { - case "TUMBLE": - windowFieldIdx = groupField; - windowFn = FixedWindows - .of(Duration.millis(getWindowParameterAsMillis(parameters.get(1)))); - if (parameters.size() == 3) { - GregorianCalendar delayTime = (GregorianCalendar) ((RexLiteral) parameters.get(2)) - .getValue(); - triggerFn = createTriggerWithDelay(delayTime); - allowedLatence = (Duration.millis(delayTime.getTimeInMillis())); - } - break; - case "HOP": - windowFieldIdx = groupField; - windowFn = SlidingWindows - .of(Duration.millis(getWindowParameterAsMillis(parameters.get(1)))) - .every(Duration.millis(getWindowParameterAsMillis(parameters.get(2)))); - if (parameters.size() == 4) { - GregorianCalendar delayTime = (GregorianCalendar) ((RexLiteral) parameters.get(3)) - .getValue(); - triggerFn = createTriggerWithDelay(delayTime); - allowedLatence = (Duration.millis(delayTime.getTimeInMillis())); - } - break; - case "SESSION": - windowFieldIdx = groupField; - windowFn = Sessions - .withGapDuration(Duration.millis(getWindowParameterAsMillis(parameters.get(1)))); - if (parameters.size() == 3) { - GregorianCalendar delayTime = (GregorianCalendar) ((RexLiteral) parameters.get(2)) - .getValue(); - triggerFn = createTriggerWithDelay(delayTime); - allowedLatence = (Duration.millis(delayTime.getTimeInMillis())); - } - break; - default: - break; - } + for (int groupFieldIndex : groupByFields.asList()) { + RexNode projNode = projectMapping.get(groupFieldIndex); + if (!(projNode instanceof RexCall)) { + continue; } + + windowField = AggregateWindowFactory.getWindowFieldAt((RexCall) projNode, groupFieldIndex); } BeamAggregationRel newAggregator = new BeamAggregationRel(aggregate.getCluster(), @@ -139,24 +85,8 @@ private void updateWindowTrigger(RelOptRuleCall call, Aggregate aggregate, aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList(), - windowFn, - triggerFn, - windowFieldIdx, - allowedLatence); + windowField); call.transformTo(newAggregator); } - private Trigger createTriggerWithDelay(GregorianCalendar delayTime) { - return Repeatedly.forever(AfterWatermark.pastEndOfWindow().withLateFirings(AfterProcessingTime - .pastFirstElementInPane().plusDelayOf(Duration.millis(delayTime.getTimeInMillis())))); - } - - private long getWindowParameterAsMillis(RexNode parameterNode) { - if (parameterNode instanceof RexLiteral) { - return RexLiteral.intValue(parameterNode); - } else { - throw new IllegalArgumentException(String.format("[%s] is not valid.", parameterNode)); - } - } - } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/package-info.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/package-info.java index fa32b44a0bfd..84cdcd85d3e9 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/package-info.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/package-info.java @@ -20,4 +20,8 @@ * {@link org.apache.calcite.plan.RelOptRule} to generate * {@link org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode}. */ +@DefaultAnnotation(NonNull.class) package org.apache.beam.sdk.extensions.sql.impl.rule; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import edu.umd.cs.findbugs.annotations.NonNull; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BaseBeamTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BaseBeamTable.java index 7f99e128a5be..9d85b2c8b752 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BaseBeamTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BaseBeamTable.java @@ -18,19 +18,19 @@ package org.apache.beam.sdk.extensions.sql.impl.schema; import java.io.Serializable; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; +import org.apache.beam.sdk.values.RowType; /** * Each IO in Beam has one table schema, by extending {@link BaseBeamTable}. */ public abstract class BaseBeamTable implements BeamSqlTable, Serializable { - protected BeamRecordSqlType beamRecordSqlType; - public BaseBeamTable(BeamRecordSqlType beamRecordSqlType) { - this.beamRecordSqlType = beamRecordSqlType; + protected RowType rowType; + public BaseBeamTable(RowType rowType) { + this.rowType = rowType; } - @Override public BeamRecordSqlType getRowType() { - return beamRecordSqlType; + @Override public RowType getRowType() { + return rowType; } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamPCollectionTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamPCollectionTable.java index 31e60e01704b..076593096694 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamPCollectionTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamPCollectionTable.java @@ -18,12 +18,12 @@ package org.apache.beam.sdk.extensions.sql.impl.schema; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** * {@code BeamPCollectionTable} converts a {@code PCollection} as a virtual table, @@ -31,15 +31,15 @@ */ public class BeamPCollectionTable extends BaseBeamTable { private BeamIOType ioType; - private transient PCollection upstream; + private transient PCollection upstream; - protected BeamPCollectionTable(BeamRecordSqlType beamSqlRowType) { - super(beamSqlRowType); + protected BeamPCollectionTable(RowType beamRowType) { + super(beamRowType); } - public BeamPCollectionTable(PCollection upstream, - BeamRecordSqlType beamSqlRowType){ - this(beamSqlRowType); + public BeamPCollectionTable(PCollection upstream, + RowType beamRowType){ + this(beamRowType); ioType = upstream.isBounded().equals(IsBounded.BOUNDED) ? BeamIOType.BOUNDED : BeamIOType.UNBOUNDED; this.upstream = upstream; @@ -51,12 +51,12 @@ public BeamIOType getSourceType() { } @Override - public PCollection buildIOReader(Pipeline pipeline) { + public PCollection buildIOReader(Pipeline pipeline) { return upstream; } @Override - public PTransform, PDone> buildIOWriter() { + public PTransform, PDone> buildIOWriter() { throw new IllegalArgumentException("cannot use [BeamPCollectionTable] as target"); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java index e9f3c7660d2e..e971b2861b9e 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java @@ -18,15 +18,19 @@ package org.apache.beam.sdk.extensions.sql.impl.schema; +import static com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.sdk.values.Row.toRow; + import java.io.IOException; import java.io.StringReader; import java.io.StringWriter; import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import java.util.stream.IntStream; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.NlsString; import org.apache.commons.csv.CSVFormat; @@ -38,37 +42,38 @@ * Utility methods for working with {@code BeamTable}. */ public final class BeamTableUtils { - public static BeamRecord csvLine2BeamRecord( + public static Row csvLine2BeamRow( CSVFormat csvFormat, String line, - BeamRecordSqlType beamRecordSqlType) { - List fieldsValue = new ArrayList<>(beamRecordSqlType.getFieldCount()); + RowType rowType) { + try (StringReader reader = new StringReader(line)) { CSVParser parser = csvFormat.parse(reader); CSVRecord rawRecord = parser.getRecords().get(0); - if (rawRecord.size() != beamRecordSqlType.getFieldCount()) { + if (rawRecord.size() != rowType.getFieldCount()) { throw new IllegalArgumentException(String.format( "Expect %d fields, but actually %d", - beamRecordSqlType.getFieldCount(), rawRecord.size() + rowType.getFieldCount(), rawRecord.size() )); - } else { - for (int idx = 0; idx < beamRecordSqlType.getFieldCount(); idx++) { - String raw = rawRecord.get(idx); - fieldsValue.add(autoCastField(beamRecordSqlType.getFieldTypeByIndex(idx), raw)); - } } + + return + IntStream + .range(0, rowType.getFieldCount()) + .mapToObj(idx -> autoCastField(rowType.getFieldCoder(idx), rawRecord.get(idx))) + .collect(toRow(rowType)); + } catch (IOException e) { throw new IllegalArgumentException("decodeRecord failed!", e); } - return new BeamRecord(beamRecordSqlType, fieldsValue); } - public static String beamRecord2CsvLine(BeamRecord row, CSVFormat csvFormat) { + public static String beamRow2CsvLine(Row row, CSVFormat csvFormat) { StringWriter writer = new StringWriter(); try (CSVPrinter printer = csvFormat.print(writer)) { for (int i = 0; i < row.getFieldCount(); i++) { - printer.print(row.getFieldValue(i).toString()); + printer.print(row.getValue(i).toString()); } printer.println(); } catch (IOException e) { @@ -77,12 +82,14 @@ public static String beamRecord2CsvLine(BeamRecord row, CSVFormat csvFormat) { return writer.toString(); } - public static Object autoCastField(int fieldType, Object rawObj) { + public static Object autoCastField(Coder coder, Object rawObj) { + checkArgument(coder instanceof SqlTypeCoder); + if (rawObj == null) { return null; } - SqlTypeName columnType = CalciteUtils.toCalciteType(fieldType); + SqlTypeName columnType = CalciteUtils.toCalciteType((SqlTypeCoder) coder); // auto-casting for numberics if ((rawObj instanceof String && SqlTypeName.NUMERIC_TYPES.contains(columnType)) || (rawObj instanceof BigDecimal && columnType != SqlTypeName.DECIMAL)) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java index 4c6511b4cde9..f2e515b53780 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java @@ -17,6 +17,10 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform; +import static org.apache.beam.sdk.values.Row.toRow; +import static org.apache.beam.sdk.values.RowType.toRowType; + +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -24,20 +28,22 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; -import org.apache.beam.sdk.coders.BeamRecordCoder; +import java.util.stream.IntStream; import org.apache.beam.sdk.coders.BigDecimalCoder; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlInputRefExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl; -import org.apache.beam.sdk.extensions.sql.impl.transform.agg.BigDecimalConverter; +import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn; import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn; +import org.apache.beam.sdk.extensions.sql.impl.utils.BigDecimalConverter; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.Count; @@ -45,10 +51,10 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; import org.apache.calcite.util.ImmutableBitSet; import org.joda.time.Instant; @@ -60,12 +66,12 @@ public class BeamAggregationTransforms implements Serializable{ /** * Merge KV to single record. */ - public static class MergeAggregationRecord extends DoFn, BeamRecord> { - private BeamRecordSqlType outRowType; + public static class MergeAggregationRecord extends DoFn, Row> { + private RowType outRowType; private List aggFieldNames; private int windowStartFieldIdx; - public MergeAggregationRecord(BeamRecordSqlType outRowType, List aggList + public MergeAggregationRecord(RowType outRowType, List aggList , int windowStartFieldIdx) { this.outRowType = outRowType; this.aggFieldNames = new ArrayList<>(); @@ -77,17 +83,19 @@ public MergeAggregationRecord(BeamRecordSqlType outRowType, List @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) { - KV kvRecord = c.element(); + KV kvRow = c.element(); List fieldValues = new ArrayList<>(); - fieldValues.addAll(kvRecord.getKey().getDataValues()); - fieldValues.addAll(kvRecord.getValue().getDataValues()); + fieldValues.addAll(kvRow.getKey().getValues()); + fieldValues.addAll(kvRow.getValue().getValues()); if (windowStartFieldIdx != -1) { fieldValues.add(windowStartFieldIdx, ((IntervalWindow) window).start().toDate()); } - BeamRecord outRecord = new BeamRecord(outRowType, fieldValues); - c.output(outRecord); + c.output(Row + .withRowType(outRowType) + .addValues(fieldValues) + .build()); } } @@ -95,7 +103,7 @@ public void processElement(ProcessContext c, BoundedWindow window) { * extract group-by fields. */ public static class AggregationGroupByKeyFn - implements SerializableFunction { + implements SerializableFunction { private List groupByKeys; public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) { @@ -108,33 +116,30 @@ public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) { } @Override - public BeamRecord apply(BeamRecord input) { - BeamRecordSqlType typeOfKey = exTypeOfKeyRecord(BeamSqlRecordHelper.getSqlRecordType(input)); - - List fieldValues = new ArrayList<>(groupByKeys.size()); - for (Integer groupByKey : groupByKeys) { - fieldValues.add(input.getFieldValue(groupByKey)); - } + public Row apply(Row input) { + RowType typeOfKey = exTypeOfKeyRow(input.getRowType()); - BeamRecord keyOfRecord = new BeamRecord(typeOfKey, fieldValues); - return keyOfRecord; + return groupByKeys + .stream() + .map(input::getValue) + .collect(toRow(typeOfKey)); } - private BeamRecordSqlType exTypeOfKeyRecord(BeamRecordSqlType dataType) { - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (int idx : groupByKeys) { - fieldNames.add(dataType.getFieldNameByIndex(idx)); - fieldTypes.add(dataType.getFieldTypeByIndex(idx)); - } - return BeamRecordSqlType.create(fieldNames, fieldTypes); + private RowType exTypeOfKeyRow(RowType dataType) { + return + groupByKeys + .stream() + .map(i -> RowType.newField( + dataType.getFieldName(i), + dataType.getFieldCoder(i))) + .collect(toRowType()); } } /** * Assign event timestamp. */ - public static class WindowTimestampFn implements SerializableFunction { + public static class WindowTimestampFn implements SerializableFunction { private int windowFieldIdx = -1; public WindowTimestampFn(int windowFieldIdx) { @@ -143,7 +148,7 @@ public WindowTimestampFn(int windowFieldIdx) { } @Override - public Instant apply(BeamRecord input) { + public Instant apply(Row input) { return new Instant(input.getDate(windowFieldIdx).getTime()); } } @@ -152,51 +157,73 @@ public Instant apply(BeamRecord input) { * An adaptor class to invoke Calcite UDAF instances in Beam {@code CombineFn}. */ public static class AggregationAdaptor - extends CombineFn { + extends CombineFn { private List aggregators; - private List sourceFieldExps; - private BeamRecordSqlType finalRowType; + private List sourceFieldExps; + private RowType finalRowType; - public AggregationAdaptor(List aggregationCalls, - BeamRecordSqlType sourceRowType) { + public AggregationAdaptor(List aggregationCalls, RowType sourceRowType) { aggregators = new ArrayList<>(); sourceFieldExps = new ArrayList<>(); - List outFieldsName = new ArrayList<>(); - List outFieldsType = new ArrayList<>(); + ImmutableList.Builder fields = ImmutableList.builder(); + for (AggregateCall call : aggregationCalls) { - int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0; - BeamSqlInputRefExpression sourceExp = new BeamSqlInputRefExpression( - CalciteUtils.getFieldType(sourceRowType, refIndex), refIndex); - sourceFieldExps.add(sourceExp); + if (call.getArgList().size() == 2) { + /** + * handle the case of aggregation function has two parameters and + * use KV pair to bundle two corresponding expressions. + */ + + int refIndexKey = call.getArgList().get(0); + int refIndexValue = call.getArgList().get(1); + + BeamSqlInputRefExpression sourceExpKey = new BeamSqlInputRefExpression( + CalciteUtils.getFieldCalciteType(sourceRowType, refIndexKey), refIndexKey); + BeamSqlInputRefExpression sourceExpValue = new BeamSqlInputRefExpression( + CalciteUtils.getFieldCalciteType(sourceRowType, refIndexValue), refIndexValue); - outFieldsName.add(call.name); - SqlTypeName outFieldSqlType = call.type.getSqlTypeName(); - int outFieldType = CalciteUtils.toJavaType(outFieldSqlType); - outFieldsType.add(outFieldType); + sourceFieldExps.add(KV.of(sourceExpKey, sourceExpValue)); + } else { + int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0; + BeamSqlInputRefExpression sourceExp = new BeamSqlInputRefExpression( + CalciteUtils.getFieldCalciteType(sourceRowType, refIndex), refIndex); + sourceFieldExps.add(sourceExp); + } + + SqlTypeCoder outFieldType = CalciteUtils.toCoder(call.type.getSqlTypeName()); + fields.add(RowType.newField(call.name, outFieldType)); switch (call.getAggregation().getName()) { case "COUNT": aggregators.add(Count.combineFn()); break; case "MAX": - aggregators.add(BeamBuiltinAggregations.createMax(outFieldSqlType)); + aggregators.add(BeamBuiltinAggregations.createMax(call.type.getSqlTypeName())); break; case "MIN": - aggregators.add(BeamBuiltinAggregations.createMin(outFieldSqlType)); + aggregators.add(BeamBuiltinAggregations.createMin(call.type.getSqlTypeName())); break; case "SUM": - aggregators.add(BeamBuiltinAggregations.createSum(outFieldSqlType)); + aggregators.add(BeamBuiltinAggregations.createSum(call.type.getSqlTypeName())); break; case "AVG": - aggregators.add(BeamBuiltinAggregations.createAvg(outFieldSqlType)); + aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName())); break; case "VAR_POP": aggregators.add( - VarianceFn.newPopulation(BigDecimalConverter.forSqlType(outFieldSqlType))); + VarianceFn.newPopulation(BigDecimalConverter.forSqlType(outFieldType))); break; case "VAR_SAMP": aggregators.add( - VarianceFn.newSample(BigDecimalConverter.forSqlType(outFieldSqlType))); + VarianceFn.newSample(BigDecimalConverter.forSqlType(outFieldType))); + break; + case "COVAR_POP": + aggregators.add( + CovarianceFn.newPopulation(BigDecimalConverter.forSqlType(outFieldType))); + break; + case "COVAR_SAMP": + aggregators.add( + CovarianceFn.newSample(BigDecimalConverter.forSqlType(outFieldType))); break; default: if (call.getAggregation() instanceof SqlUserDefinedAggFunction) { @@ -216,8 +243,9 @@ public AggregationAdaptor(List aggregationCalls, break; } } - finalRowType = BeamRecordSqlType.create(outFieldsName, outFieldsType); + finalRowType = fields.build().stream().collect(toRowType()); } + @Override public AggregationAccumulator createAccumulator() { AggregationAccumulator initialAccu = new AggregationAccumulator(); @@ -226,13 +254,31 @@ public AggregationAccumulator createAccumulator() { } return initialAccu; } + @Override - public AggregationAccumulator addInput(AggregationAccumulator accumulator, BeamRecord input) { + public AggregationAccumulator addInput(AggregationAccumulator accumulator, Row input) { AggregationAccumulator deltaAcc = new AggregationAccumulator(); for (int idx = 0; idx < aggregators.size(); ++idx) { - deltaAcc.accumulatorElements.add( - aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx), - sourceFieldExps.get(idx).evaluate(input, null).getValue())); + if (sourceFieldExps.get(idx) instanceof BeamSqlInputRefExpression) { + BeamSqlInputRefExpression exp = (BeamSqlInputRefExpression) sourceFieldExps.get(idx); + deltaAcc.accumulatorElements.add( + aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx), + exp.evaluate(input, null).getValue() + ) + ); + } else if (sourceFieldExps.get(idx) instanceof KV){ + /** + * If source expression is type of KV pair, we bundle the value of two expressions into + * KV pair and pass it to aggregator's addInput method. + */ + + KV exp = + (KV) sourceFieldExps.get(idx); + deltaAcc.accumulatorElements.add( + aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx), + KV.of(exp.getKey().evaluate(input, null).getValue(), + exp.getValue().evaluate(input, null).getValue()))); + } } return deltaAcc; } @@ -248,26 +294,49 @@ public AggregationAccumulator mergeAccumulators(Iterable } return deltaAcc; } + @Override - public BeamRecord extractOutput(AggregationAccumulator accumulator) { - List fieldValues = new ArrayList<>(aggregators.size()); - for (int idx = 0; idx < aggregators.size(); ++idx) { - fieldValues - .add(aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx))); - } - return new BeamRecord(finalRowType, fieldValues); + public Row extractOutput(AggregationAccumulator accumulator) { + return + IntStream + .range(0, aggregators.size()) + .mapToObj(idx -> getAggregatorOutput(accumulator, idx)) + .collect(toRow(finalRowType)); + } + + private Object getAggregatorOutput(AggregationAccumulator accumulator, int idx) { + return aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx)); } + @Override public Coder getAccumulatorCoder( - CoderRegistry registry, Coder inputCoder) + CoderRegistry registry, Coder inputCoder) throws CannotProvideCoderException { - BeamRecordCoder beamRecordCoder = (BeamRecordCoder) inputCoder; + RowCoder rowCoder = (RowCoder) inputCoder; registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of()); List aggAccuCoderList = new ArrayList<>(); for (int idx = 0; idx < aggregators.size(); ++idx) { - int srcFieldIndex = sourceFieldExps.get(idx).getInputRef(); - Coder srcFieldCoder = beamRecordCoder.getCoders().get(srcFieldIndex); - aggAccuCoderList.add(aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder)); + if (sourceFieldExps.get(idx) instanceof BeamSqlInputRefExpression) { + BeamSqlInputRefExpression exp = (BeamSqlInputRefExpression) sourceFieldExps.get(idx); + int srcFieldIndex = exp.getInputRef(); + + Coder srcFieldCoder = rowCoder.getCoders().get(srcFieldIndex); + aggAccuCoderList.add(aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder)); + } else if (sourceFieldExps.get(idx) instanceof KV) { + // extract coder of two expressions separately. + KV exp = + (KV) sourceFieldExps.get(idx); + + int srcFieldIndexKey = exp.getKey().getInputRef(); + int srcFieldIndexValue = exp.getValue().getInputRef(); + + Coder srcFieldCoderKey = rowCoder.getCoders().get(srcFieldIndexKey); + Coder srcFieldCoderValue = rowCoder.getCoders().get(srcFieldIndexValue); + + aggAccuCoderList.add(aggregators.get(idx).getAccumulatorCoder(registry, KvCoder.of( + srcFieldCoderKey, srcFieldCoderValue)) + ); + } } return new AggregationAccumulatorCoder(aggAccuCoderList); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java index f789e31e379a..3f708253cd6d 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java @@ -18,20 +18,24 @@ package org.apache.beam.sdk.extensions.sql.impl.transform; +import static java.util.stream.Collectors.toList; +import static org.apache.beam.sdk.values.Row.toRow; +import static org.apache.beam.sdk.values.RowType.toRowType; + import java.util.ArrayList; import java.util.List; import java.util.Map; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; -import org.apache.beam.sdk.extensions.sql.BeamSqlRecordHelper; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; @@ -47,52 +51,53 @@ public class BeamJoinTransforms { * A {@code SimpleFunction} to extract join fields from the specified row. */ public static class ExtractJoinFields - extends SimpleFunction> { - private final boolean isLeft; - private final List> joinColumns; + extends SimpleFunction> { + private final List joinColumns; public ExtractJoinFields(boolean isLeft, List> joinColumns) { - this.isLeft = isLeft; - this.joinColumns = joinColumns; + this.joinColumns = + joinColumns + .stream() + .map(pair -> isLeft ? pair.left : pair.right) + .collect(toList()); } - @Override public KV apply(BeamRecord input) { - // build the type - // the name of the join field is not important - List names = new ArrayList<>(joinColumns.size()); - List types = new ArrayList<>(joinColumns.size()); - for (int i = 0; i < joinColumns.size(); i++) { - names.add("c" + i); - types.add(isLeft - ? BeamSqlRecordHelper.getSqlRecordType(input).getFieldTypeByIndex( - joinColumns.get(i).getKey()) - : BeamSqlRecordHelper.getSqlRecordType(input).getFieldTypeByIndex( - joinColumns.get(i).getValue())); - } - BeamRecordSqlType type = BeamRecordSqlType.create(names, types); + @Override + public KV apply(Row input) { + RowType rowType = + joinColumns + .stream() + .map(fieldIndex -> toField(input.getRowType(), fieldIndex)) + .collect(toRowType()); + + Row row = + joinColumns + .stream() + .map(input::getValue) + .collect(toRow(rowType)); + + return KV.of(row, input); + } - // build the row - List fieldValues = new ArrayList<>(joinColumns.size()); - for (Pair joinColumn : joinColumns) { - fieldValues.add(input - .getFieldValue(isLeft ? joinColumn.getKey() : joinColumn.getValue())); - } - return KV.of(new BeamRecord(type, fieldValues), input); + private RowType.Field toField(RowType rowType, Integer fieldIndex) { + return RowType.newField( + "c" + fieldIndex, + //rowType.getFieldName(fieldIndex), + rowType.getFieldCoder(fieldIndex)); } } - /** * A {@code DoFn} which implement the sideInput-JOIN. */ - public static class SideInputJoinDoFn extends DoFn, BeamRecord> { - private final PCollectionView>> sideInputView; + public static class SideInputJoinDoFn extends DoFn, Row> { + private final PCollectionView>> sideInputView; private final JoinRelType joinType; - private final BeamRecord rightNullRow; + private final Row rightNullRow; private final boolean swap; - public SideInputJoinDoFn(JoinRelType joinType, BeamRecord rightNullRow, - PCollectionView>> sideInputView, + public SideInputJoinDoFn(JoinRelType joinType, Row rightNullRow, + PCollectionView>> sideInputView, boolean swap) { this.joinType = joinType; this.rightNullRow = rightNullRow; @@ -101,13 +106,13 @@ public SideInputJoinDoFn(JoinRelType joinType, BeamRecord rightNullRow, } @ProcessElement public void processElement(ProcessContext context) { - BeamRecord key = context.element().getKey(); - BeamRecord leftRow = context.element().getValue(); - Map> key2Rows = context.sideInput(sideInputView); - Iterable rightRowsIterable = key2Rows.get(key); + Row key = context.element().getKey(); + Row leftRow = context.element().getValue(); + Map> key2Rows = context.sideInput(sideInputView); + Iterable rightRowsIterable = key2Rows.get(key); if (rightRowsIterable != null && rightRowsIterable.iterator().hasNext()) { - for (BeamRecord aRightRowsIterable : rightRowsIterable) { + for (Row aRightRowsIterable : rightRowsIterable) { context.output(combineTwoRowsIntoOne(leftRow, aRightRowsIterable, swap)); } } else { @@ -123,11 +128,11 @@ public SideInputJoinDoFn(JoinRelType joinType, BeamRecord rightNullRow, * A {@code SimpleFunction} to combine two rows into one. */ public static class JoinParts2WholeRow - extends SimpleFunction>, BeamRecord> { - @Override public BeamRecord apply(KV> input) { - KV parts = input.getValue(); - BeamRecord leftRow = parts.getKey(); - BeamRecord rightRow = parts.getValue(); + extends SimpleFunction>, Row> { + @Override public Row apply(KV> input) { + KV parts = input.getValue(); + Row leftRow = parts.getKey(); + Row rightRow = parts.getValue(); return combineTwoRowsIntoOne(leftRow, rightRow, false); } } @@ -135,8 +140,8 @@ public static class JoinParts2WholeRow /** * As the method name suggests: combine two rows into one wide row. */ - private static BeamRecord combineTwoRowsIntoOne(BeamRecord leftRow, - BeamRecord rightRow, boolean swap) { + private static Row combineTwoRowsIntoOne(Row leftRow, + Row rightRow, boolean swap) { if (swap) { return combineTwoRowsIntoOneHelper(rightRow, leftRow); } else { @@ -147,37 +152,37 @@ private static BeamRecord combineTwoRowsIntoOne(BeamRecord leftRow, /** * As the method name suggests: combine two rows into one wide row. */ - private static BeamRecord combineTwoRowsIntoOneHelper(BeamRecord leftRow, - BeamRecord rightRow) { + private static Row combineTwoRowsIntoOneHelper(Row leftRow, Row rightRow) { // build the type List names = new ArrayList<>(leftRow.getFieldCount() + rightRow.getFieldCount()); - names.addAll(leftRow.getDataType().getFieldNames()); - names.addAll(rightRow.getDataType().getFieldNames()); - - List types = new ArrayList<>(leftRow.getFieldCount() + rightRow.getFieldCount()); - types.addAll(BeamSqlRecordHelper.getSqlRecordType(leftRow).getFieldTypes()); - types.addAll(BeamSqlRecordHelper.getSqlRecordType(rightRow).getFieldTypes()); - BeamRecordSqlType type = BeamRecordSqlType.create(names, types); - - List fieldValues = new ArrayList<>(leftRow.getDataValues()); - fieldValues.addAll(rightRow.getDataValues()); - return new BeamRecord(type, fieldValues); + names.addAll(leftRow.getRowType().getFieldNames()); + names.addAll(rightRow.getRowType().getFieldNames()); + + List types = new ArrayList<>(leftRow.getFieldCount() + rightRow.getFieldCount()); + types.addAll(leftRow.getRowType().getRowCoder().getCoders()); + types.addAll(rightRow.getRowType().getRowCoder().getCoders()); + RowType type = RowType.fromNamesAndCoders(names, types); + + return Row + .withRowType(type) + .addValues(leftRow.getValues()) + .addValues(rightRow.getValues()) + .build(); } /** * Transform to execute Join as Lookup. */ public static class JoinAsLookup - extends PTransform, PCollection> { -// private RexNode joinCondition; + extends PTransform, PCollection> { + BeamSqlSeekableTable seekableTable; - BeamRecordSqlType lkpRowType; -// int factTableColSize = 0; // TODO - BeamRecordSqlType joinSubsetType; + RowType lkpRowType; + RowType joinSubsetType; List factJoinIdx; public JoinAsLookup(RexNode joinCondition, BeamSqlSeekableTable seekableTable, - BeamRecordSqlType lkpRowType, int factTableColSize) { + RowType lkpRowType, int factTableColSize) { this.seekableTable = seekableTable; this.lkpRowType = lkpRowType; joinFieldsMapping(joinCondition, factTableColSize); @@ -186,7 +191,7 @@ public JoinAsLookup(RexNode joinCondition, BeamSqlSeekableTable seekableTable, private void joinFieldsMapping(RexNode joinCondition, int factTableColSize) { factJoinIdx = new ArrayList<>(); List lkpJoinFieldsName = new ArrayList<>(); - List lkpJoinFieldsType = new ArrayList<>(); + List lkpJoinFieldsType = new ArrayList<>(); RexCall call = (RexCall) joinCondition; if ("AND".equals(call.getOperator().getName())) { @@ -195,42 +200,48 @@ private void joinFieldsMapping(RexNode joinCondition, int factTableColSize) { factJoinIdx.add(((RexInputRef) ((RexCall) rexNode).getOperands().get(0)).getIndex()); int lkpJoinIdx = ((RexInputRef) ((RexCall) rexNode).getOperands().get(1)).getIndex() - factTableColSize; - lkpJoinFieldsName.add(lkpRowType.getFieldNameByIndex(lkpJoinIdx)); - lkpJoinFieldsType.add(lkpRowType.getFieldTypeByIndex(lkpJoinIdx)); + lkpJoinFieldsName.add(lkpRowType.getFieldName(lkpJoinIdx)); + lkpJoinFieldsType.add(lkpRowType.getFieldCoder(lkpJoinIdx)); } } else if ("=".equals(call.getOperator().getName())) { factJoinIdx.add(((RexInputRef) call.getOperands().get(0)).getIndex()); int lkpJoinIdx = ((RexInputRef) call.getOperands().get(1)).getIndex() - factTableColSize; - lkpJoinFieldsName.add(lkpRowType.getFieldNameByIndex(lkpJoinIdx)); - lkpJoinFieldsType.add(lkpRowType.getFieldTypeByIndex(lkpJoinIdx)); + lkpJoinFieldsName.add(lkpRowType.getFieldName(lkpJoinIdx)); + lkpJoinFieldsType.add(lkpRowType.getFieldCoder(lkpJoinIdx)); } else { throw new UnsupportedOperationException( "Operator " + call.getOperator().getName() + " is not supported in join condition"); } - joinSubsetType = BeamRecordSqlType.create(lkpJoinFieldsName, lkpJoinFieldsType); + joinSubsetType = RowType.fromNamesAndCoders(lkpJoinFieldsName, lkpJoinFieldsType); } @Override - public PCollection expand(PCollection input) { - return input.apply("join_as_lookup", ParDo.of(new DoFn(){ + public PCollection expand(PCollection input) { + return input.apply("join_as_lookup", ParDo.of(new DoFn(){ @ProcessElement public void processElement(ProcessContext context) { - BeamRecord factRow = context.element(); - BeamRecord joinSubRow = extractJoinSubRow(factRow); - List lookupRows = seekableTable.seekRecord(joinSubRow); - for (BeamRecord lr : lookupRows) { + Row factRow = context.element(); + Row joinSubRow = extractJoinSubRow(factRow); + List lookupRows = seekableTable.seekRow(joinSubRow); + for (Row lr : lookupRows) { context.output(combineTwoRowsIntoOneHelper(factRow, lr)); } } - private BeamRecord extractJoinSubRow(BeamRecord factRow) { - List joinSubsetValues = new ArrayList<>(); - for (int i : factJoinIdx) { - joinSubsetValues.add(factRow.getFieldValue(i)); - } - return new BeamRecord(joinSubsetType, joinSubsetValues); + private Row extractJoinSubRow(Row factRow) { + List joinSubsetValues = + factJoinIdx + .stream() + .map(factRow::getValue) + .collect(toList()); + + return + Row + .withRowType(joinSubsetType) + .addValues(joinSubsetValues) + .build(); } })); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java index 33ac807279ab..dfe0ae4a5f76 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSetOperatorsTransforms.java @@ -23,8 +23,8 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.join.CoGbkResult; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; /** @@ -35,8 +35,8 @@ public abstract class BeamSetOperatorsTransforms { * Transform a {@code BeamSqlRow} to a {@code KV}. */ public static class BeamSqlRow2KvFn extends - SimpleFunction> { - @Override public KV apply(BeamRecord input) { + SimpleFunction> { + @Override public KV apply(Row input) { return KV.of(input, input); } } @@ -45,15 +45,15 @@ public static class BeamSqlRow2KvFn extends * Filter function used for Set operators. */ public static class SetOperatorFilteringDoFn extends - DoFn, BeamRecord> { - private TupleTag leftTag; - private TupleTag rightTag; + DoFn, Row> { + private TupleTag leftTag; + private TupleTag rightTag; private BeamSetOperatorRelBase.OpType opType; // ALL? private boolean all; - public SetOperatorFilteringDoFn(TupleTag leftTag, TupleTag rightTag, - BeamSetOperatorRelBase.OpType opType, boolean all) { + public SetOperatorFilteringDoFn(TupleTag leftTag, TupleTag rightTag, + BeamSetOperatorRelBase.OpType opType, boolean all) { this.leftTag = leftTag; this.rightTag = rightTag; this.opType = opType; @@ -62,13 +62,13 @@ public SetOperatorFilteringDoFn(TupleTag leftTag, TupleTag leftRows = coGbkResult.getAll(leftTag); - Iterable rightRows = coGbkResult.getAll(rightTag); + Iterable leftRows = coGbkResult.getAll(leftTag); + Iterable rightRows = coGbkResult.getAll(rightTag); switch (opType) { case UNION: if (all) { // output both left & right - Iterator iter = leftRows.iterator(); + Iterator iter = leftRows.iterator(); while (iter.hasNext()) { ctx.output(iter.next()); } @@ -84,7 +84,7 @@ public SetOperatorFilteringDoFn(TupleTag leftTag, TupleTag leftTag, TupleTag iter = leftRows.iterator(); + Iterator iter = leftRows.iterator(); if (all) { // output all while (iter.hasNext()) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlFilterFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlFilterFn.java index d3a3f7b9dff1..b92947b0b399 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlFilterFn.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlFilterFn.java @@ -22,13 +22,13 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.BeamFilterRel; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** * {@code BeamSqlFilterFn} is the executor for a {@link BeamFilterRel} step. * */ -public class BeamSqlFilterFn extends DoFn { +public class BeamSqlFilterFn extends DoFn { private String stepName; private BeamSqlExpressionExecutor executor; @@ -46,7 +46,7 @@ public void setup() { @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) { - BeamRecord in = c.element(); + Row in = c.element(); List result = executor.execute(in, window); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlOutputToConsoleFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlOutputToConsoleFn.java index f97a90a168d2..98d1d30ed094 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlOutputToConsoleFn.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlOutputToConsoleFn.java @@ -18,13 +18,13 @@ package org.apache.beam.sdk.extensions.sql.impl.transform; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** * A test PTransform to display output in console. * */ -public class BeamSqlOutputToConsoleFn extends DoFn { +public class BeamSqlOutputToConsoleFn extends DoFn { private String stepName; @@ -35,7 +35,7 @@ public BeamSqlOutputToConsoleFn(String stepName) { @ProcessElement public void processElement(ProcessContext c) { - System.out.println("Output: " + c.element().getDataValues()); + System.out.println("Output: " + c.element().getValues()); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlProjectFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlProjectFn.java index 719fbf3d6ce6..3ae91b633826 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlProjectFn.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamSqlProjectFn.java @@ -17,28 +17,29 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform; -import java.util.ArrayList; +import static java.util.stream.Collectors.toList; + import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import java.util.stream.IntStream; import org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlExpressionExecutor; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamProjectRel; import org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** - * * {@code BeamSqlProjectFn} is the executor for a {@link BeamProjectRel} step. - * */ -public class BeamSqlProjectFn extends DoFn { +public class BeamSqlProjectFn extends DoFn { private String stepName; private BeamSqlExpressionExecutor executor; - private BeamRecordSqlType outputRowType; + private RowType outputRowType; - public BeamSqlProjectFn(String stepName, BeamSqlExpressionExecutor executor, - BeamRecordSqlType outputRowType) { + public BeamSqlProjectFn( + String stepName, BeamSqlExpressionExecutor executor, + RowType outputRowType) { super(); this.stepName = stepName; this.executor = executor; @@ -52,16 +53,25 @@ public void setup() { @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) { - BeamRecord inputRow = c.element(); - List results = executor.execute(inputRow, window); - List fieldsValue = new ArrayList<>(results.size()); - for (int idx = 0; idx < results.size(); ++idx) { - fieldsValue.add( - BeamTableUtils.autoCastField(outputRowType.getFieldTypeByIndex(idx), results.get(idx))); - } - BeamRecord outRow = new BeamRecord(outputRowType, fieldsValue); + Row inputRow = c.element(); + List rawResultValues = executor.execute(inputRow, window); + + List castResultValues = + IntStream + .range(0, outputRowType.getFieldCount()) + .mapToObj(i -> castField(rawResultValues, i)) + .collect(toList()); + + c.output( + Row + .withRowType(outputRowType) + .addValues(castResultValues) + .build()); + } - c.output(outRow); + private Object castField(List resultValues, int i) { + return BeamTableUtils + .autoCastField(outputRowType.getFieldCoder(i), resultValues.get(i)); } @Teardown diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceAccumulator.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceAccumulator.java new file mode 100644 index 000000000000..6e567eb7ea83 --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceAccumulator.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql.impl.transform.agg; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import java.math.BigDecimal; + +/** + * Accumulates current covariance of a sample, means of two elements, and number of elements. + */ +@AutoValue +abstract class CovarianceAccumulator implements Serializable { + static final CovarianceAccumulator EMPTY = + newCovarianceAccumulator(BigDecimal.ZERO, BigDecimal.ZERO, + BigDecimal.ZERO, BigDecimal.ZERO); + + abstract BigDecimal covariance(); + abstract BigDecimal count(); + abstract BigDecimal xavg(); + abstract BigDecimal yavg(); + + static CovarianceAccumulator newCovarianceAccumulator( + BigDecimal covariance, + BigDecimal count, + BigDecimal xavg, + BigDecimal yavg) { + + return new AutoValue_CovarianceAccumulator(covariance, count, xavg, yavg); + } + + static CovarianceAccumulator ofZeroElements() { + return EMPTY; + } + + static CovarianceAccumulator ofSingleElement( + BigDecimal inputElementX, BigDecimal inputElementY) { + return newCovarianceAccumulator(BigDecimal.ZERO, + BigDecimal.ONE, + inputElementX, + inputElementY); + } + + /** + * See {@link CovarianceFn} doc above for explanation. + */ + CovarianceAccumulator combineWith(CovarianceAccumulator otherCovariance) { + if (EMPTY.equals(this)) { + return otherCovariance; + } + + if (EMPTY.equals(otherCovariance)) { + return this; + } + + BigDecimal increment = calculateIncrement(this, otherCovariance); + BigDecimal combinedCovariance = + this.covariance() + .add(otherCovariance.covariance()) + .add(increment); + + return newCovarianceAccumulator( + combinedCovariance, + this.count().add(otherCovariance.count()), + calculateXavg(this, otherCovariance), + calculateYavg(this, otherCovariance) + ); + } + + /** + * Implements this part: {@code increment = (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X }. + */ + private BigDecimal calculateIncrement( + CovarianceAccumulator covarA, CovarianceAccumulator covarB) { + + BigDecimal countA = covarA.count(); + BigDecimal countB = covarB.count(); + + BigDecimal totalCount = countA.add(countB); + + BigDecimal avgXA = covarA.xavg(); + BigDecimal avgYA = covarA.yavg(); + + BigDecimal avgXB = covarB.xavg(); + BigDecimal avgYB = covarB.yavg(); + + BigDecimal inc = + avgXA.subtract(avgXB) + .multiply(avgYA.subtract(avgYB)) + .multiply(countA).multiply(countB) + .divide(totalCount, CovarianceFn.MATH_CTX); + + return inc; + } + + /** + * Implements this part: {@code avg_x = (avgx_A * n_A) + (avgx_B * n_B) / n_X }. + */ + private BigDecimal calculateXavg( + CovarianceAccumulator covarA, CovarianceAccumulator covarB) { + + BigDecimal countA = covarA.count(); + BigDecimal countB = covarB.count(); + + BigDecimal totalCount = countA.add(countB); + + BigDecimal avgXA = covarA.xavg(); + BigDecimal avgXB = covarB.xavg(); + + BigDecimal newXavg = avgXA.multiply(countA).add(avgXB.multiply(countB)) + .divide(totalCount, CovarianceFn.MATH_CTX); + + return newXavg; + } + + /** + * Implements this part: {@code avg_y = (avgy_A * n_A) + (avgy_B * n_B) / n_Y }. + */ + private BigDecimal calculateYavg( + CovarianceAccumulator covarA, CovarianceAccumulator covarB) { + + BigDecimal countA = covarA.count(); + BigDecimal countB = covarB.count(); + + BigDecimal totalCount = countA.add(countB); + + BigDecimal avgYA = covarA.yavg(); + BigDecimal avgYB = covarB.yavg(); + + BigDecimal newYavg = avgYA.multiply(countA).add(avgYB.multiply(countB)) + .divide(totalCount, CovarianceFn.MATH_CTX); + + return newYavg; + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java new file mode 100644 index 000000000000..e5f6463259dc --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql.impl.transform.agg; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.math.RoundingMode; +import java.util.stream.StreamSupport; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.KV; + +/** + * {@link Combine.CombineFn} for Covariance on {@link Number} types. + * + *

    Calculates Population Covariance and Sample Covariance using incremental + * formulas described in http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance, + * presumably by Pébay, Philippe (2008), in "Formulas for Robust, + * One-Pass Parallel Computation of Covariances and Arbitrary-Order + * Statistical Moments". + *

    + * + */ +@Internal +public class CovarianceFn + extends Combine.CombineFn, CovarianceAccumulator, T> { + + static final MathContext MATH_CTX = new MathContext(10, RoundingMode.HALF_UP); + + private static final boolean SAMPLE = true; + private static final boolean POP = false; + + private boolean isSample; // flag to determine return value should be Covariance Pop or Sample + private SerializableFunction decimalConverter; + + public static CovarianceFn newPopulation( + SerializableFunction decimalConverter) { + + return new CovarianceFn<>(POP, decimalConverter); + } + + public static CovarianceFn newSample( + SerializableFunction decimalConverter) { + + return new CovarianceFn<>(SAMPLE, decimalConverter); + } + + private CovarianceFn(boolean isSample, SerializableFunction decimalConverter){ + this.isSample = isSample; + this.decimalConverter = decimalConverter; + } + + @Override + public CovarianceAccumulator createAccumulator() { + return CovarianceAccumulator.ofZeroElements(); + } + + @Override + public CovarianceAccumulator addInput( + CovarianceAccumulator currentVariance, KV rawInput) { + if (rawInput == null) { + return currentVariance; + } + + return currentVariance.combineWith(CovarianceAccumulator.ofSingleElement( + toBigDecimal(rawInput.getKey()), toBigDecimal(rawInput.getValue()))); + } + + @Override + public CovarianceAccumulator mergeAccumulators(Iterable covariances) { + return StreamSupport + .stream(covariances.spliterator(), false) + .reduce(CovarianceAccumulator.ofZeroElements(), + CovarianceAccumulator::combineWith); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, + Coder> inputCoder) { + return SerializableCoder.of(CovarianceAccumulator.class); + } + + @Override + public T extractOutput(CovarianceAccumulator accumulator) { + return decimalConverter.apply(getCovariance(accumulator)); + } + + private BigDecimal getCovariance(CovarianceAccumulator covariance) { + + BigDecimal adjustedCount = this.isSample + ? covariance.count().subtract(BigDecimal.ONE) + : covariance.count(); + + return covariance.covariance().divide(adjustedCount, MATH_CTX); + } + + private BigDecimal toBigDecimal(T rawInput) { + return new BigDecimal(rawInput.toString()); + } +} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/BigDecimalConverter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/BigDecimalConverter.java similarity index 56% rename from sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/BigDecimalConverter.java rename to sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/BigDecimalConverter.java index 4f3e338f92b0..24a3f7a3b743 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/BigDecimalConverter.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/BigDecimalConverter.java @@ -16,40 +16,41 @@ * limitations under the License. */ -package org.apache.beam.sdk.extensions.sql.impl.transform.agg; +package org.apache.beam.sdk.extensions.sql.impl.utils; import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; import java.util.Map; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.calcite.sql.type.SqlTypeName; /** * Provides converters from {@link BigDecimal} to other numeric types based on - * the input {@link SqlTypeName}. + * the input {@link SqlTypeCoder}. */ public class BigDecimalConverter { - private static final Map> + private static final Map> CONVERTER_MAP = ImmutableMap - .>builder() - .put(SqlTypeName.INTEGER, BigDecimal::intValue) - .put(SqlTypeName.SMALLINT, BigDecimal::shortValue) - .put(SqlTypeName.TINYINT, BigDecimal::byteValue) - .put(SqlTypeName.BIGINT, BigDecimal::longValue) - .put(SqlTypeName.FLOAT, BigDecimal::floatValue) - .put(SqlTypeName.DOUBLE, BigDecimal::doubleValue) - .put(SqlTypeName.DECIMAL, v -> v) + .>builder() + .put(SqlTypeCoders.INTEGER, BigDecimal::intValue) + .put(SqlTypeCoders.SMALLINT, BigDecimal::shortValue) + .put(SqlTypeCoders.TINYINT, BigDecimal::byteValue) + .put(SqlTypeCoders.BIGINT, BigDecimal::longValue) + .put(SqlTypeCoders.FLOAT, BigDecimal::floatValue) + .put(SqlTypeCoders.DOUBLE, BigDecimal::doubleValue) + .put(SqlTypeCoders.DECIMAL, v -> v) .build(); public static SerializableFunction forSqlType( - SqlTypeName sqlTypeName) { + SqlTypeCoder sqlTypeCoder) { - if (!CONVERTER_MAP.containsKey(sqlTypeName)) { + if (!CONVERTER_MAP.containsKey(sqlTypeCoder)) { throw new UnsupportedOperationException( - "Conversion from " + sqlTypeName + " to BigDecimal is not supported"); + "Conversion from " + sqlTypeCoder + " to BigDecimal is not supported"); } - return CONVERTER_MAP.get(sqlTypeName); + return CONVERTER_MAP.get(sqlTypeCoder); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java index 9d992309a4aa..ad21c2227e7e 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java @@ -18,12 +18,13 @@ package org.apache.beam.sdk.extensions.sql.impl.utils; -import java.sql.Types; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import static org.apache.beam.sdk.values.RowType.toRowType; + +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -34,75 +35,82 @@ * Utility methods for Calcite related operations. */ public class CalciteUtils { - private static final Map JAVA_TO_CALCITE_MAPPING = new HashMap<>(); - private static final Map CALCITE_TO_JAVA_MAPPING = new HashMap<>(); - static { - JAVA_TO_CALCITE_MAPPING.put(Types.TINYINT, SqlTypeName.TINYINT); - JAVA_TO_CALCITE_MAPPING.put(Types.SMALLINT, SqlTypeName.SMALLINT); - JAVA_TO_CALCITE_MAPPING.put(Types.INTEGER, SqlTypeName.INTEGER); - JAVA_TO_CALCITE_MAPPING.put(Types.BIGINT, SqlTypeName.BIGINT); + private static final BiMap BEAM_TO_CALCITE_TYPE_MAPPING = + ImmutableBiMap.builder() + .put(SqlTypeCoders.TINYINT, SqlTypeName.TINYINT) + .put(SqlTypeCoders.SMALLINT, SqlTypeName.SMALLINT) + .put(SqlTypeCoders.INTEGER, SqlTypeName.INTEGER) + .put(SqlTypeCoders.BIGINT, SqlTypeName.BIGINT) - JAVA_TO_CALCITE_MAPPING.put(Types.FLOAT, SqlTypeName.FLOAT); - JAVA_TO_CALCITE_MAPPING.put(Types.DOUBLE, SqlTypeName.DOUBLE); + .put(SqlTypeCoders.FLOAT, SqlTypeName.FLOAT) + .put(SqlTypeCoders.DOUBLE, SqlTypeName.DOUBLE) - JAVA_TO_CALCITE_MAPPING.put(Types.DECIMAL, SqlTypeName.DECIMAL); + .put(SqlTypeCoders.DECIMAL, SqlTypeName.DECIMAL) - JAVA_TO_CALCITE_MAPPING.put(Types.CHAR, SqlTypeName.CHAR); - JAVA_TO_CALCITE_MAPPING.put(Types.VARCHAR, SqlTypeName.VARCHAR); + .put(SqlTypeCoders.CHAR, SqlTypeName.CHAR) + .put(SqlTypeCoders.VARCHAR, SqlTypeName.VARCHAR) - JAVA_TO_CALCITE_MAPPING.put(Types.DATE, SqlTypeName.DATE); - JAVA_TO_CALCITE_MAPPING.put(Types.TIME, SqlTypeName.TIME); - JAVA_TO_CALCITE_MAPPING.put(Types.TIMESTAMP, SqlTypeName.TIMESTAMP); + .put(SqlTypeCoders.DATE, SqlTypeName.DATE) + .put(SqlTypeCoders.TIME, SqlTypeName.TIME) + .put(SqlTypeCoders.TIMESTAMP, SqlTypeName.TIMESTAMP) - JAVA_TO_CALCITE_MAPPING.put(Types.BOOLEAN, SqlTypeName.BOOLEAN); + .put(SqlTypeCoders.BOOLEAN, SqlTypeName.BOOLEAN) + .build(); - for (Map.Entry pair : JAVA_TO_CALCITE_MAPPING.entrySet()) { - CALCITE_TO_JAVA_MAPPING.put(pair.getValue(), pair.getKey()); - } - } + private static final BiMap CALCITE_TO_BEAM_TYPE_MAPPING = + BEAM_TO_CALCITE_TYPE_MAPPING.inverse(); /** - * Get the corresponding {@code SqlTypeName} for an integer sql type. + * Get the corresponding Calcite's {@link SqlTypeName} + * for supported Beam SQL type coder, see {@link SqlTypeCoder}. */ - public static SqlTypeName toCalciteType(int type) { - return JAVA_TO_CALCITE_MAPPING.get(type); + public static SqlTypeName toCalciteType(SqlTypeCoder coder) { + return BEAM_TO_CALCITE_TYPE_MAPPING.get(coder); } /** - * Get the integer sql type from Calcite {@code SqlTypeName}. + * Get the Beam SQL type coder ({@link SqlTypeCoder}) from Calcite's {@link SqlTypeName}. */ - public static Integer toJavaType(SqlTypeName typeName) { - return CALCITE_TO_JAVA_MAPPING.get(typeName); + public static SqlTypeCoder toCoder(SqlTypeName typeName) { + return CALCITE_TO_BEAM_TYPE_MAPPING.get(typeName); } /** * Get the {@code SqlTypeName} for the specified column of a table. */ - public static SqlTypeName getFieldType(BeamRecordSqlType schema, int index) { - return toCalciteType(schema.getFieldTypeByIndex(index)); + public static SqlTypeName getFieldCalciteType(RowType schema, int index) { + return toCalciteType((SqlTypeCoder) schema.getFieldCoder(index)); } /** * Generate {@code BeamSqlRowType} from {@code RelDataType} which is used to create table. */ - public static BeamRecordSqlType toBeamRowType(RelDataType tableInfo) { - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (RelDataTypeField f : tableInfo.getFieldList()) { - fieldNames.add(f.getName()); - fieldTypes.add(toJavaType(f.getType().getSqlTypeName())); - } - return BeamRecordSqlType.create(fieldNames, fieldTypes); + public static RowType toBeamRowType(RelDataType tableInfo) { + return + tableInfo + .getFieldList() + .stream() + .map(CalciteUtils::toBeamRowField) + .collect(toRowType()); + } + + private static RowType.Field toBeamRowField(RelDataTypeField calciteField) { + return + RowType.newField( + calciteField.getName(), + toCoder(calciteField.getType().getSqlTypeName())); } /** * Create an instance of {@code RelDataType} so it can be used to create a table. */ - public static RelProtoDataType toCalciteRowType(final BeamRecordSqlType that) { - return a -> { - RelDataTypeFactory.FieldInfoBuilder builder = a.builder(); - for (int idx = 0; idx < that.getFieldNames().size(); ++idx) { - builder.add(that.getFieldNameByIndex(idx), toCalciteType(that.getFieldTypeByIndex(idx))); + public static RelProtoDataType toCalciteRowType(final RowType rowType) { + return fieldInfo -> { + RelDataTypeFactory.FieldInfoBuilder builder = fieldInfo.builder(); + for (int idx = 0; idx < rowType.getFieldNames().size(); ++idx) { + builder.add( + rowType.getFieldName(idx), + toCalciteType((SqlTypeCoder) rowType.getFieldCoder(idx))); } return builder.build(); }; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/SqlTypeUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/SqlTypeUtils.java index 9658bab81421..e1be591ba852 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/SqlTypeUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/SqlTypeUtils.java @@ -18,10 +18,8 @@ package org.apache.beam.sdk.extensions.sql.impl.utils; import com.google.common.base.Optional; - import java.util.Collection; import java.util.List; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.calcite.sql.type.SqlTypeName; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/Column.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/Column.java index 9bcc16af7f56..f5d7d5f4a6cc 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/Column.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/Column.java @@ -21,6 +21,7 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; /** * Metadata class for a {@code BeamSqlTable} column. @@ -28,7 +29,7 @@ @AutoValue public abstract class Column implements Serializable { public abstract String getName(); - public abstract Integer getType(); + public abstract Coder getCoder(); @Nullable public abstract String getComment(); public abstract boolean isPrimaryKey(); @@ -43,7 +44,7 @@ public static Builder builder() { @AutoValue.Builder public abstract static class Builder { public abstract Builder name(String name); - public abstract Builder type(Integer type); + public abstract Builder coder(Coder coder); public abstract Builder comment(String comment); public abstract Builder primaryKey(boolean isPrimaryKey); public abstract Column build(); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/MetaUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/MetaUtils.java index 35ecdce76ab0..edbc64f401f7 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/MetaUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/MetaUtils.java @@ -18,23 +18,26 @@ package org.apache.beam.sdk.extensions.sql.meta.provider; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import static org.apache.beam.sdk.values.RowType.toRowType; + import org.apache.beam.sdk.extensions.sql.meta.Column; import org.apache.beam.sdk.extensions.sql.meta.Table; +import org.apache.beam.sdk.values.RowType; /** * Utility methods for metadata. */ public class MetaUtils { - public static BeamRecordSqlType getBeamSqlRecordTypeFromTable(Table table) { - List columnNames = new ArrayList<>(table.getColumns().size()); - List columnTypes = new ArrayList<>(table.getColumns().size()); - for (Column column : table.getColumns()) { - columnNames.add(column.getName()); - columnTypes.add(column.getType()); - } - return BeamRecordSqlType.create(columnNames, columnTypes); + public static RowType getRowTypeFromTable(Table table) { + return + table + .getColumns() + .stream() + .map(MetaUtils::toRecordField) + .collect(toRowType()); + } + + private static RowType.Field toRecordField(Column column) { + return RowType.newField(column.getName(), column.getCoder()); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/TableProvider.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/TableProvider.java index d57f7034cf00..5bbadb13eeb0 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/TableProvider.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/TableProvider.java @@ -44,6 +44,13 @@ public interface TableProvider { */ void createTable(Table table); + /** + * Drops a table. + * + * @param tableName + */ + void dropTable(String tableName); + /** * List all tables from this provider. */ diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTable.java index a8c8a3099f29..c3912a4c0ce0 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTable.java @@ -17,17 +17,17 @@ */ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; -import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.beamRecord2CsvLine; -import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.csvLine2BeamRecord; +import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.beamRow2CsvLine; +import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.csvLine2BeamRow; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.commons.csv.CSVFormat; /** @@ -36,74 +36,74 @@ */ public class BeamKafkaCSVTable extends BeamKafkaTable { private CSVFormat csvFormat; - public BeamKafkaCSVTable(BeamRecordSqlType beamSqlRowType, String bootstrapServers, - List topics) { - this(beamSqlRowType, bootstrapServers, topics, CSVFormat.DEFAULT); + public BeamKafkaCSVTable(RowType beamRowType, String bootstrapServers, + List topics) { + this(beamRowType, bootstrapServers, topics, CSVFormat.DEFAULT); } - public BeamKafkaCSVTable(BeamRecordSqlType beamSqlRowType, String bootstrapServers, - List topics, CSVFormat format) { - super(beamSqlRowType, bootstrapServers, topics); + public BeamKafkaCSVTable(RowType beamRowType, String bootstrapServers, + List topics, CSVFormat format) { + super(beamRowType, bootstrapServers, topics); this.csvFormat = format; } @Override - public PTransform>, PCollection> + public PTransform>, PCollection> getPTransformForInput() { - return new CsvRecorderDecoder(beamRecordSqlType, csvFormat); + return new CsvRecorderDecoder(rowType, csvFormat); } @Override - public PTransform, PCollection>> + public PTransform, PCollection>> getPTransformForOutput() { - return new CsvRecorderEncoder(beamRecordSqlType, csvFormat); + return new CsvRecorderEncoder(rowType, csvFormat); } /** - * A PTransform to convert {@code KV} to {@link BeamRecord}. + * A PTransform to convert {@code KV} to {@link Row}. * */ public static class CsvRecorderDecoder - extends PTransform>, PCollection> { - private BeamRecordSqlType rowType; + extends PTransform>, PCollection> { + private RowType rowType; private CSVFormat format; - public CsvRecorderDecoder(BeamRecordSqlType rowType, CSVFormat format) { + public CsvRecorderDecoder(RowType rowType, CSVFormat format) { this.rowType = rowType; this.format = format; } @Override - public PCollection expand(PCollection> input) { - return input.apply("decodeRecord", ParDo.of(new DoFn, BeamRecord>() { + public PCollection expand(PCollection> input) { + return input.apply("decodeRecord", ParDo.of(new DoFn, Row>() { @ProcessElement public void processElement(ProcessContext c) { String rowInString = new String(c.element().getValue()); - c.output(csvLine2BeamRecord(format, rowInString, rowType)); + c.output(csvLine2BeamRow(format, rowInString, rowType)); } })); } } /** - * A PTransform to convert {@link BeamRecord} to {@code KV}. + * A PTransform to convert {@link Row} to {@code KV}. * */ public static class CsvRecorderEncoder - extends PTransform, PCollection>> { - private BeamRecordSqlType rowType; + extends PTransform, PCollection>> { + private RowType rowType; private CSVFormat format; - public CsvRecorderEncoder(BeamRecordSqlType rowType, CSVFormat format) { + public CsvRecorderEncoder(RowType rowType, CSVFormat format) { this.rowType = rowType; this.format = format; } @Override - public PCollection> expand(PCollection input) { - return input.apply("encodeRecord", ParDo.of(new DoFn>() { + public PCollection> expand(PCollection input) { + return input.apply("encodeRecord", ParDo.of(new DoFn>() { @ProcessElement public void processElement(ProcessContext c) { - BeamRecord in = c.element(); - c.output(KV.of(new byte[] {}, beamRecord2CsvLine(in, format).getBytes())); + Row in = c.element(); + c.output(KV.of(new byte[] {}, beamRow2CsvLine(in, format).getBytes())); } })); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java index 8f663a30d41f..aab597183a7d 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java @@ -24,16 +24,16 @@ import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; import org.apache.beam.sdk.extensions.sql.impl.schema.BeamIOType; import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.ByteArraySerializer; @@ -49,20 +49,20 @@ public abstract class BeamKafkaTable extends BaseBeamTable implements Serializab private List topicPartitions; private Map configUpdates; - protected BeamKafkaTable(BeamRecordSqlType beamSqlRowType) { - super(beamSqlRowType); + protected BeamKafkaTable(RowType beamRowType) { + super(beamRowType); } - public BeamKafkaTable(BeamRecordSqlType beamSqlRowType, String bootstrapServers, - List topics) { - super(beamSqlRowType); + public BeamKafkaTable(RowType beamRowType, String bootstrapServers, + List topics) { + super(beamRowType); this.bootstrapServers = bootstrapServers; this.topics = topics; } - public BeamKafkaTable(BeamRecordSqlType beamSqlRowType, + public BeamKafkaTable(RowType beamRowType, List topicPartitions, String bootstrapServers) { - super(beamSqlRowType); + super(beamRowType); this.bootstrapServers = bootstrapServers; this.topicPartitions = topicPartitions; } @@ -77,14 +77,14 @@ public BeamIOType getSourceType() { return BeamIOType.UNBOUNDED; } - public abstract PTransform>, PCollection> + public abstract PTransform>, PCollection> getPTransformForInput(); - public abstract PTransform, PCollection>> + public abstract PTransform, PCollection>> getPTransformForOutput(); @Override - public PCollection buildIOReader(Pipeline pipeline) { + public PCollection buildIOReader(Pipeline pipeline) { KafkaIO.Read kafkaRead = null; if (topics != null) { kafkaRead = KafkaIO.read() @@ -109,13 +109,13 @@ public PCollection buildIOReader(Pipeline pipeline) { } @Override - public PTransform, PDone> buildIOWriter() { + public PTransform, PDone> buildIOWriter() { checkArgument(topics != null && topics.size() == 1, "Only one topic can be acceptable as output."); - return new PTransform, PDone>() { + return new PTransform, PDone>() { @Override - public PDone expand(PCollection input) { + public PDone expand(PCollection input) { return input.apply("out_reformat", getPTransformForOutput()).apply("persistent", KafkaIO.write() .withBootstrapServers(bootstrapServers) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProvider.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProvider.java index 8c37d46d3f60..c143828b4f3a 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProvider.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProvider.java @@ -18,17 +18,17 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; -import static org.apache.beam.sdk.extensions.sql.meta.provider.MetaUtils.getBeamSqlRecordTypeFromTable; +import static org.apache.beam.sdk.extensions.sql.meta.provider.MetaUtils.getRowTypeFromTable; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.meta.Table; import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider; +import org.apache.beam.sdk.values.RowType; /** * Kafka table provider. @@ -47,7 +47,7 @@ */ public class KafkaTableProvider implements TableProvider { @Override public BeamSqlTable buildBeamSqlTable(Table table) { - BeamRecordSqlType recordType = getBeamSqlRecordTypeFromTable(table); + RowType rowType = getRowTypeFromTable(table); JSONObject properties = table.getProperties(); String bootstrapServers = properties.getString("bootstrap.servers"); @@ -56,7 +56,7 @@ public class KafkaTableProvider implements TableProvider { for (Object topic : topicsArr) { topics.add(topic.toString()); } - BeamKafkaCSVTable txtTable = new BeamKafkaCSVTable(recordType, bootstrapServers, topics); + BeamKafkaCSVTable txtTable = new BeamKafkaCSVTable(rowType, bootstrapServers, topics); return txtTable; } @@ -68,6 +68,10 @@ public class KafkaTableProvider implements TableProvider { // empty } + @Override public void dropTable(String tableName) { + // empty + } + @Override public List

    listTables() { return Collections.emptyList(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTable.java index 78cec745aaad..24b1406e26fc 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTable.java @@ -19,16 +19,14 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.text; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.commons.csv.CSVFormat; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * {@code BeamTextCSVTable} is a {@code BeamTextTable} which formatted in CSV. @@ -38,8 +36,6 @@ *

    */ public class BeamTextCSVTable extends BeamTextTable { - private static final Logger LOG = LoggerFactory - .getLogger(BeamTextCSVTable.class); private String filePattern; private CSVFormat csvFormat; @@ -47,27 +43,27 @@ public class BeamTextCSVTable extends BeamTextTable { /** * CSV table with {@link CSVFormat#DEFAULT DEFAULT} format. */ - public BeamTextCSVTable(BeamRecordSqlType beamSqlRowType, String filePattern) { - this(beamSqlRowType, filePattern, CSVFormat.DEFAULT); + public BeamTextCSVTable(RowType beamRowType, String filePattern) { + this(beamRowType, filePattern, CSVFormat.DEFAULT); } - public BeamTextCSVTable(BeamRecordSqlType beamRecordSqlType, String filePattern, - CSVFormat csvFormat) { - super(beamRecordSqlType, filePattern); + public BeamTextCSVTable(RowType rowType, String filePattern, + CSVFormat csvFormat) { + super(rowType, filePattern); this.filePattern = filePattern; this.csvFormat = csvFormat; } @Override - public PCollection buildIOReader(Pipeline pipeline) { + public PCollection buildIOReader(Pipeline pipeline) { return PBegin.in(pipeline).apply("decodeRecord", TextIO.read().from(filePattern)) .apply("parseCSVLine", - new BeamTextCSVTableIOReader(beamRecordSqlType, filePattern, csvFormat)); + new BeamTextCSVTableIOReader(rowType, filePattern, csvFormat)); } @Override - public PTransform, PDone> buildIOWriter() { - return new BeamTextCSVTableIOWriter(beamRecordSqlType, filePattern, csvFormat); + public PTransform, PDone> buildIOWriter() { + return new BeamTextCSVTableIOWriter(rowType, filePattern, csvFormat); } public CSVFormat getCsvFormat() { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOReader.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOReader.java index 953ac0333f17..4342aa0708b8 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOReader.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOReader.java @@ -18,41 +18,41 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.text; -import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.csvLine2BeamRecord; +import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.csvLine2BeamRow; import java.io.Serializable; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.commons.csv.CSVFormat; /** * IOReader for {@code BeamTextCSVTable}. */ public class BeamTextCSVTableIOReader - extends PTransform, PCollection> + extends PTransform, PCollection> implements Serializable { private String filePattern; - protected BeamRecordSqlType beamRecordSqlType; + protected RowType rowType; protected CSVFormat csvFormat; - public BeamTextCSVTableIOReader(BeamRecordSqlType beamRecordSqlType, String filePattern, - CSVFormat csvFormat) { + public BeamTextCSVTableIOReader(RowType rowType, String filePattern, + CSVFormat csvFormat) { this.filePattern = filePattern; - this.beamRecordSqlType = beamRecordSqlType; + this.rowType = rowType; this.csvFormat = csvFormat; } @Override - public PCollection expand(PCollection input) { - return input.apply(ParDo.of(new DoFn() { + public PCollection expand(PCollection input) { + return input.apply(ParDo.of(new DoFn() { @ProcessElement public void processElement(ProcessContext ctx) { String str = ctx.element(); - ctx.output(csvLine2BeamRecord(csvFormat, str, beamRecordSqlType)); + ctx.output(csvLine2BeamRow(csvFormat, str, rowType)); } })); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOWriter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOWriter.java index 80481d21772c..3854370f2730 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOWriter.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableIOWriter.java @@ -18,41 +18,44 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.text; -import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.beamRecord2CsvLine; +import static org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils.beamRow2CsvLine; import java.io.Serializable; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.commons.csv.CSVFormat; /** * IOWriter for {@code BeamTextCSVTable}. */ -public class BeamTextCSVTableIOWriter extends PTransform, PDone> +public class BeamTextCSVTableIOWriter extends PTransform, PDone> implements Serializable { private String filePattern; - protected BeamRecordSqlType beamRecordSqlType; + protected RowType rowType; protected CSVFormat csvFormat; - public BeamTextCSVTableIOWriter(BeamRecordSqlType beamRecordSqlType, String filePattern, - CSVFormat csvFormat) { + public BeamTextCSVTableIOWriter(RowType rowType, + String filePattern, + CSVFormat csvFormat) { this.filePattern = filePattern; - this.beamRecordSqlType = beamRecordSqlType; + this.rowType = rowType; this.csvFormat = csvFormat; } - @Override public PDone expand(PCollection input) { - return input.apply("encodeRecord", ParDo.of(new DoFn() { + @Override + public PDone expand(PCollection input) { + return input.apply("encodeRecord", ParDo.of(new DoFn() { - @ProcessElement public void processElement(ProcessContext ctx) { - BeamRecord row = ctx.element(); - ctx.output(beamRecord2CsvLine(row, csvFormat)); + @ProcessElement + public void processElement(ProcessContext ctx) { + Row row = ctx.element(); + ctx.output(beamRow2CsvLine(row, csvFormat)); } })).apply(TextIO.write().to(filePattern)); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextTable.java index 76616ef995c2..db53073a27d5 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextTable.java @@ -19,9 +19,9 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.text; import java.io.Serializable; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; import org.apache.beam.sdk.extensions.sql.impl.schema.BeamIOType; +import org.apache.beam.sdk.values.RowType; /** * {@code BeamTextTable} represents a text file/directory(backed by {@code TextIO}). @@ -29,8 +29,8 @@ public abstract class BeamTextTable extends BaseBeamTable implements Serializable { protected String filePattern; - protected BeamTextTable(BeamRecordSqlType beamRecordSqlType, String filePattern) { - super(beamRecordSqlType); + protected BeamTextTable(RowType rowType, String filePattern) { + super(rowType); this.filePattern = filePattern; } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProvider.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProvider.java index bc9f03f10322..b87194928fcd 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProvider.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProvider.java @@ -18,15 +18,15 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.text; -import static org.apache.beam.sdk.extensions.sql.meta.provider.MetaUtils.getBeamSqlRecordTypeFromTable; +import static org.apache.beam.sdk.extensions.sql.meta.provider.MetaUtils.getRowTypeFromTable; import com.alibaba.fastjson.JSONObject; import java.util.Collections; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.meta.Table; import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider; +import org.apache.beam.sdk.values.RowType; import org.apache.commons.csv.CSVFormat; /** @@ -51,7 +51,7 @@ public class TextTableProvider implements TableProvider { } @Override public BeamSqlTable buildBeamSqlTable(Table table) { - BeamRecordSqlType recordType = getBeamSqlRecordTypeFromTable(table); + RowType rowType = getRowTypeFromTable(table); String filePattern = table.getLocationAsString(); CSVFormat format = CSVFormat.DEFAULT; @@ -61,7 +61,7 @@ public class TextTableProvider implements TableProvider { format = CSVFormat.valueOf(csvFormatStr); } - BeamTextCSVTable txtTable = new BeamTextCSVTable(recordType, filePattern, format); + BeamTextCSVTable txtTable = new BeamTextCSVTable(rowType, filePattern, format); return txtTable; } @@ -69,6 +69,10 @@ public class TextTableProvider implements TableProvider { // empty } + @Override public void dropTable(String tableName) { + // empty + } + @Override public List
    listTables() { return Collections.emptyList(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStore.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStore.java index bacfbff90cb9..53eeb7e4e59c 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStore.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStore.java @@ -55,6 +55,16 @@ public InMemoryMetaStore() { tables.put(table.getName(), table); } + @Override public void dropTable(String tableName) { + if (!tables.containsKey(tableName)) { + throw new IllegalArgumentException("No such table: " + tableName); + } + + Table table = tables.get(tableName); + providers.get(table.getType()).dropTable(tableName); + tables.remove(tableName); + } + @Override public Table getTable(String tableName) { if (tableName == null) { return null; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/MetaStore.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/MetaStore.java index 2f395f028678..ac5b739aec6d 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/MetaStore.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/store/MetaStore.java @@ -27,12 +27,16 @@ * The interface to handle CRUD of {@code BeamSql} table metadata. */ public interface MetaStore { - /** * create a table. */ void createTable(Table table); + /** + * drop a table. + */ + void dropTable(String tableName); + /** * Get table with the specified name. */ diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamRecordSqlTypeTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamRecordSqlTypeTest.java deleted file mode 100644 index 78ff221e0d0c..000000000000 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamRecordSqlTypeTest.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.sdk.extensions.sql; - -import static org.junit.Assert.assertEquals; - -import com.google.common.collect.ImmutableList; -import java.sql.Types; -import java.util.List; -import org.junit.Test; - -/** - * Unit tests for {@link BeamRecordSqlType}. - */ -public class BeamRecordSqlTypeTest { - - private static final List TYPES = ImmutableList.of( - Types.TINYINT, - Types.SMALLINT, - Types.INTEGER, - Types.BIGINT, - Types.FLOAT, - Types.DOUBLE, - Types.DECIMAL, - Types.BOOLEAN, - Types.CHAR, - Types.VARCHAR, - Types.TIME, - Types.DATE, - Types.TIMESTAMP); - - private static final List NAMES = ImmutableList.of( - "TINYINT_FIELD", - "SMALLINT_FIELD", - "INTEGER_FIELD", - "BIGINT_FIELD", - "FLOAT_FIELD", - "DOUBLE_FIELD", - "DECIMAL_FIELD", - "BOOLEAN_FIELD", - "CHAR_FIELD", - "VARCHAR_FIELD", - "TIME_FIELD", - "DATE_FIELD", - "TIMESTAMP_FIELD"); - - private static final List MORE_NAMES = ImmutableList.of( - "ANOTHER_TINYINT_FIELD", - "ANOTHER_SMALLINT_FIELD", - "ANOTHER_INTEGER_FIELD", - "ANOTHER_BIGINT_FIELD", - "ANOTHER_FLOAT_FIELD", - "ANOTHER_DOUBLE_FIELD", - "ANOTHER_DECIMAL_FIELD", - "ANOTHER_BOOLEAN_FIELD", - "ANOTHER_CHAR_FIELD", - "ANOTHER_VARCHAR_FIELD", - "ANOTHER_TIME_FIELD", - "ANOTHER_DATE_FIELD", - "ANOTHER_TIMESTAMP_FIELD"); - - @Test - public void testBuildsWithCorrectFields() throws Exception { - BeamRecordSqlType.Builder recordTypeBuilder = BeamRecordSqlType.builder(); - - for (int i = 0; i < TYPES.size(); i++) { - recordTypeBuilder.withField(NAMES.get(i), TYPES.get(i)); - } - - recordTypeBuilder.withTinyIntField(MORE_NAMES.get(0)); - recordTypeBuilder.withSmallIntField(MORE_NAMES.get(1)); - recordTypeBuilder.withIntegerField(MORE_NAMES.get(2)); - recordTypeBuilder.withBigIntField(MORE_NAMES.get(3)); - recordTypeBuilder.withFloatField(MORE_NAMES.get(4)); - recordTypeBuilder.withDoubleField(MORE_NAMES.get(5)); - recordTypeBuilder.withDecimalField(MORE_NAMES.get(6)); - recordTypeBuilder.withBooleanField(MORE_NAMES.get(7)); - recordTypeBuilder.withCharField(MORE_NAMES.get(8)); - recordTypeBuilder.withVarcharField(MORE_NAMES.get(9)); - recordTypeBuilder.withTimeField(MORE_NAMES.get(10)); - recordTypeBuilder.withDateField(MORE_NAMES.get(11)); - recordTypeBuilder.withTimestampField(MORE_NAMES.get(12)); - - BeamRecordSqlType recordSqlType = recordTypeBuilder.build(); - - List expectedNames = ImmutableList.builder() - .addAll(NAMES) - .addAll(MORE_NAMES) - .build(); - - List expectedTypes = ImmutableList.builder() - .addAll(TYPES) - .addAll(TYPES) - .build(); - - assertEquals(expectedNames, recordSqlType.getFieldNames()); - assertEquals(expectedTypes, recordSqlType.getFieldTypes()); - } -} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlApiSurfaceTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlApiSurfaceTest.java index 0cd1a2a95dbc..c40ab79250e8 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlApiSurfaceTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlApiSurfaceTest.java @@ -50,8 +50,8 @@ public void testSdkApiSurface() throws Exception { .ofClass(BeamSql.class) .includingClass(BeamSqlCli.class) .includingClass(BeamSqlUdf.class) - .includingClass(BeamRecordSqlType.class) - .includingClass(BeamSqlRecordHelper.class) + .includingClass(RowSqlType.class) + .includingClass(RowHelper.class) .includingClass(BeamSqlSeekableTable.class) .pruningPrefix("java") .pruningPattern("org[.]apache[.]beam[.]sdk[.]extensions[.]sql[.].*Test") diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlCliTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlCliTest.java index 62d693310a8f..9bf724d797f0 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlCliTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlCliTest.java @@ -19,10 +19,12 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import org.apache.beam.sdk.extensions.sql.meta.Table; import org.apache.beam.sdk.extensions.sql.meta.provider.text.TextTableProvider; import org.apache.beam.sdk.extensions.sql.meta.store.InMemoryMetaStore; +import org.apache.calcite.tools.ValidationException; import org.junit.Test; /** @@ -48,6 +50,49 @@ public void testExecute_createTextTable() throws Exception { assertNotNull(table); } + @Test + public void testExecute_dropTable() throws Exception { + InMemoryMetaStore metaStore = new InMemoryMetaStore(); + metaStore.registerProvider(new TextTableProvider()); + + BeamSqlCli cli = new BeamSqlCli() + .metaStore(metaStore); + cli.execute( + "create table person (\n" + + "id int COMMENT 'id', \n" + + "name varchar(31) COMMENT 'name', \n" + + "age int COMMENT 'age') \n" + + "TYPE 'text' \n" + + "COMMENT '' LOCATION 'text://home/admin/orders'" + ); + Table table = metaStore.getTable("person"); + assertNotNull(table); + + cli.execute("drop table person"); + table = metaStore.getTable("person"); + assertNull(table); + } + + @Test(expected = ValidationException.class) + public void testExecute_dropTable_assertTableRemovedFromPlanner() throws Exception { + InMemoryMetaStore metaStore = new InMemoryMetaStore(); + metaStore.registerProvider(new TextTableProvider()); + + BeamSqlCli cli = new BeamSqlCli() + .metaStore(metaStore); + cli.execute( + "create table person (\n" + + "id int COMMENT 'id', \n" + + "name varchar(31) COMMENT 'name', \n" + + "age int COMMENT 'age') \n" + + "TYPE 'text' \n" + + "COMMENT '' LOCATION 'text://home/admin/orders'" + ); + cli.execute("drop table person"); + cli.explainQuery("select * from person"); + } + + @Test public void testExplainQuery() throws Exception { InMemoryMetaStore metaStore = new InMemoryMetaStore(); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationCovarianceTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationCovarianceTest.java new file mode 100644 index 000000000000..6fd52d49baa1 --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationCovarianceTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql; + +import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesScalar; + +import java.util.List; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +/** + * Integration tests for {@code COVAR_POP} and {@code COVAR_SAMP}. + */ +public class BeamSqlDslAggregationCovarianceTest { + + private static final double PRECISION = 1e-7; + + @Rule + public TestPipeline pipeline = TestPipeline.create(); + + private PCollection boundedInput; + + @Before + public void setUp() { + RowType rowType = RowSqlType.builder() + .withDoubleField("f_double1") + .withDoubleField("f_double2") + .withDoubleField("f_double3") + .withIntegerField("f_int1") + .withIntegerField("f_int2") + .withIntegerField("f_int3") + .build(); + + List rowsInTableB = + TestUtils.RowsBuilder + .of(rowType) + .addRows( + 3.0, 1.0, 1.0, 3, 1, 0, + 4.0, 2.0, 2.0, 4, 2, 0, + 5.0, 3.0, 1.0, 5, 3, 0, + 6.0, 4.0, 2.0, 6, 4, 0, + 8.0, 4.0, 1.0, 8, 4, 0) + .getRows(); + + boundedInput = PBegin + .in(pipeline) + .apply(Create.of(rowsInTableB).withCoder(rowType.getRowCoder())); + } + + @Test + public void testPopulationVarianceDouble() { + String sql = "SELECT COVAR_POP(f_double1, f_double2) FROM PCOLLECTION GROUP BY f_int3"; + + PAssert + .that(boundedInput.apply(BeamSql.query(sql))) + .satisfies(matchesScalar(1.84, PRECISION)); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testPopulationVarianceInt() { + String sql = "SELECT COVAR_POP(f_int1, f_int2) FROM PCOLLECTION GROUP BY f_int3"; + + PAssert + .that(boundedInput.apply(BeamSql.query(sql))) + .satisfies(matchesScalar(1)); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testSampleVarianceDouble() { + String sql = "SELECT COVAR_SAMP(f_double1, f_double2) FROM PCOLLECTION GROUP BY f_int3"; + + PAssert + .that(boundedInput.apply(BeamSql.query(sql))) + .satisfies(matchesScalar(2.3, PRECISION)); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testSampleVarianceInt() { + String sql = "SELECT COVAR_SAMP(f_int1, f_int2) FROM PCOLLECTION GROUP BY f_int3"; + + PAssert + .that(boundedInput.apply(BeamSql.query(sql))) + .satisfies(matchesScalar(2)); + + pipeline.run().waitUntilFinish(); + } +} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java index 22d14b05b8f6..cdea0f8ce1f4 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java @@ -17,24 +17,32 @@ */ package org.apache.beam.sdk.extensions.sql; +import static org.hamcrest.Matchers.isA; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.math.BigDecimal; -import java.sql.Types; -import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.DateTime; +import org.joda.time.Duration; import org.junit.Before; import org.junit.Test; @@ -43,45 +51,33 @@ * with BOUNDED PCollection. */ public class BeamSqlDslAggregationTest extends BeamSqlDslBase { - public PCollection boundedInput3; + public PCollection boundedInput3; @Before - public void setUp(){ - BeamRecordSqlType rowTypeInTableB = BeamRecordSqlType.create( - Arrays.asList("f_int", "f_double", "f_int2", "f_decimal"), - Arrays.asList(Types.INTEGER, Types.DOUBLE, Types.INTEGER, Types.DECIMAL)); - - List recordsInTableB = new ArrayList<>(); - BeamRecord row1 = new BeamRecord(rowTypeInTableB - , 1, 1.0, 0, BigDecimal.ONE); - recordsInTableB.add(row1); - - BeamRecord row2 = new BeamRecord(rowTypeInTableB - , 4, 4.0, 0, new BigDecimal(4)); - recordsInTableB.add(row2); - - BeamRecord row3 = new BeamRecord(rowTypeInTableB - , 7, 7.0, 0, new BigDecimal(7)); - recordsInTableB.add(row3); - - BeamRecord row4 = new BeamRecord(rowTypeInTableB - , 13, 13.0, 0, new BigDecimal(13)); - recordsInTableB.add(row4); - - BeamRecord row5 = new BeamRecord(rowTypeInTableB - , 5, 5.0, 0, new BigDecimal(5)); - recordsInTableB.add(row5); - - BeamRecord row6 = new BeamRecord(rowTypeInTableB - , 10, 10.0, 0, BigDecimal.TEN); - recordsInTableB.add(row6); - - BeamRecord row7 = new BeamRecord(rowTypeInTableB - , 17, 17.0, 0, new BigDecimal(17)); - recordsInTableB.add(row7); - - boundedInput3 = PBegin.in(pipeline).apply("boundedInput3", - Create.of(recordsInTableB).withCoder(rowTypeInTableB.getRecordCoder())); + public void setUp() { + RowType rowTypeInTableB = + RowSqlType.builder() + .withIntegerField("f_int") + .withDoubleField("f_double") + .withIntegerField("f_int2") + .withDecimalField("f_decimal") + .build(); + + List rowsInTableB = + TestUtils.RowsBuilder.of(rowTypeInTableB) + .addRows( + 1, 1.0, 0, new BigDecimal(1), + 4, 4.0, 0, new BigDecimal(4), + 7, 7.0, 0, new BigDecimal(7), + 13, 13.0, 0, new BigDecimal(13), + 5, 5.0, 0, new BigDecimal(5), + 10, 10.0, 0, new BigDecimal(10), + 17, 17.0, 0, new BigDecimal(17) + ).getRows(); + + boundedInput3 = PBegin.in(pipeline).apply( + "boundedInput3", + Create.of(rowsInTableB).withCoder(rowTypeInTableB.getRowCoder())); } /** @@ -100,19 +96,20 @@ public void testAggregationWithoutWindowWithUnbounded() throws Exception { runAggregationWithoutWindow(unboundedInput1); } - private void runAggregationWithoutWindow(PCollection input) throws Exception { + private void runAggregationWithoutWindow(PCollection input) throws Exception { String sql = "SELECT f_int2, COUNT(*) AS `getFieldCount` FROM PCOLLECTION GROUP BY f_int2"; - PCollection result = + PCollection result = input.apply("testAggregationWithoutWindow", BeamSql.query(sql)); - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int2", "size"), - Arrays.asList(Types.INTEGER, Types.BIGINT)); + RowType resultType = RowSqlType.builder() + .withIntegerField("f_int2") + .withBigIntField("size") + .build(); + Row row = Row.withRowType(resultType).addValues(0, 4L).build(); - BeamRecord record = new BeamRecord(resultType, 0, 4L); - - PAssert.that(result).containsInAnyOrder(record); + PAssert.that(result).containsInAnyOrder(row); pipeline.run().waitUntilFinish(); } @@ -121,7 +118,7 @@ private void runAggregationWithoutWindow(PCollection input) throws E * GROUP-BY with multiple aggregation functions with bounded PCollection. */ @Test - public void testAggregationFunctionsWithBounded() throws Exception{ + public void testAggregationFunctionsWithBounded() throws Exception { runAggregationFunctions(boundedInput1); } @@ -129,11 +126,11 @@ public void testAggregationFunctionsWithBounded() throws Exception{ * GROUP-BY with multiple aggregation functions with unbounded PCollection. */ @Test - public void testAggregationFunctionsWithUnbounded() throws Exception{ + public void testAggregationFunctionsWithUnbounded() throws Exception { runAggregationFunctions(unboundedInput1); } - private void runAggregationFunctions(PCollection input) throws Exception{ + private void runAggregationFunctions(PCollection input) throws Exception { String sql = "select f_int2, count(*) as getFieldCount, " + "sum(f_long) as sum1, avg(f_long) as avg1, max(f_long) as max1, min(f_long) as min1, " + "sum(f_short) as sum2, avg(f_short) as avg2, max(f_short) as max2, min(f_short) as min2, " @@ -146,43 +143,71 @@ private void runAggregationFunctions(PCollection input) throws Excep + "var_pop(f_int) as varpop2, var_samp(f_int) as varsamp2 " + "FROM TABLE_A group by f_int2"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testAggregationFunctions", BeamSql.queryMulti(sql)); - - BeamRecordSqlType resultType = BeamRecordSqlType.create( - Arrays.asList("f_int2", "size", "sum1", "avg1", "max1", "min1", "sum2", "avg2", "max2", - "min2", "sum3", "avg3", "max3", "min3", "sum4", "avg4", "max4", "min4", "sum5", "avg5", - "max5", "min5", "max6", "min6", - "varpop1", "varsamp1", "varpop2", "varsamp2"), - Arrays.asList(Types.INTEGER, Types.BIGINT, Types.BIGINT, Types.BIGINT, Types.BIGINT, - Types.BIGINT, Types.SMALLINT, Types.SMALLINT, Types.SMALLINT, Types.SMALLINT, - Types.TINYINT, Types.TINYINT, Types.TINYINT, Types.TINYINT, Types.FLOAT, Types.FLOAT, - Types.FLOAT, Types.FLOAT, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, - Types.TIMESTAMP, Types.TIMESTAMP, - Types.DOUBLE, Types.DOUBLE, Types.INTEGER, Types.INTEGER)); - - BeamRecord record = new BeamRecord(resultType - , 0, 4L - , 10000L, 2500L, 4000L, 1000L - , (short) 10, (short) 2, (short) 4, (short) 1 - , (byte) 10, (byte) 2, (byte) 4, (byte) 1 - , 10.0F, 2.5F, 4.0F, 1.0F - , 10.0, 2.5, 4.0, 1.0 - , FORMAT.parse("2017-01-01 02:04:03"), FORMAT.parse("2017-01-01 01:01:03") - , 1.25, 1.666666667, 1, 1); - - PAssert.that(result).containsInAnyOrder(record); + .apply("testAggregationFunctions", BeamSql.query(sql)); + + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int2") + .withBigIntField("size") + .withBigIntField("sum1") + .withBigIntField("avg1") + .withBigIntField("max1") + .withBigIntField("min1") + .withSmallIntField("sum2") + .withSmallIntField("avg2") + .withSmallIntField("max2") + .withSmallIntField("min2") + .withTinyIntField("sum3") + .withTinyIntField("avg3") + .withTinyIntField("max3") + .withTinyIntField("min3") + .withFloatField("sum4") + .withFloatField("avg4") + .withFloatField("max4") + .withFloatField("min4") + .withDoubleField("sum5") + .withDoubleField("avg5") + .withDoubleField("max5") + .withDoubleField("min5") + .withTimestampField("max6") + .withTimestampField("min6") + .withDoubleField("varpop1") + .withDoubleField("varsamp1") + .withIntegerField("varpop2") + .withIntegerField("varsamp2") + .build(); + + Row row = + Row + .withRowType(resultType) + .addValues( + 0, 4L, + 10000L, 2500L, 4000L, 1000L, + (short) 10, (short) 2, (short) 4, (short) 1, + (byte) 10, (byte) 2, (byte) 4, (byte) 1, + 10.0F, 2.5F, 4.0F, 1.0F, + 10.0, 2.5, 4.0, 1.0, + FORMAT.parse("2017-01-01 02:04:03"), + FORMAT.parse("2017-01-01 01:01:03"), + 1.25, 1.666666667, 1, 1) + .build(); + + PAssert.that(result).containsInAnyOrder(row); pipeline.run().waitUntilFinish(); } private static class CheckerBigDecimalDivide - implements SerializableFunction, Void> { - @Override public Void apply(Iterable input) { - Iterator iter = input.iterator(); + implements SerializableFunction, Void> { + + @Override + public Void apply(Iterable input) { + Iterator iter = input.iterator(); assertTrue(iter.hasNext()); - BeamRecord row = iter.next(); + Row row = iter.next(); assertEquals(row.getDouble("avg1"), 8.142857143, 1e-7); assertTrue(row.getInteger("avg2") == 8); assertEquals(row.getDouble("varpop1"), 26.40816326, 1e-7); @@ -200,20 +225,13 @@ private static class CheckerBigDecimalDivide @Test public void testAggregationFunctionsWithBoundedOnBigDecimalDivide() throws Exception { String sql = "SELECT AVG(f_double) as avg1, AVG(f_int) as avg2, " - + "VAR_POP(f_double) as varpop1, VAR_POP(f_int) as varpop2, " - + "VAR_SAMP(f_double) as varsamp1, VAR_SAMP(f_int) as varsamp2 " - + "FROM PCOLLECTION GROUP BY f_int2"; - - PCollection result = - boundedInput3.apply("testAggregationWithDecimalValue", BeamSql.query(sql)); + + "VAR_POP(f_double) as varpop1, VAR_POP(f_int) as varpop2, " + + "VAR_SAMP(f_double) as varsamp1, VAR_SAMP(f_int) as varsamp2 " + + "FROM PCOLLECTION GROUP BY f_int2"; - BeamRecordSqlType resultType = BeamRecordSqlType.create( - Arrays.asList("avg1", "avg2", "avg3", - "varpop1", "varpop2", - "varsamp1", "varsamp2"), - Arrays.asList(Types.DOUBLE, Types.INTEGER, Types.DECIMAL, - Types.DOUBLE, Types.INTEGER, - Types.DOUBLE, Types.INTEGER)); + PCollection result = + boundedInput3.apply("testAggregationWithDecimalValue", + BeamSql.query(sql)); PAssert.that(result).satisfies(new CheckerBigDecimalDivide()); @@ -236,21 +254,30 @@ public void testDistinctWithUnbounded() throws Exception { runDistinct(unboundedInput1); } - private void runDistinct(PCollection input) throws Exception { + private void runDistinct(PCollection input) throws Exception { String sql = "SELECT distinct f_int, f_long FROM PCOLLECTION "; - PCollection result = + PCollection result = input.apply("testDistinct", BeamSql.query(sql)); - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int", "f_long"), - Arrays.asList(Types.INTEGER, Types.BIGINT)); - - BeamRecord record1 = new BeamRecord(resultType, 1, 1000L); - BeamRecord record2 = new BeamRecord(resultType, 2, 2000L); - BeamRecord record3 = new BeamRecord(resultType, 3, 3000L); - BeamRecord record4 = new BeamRecord(resultType, 4, 4000L); - - PAssert.that(result).containsInAnyOrder(record1, record2, record3, record4); + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int") + .withBigIntField("f_long") + .build(); + + List expectedRows = + TestUtils.RowsBuilder + .of(resultType) + .addRows( + 1, 1000L, + 2, 2000L, + 3, 3000L, + 4, 4000L) + .getRows(); + + PAssert.that(result).containsInAnyOrder(expectedRows); pipeline.run().waitUntilFinish(); } @@ -271,23 +298,32 @@ public void testTumbleWindowWithUnbounded() throws Exception { runTumbleWindow(unboundedInput1); } - private void runTumbleWindow(PCollection input) throws Exception { + private void runTumbleWindow(PCollection input) throws Exception { String sql = "SELECT f_int2, COUNT(*) AS `getFieldCount`," + " TUMBLE_START(f_timestamp, INTERVAL '1' HOUR) AS `window_start`" + " FROM TABLE_A" + " GROUP BY f_int2, TUMBLE(f_timestamp, INTERVAL '1' HOUR)"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testTumbleWindow", BeamSql.queryMulti(sql)); - - BeamRecordSqlType resultType = BeamRecordSqlType.create( - Arrays.asList("f_int2", "size", "window_start"), - Arrays.asList(Types.INTEGER, Types.BIGINT, Types.TIMESTAMP)); - - BeamRecord record1 = new BeamRecord(resultType, 0, 3L, FORMAT.parse("2017-01-01 01:00:00")); - BeamRecord record2 = new BeamRecord(resultType, 0, 1L, FORMAT.parse("2017-01-01 02:00:00")); - - PAssert.that(result).containsInAnyOrder(record1, record2); + .apply("testTumbleWindow", BeamSql.query(sql)); + + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int2") + .withBigIntField("size") + .withTimestampField("window_start") + .build(); + + List expectedRows = + TestUtils.RowsBuilder + .of(resultType) + .addRows( + 0, 3L, FORMAT.parse("2017-01-01 01:00:00"), + 0, 1L, FORMAT.parse("2017-01-01 02:00:00")) + .getRows(); + + PAssert.that(result).containsInAnyOrder(expectedRows); pipeline.run().waitUntilFinish(); } @@ -308,24 +344,33 @@ public void testHopWindowWithUnbounded() throws Exception { runHopWindow(unboundedInput1); } - private void runHopWindow(PCollection input) throws Exception { + private void runHopWindow(PCollection input) throws Exception { String sql = "SELECT f_int2, COUNT(*) AS `getFieldCount`," - + " HOP_START(f_timestamp, INTERVAL '1' HOUR, INTERVAL '30' MINUTE) AS `window_start`" + + " HOP_START(f_timestamp, INTERVAL '30' MINUTE, INTERVAL '1' HOUR) AS `window_start`" + " FROM PCOLLECTION" - + " GROUP BY f_int2, HOP(f_timestamp, INTERVAL '1' HOUR, INTERVAL '30' MINUTE)"; - PCollection result = + + " GROUP BY f_int2, HOP(f_timestamp, INTERVAL '30' MINUTE, INTERVAL '1' HOUR)"; + PCollection result = input.apply("testHopWindow", BeamSql.query(sql)); - BeamRecordSqlType resultType = BeamRecordSqlType.create( - Arrays.asList("f_int2", "size", "window_start"), - Arrays.asList(Types.INTEGER, Types.BIGINT, Types.TIMESTAMP)); - - BeamRecord record1 = new BeamRecord(resultType, 0, 3L, FORMAT.parse("2017-01-01 00:30:00")); - BeamRecord record2 = new BeamRecord(resultType, 0, 3L, FORMAT.parse("2017-01-01 01:00:00")); - BeamRecord record3 = new BeamRecord(resultType, 0, 1L, FORMAT.parse("2017-01-01 01:30:00")); - BeamRecord record4 = new BeamRecord(resultType, 0, 1L, FORMAT.parse("2017-01-01 02:00:00")); - - PAssert.that(result).containsInAnyOrder(record1, record2, record3, record4); + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int2") + .withBigIntField("size") + .withTimestampField("window_start") + .build(); + + List expectedRows = + TestUtils.RowsBuilder + .of(resultType) + .addRows( + 0, 3L, FORMAT.parse("2017-01-01 00:30:00"), + 0, 3L, FORMAT.parse("2017-01-01 01:00:00"), + 0, 1L, FORMAT.parse("2017-01-01 01:30:00"), + 0, 1L, FORMAT.parse("2017-01-01 02:00:00")) + .getRows(); + + PAssert.that(result).containsInAnyOrder(expectedRows); pipeline.run().waitUntilFinish(); } @@ -346,23 +391,32 @@ public void testSessionWindowWithUnbounded() throws Exception { runSessionWindow(unboundedInput1); } - private void runSessionWindow(PCollection input) throws Exception { + private void runSessionWindow(PCollection input) throws Exception { String sql = "SELECT f_int2, COUNT(*) AS `getFieldCount`," + " SESSION_START(f_timestamp, INTERVAL '5' MINUTE) AS `window_start`" + " FROM TABLE_A" + " GROUP BY f_int2, SESSION(f_timestamp, INTERVAL '5' MINUTE)"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testSessionWindow", BeamSql.queryMulti(sql)); - - BeamRecordSqlType resultType = BeamRecordSqlType.create( - Arrays.asList("f_int2", "size", "window_start"), - Arrays.asList(Types.INTEGER, Types.BIGINT, Types.TIMESTAMP)); - - BeamRecord record1 = new BeamRecord(resultType, 0, 3L, FORMAT.parse("2017-01-01 01:01:03")); - BeamRecord record2 = new BeamRecord(resultType, 0, 1L, FORMAT.parse("2017-01-01 02:04:03")); - - PAssert.that(result).containsInAnyOrder(record1, record2); + .apply("testSessionWindow", BeamSql.query(sql)); + + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int2") + .withBigIntField("size") + .withTimestampField("window_start") + .build(); + + List expectedRows = + TestUtils.RowsBuilder + .of(resultType) + .addRows( + 0, 3L, FORMAT.parse("2017-01-01 01:01:03"), + 0, 1L, FORMAT.parse("2017-01-01 02:04:03")) + .getRows(); + + PAssert.that(result).containsInAnyOrder(expectedRows); pipeline.run().waitUntilFinish(); } @@ -376,9 +430,9 @@ public void testWindowOnNonTimestampField() throws Exception { String sql = "SELECT f_int2, COUNT(*) AS `getFieldCount` FROM TABLE_A " + "GROUP BY f_int2, TUMBLE(f_long, INTERVAL '1' HOUR)"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), boundedInput1) - .apply("testWindowOnNonTimestampField", BeamSql.queryMulti(sql)); + .apply("testWindowOnNonTimestampField", BeamSql.query(sql)); pipeline.run().waitUntilFinish(); } @@ -392,9 +446,140 @@ public void testUnsupportedDistinct() throws Exception { String sql = "SELECT f_int2, COUNT(DISTINCT *) AS `size` " + "FROM PCOLLECTION GROUP BY f_int2"; - PCollection result = + PCollection result = boundedInput1.apply("testUnsupportedDistinct", BeamSql.query(sql)); pipeline.run().waitUntilFinish(); } + + @Test + public void testUnsupportedGlobalWindowWithDefaultTrigger() { + exceptions.expect(IllegalStateException.class); + exceptions.expectCause(isA(UnsupportedOperationException.class)); + + pipeline.enableAbandonedNodeEnforcement(false); + + PCollection input = unboundedInput1 + .apply("unboundedInput1.globalWindow", + Window. into(new GlobalWindows()).triggering(DefaultTrigger.of())); + + String sql = "SELECT f_int2, COUNT(*) AS `size` FROM PCOLLECTION GROUP BY f_int2"; + + input.apply("testUnsupportedGlobalWindows", BeamSql.query(sql)); + } + + @Test + public void testSupportsGlobalWindowWithCustomTrigger() throws Exception { + pipeline.enableAbandonedNodeEnforcement(false); + + DateTime startTime = new DateTime(2017, 1, 1, 0, 0, 0, 0); + + RowType type = + RowSqlType + .builder() + .withIntegerField("f_intGroupingKey") + .withIntegerField("f_intValue") + .withTimestampField("f_timestamp") + .build(); + + Object[] rows = new Object[]{ + 0, 1, startTime.plusSeconds(0).toDate(), + 0, 2, startTime.plusSeconds(1).toDate(), + 0, 3, startTime.plusSeconds(2).toDate(), + 0, 4, startTime.plusSeconds(3).toDate(), + 0, 5, startTime.plusSeconds(4).toDate(), + 0, 6, startTime.plusSeconds(6).toDate() + }; + + PCollection input = + createTestPCollection(type, rows, "f_timestamp") + .apply(Window + .into(new GlobalWindows()) + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(2))) + .discardingFiredPanes() + .withOnTimeBehavior(Window.OnTimeBehavior.FIRE_IF_NON_EMPTY)); + + String sql = + "SELECT SUM(f_intValue) AS `sum` FROM PCOLLECTION GROUP BY f_intGroupingKey"; + + PCollection result = input.apply("sql", BeamSql.query(sql)); + + assertEquals(new GlobalWindows(), result.getWindowingStrategy().getWindowFn()); + PAssert + .that(result) + .containsInAnyOrder( + rowsWithSingleIntField("sum", Arrays.asList(3, 7, 11))); + + pipeline.run(); + } + + @Test + public void testSupportsNonGlobalWindowWithCustomTrigger() { + DateTime startTime = new DateTime(2017, 1, 1, 0, 0, 0, 0); + + RowType type = + RowSqlType + .builder() + .withIntegerField("f_intGroupingKey") + .withIntegerField("f_intValue") + .withTimestampField("f_timestamp") + .build(); + + Object[] rows = new Object[]{ + 0, 1, startTime.plusSeconds(0).toDate(), + 0, 2, startTime.plusSeconds(1).toDate(), + 0, 3, startTime.plusSeconds(2).toDate(), + 0, 4, startTime.plusSeconds(3).toDate(), + 0, 5, startTime.plusSeconds(4).toDate(), + 0, 6, startTime.plusSeconds(6).toDate() + }; + + PCollection input = + createTestPCollection(type, rows, "f_timestamp") + .apply(Window + .into( + FixedWindows.of(Duration.standardSeconds(3))) + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(2))) + .discardingFiredPanes() + .withAllowedLateness(Duration.ZERO) + .withOnTimeBehavior(Window.OnTimeBehavior.FIRE_IF_NON_EMPTY)); + + String sql = + "SELECT SUM(f_intValue) AS `sum` FROM PCOLLECTION GROUP BY f_intGroupingKey"; + + PCollection result = input.apply("sql", BeamSql.query(sql)); + + assertEquals( + FixedWindows.of(Duration.standardSeconds(3)), + result.getWindowingStrategy().getWindowFn()); + + PAssert + .that(result) + .containsInAnyOrder( + rowsWithSingleIntField("sum", Arrays.asList(3, 3, 9, 6))); + + pipeline.run(); + } + + private List rowsWithSingleIntField(String fieldName, List values) { + return + TestUtils + .rowsBuilderOf(RowSqlType.builder().withIntegerField(fieldName).build()) + .addRows(values) + .getRows(); + } + + private PCollection createTestPCollection( + RowType type, + Object[] rows, + String timestampField) { + return + TestUtils + .rowsBuilderOf(type) + .addRows(rows) + .getPCollectionBuilder() + .inPipeline(pipeline) + .withTimestampField(timestampField) + .buildUnbounded(); + } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java index d70b0378e10c..2930d13e8f2c 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java @@ -17,15 +17,16 @@ */ package org.apache.beam.sdk.extensions.sql; -import static org.apache.beam.sdk.extensions.sql.utils.BeamRecordAsserts.matchesScalar; +import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesScalar; import java.util.List; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -40,17 +41,19 @@ public class BeamSqlDslAggregationVarianceTest { @Rule public TestPipeline pipeline = TestPipeline.create(); - private PCollection boundedInput; + private PCollection boundedInput; @Before public void setUp() { - BeamRecordSqlType rowType = BeamRecordSqlType.builder() - .withIntegerField("f_int") - .withDoubleField("f_double") - .withIntegerField("f_int2") - .build(); - - List recordsInTableB = + RowType rowType = + RowSqlType + .builder() + .withIntegerField("f_int") + .withDoubleField("f_double") + .withIntegerField("f_int2") + .build(); + + List rowsInTableB = TestUtils.RowsBuilder .of(rowType) .addRows( @@ -65,7 +68,7 @@ public void setUp() { boundedInput = PBegin .in(pipeline) - .apply(Create.of(recordsInTableB).withCoder(rowType.getRecordCoder())); + .apply(Create.of(rowsInTableB).withCoder(rowType.getRowCoder())); } @Test diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java index 7f456892354b..629531be4c3c 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java @@ -18,19 +18,20 @@ package org.apache.beam.sdk.extensions.sql; import java.math.BigDecimal; -import java.sql.Types; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; import org.junit.BeforeClass; @@ -51,86 +52,91 @@ public class BeamSqlDslBase { @Rule public ExpectedException exceptions = ExpectedException.none(); - public static BeamRecordSqlType rowTypeInTableA; - public static List recordsInTableA; + static RowType rowTypeInTableA; + static List rowsInTableA; //bounded PCollections - public PCollection boundedInput1; - public PCollection boundedInput2; + PCollection boundedInput1; + PCollection boundedInput2; //unbounded PCollections - public PCollection unboundedInput1; - public PCollection unboundedInput2; + PCollection unboundedInput1; + PCollection unboundedInput2; @BeforeClass public static void prepareClass() throws ParseException { - rowTypeInTableA = BeamRecordSqlType.create( - Arrays.asList("f_int", "f_long", "f_short", "f_byte", "f_float", "f_double", "f_string", - "f_timestamp", "f_int2", "f_decimal"), - Arrays.asList(Types.INTEGER, Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.FLOAT, - Types.DOUBLE, Types.VARCHAR, Types.TIMESTAMP, Types.INTEGER, Types.DECIMAL)); - - recordsInTableA = prepareInputRowsInTableA(); + rowTypeInTableA = + RowSqlType + .builder() + .withIntegerField("f_int") + .withBigIntField("f_long") + .withSmallIntField("f_short") + .withTinyIntField("f_byte") + .withFloatField("f_float") + .withDoubleField("f_double") + .withVarcharField("f_string") + .withTimestampField("f_timestamp") + .withIntegerField("f_int2") + .withDecimalField("f_decimal") + .build(); + + rowsInTableA = + TestUtils.RowsBuilder.of(rowTypeInTableA) + .addRows( + 1, 1000L, (short) 1, (byte) 1, 1.0f, 1.0d, "string_row1", + FORMAT.parse("2017-01-01 01:01:03"), 0, new BigDecimal(1)) + .addRows( + 2, 2000L, (short) 2, (byte) 2, 2.0f, 2.0d, "string_row2", + FORMAT.parse("2017-01-01 01:02:03"), 0, new BigDecimal(2)) + .addRows( + 3, 3000L, (short) 3, (byte) 3, 3.0f, 3.0d, "string_row3", + FORMAT.parse("2017-01-01 01:06:03"), 0, new BigDecimal(3)) + .addRows( + 4, 4000L, (short) 4, (byte) 4, 4.0f, 4.0d, "第四行", + FORMAT.parse("2017-01-01 02:04:03"), 0, new BigDecimal(4)) + .getRows(); } @Before public void preparePCollections(){ boundedInput1 = PBegin.in(pipeline).apply("boundedInput1", - Create.of(recordsInTableA).withCoder(rowTypeInTableA.getRecordCoder())); + Create.of(rowsInTableA).withCoder(rowTypeInTableA.getRowCoder())); boundedInput2 = PBegin.in(pipeline).apply("boundedInput2", - Create.of(recordsInTableA.get(0)).withCoder(rowTypeInTableA.getRecordCoder())); + Create.of(rowsInTableA.get(0)).withCoder(rowTypeInTableA.getRowCoder())); unboundedInput1 = prepareUnboundedPCollection1(); unboundedInput2 = prepareUnboundedPCollection2(); } - private PCollection prepareUnboundedPCollection1() { - TestStream.Builder values = TestStream - .create(rowTypeInTableA.getRecordCoder()); + private PCollection prepareUnboundedPCollection1() { + TestStream.Builder values = TestStream + .create(rowTypeInTableA.getRowCoder()); - for (BeamRecord row : recordsInTableA) { + for (Row row : rowsInTableA) { values = values.advanceWatermarkTo(new Instant(row.getDate("f_timestamp"))); values = values.addElements(row); } - return PBegin.in(pipeline).apply("unboundedInput1", values.advanceWatermarkToInfinity()); + return PBegin + .in(pipeline) + .apply("unboundedInput1", values.advanceWatermarkToInfinity()) + .apply("unboundedInput1.fixedWindow1year", + Window.into(FixedWindows.of(Duration.standardDays(365)))); } - private PCollection prepareUnboundedPCollection2() { - TestStream.Builder values = TestStream - .create(rowTypeInTableA.getRecordCoder()); + private PCollection prepareUnboundedPCollection2() { + TestStream.Builder values = TestStream + .create(rowTypeInTableA.getRowCoder()); - BeamRecord row = recordsInTableA.get(0); + Row row = rowsInTableA.get(0); values = values.advanceWatermarkTo(new Instant(row.getDate("f_timestamp"))); values = values.addElements(row); - return PBegin.in(pipeline).apply("unboundedInput2", values.advanceWatermarkToInfinity()); - } - - private static List prepareInputRowsInTableA() throws ParseException{ - List rows = new ArrayList<>(); - - BeamRecord row1 = new BeamRecord(rowTypeInTableA - , 1, 1000L, Short.valueOf("1"), Byte.valueOf("1"), 1.0f, 1.0, "string_row1" - , FORMAT.parse("2017-01-01 01:01:03"), 0, BigDecimal.ONE); - rows.add(row1); - - BeamRecord row2 = new BeamRecord(rowTypeInTableA - , 2, 2000L, Short.valueOf("2"), Byte.valueOf("2"), 2.0f, 2.0, "string_row2" - , FORMAT.parse("2017-01-01 01:02:03"), 0, new BigDecimal(2)); - rows.add(row2); - - BeamRecord row3 = new BeamRecord(rowTypeInTableA - , 3, 3000L, Short.valueOf("3"), Byte.valueOf("3"), 3.0f, 3.0, "string_row3" - , FORMAT.parse("2017-01-01 01:06:03"), 0, new BigDecimal(3)); - rows.add(row3); - - BeamRecord row4 = new BeamRecord(rowTypeInTableA - , 4, 4000L, Short.valueOf("4"), Byte.valueOf("4"), 4.0f, 4.0, "第四行" - , FORMAT.parse("2017-01-01 02:04:03"), 0, new BigDecimal(4)); - rows.add(row4); - - return rows; + return PBegin + .in(pipeline) + .apply("unboundedInput2", values.advanceWatermarkToInfinity()) + .apply("unboundedInput2.fixedWindow1year", + Window.into(FixedWindows.of(Duration.standardDays(365)))); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslFilterTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslFilterTest.java index 0dd1cf93e335..66bd1f1a05e8 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslFilterTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslFilterTest.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.extensions.sql; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.junit.Test; @@ -44,13 +44,13 @@ public void testSingleFilterWithUnbounded() throws Exception { runSingleFilter(unboundedInput1); } - private void runSingleFilter(PCollection input) throws Exception { + private void runSingleFilter(PCollection input) throws Exception { String sql = "SELECT * FROM PCOLLECTION WHERE f_int = 1"; - PCollection result = + PCollection result = input.apply("testSingleFilter", BeamSql.query(sql)); - PAssert.that(result).containsInAnyOrder(recordsInTableA.get(0)); + PAssert.that(result).containsInAnyOrder(rowsInTableA.get(0)); pipeline.run().waitUntilFinish(); } @@ -71,15 +71,15 @@ public void testCompositeFilterWithUnbounded() throws Exception { runCompositeFilter(unboundedInput1); } - private void runCompositeFilter(PCollection input) throws Exception { + private void runCompositeFilter(PCollection input) throws Exception { String sql = "SELECT * FROM TABLE_A" + " WHERE f_int > 1 AND (f_long < 3000 OR f_string = 'string_row3')"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testCompositeFilter", BeamSql.queryMulti(sql)); + .apply("testCompositeFilter", BeamSql.query(sql)); - PAssert.that(result).containsInAnyOrder(recordsInTableA.get(1), recordsInTableA.get(2)); + PAssert.that(result).containsInAnyOrder(rowsInTableA.get(1), rowsInTableA.get(2)); pipeline.run().waitUntilFinish(); } @@ -100,12 +100,12 @@ public void testNoReturnFilterWithUnbounded() throws Exception { runNoReturnFilter(unboundedInput1); } - private void runNoReturnFilter(PCollection input) throws Exception { + private void runNoReturnFilter(PCollection input) throws Exception { String sql = "SELECT * FROM TABLE_A WHERE f_int < 1"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testNoReturnFilter", BeamSql.queryMulti(sql)); + .apply("testNoReturnFilter", BeamSql.query(sql)); PAssert.that(result).empty(); @@ -120,9 +120,9 @@ public void testFromInvalidTableName1() throws Exception { String sql = "SELECT * FROM TABLE_B WHERE f_int < 1"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), boundedInput1) - .apply("testFromInvalidTableName1", BeamSql.queryMulti(sql)); + .apply("testFromInvalidTableName1", BeamSql.query(sql)); pipeline.run().waitUntilFinish(); } @@ -130,12 +130,15 @@ public void testFromInvalidTableName1() throws Exception { @Test public void testFromInvalidTableName2() throws Exception { exceptions.expect(IllegalStateException.class); - exceptions.expectMessage("Use fixed table name PCOLLECTION"); + exceptions.expectMessage("Use PCOLLECTION as table name" + + " when selecting from single PCollection." + + " Use PCollectionTuple to explicitly " + + "name the input PCollections"); pipeline.enableAbandonedNodeEnforcement(false); String sql = "SELECT * FROM PCOLLECTION_NA"; - PCollection result = boundedInput1.apply(BeamSql.query(sql)); + PCollection result = boundedInput1.apply(BeamSql.query(sql)); pipeline.run().waitUntilFinish(); } @@ -148,7 +151,7 @@ public void testInvalidFilter() throws Exception { String sql = "SELECT * FROM PCOLLECTION WHERE f_int_na = 0"; - PCollection result = boundedInput1.apply(BeamSql.query(sql)); + PCollection result = boundedInput1.apply(BeamSql.query(sql)); pipeline.run().waitUntilFinish(); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslJoinTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslJoinTest.java index 40cfe35c6a11..adbed07de6a9 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslJoinTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslJoinTest.java @@ -18,52 +18,64 @@ package org.apache.beam.sdk.extensions.sql; -import static org.apache.beam.sdk.extensions.sql.impl.rel.BeamJoinRelBoundedVsBoundedTest.ORDER_DETAILS1; -import static org.apache.beam.sdk.extensions.sql.impl.rel.BeamJoinRelBoundedVsBoundedTest.ORDER_DETAILS2; +import static org.apache.beam.sdk.extensions.sql.TestUtils.tuple; +import static org.apache.beam.sdk.extensions.sql.impl.rel.BeamJoinRelBoundedVsBoundedTest + .ORDER_DETAILS1; +import static org.apache.beam.sdk.extensions.sql.impl.rel.BeamJoinRelBoundedVsBoundedTest + .ORDER_DETAILS2; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.hasProperty; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.stringContainsInOrder; -import java.sql.Types; import java.util.Arrays; -import org.apache.beam.sdk.coders.BeamRecordCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.hamcrest.Matcher; +import org.joda.time.DateTime; +import org.joda.time.Duration; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; /** * Tests for joins in queries. */ public class BeamSqlDslJoinTest { - @Rule - public final TestPipeline pipeline = TestPipeline.create(); - - private static final BeamRecordSqlType SOURCE_RECORD_TYPE = - BeamRecordSqlType.create( - Arrays.asList( - "order_id", "site_id", "price" - ), - Arrays.asList( - Types.INTEGER, Types.INTEGER, Types.INTEGER - ) - ); - - private static final BeamRecordCoder SOURCE_CODER = SOURCE_RECORD_TYPE.getRecordCoder(); - - private static final BeamRecordSqlType RESULT_RECORD_TYPE = - BeamRecordSqlType.create( - Arrays.asList( - "order_id", "site_id", "price", "order_id0", "site_id0", "price0" - ), - Arrays.asList( - Types.INTEGER, Types.INTEGER, Types.INTEGER, Types.INTEGER - , Types.INTEGER, Types.INTEGER - ) - ); - - private static final BeamRecordCoder RESULT_CODER = RESULT_RECORD_TYPE.getRecordCoder(); + + @Rule public final ExpectedException thrown = ExpectedException.none(); + @Rule public final TestPipeline pipeline = TestPipeline.create(); + + private static final RowType SOURCE_ROW_TYPE = + RowSqlType.builder() + .withIntegerField("order_id") + .withIntegerField("site_id") + .withIntegerField("price") + .build(); + + private static final RowCoder SOURCE_CODER = SOURCE_ROW_TYPE.getRowCoder(); + + private static final RowType RESULT_ROW_TYPE = + RowSqlType.builder() + .withIntegerField("order_id") + .withIntegerField("site_id") + .withIntegerField("price") + .withIntegerField("order_id0") + .withIntegerField("site_id0") + .withIntegerField("price0") + .build(); + + private static final RowCoder RESULT_CODER = RESULT_ROW_TYPE.getRowCoder(); @Test public void testInnerJoin() throws Exception { @@ -72,12 +84,11 @@ public void testInnerJoin() throws Exception { + "FROM ORDER_DETAILS1 o1" + " JOIN ORDER_DETAILS2 o2" + " on " - + " o1.order_id=o2.site_id AND o2.price=o1.site_id" - ; + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; PAssert.that(queryFromOrderTables(sql)).containsInAnyOrder( TestUtils.RowsBuilder.of( - RESULT_RECORD_TYPE + RESULT_ROW_TYPE ).addRows( 2, 3, 3, 1, 2, 3 ).getRows()); @@ -91,12 +102,11 @@ public void testLeftOuterJoin() throws Exception { + "FROM ORDER_DETAILS1 o1" + " LEFT OUTER JOIN ORDER_DETAILS2 o2" + " on " - + " o1.order_id=o2.site_id AND o2.price=o1.site_id" - ; + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; PAssert.that(queryFromOrderTables(sql)).containsInAnyOrder( TestUtils.RowsBuilder.of( - RESULT_RECORD_TYPE + RESULT_ROW_TYPE ).addRows( 1, 2, 3, null, null, null, 2, 3, 3, 1, 2, 3, @@ -112,12 +122,11 @@ public void testRightOuterJoin() throws Exception { + "FROM ORDER_DETAILS1 o1" + " RIGHT OUTER JOIN ORDER_DETAILS2 o2" + " on " - + " o1.order_id=o2.site_id AND o2.price=o1.site_id" - ; + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; PAssert.that(queryFromOrderTables(sql)).containsInAnyOrder( TestUtils.RowsBuilder.of( - RESULT_RECORD_TYPE + RESULT_ROW_TYPE ).addRows( 2, 3, 3, 1, 2, 3, null, null, null, 2, 3, 3, @@ -133,12 +142,11 @@ public void testFullOuterJoin() throws Exception { + "FROM ORDER_DETAILS1 o1" + " FULL OUTER JOIN ORDER_DETAILS2 o2" + " on " - + " o1.order_id=o2.site_id AND o2.price=o1.site_id" - ; + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; PAssert.that(queryFromOrderTables(sql)).containsInAnyOrder( TestUtils.RowsBuilder.of( - RESULT_RECORD_TYPE + RESULT_ROW_TYPE ).addRows( 2, 3, 3, 1, 2, 3, 1, 2, 3, null, null, null, @@ -156,16 +164,17 @@ public void testException_nonEqualJoin() throws Exception { + "FROM ORDER_DETAILS1 o1" + " JOIN ORDER_DETAILS2 o2" + " on " - + " o1.order_id>o2.site_id" - ; + + " o1.order_id>o2.site_id"; pipeline.enableAbandonedNodeEnforcement(false); queryFromOrderTables(sql); pipeline.run(); } - @Test(expected = IllegalStateException.class) + @Test public void testException_crossJoin() throws Exception { + thrown.expect(IllegalStateException.class); + String sql = "SELECT * " + "FROM ORDER_DETAILS1 o1, ORDER_DETAILS2 o2"; @@ -175,14 +184,172 @@ public void testException_crossJoin() throws Exception { pipeline.run(); } - private PCollection queryFromOrderTables(String sql) { - return PCollectionTuple.of( - new TupleTag<>("ORDER_DETAILS1"), - ORDER_DETAILS1.buildIOReader(pipeline).setCoder(SOURCE_CODER)) - .and( - new TupleTag<>("ORDER_DETAILS2"), - ORDER_DETAILS2.buildIOReader(pipeline).setCoder(SOURCE_CODER)) - .apply("join", BeamSql.queryMulti(sql)) + @Test + public void testJoinsUnboundedWithinWindowsWithDefaultTrigger() throws Exception { + + String sql = + "SELECT o1.order_id, o1.price, o1.site_id, o2.order_id, o2.price, o2.site_id " + + "FROM ORDER_DETAILS1 o1" + + " JOIN ORDER_DETAILS2 o2" + + " on " + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; + + PCollection orders = ordersUnbounded() + .apply("window", Window.into(FixedWindows.of(Duration.standardSeconds(50)))); + PCollectionTuple inputs = tuple("ORDER_DETAILS1", orders, "ORDER_DETAILS2", orders); + + PAssert + .that( + inputs.apply("sql", BeamSql.query(sql))) + .containsInAnyOrder( + TestUtils.RowsBuilder + .of( + RESULT_ROW_TYPE + ).addRows( + 1, 2, 2, 2, 2, 1, + 1, 4, 3, 3, 3, 1 + ).getRows()); + + pipeline.run(); + } + + @Test + public void testRejectsUnboundedWithinWindowsWithEndOfWindowTrigger() throws Exception { + + String sql = + "SELECT o1.order_id, o1.price, o1.site_id, o2.order_id, o2.price, o2.site_id " + + "FROM ORDER_DETAILS1 o1" + + " JOIN ORDER_DETAILS2 o2" + + " on " + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; + + PCollection orders = ordersUnbounded() + .apply("window", + Window + .into(FixedWindows.of(Duration.standardSeconds(50))) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .accumulatingFiredPanes()); + PCollectionTuple inputs = tuple("ORDER_DETAILS1", orders, "ORDER_DETAILS2", orders); + + thrown.expectCause(expectedSingleFireTrigger()); + + inputs.apply("sql", BeamSql.query(sql)); + + pipeline.run(); + } + + @Test + public void testRejectsGlobalWindowsWithDefaultTriggerInUnboundedInput() throws Exception { + + String sql = + "SELECT * " + + "FROM ORDER_DETAILS1 o1" + + " JOIN ORDER_DETAILS2 o2" + + " on " + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; + + PCollection orders = ordersUnbounded(); + PCollectionTuple inputs = tuple("ORDER_DETAILS1", orders, "ORDER_DETAILS2", orders); + + thrown.expectCause(expectedSingleFireTrigger()); + + inputs.apply("sql", BeamSql.query(sql)); + + pipeline.run(); + } + + @Test + public void testRejectsGlobalWindowsWithEndOfWindowTrigger() throws Exception { + + String sql = + "SELECT o1.order_id, o1.price, o1.site_id, o2.order_id, o2.price, o2.site_id " + + "FROM ORDER_DETAILS1 o1" + + " JOIN ORDER_DETAILS2 o2" + + " on " + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; + + PCollection orders = ordersUnbounded() + .apply("window", + Window + .into(new GlobalWindows()) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .accumulatingFiredPanes()); + PCollectionTuple inputs = tuple("ORDER_DETAILS1", orders, "ORDER_DETAILS2", orders); + + thrown.expectCause(expectedSingleFireTrigger()); + + inputs.apply("sql", BeamSql.query(sql)); + + pipeline.run(); + } + + @Test + public void testRejectsNonGlobalWindowsWithRepeatingTrigger() throws Exception { + + String sql = + "SELECT o1.order_id, o1.price, o1.site_id, o2.order_id, o2.price, o2.site_id " + + "FROM ORDER_DETAILS1 o1" + + " JOIN ORDER_DETAILS2 o2" + + " on " + + " o1.order_id=o2.site_id AND o2.price=o1.site_id"; + + PCollection orders = ordersUnbounded() + .apply( + "window", + Window + .into(FixedWindows.of(Duration.standardSeconds(203))) + .triggering(Repeatedly.forever(AfterWatermark.pastEndOfWindow())) + .withAllowedLateness(Duration.standardMinutes(2)) + .accumulatingFiredPanes()); + PCollectionTuple inputs = tuple("ORDER_DETAILS1", orders, "ORDER_DETAILS2", orders); + + thrown.expectCause(expectedSingleFireTrigger()); + + inputs.apply("sql", BeamSql.query(sql)); + + pipeline.run(); + } + + private PCollection ordersUnbounded() { + DateTime ts = new DateTime(2017, 1, 1, 1, 0, 0); + + return + TestUtils + .rowsBuilderOf( + RowSqlType + .builder() + .withIntegerField("order_id") + .withIntegerField("price") + .withIntegerField("site_id") + .withTimestampField("timestamp") + .build()) + .addRows( + 1, 2, 2, ts.plusSeconds(0).toDate(), + 2, 2, 1, ts.plusSeconds(40).toDate(), + 1, 4, 3, ts.plusSeconds(60).toDate(), + 3, 2, 1, ts.plusSeconds(65).toDate(), + 3, 3, 1, ts.plusSeconds(70).toDate()) + .getPCollectionBuilder() + .withTimestampField("timestamp") + .inPipeline(pipeline) + .buildUnbounded(); + } + + private PCollection queryFromOrderTables(String sql) { + return tuple( + "ORDER_DETAILS1", ORDER_DETAILS1.buildIOReader(pipeline).setCoder(SOURCE_CODER), + "ORDER_DETAILS2", ORDER_DETAILS2.buildIOReader(pipeline).setCoder(SOURCE_CODER)) + .apply("join", BeamSql.query(sql)) .setCoder(RESULT_CODER); } + + private Matcher expectedSingleFireTrigger() { + return allOf( + isA(UnsupportedOperationException.class), + hasProperty("message", + stringContainsInOrder( + Arrays.asList("once per window", "default trigger")))); + } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslProjectTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslProjectTest.java index e4e0e61e4a94..69885272ad08 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslProjectTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslProjectTest.java @@ -17,12 +17,15 @@ */ package org.apache.beam.sdk.extensions.sql; -import java.sql.Types; -import java.util.Arrays; +import static java.util.stream.Collectors.toList; + +import java.util.List; +import java.util.stream.IntStream; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.beam.sdk.values.TupleTag; import org.junit.Test; @@ -46,13 +49,13 @@ public void testSelectAllWithUnbounded() throws Exception { runSelectAll(unboundedInput2); } - private void runSelectAll(PCollection input) throws Exception { + private void runSelectAll(PCollection input) throws Exception { String sql = "SELECT * FROM PCOLLECTION"; - PCollection result = + PCollection result = input.apply("testSelectAll", BeamSql.query(sql)); - PAssert.that(result).containsInAnyOrder(recordsInTableA.get(0)); + PAssert.that(result).containsInAnyOrder(rowsInTableA.get(0)); pipeline.run().waitUntilFinish(); } @@ -73,20 +76,21 @@ public void testPartialFieldsWithUnbounded() throws Exception { runPartialFields(unboundedInput2); } - private void runPartialFields(PCollection input) throws Exception { + private void runPartialFields(PCollection input) throws Exception { String sql = "SELECT f_int, f_long FROM TABLE_A"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testPartialFields", BeamSql.queryMulti(sql)); + .apply("testPartialFields", BeamSql.query(sql)); - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int", "f_long"), - Arrays.asList(Types.INTEGER, Types.BIGINT)); + RowType resultType = RowSqlType.builder() + .withIntegerField("f_int") + .withBigIntField("f_long") + .build(); - BeamRecord record = new BeamRecord(resultType - , recordsInTableA.get(0).getFieldValue(0), recordsInTableA.get(0).getFieldValue(1)); + Row row = rowAtIndex(resultType, 0); - PAssert.that(result).containsInAnyOrder(record); + PAssert.that(result).containsInAnyOrder(row); pipeline.run().waitUntilFinish(); } @@ -107,33 +111,42 @@ public void testPartialFieldsInMultipleRowWithUnbounded() throws Exception { runPartialFieldsInMultipleRow(unboundedInput1); } - private void runPartialFieldsInMultipleRow(PCollection input) throws Exception { + private void runPartialFieldsInMultipleRow(PCollection input) throws Exception { String sql = "SELECT f_int, f_long FROM TABLE_A"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testPartialFieldsInMultipleRow", BeamSql.queryMulti(sql)); - - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int", "f_long"), - Arrays.asList(Types.INTEGER, Types.BIGINT)); - - BeamRecord record1 = new BeamRecord(resultType - , recordsInTableA.get(0).getFieldValue(0), recordsInTableA.get(0).getFieldValue(1)); + .apply("testPartialFieldsInMultipleRow", BeamSql.query(sql)); - BeamRecord record2 = new BeamRecord(resultType - , recordsInTableA.get(1).getFieldValue(0), recordsInTableA.get(1).getFieldValue(1)); + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int") + .withBigIntField("f_long") + .build(); - BeamRecord record3 = new BeamRecord(resultType - , recordsInTableA.get(2).getFieldValue(0), recordsInTableA.get(2).getFieldValue(1)); + List expectedRows = + IntStream + .range(0, 4) + .mapToObj(i -> rowAtIndex(resultType, i)) + .collect(toList()); - BeamRecord record4 = new BeamRecord(resultType - , recordsInTableA.get(3).getFieldValue(0), recordsInTableA.get(3).getFieldValue(1)); - - PAssert.that(result).containsInAnyOrder(record1, record2, record3, record4); + PAssert + .that(result) + .containsInAnyOrder(expectedRows); pipeline.run().waitUntilFinish(); } + private Row rowAtIndex(RowType rowType, int index) { + return Row + .withRowType(rowType) + .addValues( + rowsInTableA.get(index).getValue(0), + rowsInTableA.get(index).getValue(1)) + .build(); + } + /** * select partial fields with bounded PCollection. */ @@ -150,29 +163,29 @@ public void testPartialFieldsInRowsWithUnbounded() throws Exception { runPartialFieldsInRows(unboundedInput1); } - private void runPartialFieldsInRows(PCollection input) throws Exception { + private void runPartialFieldsInRows(PCollection input) throws Exception { String sql = "SELECT f_int, f_long FROM TABLE_A"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testPartialFieldsInRows", BeamSql.queryMulti(sql)); - - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int", "f_long"), - Arrays.asList(Types.INTEGER, Types.BIGINT)); - - BeamRecord record1 = new BeamRecord(resultType - , recordsInTableA.get(0).getFieldValue(0), recordsInTableA.get(0).getFieldValue(1)); - - BeamRecord record2 = new BeamRecord(resultType - , recordsInTableA.get(1).getFieldValue(0), recordsInTableA.get(1).getFieldValue(1)); + .apply("testPartialFieldsInRows", BeamSql.query(sql)); - BeamRecord record3 = new BeamRecord(resultType - , recordsInTableA.get(2).getFieldValue(0), recordsInTableA.get(2).getFieldValue(1)); + RowType resultType = + RowSqlType + .builder() + .withIntegerField("f_int") + .withBigIntField("f_long") + .build(); - BeamRecord record4 = new BeamRecord(resultType - , recordsInTableA.get(3).getFieldValue(0), recordsInTableA.get(3).getFieldValue(1)); + List expectedRows = + IntStream + .range(0, 4) + .mapToObj(i -> rowAtIndex(resultType, i)) + .collect(toList()); - PAssert.that(result).containsInAnyOrder(record1, record2, record3, record4); + PAssert + .that(result) + .containsInAnyOrder(expectedRows); pipeline.run().waitUntilFinish(); } @@ -193,19 +206,19 @@ public void testLiteralFieldWithUnbounded() throws Exception { runLiteralField(unboundedInput2); } - public void runLiteralField(PCollection input) throws Exception { + public void runLiteralField(PCollection input) throws Exception { String sql = "SELECT 1 as literal_field FROM TABLE_A"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), input) - .apply("testLiteralField", BeamSql.queryMulti(sql)); + .apply("testLiteralField", BeamSql.query(sql)); - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("literal_field"), - Arrays.asList(Types.INTEGER)); + RowType resultType = + RowSqlType.builder().withIntegerField("literal_field").build(); - BeamRecord record = new BeamRecord(resultType, 1); + Row row = Row.withRowType(resultType).addValues(1).build(); - PAssert.that(result).containsInAnyOrder(record); + PAssert.that(result).containsInAnyOrder(row); pipeline.run().waitUntilFinish(); } @@ -218,9 +231,9 @@ public void testProjectUnknownField() throws Exception { String sql = "SELECT f_int_na FROM TABLE_A"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), boundedInput1) - .apply("testProjectUnknownField", BeamSql.queryMulti(sql)); + .apply("testProjectUnknownField", BeamSql.query(sql)); pipeline.run().waitUntilFinish(); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java index bb7d8ac7befa..31bd9edfc0a2 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java @@ -17,14 +17,13 @@ */ package org.apache.beam.sdk.extensions.sql; -import java.sql.Types; -import java.util.Arrays; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.beam.sdk.values.TupleTag; import org.junit.Test; @@ -37,24 +36,31 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase { */ @Test public void testUdaf() throws Exception { - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int2", "squaresum"), - Arrays.asList(Types.INTEGER, Types.INTEGER)); + RowType resultType = RowSqlType.builder() + .withIntegerField("f_int2") + .withIntegerField("squaresum") + .build(); - BeamRecord record = new BeamRecord(resultType, 0, 30); + Row row = Row.withRowType(resultType).addValues(0, 30).build(); String sql1 = "SELECT f_int2, squaresum1(f_int) AS `squaresum`" + " FROM PCOLLECTION GROUP BY f_int2"; - PCollection result1 = - boundedInput1.apply("testUdaf1", - BeamSql.query(sql1).withUdaf("squaresum1", new SquareSum())); - PAssert.that(result1).containsInAnyOrder(record); + PCollection result1 = + boundedInput1.apply( + "testUdaf1", + BeamSql.query(sql1).registerUdaf("squaresum1", new SquareSum())); + PAssert.that(result1).containsInAnyOrder(row); String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum`" + " FROM PCOLLECTION GROUP BY f_int2"; - PCollection result2 = - PCollectionTuple.of(new TupleTag<>("PCOLLECTION"), boundedInput1) - .apply("testUdaf2", BeamSql.queryMulti(sql2).withUdaf("squaresum2", new SquareSum())); - PAssert.that(result2).containsInAnyOrder(record); + PCollection result2 = + PCollectionTuple + .of(new TupleTag<>("PCOLLECTION"), boundedInput1) + .apply("testUdaf2", + BeamSql + .query(sql2) + .registerUdaf("squaresum2", new SquareSum())); + PAssert.that(result2).containsInAnyOrder(row); pipeline.run().waitUntilFinish(); } @@ -64,22 +70,25 @@ public void testUdaf() throws Exception { */ @Test public void testUdf() throws Exception{ - BeamRecordSqlType resultType = BeamRecordSqlType.create(Arrays.asList("f_int", "cubicvalue"), - Arrays.asList(Types.INTEGER, Types.INTEGER)); + RowType resultType = RowSqlType.builder() + .withIntegerField("f_int") + .withIntegerField("cubicvalue") + .build(); - BeamRecord record = new BeamRecord(resultType, 2, 8); + Row row = Row.withRowType(resultType).addValues(2, 8).build(); String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2"; - PCollection result1 = + PCollection result1 = boundedInput1.apply("testUdf1", - BeamSql.query(sql1).withUdf("cubic1", CubicInteger.class)); - PAssert.that(result1).containsInAnyOrder(record); + BeamSql.query(sql1).registerUdf("cubic1", CubicInteger.class)); + PAssert.that(result1).containsInAnyOrder(row); String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2"; - PCollection result2 = + PCollection result2 = PCollectionTuple.of(new TupleTag<>("PCOLLECTION"), boundedInput1) - .apply("testUdf2", BeamSql.queryMulti(sql2).withUdf("cubic2", new CubicIntegerFn())); - PAssert.that(result2).containsInAnyOrder(record); + .apply("testUdf2", + BeamSql.query(sql2).registerUdf("cubic2", new CubicIntegerFn())); + PAssert.that(result2).containsInAnyOrder(row); pipeline.run().waitUntilFinish(); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlNonAsciiTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlNonAsciiTest.java index 8c8537da3a4d..219c84ab47e8 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlNonAsciiTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlNonAsciiTest.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.extensions.sql; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.junit.Test; @@ -33,11 +33,11 @@ public class BeamSqlNonAsciiTest extends BeamSqlDslBase { public void testDefaultCharsetLiteral() { String sql = "SELECT * FROM TABLE_A WHERE f_string = '第四行'"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), boundedInput1) - .apply("testCompositeFilter", BeamSql.queryMulti(sql)); + .apply("testCompositeFilter", BeamSql.query(sql)); - PAssert.that(result).containsInAnyOrder(recordsInTableA.get(3)); + PAssert.that(result).containsInAnyOrder(rowsInTableA.get(3)); pipeline.run().waitUntilFinish(); } @@ -46,11 +46,11 @@ public void testDefaultCharsetLiteral() { public void testNationalCharsetLiteral() { String sql = "SELECT * FROM TABLE_A WHERE f_string = N'第四行'"; - PCollection result = + PCollection result = PCollectionTuple.of(new TupleTag<>("TABLE_A"), boundedInput1) - .apply("testCompositeFilter", BeamSql.queryMulti(sql)); + .apply("testCompositeFilter", BeamSql.query(sql)); - PAssert.that(result).containsInAnyOrder(recordsInTableA.get(3)); + PAssert.that(result).containsInAnyOrder(rowsInTableA.get(3)); pipeline.run().waitUntilFinish(); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/SqlRowTypeFactoryTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/SqlRowTypeFactoryTest.java new file mode 100644 index 000000000000..5fbb33087a05 --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/SqlRowTypeFactoryTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.extensions.sql; + +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableList; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Date; +import java.util.GregorianCalendar; +import java.util.List; +import org.apache.beam.sdk.values.RowType; +import org.apache.beam.sdk.values.reflect.FieldValueGetter; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * Unit tests for {@link SqlRowTypeFactory}. + */ +public class SqlRowTypeFactoryTest { + + private static final List GETTERS_FOR_KNOWN_TYPES = ImmutableList + .builder() + .add(getter("byteGetter", Byte.class)) + .add(getter("shortGetter", Short.class)) + .add(getter("integerGetter", Integer.class)) + .add(getter("longGetter", Long.class)) + .add(getter("floatGetter", Float.class)) + .add(getter("doubleGetter", Double.class)) + .add(getter("bigDecimalGetter", BigDecimal.class)) + .add(getter("booleanGetter", Boolean.class)) + .add(getter("stringGetter", String.class)) + .add(getter("timeGetter", GregorianCalendar.class)) + .add(getter("dateGetter", Date.class)) + .build(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testContainsCorrectFields() throws Exception { + SqlRowTypeFactory factory = new SqlRowTypeFactory(); + + RowType rowType = factory.createRowType(GETTERS_FOR_KNOWN_TYPES); + + assertEquals(GETTERS_FOR_KNOWN_TYPES.size(), rowType.getFieldCount()); + assertEquals( + Arrays.asList( + "byteGetter", + "shortGetter", + "integerGetter", + "longGetter", + "floatGetter", + "doubleGetter", + "bigDecimalGetter", + "booleanGetter", + "stringGetter", + "timeGetter", + "dateGetter"), + rowType.getFieldNames()); + } + + @Test + public void testContainsCorrectCoders() throws Exception { + SqlRowTypeFactory factory = new SqlRowTypeFactory(); + + RowType rowType = factory.createRowType(GETTERS_FOR_KNOWN_TYPES); + + assertEquals(GETTERS_FOR_KNOWN_TYPES.size(), rowType.getFieldCount()); + assertEquals( + Arrays.asList( + SqlTypeCoders.TINYINT, + SqlTypeCoders.SMALLINT, + SqlTypeCoders.INTEGER, + SqlTypeCoders.BIGINT, + SqlTypeCoders.FLOAT, + SqlTypeCoders.DOUBLE, + SqlTypeCoders.DECIMAL, + SqlTypeCoders.BOOLEAN, + SqlTypeCoders.VARCHAR, + SqlTypeCoders.TIME, + SqlTypeCoders.TIMESTAMP), + rowType.getRowCoder().getCoders()); + } + + @Test + public void testThrowsForUnsupportedTypes() throws Exception { + thrown.expect(UnsupportedOperationException.class); + + SqlRowTypeFactory factory = new SqlRowTypeFactory(); + + factory.createRowType( + Arrays.asList(getter("arrayListGetter", ArrayList.class))); + } + + private static FieldValueGetter getter(final String fieldName, final Class fieldType) { + return new FieldValueGetter() { + @Override + public Object get(Object object) { + return null; + } + + @Override + public String name() { + return fieldName; + } + + @Override + public Class type() { + return fieldType; + } + }; + } +} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java index d4cc53a74bc0..fff1f3e266fa 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/TestUtils.java @@ -18,11 +18,27 @@ package org.apache.beam.sdk.extensions.sql; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.stream.Collectors.toList; +import static org.apache.beam.sdk.values.Row.toRow; +import static org.apache.beam.sdk.values.RowType.toRowType; + +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Stream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Instant; /** * Test utilities. @@ -31,7 +47,7 @@ public class TestUtils { /** * A {@code DoFn} to convert a {@code BeamSqlRow} to a comparable {@code String}. */ - public static class BeamSqlRow2StringDoFn extends DoFn { + public static class BeamSqlRow2StringDoFn extends DoFn { @ProcessElement public void processElement(ProcessContext ctx) { ctx.output(ctx.element().toString()); @@ -41,20 +57,24 @@ public void processElement(ProcessContext ctx) { /** * Convert list of {@code BeamSqlRow} to list of {@code String}. */ - public static List beamSqlRows2Strings(List rows) { + public static List beamSqlRows2Strings(List rows) { List strs = new ArrayList<>(); - for (BeamRecord row : rows) { + for (Row row : rows) { strs.add(row.toString()); } return strs; } + public static RowsBuilder rowsBuilderOf(RowType type) { + return RowsBuilder.of(type); + } + /** * Convenient way to build a list of {@code BeamSqlRow}s. * *

    You can use it like this: - * + *

    {@code
        * TestUtils.RowsBuilder.of(
        *   Types.INTEGER, "order_id",
    @@ -68,8 +88,8 @@ public static List beamSqlRows2Strings(List rows) {
        * {@code}
        */
       public static class RowsBuilder {
    -    private BeamRecordSqlType type;
    -    private List rows = new ArrayList<>();
    +    private RowType type;
    +    private List rows = new ArrayList<>();
     
         /**
          * Create a RowsBuilder with the specified row type info.
    @@ -85,7 +105,7 @@ public static class RowsBuilder {
          * @args pairs of column type and column names.
          */
         public static RowsBuilder of(final Object... args) {
    -      BeamRecordSqlType beamSQLRowType = buildBeamSqlRowType(args);
    +      RowType beamSQLRowType = buildBeamSqlRowType(args);
           RowsBuilder builder = new RowsBuilder();
           builder.type = beamSQLRowType;
     
    @@ -98,13 +118,14 @@ public static RowsBuilder of(final Object... args) {
          * 

    For example: *

    {@code
          * TestUtils.RowsBuilder.of(
    -     *   beamRecordSqlType
    +     *   rowType
          * )}
    - * @beamSQLRowType the record type. + * + * @beamSQLRowType the row type. */ - public static RowsBuilder of(final BeamRecordSqlType beamSQLRowType) { + public static RowsBuilder of(final RowType rowType) { RowsBuilder builder = new RowsBuilder(); - builder.type = beamSQLRowType; + builder.type = rowType; return builder; } @@ -129,13 +150,84 @@ public RowsBuilder addRows(final List args) { return this; } - public List getRows() { + public List getRows() { return rows; } public List getStringRows() { return beamSqlRows2Strings(rows); } + + public PCollectionBuilder getPCollectionBuilder() { + return + pCollectionBuilder() + .withRowType(type) + .withRows(rows); + } + } + + public static PCollectionBuilder pCollectionBuilder() { + return new PCollectionBuilder(); + } + + static class PCollectionBuilder { + private RowType type; + private List rows; + private String timestampField; + private Pipeline pipeline; + + public PCollectionBuilder withRowType(RowType type) { + this.type = type; + return this; + } + + public PCollectionBuilder withRows(List rows) { + this.rows = rows; + return this; + } + + /** + * Event time field, defines watermark. + */ + public PCollectionBuilder withTimestampField(String timestampField) { + this.timestampField = timestampField; + return this; + } + + public PCollectionBuilder inPipeline(Pipeline pipeline) { + this.pipeline = pipeline; + return this; + } + + /** + * Builds an unbounded {@link PCollection} in {@link Pipeline} + * set by {@link #inPipeline(Pipeline)}. + * + *

    If timestamp field was set with {@link #withTimestampField(String)} then + * watermark will be advanced to the values from that field. + */ + public PCollection buildUnbounded() { + checkArgument(pipeline != null); + checkArgument(rows.size() > 0); + + if (type == null) { + type = rows.get(0).getRowType(); + } + + TestStream.Builder values = TestStream.create(type.getRowCoder()); + + for (Row row : rows) { + if (timestampField != null) { + values = values.advanceWatermarkTo(new Instant(row.getDate(timestampField))); + } + + values = values.addElements(row); + } + + return PBegin + .in(pipeline) + .apply("unboundedPCollection", values.advanceWatermarkToInfinity()); + } } /** @@ -145,23 +237,24 @@ public List getStringRows() { * *

    {@code
        *   buildBeamSqlRowType(
    -   *       Types.BIGINT, "order_id",
    -   *       Types.INTEGER, "site_id",
    -   *       Types.DOUBLE, "price",
    -   *       Types.TIMESTAMP, "order_time"
    +   *       SqlCoders.BIGINT, "order_id",
    +   *       SqlCoders.INTEGER, "site_id",
    +   *       SqlCoders.DOUBLE, "price",
    +   *       SqlCoders.TIMESTAMP, "order_time"
        *   )
        * }
    */ - public static BeamRecordSqlType buildBeamSqlRowType(Object... args) { - List types = new ArrayList<>(); - List names = new ArrayList<>(); - - for (int i = 0; i < args.length - 1; i += 2) { - types.add((int) args[i]); - names.add((String) args[i + 1]); - } + public static RowType buildBeamSqlRowType(Object... args) { + return + Stream + .iterate(0, i -> i + 2) + .limit(args.length / 2) + .map(i -> toRecordField(args, i)) + .collect(toRowType()); + } - return BeamRecordSqlType.create(names, types); + private static RowType.Field toRecordField(Object[] args, int i) { + return RowType.newField((String) args[i + 1], (Coder) args[i]); } /** @@ -178,13 +271,29 @@ public static BeamRecordSqlType buildBeamSqlRowType(Object... args) { * ) * }
    */ - public static List buildRows(BeamRecordSqlType type, List args) { - List rows = new ArrayList<>(); - int fieldCount = type.getFieldCount(); + public static List buildRows(RowType type, List rowsValues) { + return + Lists + .partition(rowsValues, type.getFieldCount()) + .stream() + .map(values -> values.stream().collect(toRow(type))) + .collect(toList()); + } - for (int i = 0; i < args.size(); i += fieldCount) { - rows.add(new BeamRecord(type, args.subList(i, i + fieldCount))); - } - return rows; + public static PCollectionTuple tuple(String tag, PCollection pCollection) { + return PCollectionTuple.of(new TupleTag<>(tag), pCollection); + } + + public static PCollectionTuple tuple(String tag1, PCollection pCollection1, + String tag2, PCollection pCollection2) { + return tuple(tag1, pCollection1).and(new TupleTag<>(tag2), pCollection2); + } + + public static PCollectionTuple tuple(String tag1, PCollection pCollection1, + String tag2, PCollection pCollection2, + String tag3, PCollection pCollection3) { + return tuple( + tag1, pCollection1, + tag2, pCollection2).and(new TupleTag<>(tag3), pCollection3); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java index 9d1212643966..b6ac343a0ab3 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/BeamSqlFnExecutorTestBase.java @@ -19,13 +19,12 @@ import java.util.ArrayList; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelDataTypeSystem; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRuleSets; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.config.Lex; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; @@ -50,17 +49,13 @@ * base class to test {@link BeamSqlFnExecutor} and subclasses of {@link BeamSqlExpression}. */ public class BeamSqlFnExecutorTestBase { - public static RexBuilder rexBuilder = new RexBuilder(BeamQueryPlanner.TYPE_FACTORY); - public static RelOptCluster cluster = RelOptCluster.create(new VolcanoPlanner(), rexBuilder); + static final JavaTypeFactory TYPE_FACTORY = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT); + static RexBuilder rexBuilder = new RexBuilder(BeamQueryPlanner.TYPE_FACTORY); + static RelOptCluster cluster = RelOptCluster.create(new VolcanoPlanner(), rexBuilder); + static RelDataType relDataType; + static RelBuilder relBuilder; - public static final JavaTypeFactory TYPE_FACTORY = new JavaTypeFactoryImpl( - RelDataTypeSystem.DEFAULT); - public static RelDataType relDataType; - - public static BeamRecordSqlType beamRowType; - public static BeamRecord record; - - public static RelBuilder relBuilder; + public static Row row; @BeforeClass public static void prepare() { @@ -70,9 +65,15 @@ public static void prepare() { .add("price", SqlTypeName.DOUBLE) .add("order_time", SqlTypeName.BIGINT).build(); - beamRowType = CalciteUtils.toBeamRowType(relDataType); - record = new BeamRecord(beamRowType - , 1234567L, 0, 8.9, 1234567L); + row = + Row + .withRowType(CalciteUtils.toBeamRowType(relDataType)) + .addValues( + 1234567L, + 0, + 8.9, + 1234567L) + .build(); SchemaPlus schema = Frameworks.createRootSchema(true); final List traitDefs = new ArrayList<>(); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamNullExperssionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamNullExperssionTest.java index 1bcda2cf121f..839be751e269 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamNullExperssionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamNullExperssionTest.java @@ -34,22 +34,22 @@ public class BeamNullExperssionTest extends BeamSqlFnExecutorTestBase { public void testIsNull() { BeamSqlIsNullExpression exp1 = new BeamSqlIsNullExpression( new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0)); - Assert.assertEquals(false, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp1.evaluate(row, null).getValue()); BeamSqlIsNullExpression exp2 = new BeamSqlIsNullExpression( BeamSqlPrimitive.of(SqlTypeName.BIGINT, null)); - Assert.assertEquals(true, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp2.evaluate(row, null).getValue()); } @Test public void testIsNotNull() { BeamSqlIsNotNullExpression exp1 = new BeamSqlIsNotNullExpression( new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0)); - Assert.assertEquals(true, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp1.evaluate(row, null).getValue()); BeamSqlIsNotNullExpression exp2 = new BeamSqlIsNotNullExpression( BeamSqlPrimitive.of(SqlTypeName.BIGINT, null)); - Assert.assertEquals(false, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp2.evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlAndOrExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlAndOrExpressionTest.java index 51a170d92935..090dde907897 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlAndOrExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlAndOrExpressionTest.java @@ -37,11 +37,11 @@ public void testAnd() { operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, true)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, true)); - Assert.assertTrue(new BeamSqlAndExpression(operands).evaluate(record, null).getValue()); + Assert.assertTrue(new BeamSqlAndExpression(operands).evaluate(row, null).getValue()); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, false)); - Assert.assertFalse(new BeamSqlAndExpression(operands).evaluate(record, null).getValue()); + Assert.assertFalse(new BeamSqlAndExpression(operands).evaluate(row, null).getValue()); } @Test @@ -50,11 +50,11 @@ public void testOr() { operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, false)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, false)); - Assert.assertFalse(new BeamSqlOrExpression(operands).evaluate(record, null).getValue()); + Assert.assertFalse(new BeamSqlOrExpression(operands).evaluate(row, null).getValue()); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, true)); - Assert.assertTrue(new BeamSqlOrExpression(operands).evaluate(record, null).getValue()); + Assert.assertTrue(new BeamSqlOrExpression(operands).evaluate(row, null).getValue()); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpressionTest.java index e02554ff3b1b..64d9161e83a6 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCaseExpressionTest.java @@ -72,14 +72,14 @@ public class BeamSqlCaseExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "world")); assertEquals("hello", new BeamSqlCaseExpression(operands) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, false)); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "world")); assertEquals("world", new BeamSqlCaseExpression(operands) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, false)); @@ -88,6 +88,6 @@ public class BeamSqlCaseExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello1")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "world")); assertEquals("hello1", new BeamSqlCaseExpression(operands) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpressionTest.java index f4e3cf9694f1..999e309acdde 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCastExpressionTest.java @@ -52,14 +52,14 @@ public void testForOperands() { public void testForIntegerToBigintTypeCasting() { operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 5)); Assert.assertEquals(5L, - new BeamSqlCastExpression(operands, SqlTypeName.BIGINT).evaluate(record, null).getLong()); + new BeamSqlCastExpression(operands, SqlTypeName.BIGINT).evaluate(row, null).getLong()); } @Test public void testForDoubleToBigIntCasting() { operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 5.45)); Assert.assertEquals(5L, - new BeamSqlCastExpression(operands, SqlTypeName.BIGINT).evaluate(record, null).getLong()); + new BeamSqlCastExpression(operands, SqlTypeName.BIGINT).evaluate(row, null).getLong()); } @Test @@ -67,7 +67,7 @@ public void testForIntegerToDateCast() { // test for yyyyMMdd format operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 20170521)); Assert.assertEquals(Date.valueOf("2017-05-21"), - new BeamSqlCastExpression(operands, SqlTypeName.DATE).evaluate(record, null).getValue()); + new BeamSqlCastExpression(operands, SqlTypeName.DATE).evaluate(row, null).getValue()); } @Test @@ -75,7 +75,7 @@ public void testyyyyMMddDateFormat() { //test for yyyy-MM-dd format operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "2017-05-21")); Assert.assertEquals(Date.valueOf("2017-05-21"), - new BeamSqlCastExpression(operands, SqlTypeName.DATE).evaluate(record, null).getValue()); + new BeamSqlCastExpression(operands, SqlTypeName.DATE).evaluate(row, null).getValue()); } @Test @@ -83,14 +83,14 @@ public void testyyMMddDateFormat() { // test for yy.MM.dd format operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "17.05.21")); Assert.assertEquals(Date.valueOf("2017-05-21"), - new BeamSqlCastExpression(operands, SqlTypeName.DATE).evaluate(record, null).getValue()); + new BeamSqlCastExpression(operands, SqlTypeName.DATE).evaluate(row, null).getValue()); } @Test public void testForTimestampCastExpression() { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "17-05-21 23:59:59.989")); Assert.assertEquals(SqlTypeName.TIMESTAMP, - new BeamSqlCastExpression(operands, SqlTypeName.TIMESTAMP).evaluate(record, null) + new BeamSqlCastExpression(operands, SqlTypeName.TIMESTAMP).evaluate(row, null) .getOutputType()); } @@ -99,7 +99,7 @@ public void testDateTimeFormatWithMillis() { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "2017-05-21 23:59:59.989")); Assert.assertEquals(Timestamp.valueOf("2017-05-22 00:00:00.0"), new BeamSqlCastExpression(operands, SqlTypeName.TIMESTAMP) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); } @Test @@ -107,7 +107,7 @@ public void testDateTimeFormatWithTimezone() { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "2017-05-21 23:59:59.89079 PST")); Assert.assertEquals(Timestamp.valueOf("2017-05-22 00:00:00.0"), new BeamSqlCastExpression(operands, SqlTypeName.TIMESTAMP) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); } @Test @@ -115,7 +115,7 @@ public void testDateTimeFormat() { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "2017-05-21 23:59:59")); Assert.assertEquals(Timestamp.valueOf("2017-05-21 23:59:59"), new BeamSqlCastExpression(operands, SqlTypeName.TIMESTAMP) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); } @Test(expected = RuntimeException.class) @@ -123,7 +123,7 @@ public void testForCastTypeNotSupported() { operands.add(BeamSqlPrimitive.of(SqlTypeName.TIME, Calendar.getInstance().getTime())); Assert.assertEquals(Timestamp.valueOf("2017-05-22 00:00:00.0"), new BeamSqlCastExpression(operands, SqlTypeName.TIMESTAMP) - .evaluate(record, null).getValue()); + .evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCompareExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCompareExpressionTest.java index 8aad6b38b54c..2af6e9d9c22d 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCompareExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlCompareExpressionTest.java @@ -40,12 +40,12 @@ public void testEqual() { BeamSqlEqualsExpression exp1 = new BeamSqlEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 100L))); - Assert.assertEquals(false, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp1.evaluate(row, null).getValue()); BeamSqlEqualsExpression exp2 = new BeamSqlEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1234567L))); - Assert.assertEquals(true, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp2.evaluate(row, null).getValue()); } @Test @@ -53,12 +53,12 @@ public void testLargerThan(){ BeamSqlGreaterThanExpression exp1 = new BeamSqlGreaterThanExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1234567L))); - Assert.assertEquals(false, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp1.evaluate(row, null).getValue()); BeamSqlGreaterThanExpression exp2 = new BeamSqlGreaterThanExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1234566L))); - Assert.assertEquals(true, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp2.evaluate(row, null).getValue()); } @Test @@ -66,12 +66,12 @@ public void testLargerThanEqual(){ BeamSqlGreaterThanOrEqualsExpression exp1 = new BeamSqlGreaterThanOrEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1234567L))); - Assert.assertEquals(true, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp1.evaluate(row, null).getValue()); BeamSqlGreaterThanOrEqualsExpression exp2 = new BeamSqlGreaterThanOrEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1234568L))); - Assert.assertEquals(false, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp2.evaluate(row, null).getValue()); } @Test @@ -79,12 +79,12 @@ public void testLessThan(){ BeamSqlLessThanExpression exp1 = new BeamSqlLessThanExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.INTEGER, 1), BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1))); - Assert.assertEquals(true, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp1.evaluate(row, null).getValue()); BeamSqlLessThanExpression exp2 = new BeamSqlLessThanExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.INTEGER, 1), BeamSqlPrimitive.of(SqlTypeName.INTEGER, -1))); - Assert.assertEquals(false, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp2.evaluate(row, null).getValue()); } @Test @@ -92,12 +92,12 @@ public void testLessThanEqual(){ BeamSqlLessThanOrEqualsExpression exp1 = new BeamSqlLessThanOrEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.DOUBLE, 2), BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 8.9))); - Assert.assertEquals(true, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp1.evaluate(row, null).getValue()); BeamSqlLessThanOrEqualsExpression exp2 = new BeamSqlLessThanOrEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.DOUBLE, 2), BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 8.0))); - Assert.assertEquals(false, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp2.evaluate(row, null).getValue()); } @Test @@ -105,11 +105,11 @@ public void testNotEqual(){ BeamSqlNotEqualsExpression exp1 = new BeamSqlNotEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 3), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1234567L))); - Assert.assertEquals(false, exp1.evaluate(record, null).getValue()); + Assert.assertEquals(false, exp1.evaluate(row, null).getValue()); BeamSqlNotEqualsExpression exp2 = new BeamSqlNotEqualsExpression( Arrays.asList(new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 3), BeamSqlPrimitive.of(SqlTypeName.BIGINT, 0L))); - Assert.assertEquals(true, exp2.evaluate(record, null).getValue()); + Assert.assertEquals(true, exp2.evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpressionTest.java index e543d4ff9a23..94c88800e0ab 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpressionTest.java @@ -30,28 +30,28 @@ public class BeamSqlInputRefExpressionTest extends BeamSqlFnExecutorTestBase { @Test public void testRefInRange() { BeamSqlInputRefExpression ref0 = new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0); - Assert.assertEquals(record.getLong(0), ref0.evaluate(record, null).getValue()); + Assert.assertEquals(row.getLong(0), ref0.evaluate(row, null).getValue()); BeamSqlInputRefExpression ref1 = new BeamSqlInputRefExpression(SqlTypeName.INTEGER, 1); - Assert.assertEquals(record.getInteger(1), ref1.evaluate(record, null).getValue()); + Assert.assertEquals(row.getInteger(1), ref1.evaluate(row, null).getValue()); BeamSqlInputRefExpression ref2 = new BeamSqlInputRefExpression(SqlTypeName.DOUBLE, 2); - Assert.assertEquals(record.getDouble(2), ref2.evaluate(record, null).getValue()); + Assert.assertEquals(row.getDouble(2), ref2.evaluate(row, null).getValue()); BeamSqlInputRefExpression ref3 = new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 3); - Assert.assertEquals(record.getLong(3), ref3.evaluate(record, null).getValue()); + Assert.assertEquals(row.getLong(3), ref3.evaluate(row, null).getValue()); } @Test(expected = IndexOutOfBoundsException.class) public void testRefOutOfRange(){ BeamSqlInputRefExpression ref = new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 4); - ref.evaluate(record, null).getValue(); + ref.evaluate(row, null).getValue(); } @Test(expected = IllegalArgumentException.class) public void testTypeUnMatch(){ BeamSqlInputRefExpression ref = new BeamSqlInputRefExpression(SqlTypeName.INTEGER, 0); - ref.evaluate(record, null).getValue(); + ref.evaluate(row, null).getValue(); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitiveTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitiveTest.java index 81f9ce0888d3..e2acb21b1114 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitiveTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlPrimitiveTest.java @@ -31,28 +31,28 @@ public class BeamSqlPrimitiveTest extends BeamSqlFnExecutorTestBase { @Test public void testPrimitiveInt(){ BeamSqlPrimitive expInt = BeamSqlPrimitive.of(SqlTypeName.INTEGER, 100); - Assert.assertEquals(expInt.getValue(), expInt.evaluate(record, null).getValue()); + Assert.assertEquals(expInt.getValue(), expInt.evaluate(row, null).getValue()); } @Test(expected = IllegalArgumentException.class) public void testPrimitiveTypeUnMatch1(){ BeamSqlPrimitive expInt = BeamSqlPrimitive.of(SqlTypeName.INTEGER, 100L); - Assert.assertEquals(expInt.getValue(), expInt.evaluate(record, null).getValue()); + Assert.assertEquals(expInt.getValue(), expInt.evaluate(row, null).getValue()); } @Test(expected = IllegalArgumentException.class) public void testPrimitiveTypeUnMatch2(){ BeamSqlPrimitive expInt = BeamSqlPrimitive.of(SqlTypeName.DECIMAL, 100L); - Assert.assertEquals(expInt.getValue(), expInt.evaluate(record, null).getValue()); + Assert.assertEquals(expInt.getValue(), expInt.evaluate(row, null).getValue()); } @Test(expected = IllegalArgumentException.class) public void testPrimitiveTypeUnMatch3(){ BeamSqlPrimitive expInt = BeamSqlPrimitive.of(SqlTypeName.FLOAT, 100L); - Assert.assertEquals(expInt.getValue(), expInt.evaluate(record, null).getValue()); + Assert.assertEquals(expInt.getValue(), expInt.evaluate(row, null).getValue()); } @Test(expected = IllegalArgumentException.class) public void testPrimitiveTypeUnMatch4(){ BeamSqlPrimitive expInt = BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 100L); - Assert.assertEquals(expInt.getValue(), expInt.evaluate(record, null).getValue()); + Assert.assertEquals(expInt.getValue(), expInt.evaluate(row, null).getValue()); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlReinterpretExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlReinterpretExpressionTest.java index 3d7b8ada2e74..baeeb8685d88 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlReinterpretExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlReinterpretExpressionTest.java @@ -25,11 +25,10 @@ import java.util.Arrays; import java.util.Date; import java.util.GregorianCalendar; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlFnExecutorTestBase; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.reinterpret.BeamSqlReinterpretExpression; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Test; @@ -41,7 +40,7 @@ public class BeamSqlReinterpretExpressionTest extends BeamSqlFnExecutorTestBase private static final Date DATE = new Date(DATE_LONG); private static final GregorianCalendar CALENDAR = new GregorianCalendar(2017, 8, 9); - private static final BeamRecord NULL_ROW = null; + private static final Row NULL_ROW = null; private static final BoundedWindow NULL_WINDOW = null; private static final BeamSqlExpression DATE_PRIMITIVE = BeamSqlPrimitive.of( diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpressionTest.java index 19098a6db2d7..aa9045c4360a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlUdfExpressionTest.java @@ -37,7 +37,7 @@ public void testUdf() throws NoSuchMethodException, SecurityException { BeamSqlUdfExpression exp = new BeamSqlUdfExpression( UdfFn.class.getMethod("negative", Integer.class), operands, SqlTypeName.INTEGER); - Assert.assertEquals(-10, exp.evaluate(record, null).getValue()); + Assert.assertEquals(-10, exp.evaluate(row, null).getValue()); } /** diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpressionTest.java index 88eaa410739b..d2e950418e1a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/arithmetic/BeamSqlArithmeticExpressionTest.java @@ -84,32 +84,32 @@ public class BeamSqlArithmeticExpressionTest extends BeamSqlFnExecutorTestBase { // integer + integer => integer operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); - assertEquals(2, new BeamSqlPlusExpression(operands).evaluate(record, null).getValue()); + assertEquals(2, new BeamSqlPlusExpression(operands).evaluate(row, null).getValue()); // integer + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2L, new BeamSqlPlusExpression(operands).evaluate(record, null).getValue()); + assertEquals(2L, new BeamSqlPlusExpression(operands).evaluate(row, null).getValue()); // long + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2L, new BeamSqlPlusExpression(operands).evaluate(record, null).getValue()); + assertEquals(2L, new BeamSqlPlusExpression(operands).evaluate(row, null).getValue()); // float + long => float operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.FLOAT, 1.1F)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); assertEquals(1.1F + 1, - new BeamSqlPlusExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPlusExpression(operands).evaluate(row, null).getValue()); // double + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 1.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2.1, new BeamSqlPlusExpression(operands).evaluate(record, null).getValue()); + assertEquals(2.1, new BeamSqlPlusExpression(operands).evaluate(row, null).getValue()); } @Test public void testMinus() { @@ -118,32 +118,34 @@ public class BeamSqlArithmeticExpressionTest extends BeamSqlFnExecutorTestBase { // integer + integer => long operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); - assertEquals(1, new BeamSqlMinusExpression(operands).evaluate(record, null).getValue()); + assertEquals(1, new BeamSqlMinusExpression(operands).evaluate(row, null).getValue()); // integer + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(1L, new BeamSqlMinusExpression(operands).evaluate(record, null).getValue()); + assertEquals(1L, new BeamSqlMinusExpression(operands).evaluate(row, null).getValue()); // long + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(1L, new BeamSqlMinusExpression(operands).evaluate(record, null).getValue()); + assertEquals(1L, new BeamSqlMinusExpression(operands).evaluate(row, null).getValue()); // float + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.FLOAT, 2.1F)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2.1F - 1L, - new BeamSqlMinusExpression(operands).evaluate(record, null).getValue().floatValue(), 0.1); + assertEquals( + 2.1F - 1L, + new BeamSqlMinusExpression(operands).evaluate(row, null).getValue().floatValue(), + 0.1); // double + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(1.1, new BeamSqlMinusExpression(operands).evaluate(record, null).getValue()); + assertEquals(1.1, new BeamSqlMinusExpression(operands).evaluate(row, null).getValue()); } @Test public void testMultiply() { @@ -152,32 +154,32 @@ public class BeamSqlArithmeticExpressionTest extends BeamSqlFnExecutorTestBase { // integer + integer => integer operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); - assertEquals(2, new BeamSqlMultiplyExpression(operands).evaluate(record, null).getValue()); + assertEquals(2, new BeamSqlMultiplyExpression(operands).evaluate(row, null).getValue()); // integer + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2L, new BeamSqlMultiplyExpression(operands).evaluate(record, null).getValue()); + assertEquals(2L, new BeamSqlMultiplyExpression(operands).evaluate(row, null).getValue()); // long + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2L, new BeamSqlMultiplyExpression(operands).evaluate(record, null).getValue()); + assertEquals(2L, new BeamSqlMultiplyExpression(operands).evaluate(row, null).getValue()); // float + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.FLOAT, 2.1F)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); assertEquals(2.1F * 1L, - new BeamSqlMultiplyExpression(operands).evaluate(record, null).getValue()); + new BeamSqlMultiplyExpression(operands).evaluate(row, null).getValue()); // double + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2.1, new BeamSqlMultiplyExpression(operands).evaluate(record, null).getValue()); + assertEquals(2.1, new BeamSqlMultiplyExpression(operands).evaluate(row, null).getValue()); } @Test public void testDivide() { @@ -186,32 +188,32 @@ public class BeamSqlArithmeticExpressionTest extends BeamSqlFnExecutorTestBase { // integer + integer => integer operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); - assertEquals(2, new BeamSqlDivideExpression(operands).evaluate(record, null).getValue()); + assertEquals(2, new BeamSqlDivideExpression(operands).evaluate(row, null).getValue()); // integer + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2L, new BeamSqlDivideExpression(operands).evaluate(record, null).getValue()); + assertEquals(2L, new BeamSqlDivideExpression(operands).evaluate(row, null).getValue()); // long + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2L, new BeamSqlDivideExpression(operands).evaluate(record, null).getValue()); + assertEquals(2L, new BeamSqlDivideExpression(operands).evaluate(row, null).getValue()); // float + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.FLOAT, 2.1F)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); assertEquals(2.1F / 1, - new BeamSqlDivideExpression(operands).evaluate(record, null).getValue()); + new BeamSqlDivideExpression(operands).evaluate(row, null).getValue()); // double + long => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - assertEquals(2.1, new BeamSqlDivideExpression(operands).evaluate(record, null).getValue()); + assertEquals(2.1, new BeamSqlDivideExpression(operands).evaluate(row, null).getValue()); } @Test public void testMod() { @@ -220,18 +222,18 @@ public class BeamSqlArithmeticExpressionTest extends BeamSqlFnExecutorTestBase { // integer + integer => long operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 3)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); - assertEquals(1, new BeamSqlModExpression(operands).evaluate(record, null).getValue()); + assertEquals(1, new BeamSqlModExpression(operands).evaluate(row, null).getValue()); // integer + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 3)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); - assertEquals(1L, new BeamSqlModExpression(operands).evaluate(record, null).getValue()); + assertEquals(1L, new BeamSqlModExpression(operands).evaluate(row, null).getValue()); // long + long => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 3L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); - assertEquals(1L, new BeamSqlModExpression(operands).evaluate(record, null).getValue()); + assertEquals(1L, new BeamSqlModExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpressionTest.java index bfca72032768..84c464a9955b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentDateExpressionTest.java @@ -32,7 +32,7 @@ public void test() { Assert.assertEquals( SqlTypeName.DATE, new BeamSqlCurrentDateExpression() - .evaluate(BeamSqlFnExecutorTestBase.record, null).getOutputType() + .evaluate(BeamSqlFnExecutorTestBase.row, null).getOutputType() ); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpressionTest.java index af3cacd15399..1a33a29dc921 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimeExpressionTest.java @@ -34,6 +34,6 @@ public class BeamSqlCurrentTimeExpressionTest extends BeamSqlDateExpressionTestB public void test() { List operands = new ArrayList<>(); assertEquals(SqlTypeName.TIME, - new BeamSqlCurrentTimeExpression(operands).evaluate(record, null).getOutputType()); + new BeamSqlCurrentTimeExpression(operands).evaluate(row, null).getOutputType()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpressionTest.java index c171e403d776..26ee8b7f440b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlCurrentTimestampExpressionTest.java @@ -34,6 +34,6 @@ public class BeamSqlCurrentTimestampExpressionTest extends BeamSqlDateExpression public void test() { List operands = new ArrayList<>(); assertEquals(SqlTypeName.TIMESTAMP, - new BeamSqlCurrentTimestampExpression(operands).evaluate(record, null).getOutputType()); + new BeamSqlCurrentTimestampExpression(operands).evaluate(row, null).getOutputType()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpressionTest.java index 141bbf57a251..b2ead8ea2a1a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateCeilExpressionTest.java @@ -40,11 +40,11 @@ public class BeamSqlDateCeilExpressionTest extends BeamSqlDateExpressionTestBase operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.YEAR)); Assert.assertEquals(str2DateTime("2018-01-01 00:00:00"), new BeamSqlDateCeilExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getDate()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getDate()); operands.set(1, BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.MONTH)); Assert.assertEquals(str2DateTime("2017-06-01 00:00:00"), new BeamSqlDateCeilExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getDate()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getDate()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateExpressionTestBase.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateExpressionTestBase.java index cb0b6ec6cc6f..5e32e01e109a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateExpressionTestBase.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateExpressionTestBase.java @@ -22,7 +22,6 @@ import java.text.SimpleDateFormat; import java.util.Date; import java.util.TimeZone; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlFnExecutorTestBase; /** diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpressionTest.java index ede12ced5d7a..5a23c0de62ab 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDateFloorExpressionTest.java @@ -39,11 +39,11 @@ public class BeamSqlDateFloorExpressionTest extends BeamSqlDateExpressionTestBas // YEAR operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.YEAR)); assertEquals(str2DateTime("2017-01-01 00:00:00"), - new BeamSqlDateFloorExpression(operands).evaluate(record, null).getDate()); + new BeamSqlDateFloorExpression(operands).evaluate(row, null).getDate()); // MONTH operands.set(1, BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.MONTH)); assertEquals(str2DateTime("2017-05-01 00:00:00"), - new BeamSqlDateFloorExpression(operands).evaluate(record, null).getDate()); + new BeamSqlDateFloorExpression(operands).evaluate(row, null).getDate()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpressionTest.java index ef837ca75935..ed6f48ed345b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimeMinusExpressionTest.java @@ -25,11 +25,10 @@ import java.math.BigDecimal; import java.util.Arrays; import java.util.Date; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; @@ -40,7 +39,7 @@ */ public class BeamSqlDatetimeMinusExpressionTest { - private static final BeamRecord NULL_ROW = null; + private static final Row NULL_ROW = null; private static final BoundedWindow NULL_WINDOW = null; private static final Date DATE = new Date(329281L); @@ -58,9 +57,6 @@ public class BeamSqlDatetimeMinusExpressionTest { private static final BeamSqlPrimitive STRING = BeamSqlPrimitive.of( SqlTypeName.VARCHAR, "hello"); - private static final BeamSqlPrimitive INTERVAL_3_MONTHS = BeamSqlPrimitive.of( - SqlTypeName.INTERVAL_MONTH, TimeUnit.MONTH.multiplier.multiply(new BigDecimal(3))); - @Test public void testOutputType() { BeamSqlDatetimeMinusExpression minusExpression1 = minusExpression(SqlTypeName.TIMESTAMP, TIMESTAMP, INTERVAL_2_SEC); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpressionTest.java index 57e709f601cd..0cb5ce0dfee7 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlDatetimePlusExpressionTest.java @@ -25,11 +25,10 @@ import java.math.BigDecimal; import java.util.Arrays; import java.util.Date; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; @@ -43,7 +42,7 @@ public class BeamSqlDatetimePlusExpressionTest extends BeamSqlDateExpressionTestBase { @Rule public ExpectedException thrown = ExpectedException.none(); - private static final BeamRecord NULL_INPUT_ROW = null; + private static final Row NULL_INPUT_ROW = null; private static final BoundedWindow NULL_WINDOW = null; private static final Date DATE = str2DateTime("1984-04-19 01:02:03"); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpressionTest.java index b03827a82d6f..e10562d83f34 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlExtractExpressionTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import java.util.ArrayList; +import java.util.Date; import java.util.List; import org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlFnExecutorTestBase; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; @@ -35,69 +36,69 @@ public class BeamSqlExtractExpressionTest extends BeamSqlDateExpressionTestBase { @Test public void evaluate() throws Exception { List operands = new ArrayList<>(); - long time = str2LongTime("2017-05-22 16:17:18"); + Date time = str2DateTime("2017-05-22 16:17:18"); // YEAR operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.YEAR)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(2017L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); // MONTH operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.MONTH)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(5L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); // DAY operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.DAY)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(22L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); // DAY_OF_WEEK operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.DOW)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(2L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); // DAY_OF_YEAR operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.DOY)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(142L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); // WEEK operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.WEEK)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(21L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); // QUARTER operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, TimeUnitRange.QUARTER)); - operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, + operands.add(BeamSqlPrimitive.of(SqlTypeName.TIMESTAMP, time)); assertEquals(2L, new BeamSqlExtractExpression(operands) - .evaluate(BeamSqlFnExecutorTestBase.record, null).getValue()); + .evaluate(BeamSqlFnExecutorTestBase.row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpressionTest.java index 0c91f4018990..6d79a019d63d 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlIntervalMultiplyExpressionTest.java @@ -26,11 +26,10 @@ import java.math.BigDecimal; import java.util.Arrays; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Test; @@ -38,7 +37,7 @@ * Test for BeamSqlIntervalMultiplyExpression. */ public class BeamSqlIntervalMultiplyExpressionTest { - private static final BeamRecord NULL_INPUT_ROW = null; + private static final Row NULL_INPUT_ROW = null; private static final BoundedWindow NULL_WINDOW = null; private static final BigDecimal DECIMAL_THREE = new BigDecimal(3); private static final BigDecimal DECIMAL_FOUR = new BigDecimal(4); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpressionTest.java index 5232487f3324..10af52ce0dfd 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusIntervalExpressionTest.java @@ -30,11 +30,10 @@ import java.util.Date; import java.util.HashSet; import java.util.Set; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; @@ -46,7 +45,7 @@ * Unit tests for {@link BeamSqlTimestampMinusIntervalExpression}. */ public class BeamSqlTimestampMinusIntervalExpressionTest { - private static final BeamRecord NULL_ROW = null; + private static final Row NULL_ROW = null; private static final BoundedWindow NULL_WINDOW = null; private static final Date DATE = new Date(329281L); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpressionTest.java index 54bf52d32273..e4141aa9084a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/BeamSqlTimestampMinusTimestampExpressionTest.java @@ -24,11 +24,10 @@ import java.util.Arrays; import java.util.Date; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression; import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.sql.type.SqlTypeName; import org.joda.time.DateTime; @@ -41,7 +40,7 @@ */ public class BeamSqlTimestampMinusTimestampExpressionTest { - private static final BeamRecord NULL_ROW = null; + private static final Row NULL_ROW = null; private static final BoundedWindow NULL_WINDOW = null; private static final Date DATE = new Date(2017, 3, 4, 3, 2, 1); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtilsTest.java index 91552aeb7094..74db2b5a3460 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtilsTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/date/TimeUnitUtilsTest.java @@ -1,13 +1,3 @@ -package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.date; - -import static org.junit.Assert.assertEquals; - -import org.apache.calcite.avatica.util.TimeUnit; -import org.apache.calcite.sql.type.SqlTypeName; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -25,6 +15,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.date; + +import static org.junit.Assert.assertEquals; + +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.sql.type.SqlTypeName; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; /** * Unit tests for {@link TimeUnitUtils}. diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpressionTest.java index c98ce233a382..56d695eb59a6 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/logical/BeamSqlNotExpressionTest.java @@ -34,14 +34,14 @@ public class BeamSqlNotExpressionTest extends BeamSqlFnExecutorTestBase { @Test public void evaluate() throws Exception { List operands = new ArrayList<>(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, false)); - Assert.assertTrue(new BeamSqlNotExpression(operands).evaluate(record, null).getBoolean()); + Assert.assertTrue(new BeamSqlNotExpression(operands).evaluate(row, null).getBoolean()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, true)); - Assert.assertFalse(new BeamSqlNotExpression(operands).evaluate(record, null).getBoolean()); + Assert.assertFalse(new BeamSqlNotExpression(operands).evaluate(row, null).getBoolean()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BOOLEAN, null)); - Assert.assertNull(new BeamSqlNotExpression(operands).evaluate(record, null).getValue()); + Assert.assertNull(new BeamSqlNotExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpressionTest.java index 666525356b4a..e2a6b6e5fcc1 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathBinaryExpressionTest.java @@ -68,75 +68,75 @@ public class BeamSqlMathBinaryExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.0)); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 4.0)); Assert.assertEquals(2.0, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // round(integer,integer) => integer operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); - Assert.assertEquals(2, new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + Assert.assertEquals(2, new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // round(long,long) => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 5L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 3L)); - Assert.assertEquals(5L, new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + Assert.assertEquals(5L, new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // round(short) => short operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, new Short("4"))); Assert.assertEquals(SqlFunctions.toShort(4), - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // round(long,long) => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); - Assert.assertEquals(2L, new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + Assert.assertEquals(2L, new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // round(double, long) => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 1.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); Assert.assertEquals(1.1, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.368768)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); Assert.assertEquals(2.37, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 3.78683686458)); Assert.assertEquals(4.0, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 378.683686458)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, -2)); Assert.assertEquals(400.0, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 378.683686458)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, -1)); Assert.assertEquals(380.0, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // round(integer, double) => integer operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.2)); - Assert.assertEquals(2, new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + Assert.assertEquals(2, new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); // operand with a BeamSqlInputRefExpression - // to select a column value from row of a record + // to select a column value from row of a row operands.clear(); BeamSqlInputRefExpression ref0 = new BeamSqlInputRefExpression(SqlTypeName.BIGINT, 0); operands.add(ref0); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); Assert.assertEquals(1234567L, - new BeamSqlRoundExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRoundExpression(operands).evaluate(row, null).getValue()); } @Test public void testPowerFunction() { @@ -147,47 +147,47 @@ public class BeamSqlMathBinaryExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.0)); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 4.0)); Assert.assertEquals(16.0, - new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); // power(integer,integer) => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); Assert.assertEquals(4.0, - new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); // power(integer,long) => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 3L)); Assert.assertEquals(8.0 - , new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + , new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); // power(long,long) => long operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 2L)); Assert.assertEquals(4.0, - new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); // power(double, int) => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 1.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); Assert.assertEquals(1.1, - new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); // power(double, long) => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 1.1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); Assert.assertEquals(1.1, - new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); // power(integer, double) => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.2)); Assert.assertEquals(Math.pow(2, 2.2), - new BeamSqlPowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlPowerExpression(operands).evaluate(row, null).getValue()); } @Test public void testForTruncate() { @@ -195,13 +195,13 @@ public class BeamSqlMathBinaryExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.0)); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 4.0)); Assert.assertEquals(2.0, - new BeamSqlTruncateExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTruncateExpression(operands).evaluate(row, null).getValue()); // truncate(double, integer) => double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.80685)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 4)); Assert.assertEquals(2.8068, - new BeamSqlTruncateExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTruncateExpression(operands).evaluate(row, null).getValue()); } @Test public void testForAtan2() { @@ -209,7 +209,7 @@ public class BeamSqlMathBinaryExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 0.875)); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 0.56)); Assert.assertEquals(Math.atan2(0.875, 0.56), - new BeamSqlAtan2Expression(operands).evaluate(record, null).getValue()); + new BeamSqlAtan2Expression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpressionTest.java index d80a67071e9a..e12d2bdc831d 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/math/BeamSqlMathUnaryExpressionTest.java @@ -60,7 +60,7 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, -28965734597L)); Assert.assertEquals(28965734597L, - new BeamSqlAbsExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAbsExpression(operands).evaluate(row, null).getValue()); } @Test public void testForLnExpression() { @@ -69,19 +69,19 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for LN function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.log(2), - new BeamSqlLnExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLnExpression(operands).evaluate(row, null).getValue()); // test for LN function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert .assertEquals(Math.log(2.4), - new BeamSqlLnExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLnExpression(operands).evaluate(row, null).getValue()); // test for LN function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.log(2.56), - new BeamSqlLnExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLnExpression(operands).evaluate(row, null).getValue()); } @Test public void testForLog10Expression() { @@ -90,17 +90,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for log10 function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.log10(2), - new BeamSqlLogExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLogExpression(operands).evaluate(row, null).getValue()); // test for log10 function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert.assertEquals(Math.log10(2.4), - new BeamSqlLogExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLogExpression(operands).evaluate(row, null).getValue()); // test for log10 function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.log10(2.56), - new BeamSqlLogExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLogExpression(operands).evaluate(row, null).getValue()); } @Test public void testForExpExpression() { @@ -109,17 +109,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.exp(2), - new BeamSqlExpExpression(operands).evaluate(record, null).getValue()); + new BeamSqlExpExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert.assertEquals(Math.exp(2.4), - new BeamSqlExpExpression(operands).evaluate(record, null).getValue()); + new BeamSqlExpExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.exp(2.56), - new BeamSqlExpExpression(operands).evaluate(record, null).getValue()); + new BeamSqlExpExpression(operands).evaluate(row, null).getValue()); } @Test public void testForAcosExpression() { @@ -128,17 +128,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Double.NaN, - new BeamSqlAcosExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAcosExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 0.45)); Assert.assertEquals(Math.acos(0.45), - new BeamSqlAcosExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAcosExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(-0.367))); Assert.assertEquals(Math.acos(-0.367), - new BeamSqlAcosExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAcosExpression(operands).evaluate(row, null).getValue()); } @Test public void testForAsinExpression() { @@ -147,12 +147,12 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type double operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 0.45)); Assert.assertEquals(Math.asin(0.45), - new BeamSqlAsinExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAsinExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(-0.367))); Assert.assertEquals(Math.asin(-0.367), - new BeamSqlAsinExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAsinExpression(operands).evaluate(row, null).getValue()); } @Test public void testForAtanExpression() { @@ -161,12 +161,12 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type double operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 0.45)); Assert.assertEquals(Math.atan(0.45), - new BeamSqlAtanExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAtanExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(-0.367))); Assert.assertEquals(Math.atan(-0.367), - new BeamSqlAtanExpression(operands).evaluate(record, null).getValue()); + new BeamSqlAtanExpression(operands).evaluate(row, null).getValue()); } @Test public void testForCosExpression() { @@ -175,12 +175,12 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type double operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 0.45)); Assert.assertEquals(Math.cos(0.45), - new BeamSqlCosExpression(operands).evaluate(record, null).getValue()); + new BeamSqlCosExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(-0.367))); Assert.assertEquals(Math.cos(-0.367), - new BeamSqlCosExpression(operands).evaluate(record, null).getValue()); + new BeamSqlCosExpression(operands).evaluate(row, null).getValue()); } @Test public void testForCotExpression() { @@ -189,12 +189,12 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type double operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, .45)); Assert.assertEquals(1.0d / Math.tan(0.45), - new BeamSqlCotExpression(operands).evaluate(record, null).getValue()); + new BeamSqlCotExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(-.367))); Assert.assertEquals(1.0d / Math.tan(-0.367), - new BeamSqlCotExpression(operands).evaluate(record, null).getValue()); + new BeamSqlCotExpression(operands).evaluate(row, null).getValue()); } @Test public void testForDegreesExpression() { @@ -203,17 +203,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.toDegrees(2), - new BeamSqlDegreesExpression(operands).evaluate(record, null).getValue()); + new BeamSqlDegreesExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert.assertEquals(Math.toDegrees(2.4), - new BeamSqlDegreesExpression(operands).evaluate(record, null).getValue()); + new BeamSqlDegreesExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.toDegrees(2.56), - new BeamSqlDegreesExpression(operands).evaluate(record, null).getValue()); + new BeamSqlDegreesExpression(operands).evaluate(row, null).getValue()); } @Test public void testForRadiansExpression() { @@ -222,17 +222,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.toRadians(2), - new BeamSqlRadiansExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRadiansExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert.assertEquals(Math.toRadians(2.4), - new BeamSqlRadiansExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRadiansExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.toRadians(2.56), - new BeamSqlRadiansExpression(operands).evaluate(record, null).getValue()); + new BeamSqlRadiansExpression(operands).evaluate(row, null).getValue()); } @Test public void testForSinExpression() { @@ -241,17 +241,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.sin(2), - new BeamSqlSinExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSinExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert.assertEquals(Math.sin(2.4), - new BeamSqlSinExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSinExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.sin(2.56), - new BeamSqlSinExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSinExpression(operands).evaluate(row, null).getValue()); } @Test public void testForTanExpression() { @@ -260,17 +260,17 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals(Math.tan(2), - new BeamSqlTanExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTanExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); Assert.assertEquals(Math.tan(2.4), - new BeamSqlTanExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTanExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(Math.tan(2.56), - new BeamSqlTanExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTanExpression(operands).evaluate(row, null).getValue()); } @Test public void testForSignExpression() { @@ -279,34 +279,34 @@ public class BeamSqlMathUnaryExpressionTest extends BeamSqlFnExecutorTestBase { // test for exp function with operand type smallint operands.add(BeamSqlPrimitive.of(SqlTypeName.SMALLINT, Short.valueOf("2"))); Assert.assertEquals((short) 1 - , new BeamSqlSignExpression(operands).evaluate(record, null).getValue()); + , new BeamSqlSignExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type double operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.4)); - Assert.assertEquals(1.0, new BeamSqlSignExpression(operands).evaluate(record, null).getValue()); + Assert.assertEquals(1.0, new BeamSqlSignExpression(operands).evaluate(row, null).getValue()); // test for exp function with operand type decimal operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DECIMAL, BigDecimal.valueOf(2.56))); Assert.assertEquals(BigDecimal.ONE, - new BeamSqlSignExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSignExpression(operands).evaluate(row, null).getValue()); } @Test public void testForPi() { - Assert.assertEquals(Math.PI, new BeamSqlPiExpression().evaluate(record, null).getValue()); + Assert.assertEquals(Math.PI, new BeamSqlPiExpression().evaluate(row, null).getValue()); } @Test public void testForCeil() { List operands = new ArrayList<>(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.68687979)); Assert.assertEquals(Math.ceil(2.68687979), - new BeamSqlCeilExpression(operands).evaluate(record, null).getValue()); + new BeamSqlCeilExpression(operands).evaluate(row, null).getValue()); } @Test public void testForFloor() { List operands = new ArrayList<>(); operands.add(BeamSqlPrimitive.of(SqlTypeName.DOUBLE, 2.68687979)); Assert.assertEquals(Math.floor(2.68687979), - new BeamSqlFloorExpression(operands).evaluate(record, null).getValue()); + new BeamSqlFloorExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/DatetimeReinterpretConversionsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/DatetimeReinterpretConversionsTest.java index 894d09445dfd..1ccdc65f6513 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/DatetimeReinterpretConversionsTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/DatetimeReinterpretConversionsTest.java @@ -22,7 +22,6 @@ import java.util.Date; import java.util.GregorianCalendar; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Test; diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversionTest.java index 31cdab89c160..5f78f2a95d2f 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpretConversionTest.java @@ -27,9 +27,7 @@ import com.google.common.base.Function; import com.google.common.collect.ImmutableSet; - import java.util.Set; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Rule; diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpreterTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpreterTest.java index 6406831c0799..939eaab43b93 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpreterTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/reinterpret/ReinterpreterTest.java @@ -13,7 +13,6 @@ import java.util.Date; import java.util.HashSet; import java.util.Set; - import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlPrimitive; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Rule; diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpressionTest.java index d6c356585d34..876fbfee3f47 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlCharLengthExpressionTest.java @@ -38,7 +38,7 @@ public class BeamSqlCharLengthExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); assertEquals(5, - new BeamSqlCharLengthExpression(operands).evaluate(record, null).getValue()); + new BeamSqlCharLengthExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpressionTest.java index c350fe2ddfe1..085ecd7419d7 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlConcatExpressionTest.java @@ -60,7 +60,7 @@ public class BeamSqlConcatExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, " world")); Assert.assertEquals("hello world", - new BeamSqlConcatExpression(operands).evaluate(record, null).getValue()); + new BeamSqlConcatExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpressionTest.java index 7ea83d16eb50..c17680bf286f 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlInitCapExpressionTest.java @@ -38,17 +38,17 @@ public class BeamSqlInitCapExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello world")); assertEquals("Hello World", - new BeamSqlInitCapExpression(operands).evaluate(record, null).getValue()); + new BeamSqlInitCapExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hEllO wOrld")); assertEquals("Hello World", - new BeamSqlInitCapExpression(operands).evaluate(record, null).getValue()); + new BeamSqlInitCapExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello world")); assertEquals("Hello World", - new BeamSqlInitCapExpression(operands).evaluate(record, null).getValue()); + new BeamSqlInitCapExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpressionTest.java index 393680ce5be5..9807b9772327 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlLowerExpressionTest.java @@ -38,7 +38,7 @@ public class BeamSqlLowerExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "HELLO")); assertEquals("hello", - new BeamSqlLowerExpression(operands).evaluate(record, null).getValue()); + new BeamSqlLowerExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpressionTest.java index 2b4c0ea8beb3..3e4443d4e2fd 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlOverlayExpressionTest.java @@ -57,7 +57,7 @@ public class BeamSqlOverlayExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "resou")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 3)); Assert.assertEquals("w3resou3rce", - new BeamSqlOverlayExpression(operands).evaluate(record, null).getValue()); + new BeamSqlOverlayExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "w3333333rce")); @@ -65,7 +65,7 @@ public class BeamSqlOverlayExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 3)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 4)); Assert.assertEquals("w3resou33rce", - new BeamSqlOverlayExpression(operands).evaluate(record, null).getValue()); + new BeamSqlOverlayExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "w3333333rce")); @@ -73,7 +73,7 @@ public class BeamSqlOverlayExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 3)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 5)); Assert.assertEquals("w3resou3rce", - new BeamSqlOverlayExpression(operands).evaluate(record, null).getValue()); + new BeamSqlOverlayExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "w3333333rce")); @@ -81,7 +81,7 @@ public class BeamSqlOverlayExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 3)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 7)); Assert.assertEquals("w3resouce", - new BeamSqlOverlayExpression(operands).evaluate(record, null).getValue()); + new BeamSqlOverlayExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpressionTest.java index 3b477ccdd795..efe650db34fd 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlPositionExpressionTest.java @@ -66,19 +66,19 @@ public class BeamSqlPositionExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "worldhello")); - assertEquals(5, new BeamSqlPositionExpression(operands).evaluate(record, null).getValue()); + assertEquals(5, new BeamSqlPositionExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "worldhello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); - assertEquals(5, new BeamSqlPositionExpression(operands).evaluate(record, null).getValue()); + assertEquals(5, new BeamSqlPositionExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "world")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); - assertEquals(-1, new BeamSqlPositionExpression(operands).evaluate(record, null).getValue()); + assertEquals(-1, new BeamSqlPositionExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpressionTest.java index b48a8beddcbc..4cb06e665ec0 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlSubstringExpressionTest.java @@ -54,48 +54,48 @@ public class BeamSqlSubstringExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); assertEquals("hello", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 2)); assertEquals("he", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 5)); assertEquals("hello", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 100)); assertEquals("hello", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 0)); assertEquals("", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, 1)); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, -1)); assertEquals("", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); operands.add(BeamSqlPrimitive.of(SqlTypeName.INTEGER, -1)); assertEquals("o", - new BeamSqlSubstringExpression(operands).evaluate(record, null).getValue()); + new BeamSqlSubstringExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpressionTest.java index 36450825335b..8db93525ac14 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlTrimExpressionTest.java @@ -62,26 +62,26 @@ public class BeamSqlTrimExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "he")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hehe__hehe")); Assert.assertEquals("__hehe", - new BeamSqlTrimExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTrimExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, SqlTrimFunction.Flag.TRAILING)); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "he")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hehe__hehe")); Assert.assertEquals("hehe__", - new BeamSqlTrimExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTrimExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.SYMBOL, SqlTrimFunction.Flag.BOTH)); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "he")); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "__")); Assert.assertEquals("__", - new BeamSqlTrimExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTrimExpression(operands).evaluate(row, null).getValue()); operands.clear(); operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, " hello ")); Assert.assertEquals("hello", - new BeamSqlTrimExpression(operands).evaluate(record, null).getValue()); + new BeamSqlTrimExpression(operands).evaluate(row, null).getValue()); } @Test public void leadingTrim() throws Exception { diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpressionTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpressionTest.java index 41e5a285077b..9187b090cdc2 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpressionTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/string/BeamSqlUpperExpressionTest.java @@ -38,7 +38,7 @@ public class BeamSqlUpperExpressionTest extends BeamSqlFnExecutorTestBase { operands.add(BeamSqlPrimitive.of(SqlTypeName.VARCHAR, "hello")); assertEquals("HELLO", - new BeamSqlUpperExpression(operands).evaluate(record, null).getValue()); + new BeamSqlUpperExpression(operands).evaluate(row, null).getValue()); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/parser/BeamSqlParserTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/parser/BeamSqlParserTest.java index c7c8bf4d9ad4..f6bc5d0e2276 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/parser/BeamSqlParserTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/parser/BeamSqlParserTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.extensions.sql.impl.parser; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.INTEGER; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.VARCHAR; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -25,7 +27,6 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.ImmutableList; import java.net.URI; -import java.sql.Types; import org.apache.beam.sdk.extensions.sql.meta.Column; import org.apache.beam.sdk.extensions.sql.meta.Table; import org.apache.calcite.sql.SqlNode; @@ -121,6 +122,18 @@ public void testParseCreateTable_withoutLocation() throws Exception { ); } + @Test + public void testParseDropTable() throws Exception { + BeamSqlParser parser = new BeamSqlParser("drop table person"); + SqlNode sqlNode = parser.impl().parseSqlStmtEof(); + + assertNotNull(sqlNode); + assertTrue(sqlNode instanceof SqlDropTable); + SqlDropTable stmt = (SqlDropTable) sqlNode; + assertNotNull(stmt); + assertEquals("person", stmt.tableName()); + } + private Table parseTable(String sql) throws Exception { BeamSqlParser parser = new BeamSqlParser(sql); SqlNode sqlNode = parser.impl().parseSqlStmtEof(); @@ -150,13 +163,13 @@ private static Table mockTable(String name, String type, String comment, JSONObj .columns(ImmutableList.of( Column.builder() .name("id") - .type(Types.INTEGER) + .coder(INTEGER) .primaryKey(false) .comment("id") .build(), Column.builder() .name("name") - .type(Types.VARCHAR) + .coder(VARCHAR) .primaryKey(false) .comment("name") .build() diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BaseRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BaseRelTest.java index 906ccfd09d4a..6a09d9c0958b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BaseRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BaseRelTest.java @@ -20,14 +20,14 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; /** * Base class for rel test. */ public class BaseRelTest { - public PCollection compilePipeline ( + public PCollection compilePipeline ( String sql, Pipeline pipeline, BeamSqlEnv sqlEnv) throws Exception { return sqlEnv.getPlanner().compileBeamPipeline(sql, pipeline, sqlEnv); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRelTest.java index 8e41d0a683d4..4a44c3548d7a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamIntersectRelTest.java @@ -18,14 +18,14 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -43,9 +43,9 @@ public class BeamIntersectRelTest extends BaseRelTest { public static void prepare() { sqlEnv.registerTable("ORDER_DETAILS1", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 1L, 1, 1.0, @@ -56,9 +56,9 @@ public static void prepare() { sqlEnv.registerTable("ORDER_DETAILS2", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 2L, 2, 2.0, @@ -76,12 +76,12 @@ public void testIntersect() throws Exception { + "SELECT order_id, site_id, price " + "FROM ORDER_DETAILS2 "; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 2L, 2, 2.0 @@ -99,14 +99,14 @@ public void testIntersectAll() throws Exception { + "SELECT order_id, site_id, price " + "FROM ORDER_DETAILS2 "; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).satisfies(new CheckSize(3)); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 1L, 1, 1.0, diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java index e0d691bb1e24..4e01891a0d72 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelBoundedVsBoundedTest.java @@ -18,14 +18,14 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -40,9 +40,9 @@ public class BeamJoinRelBoundedVsBoundedTest extends BaseRelTest { public static final MockedBoundedTable ORDER_DETAILS1 = MockedBoundedTable.of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price" ).addRows( 1, 2, 3, 2, 3, 3, @@ -51,9 +51,9 @@ public class BeamJoinRelBoundedVsBoundedTest extends BaseRelTest { public static final MockedBoundedTable ORDER_DETAILS2 = MockedBoundedTable.of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price" ).addRows( 1, 2, 3, 2, 3, 3, @@ -76,15 +76,15 @@ public void testInnerJoin() throws Exception { + " o1.order_id=o2.site_id AND o2.price=o1.site_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price", - Types.INTEGER, "order_id0", - Types.INTEGER, "site_id0", - Types.INTEGER, "price0" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "site_id0", + SqlTypeCoders.INTEGER, "price0" ).addRows( 2, 3, 3, 1, 2, 3 ).getRows()); @@ -101,16 +101,16 @@ public void testLeftOuterJoin() throws Exception { + " o1.order_id=o2.site_id AND o2.price=o1.site_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); pipeline.enableAbandonedNodeEnforcement(false); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price", - Types.INTEGER, "order_id0", - Types.INTEGER, "site_id0", - Types.INTEGER, "price0" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "site_id0", + SqlTypeCoders.INTEGER, "price0" ).addRows( 1, 2, 3, null, null, null, 2, 3, 3, 1, 2, 3, @@ -129,15 +129,15 @@ public void testRightOuterJoin() throws Exception { + " o1.order_id=o2.site_id AND o2.price=o1.site_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price", - Types.INTEGER, "order_id0", - Types.INTEGER, "site_id0", - Types.INTEGER, "price0" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "site_id0", + SqlTypeCoders.INTEGER, "price0" ).addRows( 2, 3, 3, 1, 2, 3, null, null, null, 2, 3, 3, @@ -156,15 +156,15 @@ public void testFullOuterJoin() throws Exception { + " o1.order_id=o2.site_id AND o2.price=o1.site_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price", - Types.INTEGER, "order_id0", - Types.INTEGER, "site_id0", - Types.INTEGER, "price0" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "site_id0", + SqlTypeCoders.INTEGER, "price0" ).addRows( 2, 3, 3, 1, 2, 3, 1, 2, 3, null, null, null, diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java index c6053391f919..ca639bdd04cd 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsBoundedTest.java @@ -18,13 +18,12 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; import java.util.Arrays; import java.util.Date; import java.util.List; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; @@ -36,9 +35,10 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.joda.time.Duration; import org.junit.BeforeClass; import org.junit.Rule; @@ -60,10 +60,10 @@ public class BeamJoinRelUnboundedVsBoundedTest extends BaseRelTest { public static void prepare() { BEAM_SQL_ENV.registerTable("ORDER_DETAILS", MockedUnboundedTable .of( - Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price", - Types.TIMESTAMP, "order_time" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price", + SqlTypeCoders.TIMESTAMP, "order_time" ) .timestampColumnIndex(3) .addRows( @@ -87,14 +87,19 @@ public static void prepare() { ); BEAM_SQL_ENV.registerTable("ORDER_DETAILS1", MockedBoundedTable - .of(Types.INTEGER, "order_id", - Types.VARCHAR, "buyer" + .of(SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.VARCHAR, "buyer" ).addRows( 1, "james", 2, "bond" )); - BEAM_SQL_ENV.registerTable("SITE_LKP", new SiteLookupTable( - TestUtils.buildBeamSqlRowType(Types.INTEGER, "site_id", Types.VARCHAR, "site_name"))); + + BEAM_SQL_ENV.registerTable( + "SITE_LKP", + new SiteLookupTable( + TestUtils.buildBeamSqlRowType( + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.VARCHAR, "site_name"))); } /** @@ -103,8 +108,8 @@ public static void prepare() { */ public static class SiteLookupTable extends BaseBeamTable implements BeamSqlSeekableTable{ - public SiteLookupTable(BeamRecordSqlType beamRecordSqlType) { - super(beamRecordSqlType); + public SiteLookupTable(RowType rowType) { + super(rowType); } @Override @@ -113,20 +118,21 @@ public BeamIOType getSourceType() { } @Override - public PCollection buildIOReader(Pipeline pipeline) { + public PCollection buildIOReader(Pipeline pipeline) { throw new UnsupportedOperationException(); } @Override - public PTransform, PDone> buildIOWriter() { + public PTransform, PDone> buildIOWriter() { throw new UnsupportedOperationException(); } @Override - public List seekRecord(BeamRecord lookupSubRecord) { - return Arrays.asList(new BeamRecord(getRowType(), 1, "SITE1")); + public List seekRow(Row lookupSubRow) { + return Arrays.asList(Row.withRowType(getRowType()).addValues(1, "SITE1").build()); } } + @Test public void testInnerJoin_unboundedTableOnTheLeftSide() throws Exception { String sql = "SELECT o1.order_id, o1.sum_site_id, o2.buyer FROM " @@ -138,13 +144,13 @@ public void testInnerJoin_unboundedTableOnTheLeftSide() throws Exception { + " o1.order_id=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.VARCHAR, "buyer" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.VARCHAR, "buyer" ).addRows( 1, 3, "james", 2, 5, "bond" @@ -164,13 +170,13 @@ public void testInnerJoin_boundedTableOnTheLeftSide() throws Exception { + " o1.order_id=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.VARCHAR, "buyer" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.VARCHAR, "buyer" ).addRows( 1, 3, "james", 2, 5, "bond" @@ -190,14 +196,14 @@ public void testLeftOuterJoin() throws Exception { + " o1.order_id=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); rows.apply(ParDo.of(new BeamSqlOutputToConsoleFn("helloworld"))); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.VARCHAR, "buyer" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.VARCHAR, "buyer" ).addRows( 1, 3, "james", 2, 5, "bond", @@ -232,13 +238,13 @@ public void testRightOuterJoin() throws Exception { + " on " + " o1.order_id=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.VARCHAR, "buyer" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.VARCHAR, "buyer" ).addRows( 1, 3, "james", 2, 5, "bond", @@ -288,12 +294,12 @@ public void testJoinAsLookup() throws Exception { + " o1.site_id=o2.site_id " + " WHERE o1.site_id=1" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.VARCHAR, "site_name" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.VARCHAR, "site_name" ).addRows( 1, "SITE1" ).getStringRows() diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java index e5470ca68624..28916d1d7692 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamJoinRelUnboundedVsUnboundedTest.java @@ -18,8 +18,8 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; import java.util.Date; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.impl.transform.BeamSqlOutputToConsoleFn; @@ -27,8 +27,8 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.joda.time.Duration; import org.junit.BeforeClass; import org.junit.Rule; @@ -49,10 +49,10 @@ public class BeamJoinRelUnboundedVsUnboundedTest extends BaseRelTest { @BeforeClass public static void prepare() { BEAM_SQL_ENV.registerTable("ORDER_DETAILS", MockedUnboundedTable - .of(Types.INTEGER, "order_id", - Types.INTEGER, "site_id", - Types.INTEGER, "price", - Types.TIMESTAMP, "order_time" + .of(SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.INTEGER, "price", + SqlTypeCoders.TIMESTAMP, "order_time" ) .timestampColumnIndex(3) .addRows( @@ -87,14 +87,14 @@ public void testInnerJoin() throws Exception { + " o1.order_id=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.INTEGER, "order_id0", - Types.INTEGER, "sum_site_id0").addRows( + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "sum_site_id0").addRows( 1, 3, 1, 3, 2, 5, 2, 5 ).getStringRows() @@ -120,14 +120,14 @@ public void testLeftOuterJoin() throws Exception { // 2, 2 | 2, 5 // 3, 3 | NULL, NULL - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.INTEGER, "order_id0", - Types.INTEGER, "sum_site_id0" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "sum_site_id0" ).addRows( 1, 1, 1, 3, 2, 2, null, null, @@ -150,14 +150,14 @@ public void testRightOuterJoin() throws Exception { + " o1.order_id=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id", - Types.INTEGER, "order_id0", - Types.INTEGER, "sum_site_id0" + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.INTEGER, "order_id0", + SqlTypeCoders.INTEGER, "sum_site_id0" ).addRows( 1, 3, 1, 1, null, null, 2, 2, @@ -180,15 +180,15 @@ public void testFullOuterJoin() throws Exception { + " o1.order_id1=o2.order_id" ; - PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); + PCollection rows = compilePipeline(sql, pipeline, BEAM_SQL_ENV); rows.apply(ParDo.of(new BeamSqlOutputToConsoleFn("hello"))); PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "order_id1", - Types.INTEGER, "sum_site_id", - Types.INTEGER, "order_id", - Types.INTEGER, "sum_site_id0" + SqlTypeCoders.INTEGER, "order_id1", + SqlTypeCoders.INTEGER, "sum_site_id", + SqlTypeCoders.INTEGER, "order_id", + SqlTypeCoders.INTEGER, "sum_site_id0" ).addRows( 1, 1, 1, 3, 6, 2, null, null, diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java index 5c4ae2ca04f2..130232bb1322 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamMinusRelTest.java @@ -18,14 +18,14 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -43,9 +43,9 @@ public class BeamMinusRelTest extends BaseRelTest { public static void prepare() { sqlEnv.registerTable("ORDER_DETAILS1", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 1L, 1, 1.0, @@ -57,9 +57,9 @@ public static void prepare() { sqlEnv.registerTable("ORDER_DETAILS2", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 2L, 2, 2.0, @@ -77,12 +77,12 @@ public void testExcept() throws Exception { + "SELECT order_id, site_id, price " + "FROM ORDER_DETAILS2 "; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 4L, 4, 4.0 ).getRows()); @@ -99,14 +99,14 @@ public void testExceptAll() throws Exception { + "SELECT order_id, site_id, price " + "FROM ORDER_DETAILS2 "; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).satisfies(new CheckSize(2)); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 4L, 4, 4.0, 4L, 4, 4.0 diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBaseTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBaseTest.java index cd0297ad1d16..5830ee0614f0 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBaseTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSetOperatorRelBaseTest.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; import java.util.Date; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; @@ -28,8 +28,8 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -48,10 +48,10 @@ public class BeamSetOperatorRelBaseTest extends BaseRelTest { public static void prepare() { sqlEnv.registerTable("ORDER_DETAILS", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price", - Types.TIMESTAMP, "order_time" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price", + SqlTypeCoders.TIMESTAMP, "order_time" ).addRows( 1L, 1, 1.0, THE_DATE, 2L, 2, 2.0, THE_DATE @@ -70,14 +70,14 @@ public void testSameWindow() throws Exception { + "FROM ORDER_DETAILS GROUP BY order_id, site_id" + ", TUMBLE(order_time, INTERVAL '1' HOUR) "; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); // compare valueInString to ignore the windowStart & windowEnd PAssert.that(rows.apply(ParDo.of(new TestUtils.BeamSqlRow2StringDoFn()))) .containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.BIGINT, "cnt" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.BIGINT, "cnt" ).addRows( 1L, 1, 1L, 2L, 2, 1L diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java index bab52967ef9f..6b37444d5e00 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRelTest.java @@ -18,15 +18,16 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; import java.util.Date; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.calcite.tools.ValidationException; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -44,10 +45,10 @@ public class BeamSortRelTest extends BaseRelTest { public void prepare() { sqlEnv.registerTable("ORDER_DETAILS", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price", - Types.TIMESTAMP, "order_time" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price", + SqlTypeCoders.TIMESTAMP, "order_time" ).addRows( 1L, 2, 1.0, new Date(0), 1L, 1, 2.0, new Date(1), @@ -63,9 +64,9 @@ public void prepare() { ); sqlEnv.registerTable("SUB_ORDER_RAM", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ) ); } @@ -77,11 +78,11 @@ public void testOrderBy_basic() throws Exception { + "FROM ORDER_DETAILS " + "ORDER BY order_id asc, site_id desc limit 4"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder(TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 2, 1.0, 1L, 1, 2.0, @@ -97,12 +98,12 @@ public void testOrderBy_timestamp() throws Exception { + "FROM ORDER_DETAILS " + "ORDER BY order_time desc limit 4"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder(TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price", - Types.TIMESTAMP, "order_time" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price", + SqlTypeCoders.TIMESTAMP, "order_time" ).addRows( 7L, 7, 7.0, new Date(6), 8L, 8888, 8.0, new Date(7), @@ -116,9 +117,9 @@ public void testOrderBy_timestamp() throws Exception { public void testOrderBy_nullsFirst() throws Exception { sqlEnv.registerTable("ORDER_DETAILS", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 2, 1.0, 1L, null, 2.0, @@ -128,21 +129,21 @@ public void testOrderBy_nullsFirst() throws Exception { ) ); sqlEnv.registerTable("SUB_ORDER_RAM", MockedBoundedTable - .of(Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price")); + .of(SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price")); String sql = "INSERT INTO SUB_ORDER_RAM(order_id, site_id, price) SELECT " + " order_id, site_id, price " + "FROM ORDER_DETAILS " + "ORDER BY order_id asc, site_id desc NULLS FIRST limit 4"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, null, 2.0, 1L, 2, 1.0, @@ -156,9 +157,9 @@ public void testOrderBy_nullsFirst() throws Exception { @Test public void testOrderBy_nullsLast() throws Exception { sqlEnv.registerTable("ORDER_DETAILS", MockedBoundedTable - .of(Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + .of(SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 2, 1.0, 1L, null, 2.0, @@ -166,21 +167,21 @@ public void testOrderBy_nullsLast() throws Exception { 2L, null, 4.0, 5L, 5, 5.0)); sqlEnv.registerTable("SUB_ORDER_RAM", MockedBoundedTable - .of(Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price")); + .of(SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price")); String sql = "INSERT INTO SUB_ORDER_RAM(order_id, site_id, price) SELECT " + " order_id, site_id, price " + "FROM ORDER_DETAILS " + "ORDER BY order_id asc, site_id desc NULLS LAST limit 4"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 2, 1.0, 1L, null, 2.0, @@ -198,12 +199,12 @@ public void testOrderBy_with_offset() throws Exception { + "FROM ORDER_DETAILS " + "ORDER BY order_id asc, site_id desc limit 4 offset 4"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 5L, 5, 5.0, 6L, 6, 6.0, @@ -221,12 +222,12 @@ public void testOrderBy_bigFetch() throws Exception { + "FROM ORDER_DETAILS " + "ORDER BY order_id asc, site_id desc limit 11"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 2, 1.0, 1L, 1, 2.0, @@ -243,7 +244,7 @@ public void testOrderBy_bigFetch() throws Exception { pipeline.run().waitUntilFinish(); } - @Test(expected = UnsupportedOperationException.class) + @Test(expected = ValidationException.class) public void testOrderBy_exception() throws Exception { String sql = "INSERT INTO SUB_ORDER_RAM(order_id, site_id) SELECT " + " order_id, COUNT(*) " diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRelTest.java index d79a54eaadc6..4d166bfef270 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnionRelTest.java @@ -18,14 +18,14 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -43,9 +43,9 @@ public class BeamUnionRelTest extends BaseRelTest { public static void prepare() { sqlEnv.registerTable("ORDER_DETAILS", MockedBoundedTable.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 2L, 2, 2.0 @@ -62,12 +62,12 @@ public void testUnion() throws Exception { + " order_id, site_id, price " + "FROM ORDER_DETAILS "; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 2L, 2, 2.0 @@ -85,12 +85,12 @@ public void testUnionAll() throws Exception { + " SELECT order_id, site_id, price " + "FROM ORDER_DETAILS"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.BIGINT, "order_id", - Types.INTEGER, "site_id", - Types.DOUBLE, "price" + SqlTypeCoders.BIGINT, "order_id", + SqlTypeCoders.INTEGER, "site_id", + SqlTypeCoders.DOUBLE, "price" ).addRows( 1L, 1, 1.0, 1L, 1, 1.0, diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRelTest.java index 5604e3205395..18322e5d57b4 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamValuesRelTest.java @@ -18,14 +18,14 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.sql.Types; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -43,14 +43,14 @@ public class BeamValuesRelTest extends BaseRelTest { public static void prepare() { sqlEnv.registerTable("string_table", MockedBoundedTable.of( - Types.VARCHAR, "name", - Types.VARCHAR, "description" + SqlTypeCoders.VARCHAR, "name", + SqlTypeCoders.VARCHAR, "description" ) ); sqlEnv.registerTable("int_table", MockedBoundedTable.of( - Types.INTEGER, "c0", - Types.INTEGER, "c1" + SqlTypeCoders.INTEGER, "c0", + SqlTypeCoders.INTEGER, "c1" ) ); } @@ -59,11 +59,11 @@ public static void prepare() { public void testValues() throws Exception { String sql = "insert into string_table(name, description) values " + "('hello', 'world'), ('james', 'bond')"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.VARCHAR, "name", - Types.VARCHAR, "description" + SqlTypeCoders.VARCHAR, "name", + SqlTypeCoders.VARCHAR, "description" ).addRows( "hello", "world", "james", "bond" @@ -75,11 +75,11 @@ public void testValues() throws Exception { @Test public void testValues_castInt() throws Exception { String sql = "insert into int_table (c0, c1) values(cast(1 as int), cast(2 as int))"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "c0", - Types.INTEGER, "c1" + SqlTypeCoders.INTEGER, "c0", + SqlTypeCoders.INTEGER, "c1" ).addRows( 1, 2 ).getRows() @@ -90,11 +90,11 @@ public void testValues_castInt() throws Exception { @Test public void testValues_onlySelect() throws Exception { String sql = "select 1, '1'"; - PCollection rows = compilePipeline(sql, pipeline, sqlEnv); + PCollection rows = compilePipeline(sql, pipeline, sqlEnv); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder.of( - Types.INTEGER, "EXPR$0", - Types.CHAR, "EXPR$1" + SqlTypeCoders.INTEGER, "EXPR$0", + SqlTypeCoders.CHAR, "EXPR$1" ).addRows( 1, "1" ).getRows() diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/CheckSize.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/CheckSize.java index 7407a7667fd8..73bd7cd92cf0 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/CheckSize.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/CheckSize.java @@ -19,20 +19,21 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; import org.junit.Assert; /** * Utility class to check size of BeamSQLRow iterable. */ -public class CheckSize implements SerializableFunction, Void> { +public class CheckSize implements SerializableFunction, Void> { private int size; public CheckSize(int size) { this.size = size; } - @Override public Void apply(Iterable input) { + + @Override public Void apply(Iterable input) { int count = 0; - for (BeamRecord row : input) { + for (Row row : input) { count++; } Assert.assertEquals(size, count); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java index 08df6da6af54..121dfb607dc1 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java @@ -21,14 +21,14 @@ import java.math.BigDecimal; import java.util.Date; import java.util.GregorianCalendar; -import org.apache.beam.sdk.coders.BeamRecordCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.testing.CoderProperties; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeSystem; -import org.apache.calcite.rel.type.RelProtoDataType; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Test; @@ -39,34 +39,43 @@ public class BeamSqlRowCoderTest { @Test public void encodeAndDecode() throws Exception { - final RelProtoDataType protoRowType = - a0 -> - a0.builder() - .add("col_tinyint", SqlTypeName.TINYINT) - .add("col_smallint", SqlTypeName.SMALLINT) - .add("col_integer", SqlTypeName.INTEGER) - .add("col_bigint", SqlTypeName.BIGINT) - .add("col_float", SqlTypeName.FLOAT) - .add("col_double", SqlTypeName.DOUBLE) - .add("col_decimal", SqlTypeName.DECIMAL) - .add("col_string_varchar", SqlTypeName.VARCHAR) - .add("col_time", SqlTypeName.TIME) - .add("col_timestamp", SqlTypeName.TIMESTAMP) - .add("col_boolean", SqlTypeName.BOOLEAN) - .build(); + RelDataType relDataType = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT) + .builder() + .add("col_tinyint", SqlTypeName.TINYINT) + .add("col_smallint", SqlTypeName.SMALLINT) + .add("col_integer", SqlTypeName.INTEGER) + .add("col_bigint", SqlTypeName.BIGINT) + .add("col_float", SqlTypeName.FLOAT) + .add("col_double", SqlTypeName.DOUBLE) + .add("col_decimal", SqlTypeName.DECIMAL) + .add("col_string_varchar", SqlTypeName.VARCHAR) + .add("col_time", SqlTypeName.TIME) + .add("col_timestamp", SqlTypeName.TIMESTAMP) + .add("col_boolean", SqlTypeName.BOOLEAN) + .build(); - BeamRecordSqlType beamSQLRowType = CalciteUtils.toBeamRowType( - protoRowType.apply(new JavaTypeFactoryImpl( - RelDataTypeSystem.DEFAULT))); + RowType beamRowType = CalciteUtils.toBeamRowType(relDataType); GregorianCalendar calendar = new GregorianCalendar(); calendar.setTime(new Date()); - BeamRecord row = new BeamRecord(beamSQLRowType - , Byte.valueOf("1"), Short.valueOf("1"), 1, 1L, 1.1F, 1.1 - , BigDecimal.ZERO, "hello", calendar, new Date(), true); + Row row = + Row + .withRowType(beamRowType) + .addValues( + Byte.valueOf("1"), + Short.valueOf("1"), + 1, + 1L, + 1.1F, + 1.1, + BigDecimal.ZERO, + "hello", + calendar, + new Date(), + true) + .build(); - - BeamRecordCoder coder = beamSQLRowType.getRecordCoder(); + RowCoder coder = beamRowType.getRowCoder(); CoderProperties.coderDecodeEncodeEqual(coder, row); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamAggregationTransformTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamAggregationTransformTest.java index 418c5ec9914b..e76e8e207586 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamAggregationTransformTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamAggregationTransformTest.java @@ -21,13 +21,13 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.beam.sdk.coders.BeamRecordCoder; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; -import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.extensions.sql.impl.transform.BeamAggregationTransforms; -import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Combine; @@ -35,11 +35,11 @@ import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.WithKeys; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlAvgAggFunction; @@ -54,23 +54,22 @@ /** * Unit tests for {@link BeamAggregationTransforms}. - * */ -public class BeamAggregationTransformTest extends BeamTransformBaseTest{ +public class BeamAggregationTransformTest extends BeamTransformBaseTest { @Rule public TestPipeline p = TestPipeline.create(); private List aggCalls; - private BeamRecordSqlType keyType; - private BeamRecordSqlType aggPartType; - private BeamRecordSqlType outputType; + private RowType keyType; + private RowType aggPartType; + private RowType outputType; - private BeamRecordCoder inRecordCoder; - private BeamRecordCoder keyCoder; - private BeamRecordCoder aggCoder; - private BeamRecordCoder outRecordCoder; + private RowCoder inRecordCoder; + private RowCoder keyCoder; + private RowCoder aggCoder; + private RowCoder outRecordCoder; /** * This step equals to below query. @@ -93,16 +92,17 @@ public class BeamAggregationTransformTest extends BeamTransformBaseTest{ * FROM TABLE_NAME * GROUP BY `f_int` * + * * @throws ParseException */ @Test public void testCountPerElementBasic() throws ParseException { setupEnvironment(); - PCollection input = p.apply(Create.of(inputRows)); + PCollection input = p.apply(Create.of(inputRows)); // 1. extract fields in group-by key part - PCollection> exGroupByStream = + PCollection> exGroupByStream = input .apply( "exGroupBy", @@ -112,13 +112,13 @@ public void testCountPerElementBasic() throws ParseException { .setCoder(KvCoder.of(keyCoder, inRecordCoder)); // 2. apply a GroupByKey. - PCollection>> groupedStream = + PCollection>> groupedStream = exGroupByStream .apply("groupBy", GroupByKey.create()) .setCoder(KvCoder.of(keyCoder, IterableCoder.of(inRecordCoder))); // 3. run aggregation functions - PCollection> aggregatedStream = + PCollection> aggregatedStream = groupedStream .apply( "aggregation", @@ -127,7 +127,8 @@ public void testCountPerElementBasic() throws ParseException { .setCoder(KvCoder.of(keyCoder, aggCoder)); //4. flat KV to a single record - PCollection mergedStream = aggregatedStream.apply("mergeRecord", + PCollection mergedStream = aggregatedStream.apply( + "mergeRecord", ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(outputType, aggCalls, -1))); mergedStream.setCoder(outRecordCoder); @@ -137,11 +138,11 @@ public void testCountPerElementBasic() throws ParseException { //assert BeamAggregationTransform.AggregationCombineFn PAssert.that(aggregatedStream).containsInAnyOrder(prepareResultOfAggregationCombineFn()); - //assert BeamAggregationTransform.MergeAggregationRecord - PAssert.that(mergedStream).containsInAnyOrder(prepareResultOfMergeAggregationRecord()); + //assert BeamAggregationTransform.MergeAggregationRecord + PAssert.that(mergedStream).containsInAnyOrder(prepareResultOfMergeAggregationRow()); p.run(); -} + } private void setupEnvironment() { prepareAggregationCalls(); @@ -157,7 +158,7 @@ private void prepareAggregationCalls() { aggCalls = new ArrayList<>(); aggCalls.add( new AggregateCall( - new SqlCountAggFunction(), + new SqlCountAggFunction("COUNT"), false, Arrays.asList(), new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), @@ -357,70 +358,91 @@ private void prepareAggregationCalls() { * Coders used in aggregation steps. */ private void prepareTypeAndCoder() { - inRecordCoder = inputRowType.getRecordCoder(); + inRecordCoder = inputRowType.getRowCoder(); + + keyType = + RowSqlType + .builder() + .withIntegerField("f_int") + .build(); + + keyCoder = keyType.getRowCoder(); + + aggPartType = RowSqlType + .builder() + .withBigIntField("count") - keyType = initTypeOfSqlRow(Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER))); - keyCoder = keyType.getRecordCoder(); + .withBigIntField("sum1") + .withBigIntField("avg1") + .withBigIntField("max1") + .withBigIntField("min1") - aggPartType = initTypeOfSqlRow( - Arrays.asList(KV.of("count", SqlTypeName.BIGINT), + .withSmallIntField("sum2") + .withSmallIntField("avg2") + .withSmallIntField("max2") + .withSmallIntField("min2") - KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT), - KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT), + .withTinyIntField("sum3") + .withTinyIntField("avg3") + .withTinyIntField("max3") + .withTinyIntField("min3") - KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT), - KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT), + .withFloatField("sum4") + .withFloatField("avg4") + .withFloatField("max4") + .withFloatField("min4") - KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT), - KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT), + .withDoubleField("sum5") + .withDoubleField("avg5") + .withDoubleField("max5") + .withDoubleField("min5") - KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT), - KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT), + .withTimestampField("max7") + .withTimestampField("min7") - KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE), - KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE), + .withIntegerField("sum8") + .withIntegerField("avg8") + .withIntegerField("max8") + .withIntegerField("min8") - KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP), + .build(); - KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER), - KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER) - )); - aggCoder = aggPartType.getRecordCoder(); + aggCoder = aggPartType.getRowCoder(); outputType = prepareFinalRowType(); - outRecordCoder = outputType.getRecordCoder(); + outRecordCoder = outputType.getRowCoder(); } /** * expected results after {@link BeamAggregationTransforms.AggregationGroupByKeyFn}. */ - private List> prepareResultOfAggregationGroupByKeyFn() { - return Arrays.asList( - KV.of( - new BeamRecord(keyType, Arrays.asList(inputRows.get(0).getInteger(0))), - inputRows.get(0)), - KV.of( - new BeamRecord(keyType, Arrays.asList(inputRows.get(1).getInteger(0))), - inputRows.get(1)), - KV.of( - new BeamRecord(keyType, Arrays.asList(inputRows.get(2).getInteger(0))), - inputRows.get(2)), - KV.of( - new BeamRecord(keyType, Arrays.asList(inputRows.get(3).getInteger(0))), - inputRows.get(3))); + private List> prepareResultOfAggregationGroupByKeyFn() { + return + IntStream + .range(0, 4) + .mapToObj(i -> KV.of( + Row + .withRowType(keyType) + .addValues(inputRows.get(i).getInteger(0)) + .build(), + inputRows.get(i) + )).collect(Collectors.toList()); } /** - * expected results after {@link BeamAggregationTransforms.AggregationCombineFn}. + * expected results. */ - private List> prepareResultOfAggregationCombineFn() + private List> prepareResultOfAggregationCombineFn() throws ParseException { return Arrays.asList( KV.of( - new BeamRecord(keyType, Arrays.asList(inputRows.get(0).getInteger(0))), - new BeamRecord( - aggPartType, - Arrays.asList( + Row + .withRowType(keyType) + .addValues(inputRows.get(0).getInteger(0)) + .build(), + Row + .withRowType(aggPartType) + .addValues( 4L, 10000L, 2500L, @@ -447,50 +469,64 @@ private List> prepareResultOfAggregationCombineFn() 10, 2, 4, - 1)))); + 1) + .build())); } + /** * Row type of final output row. */ - private BeamRecordSqlType prepareFinalRowType() { - FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder(); - List> columnMetadata = - Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER), KV.of("count", SqlTypeName.BIGINT), - - KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT), - KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT), - - KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT), - KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT), - - KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT), - KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT), - - KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT), - KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT), - - KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE), - KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE), - - KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP), - - KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER), - KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER) - ); - for (KV cm : columnMetadata) { - builder.add(cm.getKey(), cm.getValue()); - } - return CalciteUtils.toBeamRowType(builder.build()); + private RowType prepareFinalRowType() { + return + RowSqlType + .builder() + .withIntegerField("f_int") + .withBigIntField("count") + + .withBigIntField("sum1") + .withBigIntField("avg1") + .withBigIntField("max1") + .withBigIntField("min1") + + .withSmallIntField("sum2") + .withSmallIntField("avg2") + .withSmallIntField("max2") + .withSmallIntField("min2") + + .withTinyIntField("sum3") + .withTinyIntField("avg3") + .withTinyIntField("max3") + .withTinyIntField("min3") + + .withFloatField("sum4") + .withFloatField("avg4") + .withFloatField("max4") + .withFloatField("min4") + + .withDoubleField("sum5") + .withDoubleField("avg5") + .withDoubleField("max5") + .withDoubleField("min5") + + .withTimestampField("max7") + .withTimestampField("min7") + + .withIntegerField("sum8") + .withIntegerField("avg8") + .withIntegerField("max8") + .withIntegerField("min8") + + .build(); } /** * expected results after {@link BeamAggregationTransforms.MergeAggregationRecord}. */ - private BeamRecord prepareResultOfMergeAggregationRecord() throws ParseException { - return new BeamRecord( - outputType, - Arrays.asList( + private Row prepareResultOfMergeAggregationRow() throws ParseException { + return Row + .withRowType(outputType) + .addValues( 1, 4L, 10000L, @@ -518,6 +554,7 @@ private BeamRecord prepareResultOfMergeAggregationRecord() throws ParseException 10, 2, 4, - 1)); + 1) + .build(); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamTransformBaseTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamTransformBaseTest.java index 5e5d82746b3f..9c3a7978dbea 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamTransformBaseTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/transform/BeamTransformBaseTest.java @@ -20,116 +20,81 @@ import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; -import java.util.Arrays; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; -import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner; -import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; -import org.apache.beam.sdk.values.BeamRecord; -import org.apache.beam.sdk.values.KV; -import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder; -import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.beam.sdk.extensions.sql.RowSqlType; +import org.apache.beam.sdk.extensions.sql.TestUtils; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.junit.BeforeClass; /** * shared methods to test PTransforms which execute Beam SQL steps. - * */ public class BeamTransformBaseTest { - public static DateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + static DateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - public static BeamRecordSqlType inputRowType; - public static List inputRows; + static RowType inputRowType; + static List inputRows; @BeforeClass - public static void prepareInput() throws NumberFormatException, ParseException{ - List> columnMetadata = Arrays.asList( - KV.of("f_int", SqlTypeName.INTEGER), KV.of("f_long", SqlTypeName.BIGINT), - KV.of("f_short", SqlTypeName.SMALLINT), KV.of("f_byte", SqlTypeName.TINYINT), - KV.of("f_float", SqlTypeName.FLOAT), KV.of("f_double", SqlTypeName.DOUBLE), - KV.of("f_string", SqlTypeName.VARCHAR), KV.of("f_timestamp", SqlTypeName.TIMESTAMP), - KV.of("f_int2", SqlTypeName.INTEGER) - ); - inputRowType = initTypeOfSqlRow(columnMetadata); - inputRows = - Arrays.asList( - initBeamSqlRow( - columnMetadata, - Arrays.asList( - 1, - 1000L, - Short.valueOf("1"), - Byte.valueOf("1"), - 1.0F, - 1.0, - "string_row1", - format.parse("2017-01-01 01:01:03"), - 1)), - initBeamSqlRow( - columnMetadata, - Arrays.asList( - 1, - 2000L, - Short.valueOf("2"), - Byte.valueOf("2"), - 2.0F, - 2.0, - "string_row2", - format.parse("2017-01-01 01:02:03"), - 2)), - initBeamSqlRow( - columnMetadata, - Arrays.asList( - 1, - 3000L, - Short.valueOf("3"), - Byte.valueOf("3"), - 3.0F, - 3.0, - "string_row3", - format.parse("2017-01-01 01:03:03"), - 3)), - initBeamSqlRow( - columnMetadata, - Arrays.asList( - 1, - 4000L, - Short.valueOf("4"), - Byte.valueOf("4"), - 4.0F, - 4.0, - "string_row4", - format.parse("2017-01-01 02:04:03"), - 4))); - } + public static void prepareInput() throws NumberFormatException, ParseException { + inputRowType = + RowSqlType + .builder() + .withIntegerField("f_int") + .withBigIntField("f_long") + .withSmallIntField("f_short") + .withTinyIntField("f_byte") + .withFloatField("f_float") + .withDoubleField("f_double") + .withVarcharField("f_string") + .withTimestampField("f_timestamp") + .withIntegerField("f_int2") + .build(); - /** - * create a {@code BeamSqlRowType} for given column metadata. - */ - public static BeamRecordSqlType initTypeOfSqlRow(List> columnMetadata){ - FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder(); - for (KV cm : columnMetadata) { - builder.add(cm.getKey(), cm.getValue()); - } - return CalciteUtils.toBeamRowType(builder.build()); - } - - /** - * Create an empty row with given column metadata. - */ - public static BeamRecord initBeamSqlRow(List> columnMetadata) { - return initBeamSqlRow(columnMetadata, Arrays.asList()); - } - - /** - * Create a row with given column metadata, and values for each column. - * - */ - public static BeamRecord initBeamSqlRow(List> columnMetadata, - List rowValues){ - BeamRecordSqlType rowType = initTypeOfSqlRow(columnMetadata); - - return new BeamRecord(rowType, rowValues); + inputRows = + TestUtils.RowsBuilder + .of(inputRowType) + .addRows( + 1, + 1000L, + Short.valueOf("1"), + Byte.valueOf("1"), + 1.0F, + 1.0, + "string_row1", + format.parse("2017-01-01 01:01:03"), + 1) + .addRows( + 1, + 2000L, + Short.valueOf("2"), + Byte.valueOf("2"), + 2.0F, + 2.0, + "string_row2", + format.parse("2017-01-01 01:02:03"), + 2) + .addRows( + 1, + 3000L, + Short.valueOf("3"), + Byte.valueOf("3"), + 3.0F, + 3.0, + "string_row3", + format.parse("2017-01-01 01:03:03"), + 3) + .addRows( + 1, + 4000L, + Short.valueOf("4"), + Byte.valueOf("4"), + 4.0F, + 4.0, + "string_row4", + format.parse("2017-01-01 02:04:03"), + 4) + .getRows(); } - } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/BigDecimalConverterTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/BigDecimalConverterTest.java similarity index 74% rename from sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/BigDecimalConverterTest.java rename to sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/BigDecimalConverterTest.java index c8099f6631b5..144fc8adcc70 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/BigDecimalConverterTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/BigDecimalConverterTest.java @@ -16,15 +16,15 @@ * limitations under the License. */ -package org.apache.beam.sdk.extensions.sql.impl.transform.agg; +package org.apache.beam.sdk.extensions.sql.impl.utils; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.NUMERIC_TYPES; import static org.junit.Assert.assertNotNull; -import com.google.common.collect.ImmutableSet; import java.math.BigDecimal; -import java.util.Set; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.calcite.sql.type.SqlTypeName; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -36,18 +36,9 @@ public class BigDecimalConverterTest { @Rule public ExpectedException thrown = ExpectedException.none(); - private static final Set NUMERIC_TYPES = ImmutableSet.of( - SqlTypeName.TINYINT, - SqlTypeName.SMALLINT, - SqlTypeName.INTEGER, - SqlTypeName.BIGINT, - SqlTypeName.FLOAT, - SqlTypeName.DOUBLE, - SqlTypeName.DECIMAL); - @Test public void testReturnsConverterForNumericTypes() { - for (SqlTypeName numericType : NUMERIC_TYPES) { + for (SqlTypeCoder numericType : NUMERIC_TYPES) { SerializableFunction converter = BigDecimalConverter.forSqlType(numericType); @@ -59,6 +50,6 @@ public void testReturnsConverterForNumericTypes() { @Test public void testThrowsForUnsupportedTypes() { thrown.expect(UnsupportedOperationException.class); - BigDecimalConverter.forSqlType(SqlTypeName.ARRAY); + BigDecimalConverter.forSqlType(SqlTypeCoders.VARCHAR); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlArithmeticOperatorsIntegrationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlArithmeticOperatorsIntegrationTest.java index 5e626a2ac1a7..e5409814d14b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlArithmeticOperatorsIntegrationTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlArithmeticOperatorsIntegrationTest.java @@ -29,7 +29,6 @@ public class BeamSqlArithmeticOperatorsIntegrationTest extends BeamSqlBuiltinFunctionsIntegrationTestBase { private static final BigDecimal ZERO = BigDecimal.valueOf(0.0); - private static final BigDecimal ONE0 = BigDecimal.valueOf(1); private static final BigDecimal ONE = BigDecimal.valueOf(1.0); private static final BigDecimal ONE2 = BigDecimal.valueOf(1.0).multiply(BigDecimal.valueOf(1.0)); private static final BigDecimal ONE10 = BigDecimal.ONE.divide( diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlBuiltinFunctionsIntegrationTestBase.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlBuiltinFunctionsIntegrationTestBase.java index 5997540099c5..bb1341f18b8a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlBuiltinFunctionsIntegrationTestBase.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlBuiltinFunctionsIntegrationTestBase.java @@ -18,24 +18,29 @@ package org.apache.beam.sdk.extensions.sql.integrationtest; +import static java.util.stream.Collectors.toList; +import static org.apache.beam.sdk.values.RowType.toRowType; + import com.google.common.base.Joiner; import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; -import java.sql.Types; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; import java.util.List; import java.util.Map; import java.util.TimeZone; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSql; +import org.apache.beam.sdk.extensions.sql.RowSqlType; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.util.Pair; import org.junit.Rule; @@ -43,21 +48,21 @@ * Base class for all built-in functions integration tests. */ public class BeamSqlBuiltinFunctionsIntegrationTestBase { - private static final Map JAVA_CLASS_TO_SQL_TYPE = ImmutableMap - . builder() - .put(Byte.class, Types.TINYINT) - .put(Short.class, Types.SMALLINT) - .put(Integer.class, Types.INTEGER) - .put(Long.class, Types.BIGINT) - .put(Float.class, Types.FLOAT) - .put(Double.class, Types.DOUBLE) - .put(BigDecimal.class, Types.DECIMAL) - .put(String.class, Types.VARCHAR) - .put(Date.class, Types.DATE) - .put(Boolean.class, Types.BOOLEAN) + private static final Map JAVA_CLASS_TO_CODER = ImmutableMap + .builder() + .put(Byte.class, SqlTypeCoders.TINYINT) + .put(Short.class, SqlTypeCoders.SMALLINT) + .put(Integer.class, SqlTypeCoders.INTEGER) + .put(Long.class, SqlTypeCoders.BIGINT) + .put(Float.class, SqlTypeCoders.FLOAT) + .put(Double.class, SqlTypeCoders.DOUBLE) + .put(BigDecimal.class, SqlTypeCoders.DECIMAL) + .put(String.class, SqlTypeCoders.VARCHAR) + .put(Date.class, SqlTypeCoders.DATE) + .put(Boolean.class, SqlTypeCoders.BOOLEAN) .build(); - private static final BeamRecordSqlType RECORD_SQL_TYPE = BeamRecordSqlType.builder() + private static final RowType ROW_TYPE = RowSqlType.builder() .withDateField("ts") .withTinyIntField("c_tinyint") .withSmallIntField("c_smallint") @@ -75,10 +80,10 @@ public class BeamSqlBuiltinFunctionsIntegrationTestBase { @Rule public final TestPipeline pipeline = TestPipeline.create(); - protected PCollection getTestPCollection() { + protected PCollection getTestPCollection() { try { return MockedBoundedTable - .of(RECORD_SQL_TYPE) + .of(ROW_TYPE) .addRows( parseDate("1986-02-15 11:35:26"), (byte) 1, @@ -94,7 +99,7 @@ protected PCollection getTestPCollection() { 9223372036854775807L ) .buildIOReader(pipeline) - .setCoder(RECORD_SQL_TYPE.getRecordCoder()); + .setCoder(ROW_TYPE.getRowCoder()); } catch (Exception e) { throw new RuntimeException(e); } @@ -145,23 +150,22 @@ private String getSql() { * Build the corresponding SQL, compile to Beam Pipeline, run it, and check the result. */ public void buildRunAndCheck() { - PCollection inputCollection = getTestPCollection(); + PCollection inputCollection = getTestPCollection(); System.out.println("SQL:>\n" + getSql()); try { - List names = new ArrayList<>(); - List types = new ArrayList<>(); - List values = new ArrayList<>(); + RowType rowType = + exps.stream() + .map(exp -> RowType.newField( + exp.getKey(), + JAVA_CLASS_TO_CODER.get(exp.getValue().getClass()))) + .collect(toRowType()); - for (Pair pair : exps) { - names.add(pair.getKey()); - types.add(JAVA_CLASS_TO_SQL_TYPE.get(pair.getValue().getClass())); - values.add(pair.getValue()); - } + List values = exps.stream().map(Pair::getValue).collect(toList()); - PCollection rows = inputCollection.apply(BeamSql.query(getSql())); + PCollection rows = inputCollection.apply(BeamSql.query(getSql())); PAssert.that(rows).containsInAnyOrder( TestUtils.RowsBuilder - .of(BeamRecordSqlType.create(names, types)) + .of(rowType) .addRows(values) .getRows() ); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlComparisonOperatorsIntegrationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlComparisonOperatorsIntegrationTest.java index a836f79a475a..f936070ba35f 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlComparisonOperatorsIntegrationTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlComparisonOperatorsIntegrationTest.java @@ -19,12 +19,11 @@ package org.apache.beam.sdk.extensions.sql.integrationtest; import java.math.BigDecimal; -import java.sql.Types; -import java.util.Arrays; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.junit.Test; /** @@ -281,31 +280,36 @@ public void testIsNullAndIsNotNull() throws Exception { checker.buildRunAndCheck(); } - @Override protected PCollection getTestPCollection() { - BeamRecordSqlType type = BeamRecordSqlType.create( - Arrays.asList( - "c_tinyint_0", "c_tinyint_1", "c_tinyint_2", - "c_smallint_0", "c_smallint_1", "c_smallint_2", - "c_integer_0", "c_integer_1", "c_integer_2", - "c_bigint_0", "c_bigint_1", "c_bigint_2", - "c_float_0", "c_float_1", "c_float_2", - "c_double_0", "c_double_1", "c_double_2", - "c_decimal_0", "c_decimal_1", "c_decimal_2", - "c_varchar_0", "c_varchar_1", "c_varchar_2", - "c_boolean_false", "c_boolean_true" - ), - Arrays.asList( - Types.TINYINT, Types.TINYINT, Types.TINYINT, - Types.SMALLINT, Types.SMALLINT, Types.SMALLINT, - Types.INTEGER, Types.INTEGER, Types.INTEGER, - Types.BIGINT, Types.BIGINT, Types.BIGINT, - Types.FLOAT, Types.FLOAT, Types.FLOAT, - Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, - Types.DECIMAL, Types.DECIMAL, Types.DECIMAL, - Types.VARCHAR, Types.VARCHAR, Types.VARCHAR, - Types.BOOLEAN, Types.BOOLEAN - ) - ); + @Override protected PCollection getTestPCollection() { + RowType type = RowSqlType.builder() + .withTinyIntField("c_tinyint_0") + .withTinyIntField("c_tinyint_1") + .withTinyIntField("c_tinyint_2") + .withSmallIntField("c_smallint_0") + .withSmallIntField("c_smallint_1") + .withSmallIntField("c_smallint_2") + .withIntegerField("c_integer_0") + .withIntegerField("c_integer_1") + .withIntegerField("c_integer_2") + .withBigIntField("c_bigint_0") + .withBigIntField("c_bigint_1") + .withBigIntField("c_bigint_2") + .withFloatField("c_float_0") + .withFloatField("c_float_1") + .withFloatField("c_float_2") + .withDoubleField("c_double_0") + .withDoubleField("c_double_1") + .withDoubleField("c_double_2") + .withDecimalField("c_decimal_0") + .withDecimalField("c_decimal_1") + .withDecimalField("c_decimal_2") + .withVarcharField("c_varchar_0") + .withVarcharField("c_varchar_1") + .withVarcharField("c_varchar_2") + .withBooleanField("c_boolean_false") + .withBooleanField("c_boolean_true") + .build(); + try { return MockedBoundedTable .of(type) @@ -321,7 +325,7 @@ public void testIsNullAndIsNotNull() throws Exception { false, true ) .buildIOReader(pipeline) - .setCoder(type.getRecordCoder()); + .setCoder(type.getRowCoder()); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlDateFunctionsIntegrationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlDateFunctionsIntegrationTest.java index ec5b2953a186..c6069033a8e2 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlDateFunctionsIntegrationTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/integrationtest/BeamSqlDateFunctionsIntegrationTest.java @@ -23,12 +23,11 @@ import java.util.Date; import java.util.Iterator; - import org.apache.beam.sdk.extensions.sql.BeamSql; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.Test; /** @@ -170,17 +169,17 @@ public class BeamSqlDateFunctionsIntegrationTest + "CURRENT_TIMESTAMP as c3" + " FROM PCOLLECTION" ; - PCollection rows = getTestPCollection().apply( + PCollection rows = getTestPCollection().apply( BeamSql.query(sql)); PAssert.that(rows).satisfies(new Checker()); pipeline.run(); } - private static class Checker implements SerializableFunction, Void> { - @Override public Void apply(Iterable input) { - Iterator iter = input.iterator(); + private static class Checker implements SerializableFunction, Void> { + @Override public Void apply(Iterable input) { + Iterator iter = input.iterator(); assertTrue(iter.hasNext()); - BeamRecord row = iter.next(); + Row row = iter.next(); // LOCALTIME Date date = new Date(); assertTrue(date.getTime() - row.getGregorianCalendar(0).getTime().getTime() < 1000); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java index d2e7d40b7f29..347b490a84ef 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java @@ -19,7 +19,6 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; import java.io.Serializable; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.testing.PAssert; @@ -27,13 +26,12 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.calcite.rel.type.RelProtoDataType; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.csv.CSVFormat; -import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -41,20 +39,21 @@ * Test for BeamKafkaCSVTable. */ public class BeamKafkaCSVTableTest { - @Rule - public TestPipeline pipeline = TestPipeline.create(); - public static BeamRecord row1; - public static BeamRecord row2; + @Rule public TestPipeline pipeline = TestPipeline.create(); - @BeforeClass - public static void setUp() { - row1 = new BeamRecord(genRowType(), 1L, 1, 1.0); + private static final Row ROW1 = + Row + .withRowType(genRowType()) + .addValues(1L, 1, 1.0) + .build(); - row2 = new BeamRecord(genRowType(), 2L, 2, 2.0); - } + private static final Row ROW2 = + Row.withRowType(genRowType()) + .addValues(2L, 2, 2.0) + .build(); @Test public void testCsvRecorderDecoder() throws Exception { - PCollection result = pipeline + PCollection result = pipeline .apply( Create.of("1,\"1\",1.0", "2,2,2.0") ) @@ -63,15 +62,15 @@ public static void setUp() { new BeamKafkaCSVTable.CsvRecorderDecoder(genRowType(), CSVFormat.DEFAULT) ); - PAssert.that(result).containsInAnyOrder(row1, row2); + PAssert.that(result).containsInAnyOrder(ROW1, ROW2); pipeline.run(); } @Test public void testCsvRecorderEncoder() throws Exception { - PCollection result = pipeline + PCollection result = pipeline .apply( - Create.of(row1, row2) + Create.of(ROW1, ROW2) ) .apply( new BeamKafkaCSVTable.CsvRecorderEncoder(genRowType(), CSVFormat.DEFAULT) @@ -79,21 +78,18 @@ public static void setUp() { new BeamKafkaCSVTable.CsvRecorderDecoder(genRowType(), CSVFormat.DEFAULT) ); - PAssert.that(result).containsInAnyOrder(row1, row2); + PAssert.that(result).containsInAnyOrder(ROW1, ROW2); pipeline.run(); } - private static BeamRecordSqlType genRowType() { + private static RowType genRowType() { return CalciteUtils.toBeamRowType( - ((RelProtoDataType) - a0 -> - a0.builder() - .add("order_id", SqlTypeName.BIGINT) - .add("site_id", SqlTypeName.INTEGER) - .add("price", SqlTypeName.DOUBLE) - .build()) - .apply(BeamQueryPlanner.TYPE_FACTORY)); + BeamQueryPlanner.TYPE_FACTORY.builder() + .add("order_id", SqlTypeName.BIGINT) + .add("site_id", SqlTypeName.INTEGER) + .add("price", SqlTypeName.DOUBLE) + .build()); } private static class String2KvBytes extends DoFn> diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProviderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProviderTest.java index a7c27193798b..10ed18a77cd7 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProviderTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTableProviderTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.INTEGER; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.VARCHAR; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -25,7 +27,6 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.ImmutableList; import java.net.URI; -import java.sql.Types; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.meta.Column; import org.apache.beam.sdk.extensions.sql.meta.Table; @@ -66,8 +67,8 @@ private static Table mockTable(String name) { .comment(name + " table") .location(URI.create("kafka://localhost:2181/brokers?topic=test")) .columns(ImmutableList.of( - Column.builder().name("id").type(Types.INTEGER).primaryKey(true).build(), - Column.builder().name("name").type(Types.VARCHAR).primaryKey(false).build() + Column.builder().name("id").coder(INTEGER).primaryKey(true).build(), + Column.builder().name("name").coder(VARCHAR).primaryKey(false).build() )) .type("kafka") .properties(properties) diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableTest.java index 39474d4bdc3c..42a941067871 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/BeamTextCSVTableTest.java @@ -28,19 +28,14 @@ import java.nio.file.Path; import java.nio.file.SimpleFileVisitor; import java.nio.file.attribute.BasicFileAttributes; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; -import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner; -import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelProtoDataType; -import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVPrinter; import org.junit.AfterClass; @@ -64,36 +59,46 @@ public class BeamTextCSVTableTest { * integer,bigint,float,double,string *

    */ + private static RowType rowType = + RowSqlType + .builder() + .withIntegerField("id") + .withBigIntField("order_id") + .withFloatField("price") + .withDoubleField("amount") + .withVarcharField("user_name") + .build(); private static Object[] data1 = new Object[] { 1, 1L, 1.1F, 1.1, "james" }; private static Object[] data2 = new Object[] { 2, 2L, 2.2F, 2.2, "bond" }; private static List testData = Arrays.asList(data1, data2); - private static List testDataRows = new ArrayList() {{ - for (Object[] data : testData) { - add(buildRow(data)); - } - }}; + private static List testDataRows = Arrays.asList( + Row.withRowType(rowType).addValues(data1).build(), + Row.withRowType(rowType).addValues(data2).build()); private static Path tempFolder; private static File readerSourceFile; private static File writerTargetFile; @Test public void testBuildIOReader() { - PCollection rows = new BeamTextCSVTable(buildBeamSqlRowType(), - readerSourceFile.getAbsolutePath()).buildIOReader(pipeline); + PCollection rows = + new BeamTextCSVTable(rowType, readerSourceFile.getAbsolutePath()) + .buildIOReader(pipeline); PAssert.that(rows).containsInAnyOrder(testDataRows); pipeline.run(); } @Test public void testBuildIOWriter() { - new BeamTextCSVTable(buildBeamSqlRowType(), - readerSourceFile.getAbsolutePath()).buildIOReader(pipeline) - .apply(new BeamTextCSVTable(buildBeamSqlRowType(), writerTargetFile.getAbsolutePath()) - .buildIOWriter()); + new BeamTextCSVTable(rowType, readerSourceFile.getAbsolutePath()) + .buildIOReader(pipeline) + .apply( + new BeamTextCSVTable(rowType, writerTargetFile.getAbsolutePath()) + .buildIOWriter()); pipeline.run(); - PCollection rows = new BeamTextCSVTable(buildBeamSqlRowType(), - writerTargetFile.getAbsolutePath()).buildIOReader(pipeline2); + PCollection rows = + new BeamTextCSVTable(rowType, writerTargetFile.getAbsolutePath()) + .buildIOReader(pipeline2); // confirm the two reads match PAssert.that(rows).containsInAnyOrder(testDataRows); @@ -147,29 +152,4 @@ private static void writeToStreamAndClose(List rows, OutputStream outp e.printStackTrace(); } } - - private RelProtoDataType buildRowType() { - return a0 -> - a0.builder() - .add("id", SqlTypeName.INTEGER) - .add("order_id", SqlTypeName.BIGINT) - .add("price", SqlTypeName.FLOAT) - .add("amount", SqlTypeName.DOUBLE) - .add("user_name", SqlTypeName.VARCHAR) - .build(); - } - - private static RelDataType buildRelDataType() { - return BeamQueryPlanner.TYPE_FACTORY.builder().add("id", SqlTypeName.INTEGER) - .add("order_id", SqlTypeName.BIGINT).add("price", SqlTypeName.FLOAT) - .add("amount", SqlTypeName.DOUBLE).add("user_name", SqlTypeName.VARCHAR).build(); - } - - private static BeamRecordSqlType buildBeamSqlRowType() { - return CalciteUtils.toBeamRowType(buildRelDataType()); - } - - private static BeamRecord buildRow(Object[] data) { - return new BeamRecord(buildBeamSqlRowType(), Arrays.asList(data)); - } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProviderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProviderTest.java index 86edd47a9714..030c785ca8f9 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProviderTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/text/TextTableProviderTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.extensions.sql.meta.provider.text; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.INTEGER; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.VARCHAR; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -24,7 +26,6 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.ImmutableList; import java.net.URI; -import java.sql.Types; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; import org.apache.beam.sdk.extensions.sql.meta.Column; import org.apache.beam.sdk.extensions.sql.meta.Table; @@ -77,8 +78,8 @@ private static Table mockTable(String name, String format) { .comment(name + " table") .location(URI.create("text://home/admin/" + name)) .columns(ImmutableList.of( - Column.builder().name("id").type(Types.INTEGER).primaryKey(true).build(), - Column.builder().name("name").type(Types.VARCHAR).primaryKey(false).build() + Column.builder().name("id").coder(INTEGER).primaryKey(true).build(), + Column.builder().name("name").coder(VARCHAR).primaryKey(false).build() )) .type("text") .properties(properties) diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStoreTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStoreTest.java index 2be5e8a4c8e5..e86ba54e3abb 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStoreTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/store/InMemoryMetaStoreTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.extensions.sql.meta.store; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.INTEGER; +import static org.apache.beam.sdk.extensions.sql.SqlTypeCoders.VARCHAR; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -25,11 +27,10 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.ImmutableList; import java.net.URI; -import java.sql.Types; import java.util.ArrayList; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.BeamSqlTable; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.extensions.sql.meta.Column; import org.apache.beam.sdk.extensions.sql.meta.Table; import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider; @@ -91,10 +92,7 @@ public void testGetTable_nullName() throws Exception { BeamSqlTable actualSqlTable = store.buildBeamSqlTable("hello"); assertNotNull(actualSqlTable); assertEquals( - BeamRecordSqlType.create( - ImmutableList.of("id", "name"), - ImmutableList.of(Types.INTEGER, Types.VARCHAR) - ), + RowSqlType.builder().withIntegerField("id").withVarcharField("name").build(), actualSqlTable.getRowType() ); } @@ -133,8 +131,8 @@ private static Table mockTable(String name, String type) { .comment(name + " table") .location(URI.create("text://home/admin/" + name)) .columns(ImmutableList.of( - Column.builder().name("id").type(Types.INTEGER).primaryKey(true).build(), - Column.builder().name("name").type(Types.VARCHAR).primaryKey(false).build() + Column.builder().name("id").coder(INTEGER).primaryKey(true).build(), + Column.builder().name("name").coder(VARCHAR).primaryKey(false).build() )) .type(type) .properties(new JSONObject()) @@ -165,6 +163,10 @@ public MockTableProvider(String type, String... names) { } + @Override public void dropTable(String tableName) { + + } + @Override public List
    listTables() { List
    ret = new ArrayList<>(names.length); for (String name : names) { diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedBoundedTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedBoundedTable.java index cf66268b13e8..ac4f2df0f95c 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedBoundedTable.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedBoundedTable.java @@ -25,28 +25,28 @@ import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.impl.schema.BeamIOType; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** * Mocked table for bounded data sources. */ public class MockedBoundedTable extends MockedTable { /** rows written to this table. */ - private static final ConcurrentLinkedQueue CONTENT = new ConcurrentLinkedQueue<>(); + private static final ConcurrentLinkedQueue CONTENT = new ConcurrentLinkedQueue<>(); /** rows flow out from this table. */ - private final List rows = new ArrayList<>(); + private final List rows = new ArrayList<>(); - public MockedBoundedTable(BeamRecordSqlType beamSqlRowType) { - super(beamSqlRowType); + public MockedBoundedTable(RowType beamRowType) { + super(beamRowType); } /** @@ -69,7 +69,7 @@ public static MockedBoundedTable of(final Object... args){ /** * Build a mocked bounded table with the specified type. */ - public static MockedBoundedTable of(final BeamRecordSqlType type) { + public static MockedBoundedTable of(final RowType type) { return new MockedBoundedTable(type); } @@ -88,7 +88,7 @@ public static MockedBoundedTable of(final BeamRecordSqlType type) { * } */ public MockedBoundedTable addRows(Object... args) { - List rows = buildRows(getRowType(), Arrays.asList(args)); + List rows = buildRows(getRowType(), Arrays.asList(args)); this.rows.addAll(rows); return this; } @@ -99,12 +99,12 @@ public BeamIOType getSourceType() { } @Override - public PCollection buildIOReader(Pipeline pipeline) { + public PCollection buildIOReader(Pipeline pipeline) { return PBegin.in(pipeline).apply( "MockedBoundedTable_Reader_" + COUNTER.incrementAndGet(), Create.of(rows)); } - @Override public PTransform, PDone> buildIOWriter() { + @Override public PTransform, PDone> buildIOWriter() { return new OutputStore(); } @@ -112,11 +112,11 @@ public PCollection buildIOReader(Pipeline pipeline) { * Keep output in {@code CONTENT} for validation. * */ - public static class OutputStore extends PTransform, PDone> { + public static class OutputStore extends PTransform, PDone> { @Override - public PDone expand(PCollection input) { - input.apply(ParDo.of(new DoFn() { + public PDone expand(PCollection input) { + input.apply(ParDo.of(new DoFn() { @ProcessElement public void processElement(ProcessContext c) { CONTENT.add(c.element()); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedTable.java index d661866bcd4e..adafee3483e5 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedTable.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedTable.java @@ -19,24 +19,24 @@ package org.apache.beam.sdk.extensions.sql.mock; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** * Base class for mocked table. */ public abstract class MockedTable extends BaseBeamTable { public static final AtomicInteger COUNTER = new AtomicInteger(); - public MockedTable(BeamRecordSqlType beamSqlRowType) { - super(beamSqlRowType); + public MockedTable(RowType beamRowType) { + super(beamRowType); } @Override - public PTransform, PDone> buildIOWriter() { + public PTransform, PDone> buildIOWriter() { throw new UnsupportedOperationException("buildIOWriter unsupported!"); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedUnboundedTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedUnboundedTable.java index 2e4790be587b..91ad4babdf6d 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedUnboundedTable.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/mock/MockedUnboundedTable.java @@ -22,12 +22,12 @@ import java.util.Arrays; import java.util.List; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; import org.apache.beam.sdk.extensions.sql.TestUtils; import org.apache.beam.sdk.extensions.sql.impl.schema.BeamIOType; import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.calcite.util.Pair; import org.joda.time.Duration; @@ -38,11 +38,11 @@ */ public class MockedUnboundedTable extends MockedTable { /** rows flow out from this table with the specified watermark instant. */ - private final List>> timestampedRows = new ArrayList<>(); + private final List>> timestampedRows = new ArrayList<>(); /** specify the index of column in the row which stands for the event time field. */ private int timestampField; - private MockedUnboundedTable(BeamRecordSqlType beamSqlRowType) { - super(beamSqlRowType); + private MockedUnboundedTable(RowType beamRowType) { + super(beamRowType); } /** @@ -82,7 +82,7 @@ public MockedUnboundedTable timestampColumnIndex(int idx) { * } */ public MockedUnboundedTable addRows(Duration duration, Object... args) { - List rows = TestUtils.buildRows(getRowType(), Arrays.asList(args)); + List rows = TestUtils.buildRows(getRowType(), Arrays.asList(args)); // record the watermark + rows this.timestampedRows.add(Pair.of(duration, rows)); return this; @@ -92,10 +92,10 @@ public MockedUnboundedTable addRows(Duration duration, Object... args) { return BeamIOType.UNBOUNDED; } - @Override public PCollection buildIOReader(Pipeline pipeline) { - TestStream.Builder values = TestStream.create(beamRecordSqlType.getRecordCoder()); + @Override public PCollection buildIOReader(Pipeline pipeline) { + TestStream.Builder values = TestStream.create(rowType.getRowCoder()); - for (Pair> pair : timestampedRows) { + for (Pair> pair : timestampedRows) { values = values.advanceWatermarkTo(new Instant(0).plus(pair.getKey())); for (int i = 0; i < pair.getValue().size(); i++) { values = values.addElements(TimestampedValue.of(pair.getValue().get(i), diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/BeamRecordAsserts.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java similarity index 63% rename from sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/BeamRecordAsserts.java rename to sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java index 6f2da2c84550..a79677a997e4 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/BeamRecordAsserts.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java @@ -23,35 +23,35 @@ import com.google.common.collect.Iterables; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** - * Contain helpers to assert {@link BeamRecord}s. + * Contain helpers to assert {@link Row}s. */ -public class BeamRecordAsserts { +public class RowAsserts { /** - * Asserts result contains single record with an int field. + * Asserts result contains single row with an int field. */ - public static SerializableFunction, Void> matchesScalar(int expected) { + public static SerializableFunction, Void> matchesScalar(int expected) { return records -> { - BeamRecord record = Iterables.getOnlyElement(records); - assertNotNull(record); - assertEquals(expected, (int) record.getInteger(0)); + Row row = Iterables.getOnlyElement(records); + assertNotNull(row); + assertEquals(expected, (int) row.getInteger(0)); return null; }; } /** - * Asserts result contains single record with a double field. + * Asserts result contains single row with a double field. */ - public static SerializableFunction, Void> matchesScalar( + public static SerializableFunction, Void> matchesScalar( double expected, double delta) { return input -> { - BeamRecord record = Iterables.getOnlyElement(input); - assertNotNull(record); - assertEquals(expected, record.getDouble(0), delta); + Row row = Iterables.getOnlyElement(input); + assertNotNull(row); + assertEquals(expected, row.getDouble(0), delta); return null; }; } diff --git a/sdks/java/fn-execution/pom.xml b/sdks/java/fn-execution/pom.xml index 52a9b98f3a02..0535f7e02289 100644 --- a/sdks/java/fn-execution/pom.xml +++ b/sdks/java/fn-execution/pom.xml @@ -111,16 +111,22 @@ junit test - + + + org.hamcrest + hamcrest-core + test + + org.hamcrest - hamcrest-all + hamcrest-library test org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java index 35abc4c2a4fc..4ddd5127f939 100644 --- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java +++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java @@ -126,7 +126,8 @@ public int read(byte[] b, int off, int len) throws IOException { * first access to {@link #next()} or {@link #hasNext()}. */ public static class DataStreamDecoder implements Iterator { - private enum State { READ_REQUIRED, HAS_NEXT, EOF }; + + private enum State { READ_REQUIRED, HAS_NEXT, EOF } private final CountingInputStream countingInputStream; private final PushbackInputStream pushbackInputStream; diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/SynchronizedStreamObserver.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/SynchronizedStreamObserver.java new file mode 100644 index 000000000000..62bd46aa228f --- /dev/null +++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/SynchronizedStreamObserver.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.fn.stream; + +import io.grpc.stub.StreamObserver; + +/** + * A {@link StreamObserver} which provides synchronous access access to an underlying {@link + * StreamObserver}. + * + *

    The underlying {@link StreamObserver} must not be used by any other clients. + */ +public class SynchronizedStreamObserver implements StreamObserver { + private final StreamObserver underlying; + + private SynchronizedStreamObserver(StreamObserver underlying) { + this.underlying = underlying; + } + + /** + * Create a new {@link SynchronizedStreamObserver} which will delegate all calls to the underlying + * {@link StreamObserver}, synchronizing access to that observer. + */ + public static StreamObserver wrapping(StreamObserver underlying) { + return new SynchronizedStreamObserver<>(underlying); + } + + @Override + public void onNext(V value) { + synchronized (underlying) { + underlying.onNext(value); + } + } + + @Override + public synchronized void onError(Throwable t) { + synchronized (underlying) { + underlying.onError(t); + } + } + + @Override + public synchronized void onCompleted() { + synchronized (underlying) { + underlying.onCompleted(); + } + } +} diff --git a/sdks/java/harness/pom.xml b/sdks/java/harness/pom.xml index f82c1600414e..6d830174995f 100644 --- a/sdks/java/harness/pom.xml +++ b/sdks/java/harness/pom.xml @@ -266,10 +266,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + junit junit @@ -278,7 +284,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index cf3a227d49be..f7dcb650b4b0 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -512,7 +512,7 @@ public DoFn.OnTimerContext onTimerContext(DoFn } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException( "Cannot access RestrictionTracker outside of @ProcessElement method."); } @@ -569,7 +569,7 @@ public OnTimerContext onTimerContext(DoFn doFn) { } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); } @@ -728,7 +728,7 @@ public DoFn.OnTimerContext onTimerContext(DoFn } @Override - public RestrictionTracker restrictionTracker() { + public RestrictionTracker restrictionTracker() { throw new UnsupportedOperationException( "Cannot access RestrictionTracker outside of @ProcessElement method."); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunner.java new file mode 100644 index 000000000000..33b378786581 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunner.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness; + +import static com.google.common.collect.Iterables.getOnlyElement; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver; +import org.apache.beam.fn.harness.fn.ThrowingFunction; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.harness.state.BeamFnStateClient; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; + +/** + * A {@code PTransformRunner} which executes simple map functions. + * + *

    Simple map functions are used in a large number of transforms, especially runner-managed + * transforms, such as map_windows. + * + *

    TODO: Add support for DoFns which are actually user supplied map/lambda functions instead + * of using the {@link FnApiDoFnRunner} instance. + */ +public class MapFnRunner { + + public static PTransformRunnerFactory + createMapFnRunnerFactoryWith( + CreateMapFunctionForPTransform fnFactory) { + return new Factory<>(fnFactory); + } + + /** A function factory which given a PTransform returns a map function. */ + public interface CreateMapFunctionForPTransform { + ThrowingFunction createMapFunctionForPTransform( + String ptransformId, + PTransform pTransform) throws IOException; + } + + /** A factory for {@link MapFnRunner}s. */ + static class Factory + implements PTransformRunnerFactory> { + + private final CreateMapFunctionForPTransform fnFactory; + + Factory(CreateMapFunctionForPTransform fnFactory) { + this.fnFactory = fnFactory; + } + + @Override + public MapFnRunner createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + BeamFnStateClient beamFnStateClient, + String pTransformId, + PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Map windowingStrategies, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + + Collection>> consumers = + (Collection) pCollectionIdsToConsumers.get( + getOnlyElement(pTransform.getOutputsMap().values())); + + MapFnRunner runner = new MapFnRunner<>( + fnFactory.createMapFunctionForPTransform(pTransformId, pTransform), + MultiplexingFnDataReceiver.forConsumers(consumers)); + + pCollectionIdsToConsumers.put( + Iterables.getOnlyElement(pTransform.getInputsMap().values()), + (FnDataReceiver) (FnDataReceiver>) runner::map); + return runner; + } + } + + private final ThrowingFunction mapFunction; + private final FnDataReceiver> consumer; + + MapFnRunner( + ThrowingFunction mapFunction, + FnDataReceiver> consumer) { + this.mapFunction = mapFunction; + this.consumer = consumer; + } + + public void map(WindowedValue element) throws Exception { + WindowedValue output = element.withValue(mapFunction.apply(element.getValue())); + consumer.accept(output); + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/WindowMappingFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/WindowMappingFnRunner.java new file mode 100644 index 000000000000..f6e1deffd434 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/WindowMappingFnRunner.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness; + +import static org.apache.beam.runners.core.construction.UrnUtils.validateCommonUrn; + +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.Map; +import org.apache.beam.fn.harness.fn.ThrowingFunction; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec; +import org.apache.beam.runners.core.construction.PCollectionViewTranslation; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; +import org.apache.beam.sdk.values.KV; + +/** + * Maps windows using a window mapping fn. The input is {@link KV} with the key being a nonce + * and the value being a window, the output must be a {@link KV} with the key being the same nonce + * as the input and the value being the mapped window. + */ +public class WindowMappingFnRunner { + static final String URN = validateCommonUrn("beam:transform:map_windows:v1"); + + /** + * A registrar which provides a factory to handle mapping main input windows onto side input + * windows. + */ + @AutoService(PTransformRunnerFactory.Registrar.class) + public static class Registrar implements PTransformRunnerFactory.Registrar { + + @Override + public Map getPTransformRunnerFactories() { + return ImmutableMap.of(URN, MapFnRunner.createMapFnRunnerFactoryWith( + WindowMappingFnRunner::createMapFunctionForPTransform)); + } + } + + static + ThrowingFunction, KV> createMapFunctionForPTransform( + String ptransformId, PTransform pTransform) throws IOException { + SdkFunctionSpec windowMappingFnPayload = + SdkFunctionSpec.parseFrom(pTransform.getSpec().getPayload()); + WindowMappingFn windowMappingFn = + (WindowMappingFn) PCollectionViewTranslation.windowMappingFnFromProto( + windowMappingFnPayload); + return (KV input) -> + KV.of(input.getKey(), windowMappingFn.getSideInputWindow(input.getValue())); + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserver.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserver.java index 2b67bee11fc4..51856c3a8c5d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserver.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserver.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.BufferedElementCountingOutputStream; import org.apache.beam.sdk.util.WindowedValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,7 +74,8 @@ public BeamFnDataBufferingOutboundObserver( this.outputLocation = outputLocation; this.coder = coder; this.outboundObserver = outboundObserver; - this.bufferedElements = ByteString.newOutput(); + this.bufferedElements = ByteString.newOutput( + BufferedElementCountingOutputStream.DEFAULT_BUFFER_SIZE); } /** diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 683314afb011..014695f07bad 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -59,7 +59,9 @@ public static Iterator forFirstChunk( * pre-fetch any future chunks and blocks whenever required to fetch the next block. */ static class LazyBlockingStateFetchingIterator implements Iterator { - private enum State { READ_REQUIRED, HAS_NEXT, EOF }; + + private enum State { READ_REQUIRED, HAS_NEXT, EOF } + private final BeamFnStateClient beamFnStateClient; private final StateRequest stateRequestForFirstChunk; private State currentState; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java index e0910582d22e..8f5a592cc5b8 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java @@ -107,6 +107,7 @@ public class BeamFnDataReadRunnerTest { throw new ExceptionInInitializerError(e); } } + private static final BeamFnApi.Target INPUT_TARGET = BeamFnApi.Target.newBuilder() .setPrimitiveTransformReference("1") .setName("out") diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index aaf2b3fc9a37..68d42894ea12 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -102,6 +102,7 @@ public class BeamFnDataWriteRunnerTest { throw new ExceptionInInitializerError(e); } } + private static final BeamFnApi.Target OUTPUT_TARGET = BeamFnApi.Target.newBuilder() .setPrimitiveTransformReference("1") .setName("out") diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnerTest.java new file mode 100644 index 000000000000..09b9b6b72970 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnerTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.beam.fn.harness.fn.ThrowingFunction; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.WindowedValue; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link MapFnRunner}. */ +@RunWith(JUnit4.class) +public class MapFnRunnerTest { + private static final String EXPECTED_ID = "pTransformId"; + private static final RunnerApi.PTransform EXPECTED_PTRANSFORM = RunnerApi.PTransform.newBuilder() + .putInputs("input", "inputPC") + .putOutputs("output", "outputPC") + .build(); + + @Test + public void testWindowMapping() throws Exception { + + List> outputConsumer = new ArrayList<>(); + Multimap>> consumers = HashMultimap.create(); + consumers.put("outputPC", outputConsumer::add); + + List startFunctions = new ArrayList<>(); + List finishFunctions = new ArrayList<>(); + + new MapFnRunner.Factory<>(this::createMapFunctionForPTransform) + .createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + null /* beamFnStateClient */, + EXPECTED_ID, + EXPECTED_PTRANSFORM, + Suppliers.ofInstance("57L")::get, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + consumers, + startFunctions::add, + finishFunctions::add); + + assertThat(startFunctions, empty()); + assertThat(finishFunctions, empty()); + + assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); + + Iterables.getOnlyElement( + consumers.get("inputPC")).accept(valueInGlobalWindow("abc")); + + assertThat(outputConsumer, contains(valueInGlobalWindow("ABC"))); + } + + public ThrowingFunction createMapFunctionForPTransform(String ptransformId, + PTransform pTransform) throws IOException { + assertEquals(EXPECTED_ID, ptransformId); + assertEquals(EXPECTED_PTRANSFORM, pTransform); + return (String str) -> str.toUpperCase(); + } +} diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/WindowMappingFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/WindowMappingFnRunnerTest.java new file mode 100644 index 000000000000..1aa6e15ceac3 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/WindowMappingFnRunnerTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness; + +import static org.junit.Assert.assertEquals; + +import org.apache.beam.fn.harness.fn.ThrowingFunction; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.core.construction.SdkComponents; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link WindowMappingFnRunner}. */ +@RunWith(JUnit4.class) +public class WindowMappingFnRunnerTest { + @Test + public void testWindowMapping() throws Exception { + String pTransformId = "pTransformId"; + + RunnerApi.FunctionSpec functionSpec = + RunnerApi.FunctionSpec.newBuilder() + .setUrn(WindowMappingFnRunner.URN) + .setPayload( + ParDoTranslation.translateWindowMappingFn( + new GlobalWindows().getDefaultWindowMappingFn(), + SdkComponents.create() + ).toByteString()) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .build(); + + + ThrowingFunction, KV> mapFunction = + WindowMappingFnRunner.createMapFunctionForPTransform(pTransformId, pTransform); + + KV input = + KV.of("abc", new IntervalWindow(Instant.now(), Duration.standardMinutes(1))); + + assertEquals( + KV.of(input.getKey(), GlobalWindow.INSTANCE), + mapFunction.apply(input)); + } +} diff --git a/sdks/java/io/amazon-web-services/pom.xml b/sdks/java/io/amazon-web-services/pom.xml index c68fb2b3bb33..5fdf6a71b17d 100644 --- a/sdks/java/io/amazon-web-services/pom.xml +++ b/sdks/java/io/amazon-web-services/pom.xml @@ -144,13 +144,19 @@ org.hamcrest - hamcrest-all + hamcrest-core provided + + org.hamcrest + hamcrest-library + provided + + org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java index 5adf42a11749..05fdda877186 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java +++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java @@ -27,6 +27,7 @@ import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.amazonaws.services.s3.model.AmazonS3Exception; import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest; +import com.amazonaws.services.s3.model.CopyObjectRequest; import com.amazonaws.services.s3.model.CopyPartRequest; import com.amazonaws.services.s3.model.CopyPartResult; import com.amazonaws.services.s3.model.DeleteObjectsRequest; @@ -47,8 +48,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -64,6 +63,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -73,6 +73,7 @@ import org.apache.beam.sdk.io.aws.options.S3Options; import org.apache.beam.sdk.io.fs.CreateOptions; import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.util.MoreFutures; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,6 +87,9 @@ class S3FileSystem extends FileSystem { Runtime.getRuntime().maxMemory() < 512 * 1024 * 1024 ? MINIMUM_UPLOAD_BUFFER_SIZE_BYTES : 64 * 1024 * 1024; + // Amazon S3 API: You can create a copy of your object up to 5 GB in a single atomic operation + // Ref. https://docs.aws.amazon.com/AmazonS3/latest/dev/CopyingObjectsExamples.html + private static final int MAX_COPY_OBJECT_SIZE_BYTES = 5 * 1024 * 1024 * 1024; // S3 API, delete-objects: "You may specify up to 1000 keys." private static final int MAX_DELETE_OBJECTS_PER_REQUEST = 1000; @@ -475,8 +479,7 @@ protected ReadableByteChannel open(S3ResourceId resourceId) throws IOException { } @Override - protected void copy( - List sourcePaths, List destinationPaths) + protected void copy(List sourcePaths, List destinationPaths) throws IOException { checkArgument( sourcePaths.size() == destinationPaths.size(), @@ -502,28 +505,47 @@ protected void copy( @VisibleForTesting void copy(S3ResourceId sourcePath, S3ResourceId destinationPath) throws IOException { - String uploadId; - long objectSize; try { ObjectMetadata objectMetadata = amazonS3.getObjectMetadata(sourcePath.getBucket(), sourcePath.getKey()); - objectSize = objectMetadata.getContentLength(); - - InitiateMultipartUploadRequest initiateUploadRequest = - new InitiateMultipartUploadRequest(destinationPath.getBucket(), destinationPath.getKey()) - .withStorageClass(storageClass) - .withObjectMetadata(objectMetadata); - - InitiateMultipartUploadResult initiateUploadResult = - amazonS3.initiateMultipartUpload(initiateUploadRequest); - uploadId = initiateUploadResult.getUploadId(); - + if (objectMetadata.getContentLength() < MAX_COPY_OBJECT_SIZE_BYTES) { + atomicCopy(sourcePath, destinationPath); + } else { + multipartCopy(sourcePath, destinationPath, objectMetadata); + } } catch (AmazonClientException e) { throw new IOException(e); } + } - List eTags = new ArrayList<>(); + private void atomicCopy(S3ResourceId sourcePath, S3ResourceId destinationPath) + throws AmazonClientException { + CopyObjectRequest copyObjectRequest = + new CopyObjectRequest( + sourcePath.getBucket(), + sourcePath.getKey(), + destinationPath.getBucket(), + destinationPath.getKey()); + copyObjectRequest.setStorageClass(storageClass); + + amazonS3.copyObject(copyObjectRequest); + } + + private void multipartCopy( + S3ResourceId sourcePath, S3ResourceId destinationPath, ObjectMetadata objectMetadata) + throws AmazonClientException { + InitiateMultipartUploadRequest initiateUploadRequest = + new InitiateMultipartUploadRequest(destinationPath.getBucket(), destinationPath.getKey()) + .withStorageClass(storageClass) + .withObjectMetadata(objectMetadata); + + InitiateMultipartUploadResult initiateUploadResult = + amazonS3.initiateMultipartUpload(initiateUploadRequest); + String uploadId = initiateUploadResult.getUploadId(); + final long objectSize = objectMetadata.getContentLength(); + + List eTags = new ArrayList<>(); long bytePosition = 0; // Amazon parts are 1-indexed, not zero-indexed. @@ -539,12 +561,7 @@ void copy(S3ResourceId sourcePath, S3ResourceId destinationPath) throws IOExcept .withFirstByte(bytePosition) .withLastByte(Math.min(objectSize - 1, bytePosition + s3UploadBufferSizeBytes - 1)); - CopyPartResult copyPartResult; - try { - copyPartResult = amazonS3.copyPart(copyPartRequest); - } catch (AmazonClientException e) { - throw new IOException(e); - } + CopyPartResult copyPartResult = amazonS3.copyPart(copyPartRequest); eTags.add(copyPartResult.getPartETag()); bytePosition += s3UploadBufferSizeBytes; @@ -556,12 +573,7 @@ void copy(S3ResourceId sourcePath, S3ResourceId destinationPath) throws IOExcept .withKey(destinationPath.getKey()) .withUploadId(uploadId) .withPartETags(eTags); - - try { - amazonS3.completeMultipartUpload(completeUploadRequest); - } catch (AmazonClientException e) { - throw new IOException(e); - } + amazonS3.completeMultipartUpload(completeUploadRequest); } @Override @@ -637,11 +649,11 @@ protected S3ResourceId matchNewResource(String singleResourceSpec, boolean isDir private List callTasks(Collection> tasks) throws IOException { try { - List> futures = new ArrayList<>(tasks.size()); + List> futures = new ArrayList<>(tasks.size()); for (Callable task : tasks) { - futures.add(executorService.submit(task)); + futures.add(MoreFutures.supplyAsync(() -> task.call(), executorService)); } - return Futures.allAsList(futures).get(); + return MoreFutures.get(MoreFutures.allAsList(futures)); } catch (ExecutionException e) { if (e.getCause() != null) { diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java index 931e9d0a7478..d39f26ae4518 100644 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java @@ -33,6 +33,8 @@ import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.model.AmazonS3Exception; import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest; +import com.amazonaws.services.s3.model.CopyObjectRequest; +import com.amazonaws.services.s3.model.CopyObjectResult; import com.amazonaws.services.s3.model.CopyPartRequest; import com.amazonaws.services.s3.model.CopyPartResult; import com.amazonaws.services.s3.model.DeleteObjectsRequest; @@ -100,6 +102,22 @@ public void testCopyMultipleParts() throws IOException { AmazonS3 mockAmazonS3 = Mockito.mock(AmazonS3.class); s3FileSystem.setAmazonS3Client(mockAmazonS3); + // Test atomic copy + S3ResourceId smallSourcePath = S3ResourceId.fromUri("s3://bucket/small"); + S3ResourceId smallDestinationPath = S3ResourceId.fromUri("s3://bucket/tosmall"); + + ObjectMetadata smallobjectMetadata = new ObjectMetadata(); + smallobjectMetadata.setContentLength(0); + when(mockAmazonS3.getObjectMetadata(smallSourcePath.getBucket(), smallSourcePath.getKey())) + .thenReturn(smallobjectMetadata); + + s3FileSystem.copy(smallSourcePath, smallDestinationPath); + + CopyObjectResult copyObjectResult = new CopyObjectResult(); + when(mockAmazonS3.copyObject(argThat(notNullValue(CopyObjectRequest.class)))) + .thenReturn(copyObjectResult); + + // Test multi-part copy S3ResourceId sourcePath = S3ResourceId.fromUri("s3://bucket/from"); S3ResourceId destinationPath = S3ResourceId.fromUri("s3://bucket/to"); @@ -107,12 +125,11 @@ public void testCopyMultipleParts() throws IOException { new InitiateMultipartUploadResult(); initiateMultipartUploadResult.setUploadId("upload-id"); when(mockAmazonS3.initiateMultipartUpload( - argThat(notNullValue(InitiateMultipartUploadRequest.class)))) + argThat(notNullValue(InitiateMultipartUploadRequest.class)))) .thenReturn(initiateMultipartUploadResult); ObjectMetadata sourceS3ObjectMetadata = new ObjectMetadata(); - sourceS3ObjectMetadata - .setContentLength((long) (s3FileSystem.getS3UploadBufferSizeBytes() * 1.5)); + sourceS3ObjectMetadata.setContentLength((long) 5 * 1024 * 1024 * 1024); sourceS3ObjectMetadata.setContentEncoding("read-seek-efficient"); when(mockAmazonS3.getObjectMetadata(sourcePath.getBucket(), sourcePath.getKey())) .thenReturn(sourceS3ObjectMetadata); diff --git a/sdks/java/io/amqp/pom.xml b/sdks/java/io/amqp/pom.xml index 2654ccabc9bb..90cb4f13a7d9 100644 --- a/sdks/java/io/amqp/pom.xml +++ b/sdks/java/io/amqp/pom.xml @@ -92,7 +92,12 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test @@ -120,4 +125,4 @@ - \ No newline at end of file + diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java index 5a552600168f..7b6335c66a28 100644 --- a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java @@ -18,12 +18,10 @@ package org.apache.beam.sdk.io.amqp; import com.google.common.io.ByteStreams; - import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.BufferOverflowException; - import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.util.VarInt; diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java index bc3445cf9781..6a85f28e2a0f 100644 --- a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java @@ -19,9 +19,7 @@ import com.google.auto.service.AutoService; import com.google.common.collect.ImmutableList; - import java.util.List; - import org.apache.beam.sdk.coders.CoderProvider; import org.apache.beam.sdk.coders.CoderProviderRegistrar; import org.apache.beam.sdk.coders.CoderProviders; diff --git a/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java b/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java index 7a8efeb61c1c..f02a1b54b68d 100644 --- a/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java +++ b/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java @@ -20,9 +20,7 @@ import static org.junit.Assert.assertEquals; import com.google.common.base.Joiner; - import java.util.Collections; - import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.util.CoderUtils; import org.apache.qpid.proton.amqp.messaging.AmqpValue; diff --git a/sdks/java/io/cassandra/pom.xml b/sdks/java/io/cassandra/pom.xml index cb042df3a78a..7aa49119c257 100644 --- a/sdks/java/io/cassandra/pom.xml +++ b/sdks/java/io/cassandra/pom.xml @@ -84,7 +84,7 @@ org.hamcrest - hamcrest-all + hamcrest-library test @@ -99,7 +99,7 @@ org.mockito - mockito-all + mockito-core test @@ -110,4 +110,4 @@ - \ No newline at end of file + diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java index 183805e4b9d6..977b995d9931 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java @@ -35,8 +35,6 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * An IO to read from Apache Cassandra. @@ -82,8 +80,6 @@ @Experimental(Experimental.Kind.SOURCE_SINK) public class CassandraIO { - private static final Logger LOG = LoggerFactory.getLogger(CassandraIO.class); - private CassandraIO() {} /** diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraService.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraService.java index 50717621fe94..4b6015e2f40e 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraService.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraService.java @@ -19,7 +19,6 @@ import java.io.Serializable; import java.util.List; - import org.apache.beam.sdk.io.BoundedSource; /** diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java index a196361cfe58..f6f8aa1e8fa8 100644 --- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java +++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java @@ -44,14 +44,10 @@ import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** Tests of {@link CassandraIO}. */ public class CassandraIOTest implements Serializable { - private static final Logger LOG = LoggerFactory.getLogger(CassandraIOTest.class); - @Rule public transient TestPipeline pipeline = TestPipeline.create(); @Test diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java index 6a68e9057a5f..1b27dc2e5dc4 100644 --- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java +++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java @@ -22,11 +22,9 @@ import com.datastax.driver.core.Cluster; import com.datastax.driver.core.Metadata; - import java.math.BigInteger; import java.util.ArrayList; import java.util.List; - import org.junit.Test; import org.mockito.Mockito; import org.slf4j.Logger; diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraTestDataSet.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraTestDataSet.java index 461f5eac490d..55a3fc182ddb 100644 --- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraTestDataSet.java +++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraTestDataSet.java @@ -19,7 +19,6 @@ import com.datastax.driver.core.Cluster; import com.datastax.driver.core.Session; - import org.apache.beam.sdk.io.common.IOTestPipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.slf4j.Logger; diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/pom.xml b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/pom.xml index 1dd90275ae73..6fd24c49241f 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/pom.xml +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/pom.xml @@ -75,6 +75,22 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + analyze-only + + + + org.hamcrest:hamcrest-all:jar:${hamcrest.version} + + + + + org.apache.maven.plugins maven-surefire-plugin diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/pom.xml b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/pom.xml index 1187f22fb7bd..ec3ce91fb95d 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/pom.xml +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/pom.xml @@ -74,4 +74,4 @@ - \ No newline at end of file + diff --git a/sdks/java/io/elasticsearch-tests/pom.xml b/sdks/java/io/elasticsearch-tests/pom.xml index 0e42bcee8501..f30e1334dc97 100644 --- a/sdks/java/io/elasticsearch-tests/pom.xml +++ b/sdks/java/io/elasticsearch-tests/pom.xml @@ -84,7 +84,13 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + + org.hamcrest + hamcrest-library test diff --git a/sdks/java/io/elasticsearch/pom.xml b/sdks/java/io/elasticsearch/pom.xml index 9060041dc843..b4a2e4c8b891 100644 --- a/sdks/java/io/elasticsearch/pom.xml +++ b/sdks/java/io/elasticsearch/pom.xml @@ -104,4 +104,4 @@ - \ No newline at end of file + diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java index bd10e9bcb6ee..c309d666d9a0 100644 --- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java +++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java @@ -472,6 +472,7 @@ private BoundedElasticsearchSource(Read spec, @Nullable String shardPreference, this.numSlices = numSlices; this.sliceId = sliceId; } + @Override public List> split( long desiredBundleSizeBytes, PipelineOptions options) throws Exception { @@ -605,15 +606,12 @@ public boolean start() throws IOException { if (query == null) { query = "{\"query\": { \"match_all\": {} }}"; } - if (source.backendVersion == 5){ - //if there is more than one slice - if (source.numSlices != null && source.numSlices > 1){ - // add slice to the user query - String sliceQuery = String - .format("\"slice\": {\"id\": %s,\"max\": %s}", source.sliceId, - source.numSlices); - query = query.replaceFirst("\\{", "{" + sliceQuery + ","); - } + if (source.backendVersion == 5 && source.numSlices != null && source.numSlices > 1){ + //if there is more than one slice, add the slice to the user query + String sliceQuery = String + .format("\"slice\": {\"id\": %s,\"max\": %s}", source.sliceId, + source.numSlices); + query = query.replaceFirst("\\{", "{" + sliceQuery + ","); } Response response; String endPoint = @@ -861,6 +859,7 @@ public void closeClient() throws Exception { } } } + static int getBackendVersion(ConnectionConfiguration connectionConfiguration) { try (RestClient restClient = connectionConfiguration.createClient()) { Response response = restClient.performRequest("GET", ""); diff --git a/sdks/java/io/file-based-io-tests/pom.xml b/sdks/java/io/file-based-io-tests/pom.xml index f44c63de73f8..6f6d40a93fd1 100644 --- a/sdks/java/io/file-based-io-tests/pom.xml +++ b/sdks/java/io/file-based-io-tests/pom.xml @@ -326,9 +326,14 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + org.apache.beam beam-sdks-java-io-common diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml index 7aae03b90c77..9b783ee361dc 100644 --- a/sdks/java/io/google-cloud-platform/pom.xml +++ b/sdks/java/io/google-cloud-platform/pom.xml @@ -382,13 +382,19 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index a70f25d4d42e..a9835140ec7e 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -134,17 +134,19 @@ * into a custom type using a specified parse function, and by {@link #readTableRows} which parses * them into {@link TableRow}, which may be more convenient but has lower performance. * - *

    Both functions support reading either from a table or from the result of a query, via - * {@link TypedRead#from(String)} and {@link TypedRead#fromQuery} respectively. Exactly one - * of these must be specified. + *

    Both functions support reading either from a table or from the result of a query, via {@link + * TypedRead#from(String)} and {@link TypedRead#fromQuery} respectively. Exactly one of these must + * be specified. + * + *

    Example: Reading rows of a table as {@link TableRow}. * - * Example: Reading rows of a table as {@link TableRow}. *

    {@code
      * PCollection weatherData = pipeline.apply(
      *     BigQueryIO.readTableRows().from("clouddataflow-readonly:samples.weather_stations"));
      * }
    * * Example: Reading rows of a table and parsing them into a custom type. + * *
    {@code
      * PCollection weatherData = pipeline.apply(
      *    BigQueryIO
    @@ -157,11 +159,12 @@
      *      .withCoder(SerializableCoder.of(WeatherRecord.class));
      * }
    * - *

    Note: When using {@link #read(SerializableFunction)}, you may sometimes need to use - * {@link TypedRead#withCoder(Coder)} to specify a {@link Coder} for the result type, if Beam - * fails to infer it automatically. + *

    Note: When using {@link #read(SerializableFunction)}, you may sometimes need to use {@link + * TypedRead#withCoder(Coder)} to specify a {@link Coder} for the result type, if Beam fails to + * infer it automatically. + * + *

    Example: Reading results of a query as {@link TableRow}. * - * Example: Reading results of a query as {@link TableRow}. *

    {@code
      * PCollection meanTemperatureData = pipeline.apply(BigQueryIO.readTableRows()
      *     .fromQuery("SELECT year, mean_temp FROM [samples.weather_stations]"));
    @@ -169,23 +172,27 @@
      *
      * 

    Writing

    * - *

    To write to a BigQuery table, apply a {@link BigQueryIO.Write} transformation. This consumes - * either a {@link PCollection} of {@link TableRow TableRows} as input when using {@link - * BigQueryIO#writeTableRows()} or of a user-defined type when using {@link BigQueryIO#write()}. - * When using a user-defined type, a function must be provided to turn this type into a {@link - * TableRow} using {@link BigQueryIO.Write#withFormatFunction(SerializableFunction)}. + *

    To write to a BigQuery table, apply a {@link BigQueryIO.Write} transformation. This consumes a + * {@link PCollection} of a user-defined type when using {@link BigQueryIO#write()} (recommended), + * or a {@link PCollection} of {@link TableRow TableRows} as input when using {@link + * BigQueryIO#writeTableRows()} (not recommended). When using a user-defined type, a function must + * be provided to turn this type into a {@link TableRow} using {@link + * BigQueryIO.Write#withFormatFunction(SerializableFunction)}. * *

    {@code
    - * PCollection quotes = ...
    + * class Quote { Instant timestamp; String exchange; String symbol; double price; }
      *
    - * List fields = new ArrayList<>();
    - * fields.add(new TableFieldSchema().setName("source").setType("STRING"));
    - * fields.add(new TableFieldSchema().setName("quote").setType("STRING"));
    - * TableSchema schema = new TableSchema().setFields(fields);
    + * PCollection quotes = ...
      *
    - * quotes.apply(BigQueryIO.writeTableRows()
    - *     .to("my-project:output.output_table")
    - *     .withSchema(schema)
    + * quotes.apply(BigQueryIO.write()
    + *     .to("my-project:my_dataset.my_table")
    + *     .withSchema(new TableSchema().setFields(
    + *         ImmutableList.of(
    + *           new TableFieldSchema().setName("timestamp").setType("TIMESTAMP"),
    + *           new TableFieldSchema().setName("exchange").setType("STRING"),
    + *           new TableFieldSchema().setName("symbol").setType("STRING"),
    + *           new TableFieldSchema().setName("price").setType("FLOAT"))))
    + *     .withFormatFunction(quote -> new TableRow().set(..set the columns..))
      *     .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
      * }
    * @@ -194,39 +201,48 @@ * written to must already exist. Unbounded PCollections can only be written using {@link * Write.WriteDisposition#WRITE_EMPTY} or {@link Write.WriteDisposition#WRITE_APPEND}. * - *

    Sharding BigQuery output tables

    + *

    Loading historical data into time-partitioned BigQuery tables

    * - *

    A common use case is to dynamically generate BigQuery table names based on the current window - * or the current value. To support this, {@link BigQueryIO.Write#to(SerializableFunction)} accepts - * a function mapping the current element to a tablespec. For example, here's code that outputs - * daily tables to BigQuery: + *

    To load historical data into a time-partitioned BigQuery table, specify {@link + * BigQueryIO.Write#withTimePartitioning} with a {@link TimePartitioning#setField(String) field} + * used for column-based + * partitioning. For example: * *

    {@code
    - * PCollection quotes = ...
    - * quotes.apply(Window.into(CalendarWindows.days(1)))
    - *       .apply(BigQueryIO.writeTableRows()
    + * PCollection quotes = ...;
    + *
    + * quotes.apply(BigQueryIO.write()
    + *         .withSchema(schema)
    + *         .withFormatFunction(quote -> new TableRow()
    + *            .set("timestamp", quote.getTimestamp())
    + *            .set(..other columns..))
    + *         .to("my-project:my_dataset.my_table")
    + *         .withTimePartitioning(new TimePartitioning().setField("time")));
    + * }
    + * + *

    Writing different values to different tables

    + * + *

    A common use case is to dynamically generate BigQuery table names based on the current value. + * To support this, {@link BigQueryIO.Write#to(SerializableFunction)} accepts a function mapping the + * current element to a tablespec. For example, here's code that outputs quotes of different stocks + * to different tables: + * + *

    {@code
    + * PCollection quotes = ...;
    + *
    + * quotes.apply(BigQueryIO.write()
      *         .withSchema(schema)
    - *         .to(new SerializableFunction, TableDestination>() {
    - *           public TableDestination apply(ValueInSingleWindow value) {
    - *             // The cast below is safe because CalendarWindows.days(1) produces IntervalWindows.
    - *             String dayString = DateTimeFormat.forPattern("yyyy_MM_dd")
    - *                  .withZone(DateTimeZone.UTC)
    - *                  .print(((IntervalWindow) value.getWindow()).start());
    + *         .withFormatFunction(quote -> new TableRow()...)
    + *         .to((ValueInSingleWindow quote) -> {
    + *             String symbol = quote.getSymbol();
      *             return new TableDestination(
    - *                 "my-project:output.output_table_" + dayString, // Table spec
    - *                 "Output for day " + dayString // Table description
    + *                 "my-project:my_dataset.quotes_" + symbol, // Table spec
    + *                 "Quotes of stock " + symbol // Table description
      *               );
    - *           }
    - *         }));
    + *           });
      * }
    * - *

    Note that this also allows the table to be a function of the element as well as the current - * pane, in the case of triggered windows. In this case it might be convenient to call {@link - * BigQueryIO#write()} directly instead of using the {@link BigQueryIO#writeTableRows()} helper. - * This will allow the mapping function to access the element of the user-defined type. In this - * case, a formatting function must be specified using {@link BigQueryIO.Write#withFormatFunction} - * to convert each element into a {@link TableRow} object. - * *

    Per-table schemas can also be provided using {@link BigQueryIO.Write#withSchemaFromView}. This * allows you the schemas to be calculated based on a previous pipeline stage or statically via a * {@link org.apache.beam.sdk.transforms.Create} transform. This method expects to receive a @@ -261,9 +277,9 @@ * loads involves writing temporary files to this location, so the location must be accessible at * pipeline execution time. By default, this location is captured at pipeline construction * time, may be inaccessible if the template may be reused from a different project or at a moment - * when the original location no longer exists. - * {@link Write#withCustomGcsTempLocation(ValueProvider)} allows specifying the location as an - * argument to the template invocation. + * when the original location no longer exists. {@link + * Write#withCustomGcsTempLocation(ValueProvider)} allows specifying the location as an argument to + * the template invocation. * *

    Permissions

    * @@ -347,10 +363,12 @@ public static TypedRead readTableRows() { * sample parse function that parses click events from a table. * *
    {@code
    +   * class ClickEvent { long userId; String url; ... }
    +   *
        * p.apply(BigQueryIO.read(new SerializableFunction() {
    -   *   public Event apply(SchemaAndRecord record) {
    +   *   public ClickEvent apply(SchemaAndRecord record) {
        *     GenericRecord r = record.getRecord();
    -   *     return new Event((Long) r.get("userId"), (String) r.get("url"));
    +   *     return new ClickEvent((Long) r.get("userId"), (String) r.get("url"));
        *   }
        * }).from("...");
        * }
    @@ -529,6 +547,7 @@ abstract static class Builder { abstract Builder setUseLegacySql(Boolean useLegacySql); abstract Builder setWithTemplateCompatibility(Boolean useTemplateCompatibility); abstract Builder setBigQueryServices(BigQueryServices bigQueryServices); + abstract Builder setPriority(Priority priority); abstract TypedRead build(); abstract Builder setParseFn( @@ -548,8 +567,36 @@ abstract Builder setParseFn( abstract SerializableFunction getParseFn(); + @Nullable abstract Priority getPriority(); + @Nullable abstract Coder getCoder(); + /** + * An enumeration type for the priority of a query. + * + * @see + * + * Running Interactive and Batch Queries in the BigQuery documentation + */ + public enum Priority { + /** + * Specifies that a query should be run with an INTERACTIVE priority. + * + *

    Interactive mode allows for BigQuery to execute the query as soon as possible. These + * queries count towards your concurrent rate limit and your daily limit. + */ + INTERACTIVE, + + /** + * Specifies that a query should be run with a BATCH priority. + * + *

    Batch mode queries are queued by BigQuery. These are started as soon as idle + * resources are available, usually within a few minutes. Batch queries don’t count + * towards your concurrent rate limit. + */ + BATCH + } + @VisibleForTesting Coder inferCoder(CoderRegistry coderRegistry) { if (getCoder() != null) { @@ -583,7 +630,8 @@ private BigQuerySourceBase createSource(String jobUuid, Coder coder) { getUseLegacySql(), getBigQueryServices(), coder, - getParseFn()); + getParseFn(), + getPriority()); } return source; } @@ -882,7 +930,17 @@ public TypedRead usingStandardSql() { return toBuilder().setUseLegacySql(false).build(); } - /** See {@link Read#withTemplateCompatibility()}. */ + /** See {@link Priority#INTERACTIVE}. */ + public TypedRead usingInteractivePriority() { + return toBuilder().setPriority(Priority.INTERACTIVE).build(); + } + + /** See {@link Priority#BATCH}. */ + public TypedRead usingBatchPriority() { + return toBuilder().setPriority(Priority.BATCH).build(); + } + + /** See {@link TypedRead#withTemplateCompatibility()}. */ @Experimental(Experimental.Kind.SOURCE_SINK) public TypedRead withTemplateCompatibility() { return toBuilder().setWithTemplateCompatibility(true).build(); @@ -971,6 +1029,9 @@ public static Write write() { /** * A {@link PTransform} that writes a {@link PCollection} containing {@link TableRow TableRows} to * a BigQuery table. + * + *

    It is recommended to instead use {@link #write} with {@link + * Write#withFormatFunction(SerializableFunction)}. */ public static Write writeTableRows() { return BigQueryIO.write().withFormatFunction(IDENTITY_FORMATTER); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java index df3be1534786..4b4684ebce16 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java @@ -34,6 +34,7 @@ import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.Status; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Priority; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService; import org.apache.beam.sdk.options.PipelineOptions; @@ -57,15 +58,17 @@ static BigQueryQuerySource create( Boolean useLegacySql, BigQueryServices bqServices, Coder coder, - SerializableFunction parseFn) { + SerializableFunction parseFn, + Priority priority) { return new BigQueryQuerySource<>( - stepUuid, query, flattenResults, useLegacySql, bqServices, coder, parseFn); + stepUuid, query, flattenResults, useLegacySql, bqServices, coder, parseFn, priority); } private final ValueProvider query; private final Boolean flattenResults; private final Boolean useLegacySql; private transient AtomicReference dryRunJobStats; + private final Priority priority; private BigQueryQuerySource( String stepUuid, @@ -74,12 +77,19 @@ private BigQueryQuerySource( Boolean useLegacySql, BigQueryServices bqServices, Coder coder, - SerializableFunction parseFn) { + SerializableFunction parseFn, + Priority priority) { super(stepUuid, bqServices, coder, parseFn); this.query = checkNotNull(query, "query"); this.flattenResults = checkNotNull(flattenResults, "flattenResults"); this.useLegacySql = checkNotNull(useLegacySql, "useLegacySql"); this.dryRunJobStats = new AtomicReference<>(); + if (priority != BigQueryIO.TypedRead.Priority.BATCH + || priority != BigQueryIO.TypedRead.Priority.INTERACTIVE) { + this.priority = BigQueryIO.TypedRead.Priority.BATCH; + } else { + this.priority = priority; + } } @Override @@ -174,7 +184,7 @@ private void executeQuery( .setAllowLargeResults(true) .setCreateDisposition("CREATE_IF_NEEDED") .setDestinationTable(destinationTable) - .setPriority("BATCH") + .setPriority(this.priority.name()) .setWriteDisposition("WRITE_EMPTY"); jobService.startQueryJob(jobRef, queryConfig); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java index 6e470a8d8c14..b0a26bf8d97d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinationsHelpers.java @@ -135,6 +135,7 @@ static class DelegatingDynamicDestinations DelegatingDynamicDestinations(DynamicDestinations inner) { this.inner = inner; } + @Override public DestinationT getDestination(ValueInSingleWindow element) { return inner.getDestination(element); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java index a21085864032..b7060a2b0cdf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java @@ -29,7 +29,6 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.SinkMetrics; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.SystemDoFnInternal; import org.apache.beam.sdk.values.KV; @@ -109,11 +108,6 @@ public void finishBundle(FinishBundleContext context) throws Exception { } } - @Override - public void populateDisplayData(DisplayData.Builder builder) { - super.populateDisplayData(builder); - } - /** * Writes the accumulated rows into BigQuery with streaming API. */ diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java index 51b9375a587f..d58875f74e9e 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.util.UUID; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.ShardedKey; @@ -56,8 +55,4 @@ public void processElement(ProcessContext context, BoundedWindow window) throws context.element().getKey(), new TableRowInfo(context.element().getValue(), uniqueId))); } - @Override - public void populateDisplayData(DisplayData.Builder builder) { - super.populateDisplayData(builder); - } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java index 017d5c15ceab..68eebf63ce54 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java @@ -44,8 +44,6 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Writes each bundle of {@link TableRow} elements out to separate file using {@link @@ -57,7 +55,6 @@ */ class WriteBundlesToFiles extends DoFn, Result> { - private static final Logger LOG = LoggerFactory.getLogger(WriteBundlesToFiles.class); // When we spill records, shard the output keys to prevent hotspots. Experiments running up to // 10TB of data have shown a sharding of 10 to be a good choice. diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java index e82b29d3d09b..cc1df928225d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java @@ -23,8 +23,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.ShardedKey; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Receives elements grouped by their (sharded) destination, and writes them out to a file. @@ -34,7 +32,6 @@ class WriteGroupedRecordsToFiles extends DoFn, Iterable>, WriteBundlesToFiles.Result> { - private static final Logger LOG = LoggerFactory.getLogger(WriteGroupedRecordsToFiles.class); private final PCollectionView tempFilePrefix; private final long maxFileSize; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index 4cad3b16e242..74eef6a9ca75 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -61,8 +61,6 @@ import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Writes partitions to BigQuery tables. @@ -79,7 +77,6 @@ class WriteTables extends PTransform, List>>, PCollection>> { - private static final Logger LOG = LoggerFactory.getLogger(WriteTables.class); private final boolean singlePartition; private final BigQueryServices bqServices; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index 1de30de5c0e3..4e602699e580 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -22,7 +22,6 @@ import static com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; -import com.google.bigtable.v2.MutateRowResponse; import com.google.bigtable.v2.Mutation; import com.google.bigtable.v2.Row; import com.google.bigtable.v2.RowFilter; @@ -33,9 +32,6 @@ import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; import java.io.IOException; import java.util.Arrays; @@ -705,10 +701,14 @@ public void startBundle(StartBundleContext c) throws IOException { @ProcessElement public void processElement(ProcessContext c) throws Exception { checkForFailures(); - Futures.addCallback( - bigtableWriter.writeRecord(c.element()), - new WriteExceptionCallback(c.element()), - MoreExecutors.directExecutor()); + bigtableWriter + .writeRecord(c.element()) + .whenComplete( + (mutationResult, exception) -> { + if (exception != null) { + failures.add(new BigtableWriteException(c.element(), exception)); + } + }); ++recordsWritten; } @@ -716,7 +716,7 @@ public void processElement(ProcessContext c) throws Exception { public void finishBundle() throws Exception { bigtableWriter.flush(); checkForFailures(); - LOG.info("Wrote {} records", recordsWritten); + LOG.debug("Wrote {} records", recordsWritten); } @Teardown @@ -772,22 +772,6 @@ private void checkForFailures() throws IOException { } throw exception; } - - private class WriteExceptionCallback implements FutureCallback { - private final KV> value; - - public WriteExceptionCallback(KV> value) { - this.value = value; - } - - @Override - public void onFailure(Throwable cause) { - failures.add(new BigtableWriteException(value, cause)); - } - - @Override - public void onSuccess(MutateRowResponse produced) {} - } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableService.java index ecb7b32b0300..1c9fffff5a78 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableService.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableService.java @@ -22,12 +22,12 @@ import com.google.bigtable.v2.Row; import com.google.bigtable.v2.SampleRowKeysResponse; import com.google.cloud.bigtable.config.BigtableOptions; -import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.ByteString; import java.io.IOException; import java.io.Serializable; import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.CompletionStage; import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource; import org.apache.beam.sdk.values.KV; @@ -47,7 +47,7 @@ interface Writer { * * @throws IOException if there is an error submitting the write. */ - ListenableFuture writeRecord(KV> record) + CompletionStage writeRecord(KV> record) throws IOException; /** diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java index fe25e20b5d0a..b9492b3ab4e3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java @@ -35,13 +35,16 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.io.Closer; -import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; import com.google.protobuf.ByteString; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import java.io.IOException; import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource; import org.apache.beam.sdk.io.range.ByteKeyRange; import org.apache.beam.sdk.values.KV; @@ -216,14 +219,29 @@ public void close() throws IOException { } @Override - public ListenableFuture writeRecord( + public CompletionStage writeRecord( KV> record) throws IOException { MutateRowsRequest.Entry request = MutateRowsRequest.Entry.newBuilder() .setRowKey(record.getKey()) .addAllMutations(record.getValue()) .build(); - return bulkMutation.add(request); + + CompletableFuture result = new CompletableFuture<>(); + Futures.addCallback( + bulkMutation.add(request), + new FutureCallback() { + @Override + public void onSuccess(MutateRowResponse mutateRowResponse) { + result.complete(mutateRowResponse); + } + + @Override + public void onFailure(Throwable throwable) { + result.completeExceptionally(throwable); + } + }); + return result; } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java index 5b08da2f2536..5f9fd5ed4dc2 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java @@ -19,7 +19,6 @@ import com.google.cloud.spanner.Mutation; import com.google.common.collect.ImmutableList; - import java.io.Serializable; import java.util.Arrays; import java.util.Iterator; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 45cc842725a5..de0d02d9c205 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -293,7 +293,7 @@ public ReadAll withDatabaseId(String databaseId) { } /** Specifies the Cloud Spanner host. */ - public ReadAll witHost(String host) { + public ReadAll withHost(String host) { SpannerConfig config = getSpannerConfig(); return withSpannerConfig(config.withHost(host)); } @@ -399,7 +399,7 @@ public Read withDatabaseId(ValueProvider databaseId) { } /** Specifies the Cloud Spanner host. */ - public Read witHost(String host) { + public Read withHost(String host) { SpannerConfig config = getSpannerConfig(); return withSpannerConfig(config.withHost(host)); } @@ -559,7 +559,7 @@ public CreateTransaction withDatabaseId(ValueProvider databaseId) { } /** Specifies the Cloud Spanner host. */ - public CreateTransaction witHost(String host) { + public CreateTransaction withHost(String host) { SpannerConfig config = getSpannerConfig(); return withSpannerConfig(config.withHost(host)); } @@ -663,7 +663,7 @@ public Write withDatabaseId(ValueProvider databaseId) { } /** Specifies the Cloud Spanner host. */ - public Write witHost(String host) { + public Write withHost(String host) { SpannerConfig config = getSpannerConfig(); return withSpannerConfig(config.withHost(host)); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java index 0fd41e8fafbe..2a9d5ca3fd7c 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java @@ -354,7 +354,8 @@ private void testReadFromTable(boolean useTemplateCompatibility, boolean useRead BigQueryIO.readTableRows() .from("non-executing-project:somedataset.sometable") .withTestServices(fakeBqServices) - .withoutValidation(); + .withoutValidation() + .usingInteractivePriority(); readTransform = useTemplateCompatibility ? read.withTemplateCompatibility() : read; } PCollection> output = @@ -376,6 +377,124 @@ public void processElement(ProcessContext c) throws Exception { p.run(); } + @Test + public void testReadFromTableInteractive() + throws IOException, InterruptedException { + Table sometable = new Table(); + sometable.setSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))); + sometable.setTableReference( + new TableReference() + .setProjectId("non-executing-project") + .setDatasetId("somedataset") + .setTableId("sometable")); + sometable.setNumBytes(1024L * 1024L); + FakeDatasetService fakeDatasetService = new FakeDatasetService(); + fakeDatasetService.createDataset("non-executing-project", "somedataset", "", "", null); + fakeDatasetService.createTable(sometable); + + List records = Lists.newArrayList( + new TableRow().set("name", "a").set("number", 1L), + new TableRow().set("name", "b").set("number", 2L), + new TableRow().set("name", "c").set("number", 3L)); + fakeDatasetService.insertAll(sometable.getTableReference(), records, null); + + FakeBigQueryServices fakeBqServices = new FakeBigQueryServices() + .withJobService(new FakeJobService()) + .withDatasetService(fakeDatasetService); + + PTransform> readTransform; + BigQueryIO.TypedRead read = + BigQueryIO.readTableRows() + .from("non-executing-project:somedataset.sometable") + .withTestServices(fakeBqServices) + .withoutValidation() + .usingInteractivePriority(); + readTransform = read; + + PCollection> output = + p.apply(readTransform) + .apply( + ParDo.of( + new DoFn>() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output( + KV.of( + (String) c.element().get("name"), + Long.valueOf((String) c.element().get("number")))); + } + })); + + PAssert.that(output) + .containsInAnyOrder(ImmutableList.of(KV.of("a", 1L), KV.of("b", 2L), KV.of("c", 3L))); + assertEquals(read.getPriority(), BigQueryIO.TypedRead.Priority.INTERACTIVE); + p.run(); + } + + @Test + public void testReadFromTableBatch() + throws IOException, InterruptedException { + Table sometable = new Table(); + sometable.setSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))); + sometable.setTableReference( + new TableReference() + .setProjectId("non-executing-project") + .setDatasetId("somedataset") + .setTableId("sometable")); + sometable.setNumBytes(1024L * 1024L); + FakeDatasetService fakeDatasetService = new FakeDatasetService(); + fakeDatasetService.createDataset("non-executing-project", "somedataset", "", "", null); + fakeDatasetService.createTable(sometable); + + List records = Lists.newArrayList( + new TableRow().set("name", "a").set("number", 1L), + new TableRow().set("name", "b").set("number", 2L), + new TableRow().set("name", "c").set("number", 3L)); + fakeDatasetService.insertAll(sometable.getTableReference(), records, null); + + FakeBigQueryServices fakeBqServices = new FakeBigQueryServices() + .withJobService(new FakeJobService()) + .withDatasetService(fakeDatasetService); + + PTransform> readTransform; + BigQueryIO.TypedRead read = + BigQueryIO.readTableRows() + .from("non-executing-project:somedataset.sometable") + .withTestServices(fakeBqServices) + .withoutValidation() + .usingBatchPriority(); + readTransform = read; + + PCollection> output = + p.apply(readTransform) + .apply( + ParDo.of( + new DoFn>() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output( + KV.of( + (String) c.element().get("name"), + Long.valueOf((String) c.element().get("number")))); + } + })); + + PAssert.that(output) + .containsInAnyOrder(ImmutableList.of(KV.of("a", 1L), KV.of("b", 2L), KV.of("c", 3L))); + assertEquals(read.getPriority(), BigQueryIO.TypedRead.Priority.BATCH); + p.run(); + } + @Test public void testBuildSourceDisplayDataTable() { String tableSpec = "project:dataset.tableid"; @@ -612,7 +731,8 @@ public void testBigQueryQuerySourceInitSplit() throws Exception { true /* useLegacySql */, fakeBqServices, TableRowJsonCoder.of(), - BigQueryIO.TableRowParser.INSTANCE); + BigQueryIO.TableRowParser.INSTANCE, + BigQueryIO.TypedRead.Priority.BATCH); options.setTempLocation(testFolder.getRoot().getAbsolutePath()); TableReference queryTable = new TableReference() @@ -690,7 +810,8 @@ public void testBigQueryNoTableQuerySourceInitSplit() throws Exception { true /* useLegacySql */, fakeBqServices, TableRowJsonCoder.of(), - BigQueryIO.TableRowParser.INSTANCE); + BigQueryIO.TableRowParser.INSTANCE, + BigQueryIO.TypedRead.Priority.BATCH); options.setTempLocation(testFolder.getRoot().getAbsolutePath()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowJsonCoderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowJsonCoderTest.java index 1fb97f5e4fdb..a4606ec5945e 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowJsonCoderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowJsonCoderTest.java @@ -41,6 +41,7 @@ private static class TableRowBuilder { public TableRowBuilder() { row = new TableRow(); } + public TableRowBuilder set(String fieldName, Object value) { row.set(fieldName, value); return this; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java index efe629293ffb..9f98371bdf83 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java @@ -56,8 +56,6 @@ import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.ByteString; import java.io.IOException; import java.io.Serializable; @@ -72,6 +70,8 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; @@ -825,7 +825,7 @@ public void testWriting() throws Exception { .apply("write", defaultWrite.withTableId(table)); p.run(); - logged.verifyInfo("Wrote 1 records"); + logged.verifyDebug("Wrote 1 records"); assertEquals(1, service.tables.size()); assertNotNull(service.getTable(table)); @@ -1180,7 +1180,7 @@ public void close() { * entries. The column family in the {@link SetCell} is ignored; only the value is used. * *

    When no {@link SetCell} is provided, the write will fail and this will be exposed via an - * exception on the returned {@link ListenableFuture}. + * exception on the returned {@link CompletionStage}. */ private static class FakeBigtableWriter implements BigtableService.Writer { private final String tableId; @@ -1190,7 +1190,7 @@ public FakeBigtableWriter(String tableId) { } @Override - public ListenableFuture writeRecord( + public CompletionStage writeRecord( KV> record) { service.verifyTableExists(tableId); Map table = service.getTable(tableId); @@ -1198,11 +1198,13 @@ public ListenableFuture writeRecord( for (Mutation m : record.getValue()) { SetCell cell = m.getSetCell(); if (cell.getValue().isEmpty()) { - return Futures.immediateFailedCheckedFuture(new IOException("cell value missing")); + CompletableFuture result = new CompletableFuture<>(); + result.completeExceptionally(new IOException("cell value missing")); + return result; } table.put(key, cell.getValue()); } - return Futures.immediateFuture(MutateRowResponse.getDefaultInstance()); + return CompletableFuture.completedFuture(MutateRowResponse.getDefaultInstance()); } @Override diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java index 1494bd573dce..fb56ee49d176 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.bigtable.v2.MutateRowResponse; import com.google.bigtable.v2.MutateRowsRequest; import com.google.bigtable.v2.MutateRowsRequest.Entry; import com.google.bigtable.v2.Mutation; @@ -36,10 +37,11 @@ import com.google.cloud.bigtable.grpc.BigtableTableName; import com.google.cloud.bigtable.grpc.async.BulkMutation; import com.google.cloud.bigtable.grpc.scanner.ResultScanner; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.ByteString; import java.io.IOException; import java.util.Arrays; - import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource; import org.apache.beam.sdk.io.range.ByteKey; import org.apache.beam.sdk.io.range.ByteKeyRange; @@ -133,7 +135,11 @@ public void testWrite() throws IOException, InterruptedException { Mutation mutation = Mutation.newBuilder() .setSetCell(SetCell.newBuilder().setFamilyName("Family").build()).build(); ByteString key = ByteString.copyFromUtf8("key"); - underTest.writeRecord(KV.of(key, (Iterable) Arrays.asList(mutation))); + + SettableFuture fakeResponse = SettableFuture.create(); + when(mockBulkMutation.add(any(MutateRowsRequest.Entry.class))).thenReturn(fakeResponse); + + underTest.writeRecord(KV.of(key, ImmutableList.of(mutation))); Entry expected = MutateRowsRequest.Entry.newBuilder() .setRowKey(key) .addMutations(mutation) diff --git a/sdks/java/io/hadoop-common/pom.xml b/sdks/java/io/hadoop-common/pom.xml index b90242d3719f..a9c181b43055 100644 --- a/sdks/java/io/hadoop-common/pom.xml +++ b/sdks/java/io/hadoop-common/pom.xml @@ -68,10 +68,16 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + junit junit diff --git a/sdks/java/io/hadoop-file-system/pom.xml b/sdks/java/io/hadoop-file-system/pom.xml index 9c55b7498c69..e890cf675762 100644 --- a/sdks/java/io/hadoop-file-system/pom.xml +++ b/sdks/java/io/hadoop-file-system/pom.xml @@ -126,13 +126,19 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + + org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/io/hadoop-input-format/build.gradle b/sdks/java/io/hadoop-input-format/build.gradle index ba4315c93160..7c2d8d60d9ce 100644 --- a/sdks/java/io/hadoop-input-format/build.gradle +++ b/sdks/java/io/hadoop-input-format/build.gradle @@ -61,6 +61,10 @@ dependencies { testCompile "io.netty:netty-transport-native-epoll:4.1.0.CR3" testCompile "org.elasticsearch:elasticsearch:$elastic_search_version" testCompile ("org.elasticsearch:elasticsearch-hadoop:$elastic_search_version") { + // TODO(https://issues.apache.org/jira/browse/BEAM-3715) + // These are all optional deps of elasticsearch-hadoop. Why do they have to be excluded? + exclude group: "cascading", module: "cascading-local" + exclude group: "org.apache.hive", module: "hive-service" exclude group: "org.apache.spark", module: "spark-core_2.10" exclude group: "org.apache.spark", module: "spark-streaming_2.10" exclude group: "org.apache.spark", module: "spark-sql_2.10" diff --git a/sdks/java/io/hadoop-input-format/pom.xml b/sdks/java/io/hadoop-input-format/pom.xml index f998ac8d8d76..e5d00f28ebb1 100644 --- a/sdks/java/io/hadoop-input-format/pom.xml +++ b/sdks/java/io/hadoop-input-format/pom.xml @@ -39,13 +39,6 @@ none - - org.apache.maven.plugins - maven-deploy-plugin - - true - - @@ -75,6 +68,9 @@ spark-runner + + 4.0.43.Final + org.apache.beam @@ -427,7 +423,12 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test @@ -443,7 +444,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/sdks/java/io/hadoop-input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java b/sdks/java/io/hadoop-input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java index f3dba7e1967f..508c7bb74002 100644 --- a/sdks/java/io/hadoop-input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java +++ b/sdks/java/io/hadoop-input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java @@ -316,11 +316,9 @@ void validateTransform() { */ private void validateTranslationFunction(TypeDescriptor inputType, SimpleFunction simpleFunction, String errorMsg) { - if (simpleFunction != null) { - if (!simpleFunction.getInputTypeDescriptor().equals(inputType)) { - throw new IllegalArgumentException( - String.format(errorMsg, getinputFormatClass().getRawType(), inputType.getRawType())); - } + if (simpleFunction != null && !simpleFunction.getInputTypeDescriptor().equals(inputType)) { + throw new IllegalArgumentException( + String.format(errorMsg, getinputFormatClass().getRawType(), inputType.getRawType())); } } diff --git a/sdks/java/io/hbase/pom.xml b/sdks/java/io/hbase/pom.xml index 7e7dd0a778a9..e36383496b2a 100644 --- a/sdks/java/io/hbase/pom.xml +++ b/sdks/java/io/hbase/pom.xml @@ -171,7 +171,12 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java index f1fdea833a00..41350e5b2ada 100644 --- a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java @@ -30,6 +30,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.hbase.HBaseIO.HBaseSource; @@ -72,6 +73,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.rules.ExternalResource; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -80,6 +82,7 @@ public class HBaseIOTest { @Rule public final transient TestPipeline p = TestPipeline.create(); @Rule public ExpectedException thrown = ExpectedException.none(); + @Rule public TemporaryHBaseTable tmpTable = new TemporaryHBaseTable(); private static HBaseTestingUtility htu; private static HBaseAdmin admin; @@ -164,7 +167,7 @@ public void testWriteValidationFailsMissingConfiguration() { /** Tests that when reading from a non-existent table, the read fails. */ @Test public void testReadingFailsTableDoesNotExist() throws Exception { - final String table = "TEST-TABLE-INVALID"; + final String table = tmpTable.getName(); // Exception will be thrown by read.expand() when read is applied. thrown.expect(IllegalArgumentException.class); thrown.expectMessage(String.format("Table %s does not exist", table)); @@ -174,14 +177,14 @@ public void testReadingFailsTableDoesNotExist() throws Exception { /** Tests that when reading from an empty table, the read succeeds. */ @Test public void testReadingEmptyTable() throws Exception { - final String table = "TEST-EMPTY-TABLE"; + final String table = tmpTable.getName(); createTable(table); runReadTest(HBaseIO.read().withConfiguration(conf).withTableId(table), new ArrayList<>()); } @Test public void testReading() throws Exception { - final String table = "TEST-MANY-ROWS-TABLE"; + final String table = tmpTable.getName(); final int numRows = 1001; createTable(table); writeData(table, numRows); @@ -191,7 +194,7 @@ public void testReading() throws Exception { /** Tests reading all rows from a split table. */ @Test public void testReadingWithSplits() throws Exception { - final String table = "TEST-MANY-ROWS-SPLITS-TABLE"; + final String table = tmpTable.getName(); final int numRows = 1500; final int numRegions = 4; final long bytesPerRow = 100L; @@ -213,7 +216,7 @@ public void testReadingWithSplits() throws Exception { /** Tests that a {@link HBaseSource} can be read twice, verifying its immutability. */ @Test public void testReadingSourceTwice() throws Exception { - final String table = "TEST-READING-TWICE"; + final String table = tmpTable.getName(); final int numRows = 10; // Set up test table data and sample row keys for size estimation and splitting. @@ -230,7 +233,7 @@ public void testReadingSourceTwice() throws Exception { /** Tests reading all rows using a filter. */ @Test public void testReadingWithFilter() throws Exception { - final String table = "TEST-FILTER-TABLE"; + final String table = tmpTable.getName(); final int numRows = 1001; createTable(table); @@ -248,8 +251,8 @@ public void testReadingWithFilter() throws Exception { * [] and that some properties hold across them. */ @Test - public void testReadingWithKeyRange() throws Exception { - final String table = "TEST-KEY-RANGE-TABLE"; + public void testReadingKeyRangePrefix() throws Exception { + final String table = tmpTable.getName(); final int numRows = 1001; final byte[] startRow = "2".getBytes(); final byte[] stopRow = "9".getBytes(); @@ -262,11 +265,43 @@ public void testReadingWithKeyRange() throws Exception { final ByteKeyRange prefixRange = ByteKeyRange.ALL_KEYS.withEndKey(startKey); runReadTestLength( HBaseIO.read().withConfiguration(conf).withTableId(table).withKeyRange(prefixRange), 126); + } + + /** + * Tests reading all rows using key ranges. Tests a prefix [), a suffix (], and a restricted range + * [] and that some properties hold across them. + */ + @Test + public void testReadingKeyRangeSuffix() throws Exception { + final String table = tmpTable.getName(); + final int numRows = 1001; + final byte[] startRow = "2".getBytes(); + final byte[] stopRow = "9".getBytes(); + final ByteKey startKey = ByteKey.copyFrom(startRow); + + createTable(table); + writeData(table, numRows); // Test suffix: [startKey, end). final ByteKeyRange suffixRange = ByteKeyRange.ALL_KEYS.withStartKey(startKey); runReadTestLength( HBaseIO.read().withConfiguration(conf).withTableId(table).withKeyRange(suffixRange), 875); + } + + /** + * Tests reading all rows using key ranges. Tests a prefix [), a suffix (], and a restricted range + * [] and that some properties hold across them. + */ + @Test + public void testReadingKeyRangeMiddle() throws Exception { + final String table = tmpTable.getName(); + final int numRows = 1001; + final byte[] startRow = "2".getBytes(); + final byte[] stopRow = "9".getBytes(); + final ByteKey startKey = ByteKey.copyFrom(startRow); + + createTable(table); + writeData(table, numRows); // Test restricted range: [startKey, endKey). // This one tests the second signature of .withKeyRange @@ -278,7 +313,7 @@ public void testReadingWithKeyRange() throws Exception { /** Tests dynamic work rebalancing exhaustively. */ @Test public void testReadingSplitAtFractionExhaustive() throws Exception { - final String table = "TEST-FEW-ROWS-SPLIT-EXHAUSTIVE-TABLE"; + final String table = tmpTable.getName(); final int numRows = 7; createTable(table); @@ -296,7 +331,7 @@ public void testReadingSplitAtFractionExhaustive() throws Exception { /** Unit tests of splitAtFraction. */ @Test public void testReadingSplitAtFraction() throws Exception { - final String table = "TEST-SPLIT-AT-FRACTION"; + final String table = tmpTable.getName(); final int numRows = 10; createTable(table); @@ -335,7 +370,7 @@ public void testReadingDisplayData() { /** Tests that a record gets written to the service and messages are logged. */ @Test public void testWriting() throws Exception { - final String table = "table"; + final String table = tmpTable.getName(); final String key = "key"; final String value = "value"; final int numMutations = 100; @@ -353,7 +388,7 @@ public void testWriting() throws Exception { /** Tests that when writing to a non-existent table, the write fails. */ @Test public void testWritingFailsTableDoesNotExist() throws Exception { - final String table = "TEST-TABLE-DOES-NOT-EXIST"; + final String table = tmpTable.getName(); // Exception will be thrown by write.expand() when writeToDynamic is applied. thrown.expect(IllegalArgumentException.class); @@ -365,7 +400,7 @@ public void testWritingFailsTableDoesNotExist() throws Exception { /** Tests that when writing an element fails, the write fails. */ @Test public void testWritingFailsBadElement() throws Exception { - final String table = "TEST-TABLE-BAD-ELEMENT"; + final String table = tmpTable.getName(); final String key = "KEY"; createTable(table); @@ -380,9 +415,10 @@ public void testWritingFailsBadElement() throws Exception { @Test public void testWritingDisplayData() { - HBaseIO.Write write = HBaseIO.write().withTableId("fooTable").withConfiguration(conf); + final String table = tmpTable.getName(); + HBaseIO.Write write = HBaseIO.write().withTableId(table).withConfiguration(conf); DisplayData displayData = DisplayData.from(write); - assertThat(displayData, hasDisplayItem("tableId", "fooTable")); + assertThat(displayData, hasDisplayItem("tableId", table)); } // HBase helper methods @@ -477,4 +513,17 @@ private void runReadTestLength(HBaseIO.Read read, long numElements) { .isEqualTo(numElements); p.run().waitUntilFinish(); } + + private class TemporaryHBaseTable extends ExternalResource { + + private String name; + + @Override protected void before() { + name = "table_" + UUID.randomUUID(); + } + + public String getName() { + return name; + } + } } diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/SerializableScanTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/SerializableScanTest.java index bd2faba22e7a..8692d4087450 100644 --- a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/SerializableScanTest.java +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/SerializableScanTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertNotNull; import java.nio.charset.StandardCharsets; - import org.apache.commons.lang3.SerializationUtils; import org.apache.hadoop.hbase.client.Scan; import org.junit.Rule; diff --git a/sdks/java/io/hcatalog/pom.xml b/sdks/java/io/hcatalog/pom.xml index 00168508ab57..f5c15ba508e6 100644 --- a/sdks/java/io/hcatalog/pom.xml +++ b/sdks/java/io/hcatalog/pom.xml @@ -172,8 +172,13 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test - \ No newline at end of file + diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index efbc84d59d70..503993ae4f29 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -35,6 +35,7 @@ dependencies { shadow project(path: ":sdks:java:core", configuration: "shadow") shadow library.java.findbugs_jsr305 shadow "org.apache.commons:commons-dbcp2:2.1.1" + testCompile project(path: ":sdks:java:core", configuration: "shadowTest") testCompile project(path: ":runners:direct-java", configuration: "shadow") testCompile project(path: ":sdks:java:io:common", configuration: "shadow") testCompile project(":sdks:java:io:common").sourceSets.test.output diff --git a/sdks/java/io/jdbc/pom.xml b/sdks/java/io/jdbc/pom.xml index 6e0fc3e06409..2a0d3528897a 100644 --- a/sdks/java/io/jdbc/pom.xml +++ b/sdks/java/io/jdbc/pom.xml @@ -40,6 +40,9 @@ spark-runner + + 4.0.43.Final + org.apache.beam @@ -263,6 +266,11 @@ guava + + org.slf4j + slf4j-api + + com.google.code.findbugs jsr305 @@ -274,6 +282,11 @@ 2.1.1 + + joda-time + joda-time + + com.google.auto.value @@ -312,14 +325,14 @@ org.hamcrest - hamcrest-all + hamcrest-core test - org.slf4j - slf4j-api + org.hamcrest + hamcrest-library test - + org.slf4j slf4j-jdk14 @@ -330,6 +343,12 @@ postgresql test + + org.apache.beam + beam-sdks-java-core + test + tests + org.apache.beam beam-sdks-java-io-common diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 8b47aa9d985a..f7a66045886b 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -20,11 +20,15 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; +import java.io.IOException; import java.io.Serializable; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; import javax.annotation.Nullable; import javax.sql.DataSource; import org.apache.beam.sdk.annotations.Experimental; @@ -39,11 +43,18 @@ import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.commons.dbcp2.BasicDataSource; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * IO to read and write data on JDBC. @@ -134,6 +145,9 @@ */ @Experimental(Experimental.Kind.SOURCE_SINK) public class JdbcIO { + + private static final Logger LOG = LoggerFactory.getLogger(JdbcIO.class); + /** * Read data from a JDBC datasource. * @@ -164,9 +178,22 @@ public static ReadAll readAll() { public static Write write() { return new AutoValue_JdbcIO_Write.Builder() .setBatchSize(DEFAULT_BATCH_SIZE) + .setRetryStrategy(new DefaultRetryStrategy()) .build(); } + /** + * This is the default {@link Predicate} we use to detect DeadLock. + * It basically test if the {@link SQLException#getSQLState()} equals 40001. + * 40001 is the SQL State used by most of database to identify deadlock. + */ + public static class DefaultRetryStrategy implements RetryStrategy { + @Override + public boolean apply(SQLException e) { + return (e.getSQLState().equals("40001")); + } + } + private JdbcIO() {} /** @@ -184,22 +211,22 @@ public interface RowMapper extends Serializable { */ @AutoValue public abstract static class DataSourceConfiguration implements Serializable { - @Nullable abstract String getDriverClassName(); - @Nullable abstract String getUrl(); - @Nullable abstract String getUsername(); - @Nullable abstract String getPassword(); - @Nullable abstract String getConnectionProperties(); + @Nullable abstract ValueProvider getDriverClassName(); + @Nullable abstract ValueProvider getUrl(); + @Nullable abstract ValueProvider getUsername(); + @Nullable abstract ValueProvider getPassword(); + @Nullable abstract ValueProvider getConnectionProperties(); @Nullable abstract DataSource getDataSource(); abstract Builder builder(); @AutoValue.Builder abstract static class Builder { - abstract Builder setDriverClassName(String driverClassName); - abstract Builder setUrl(String url); - abstract Builder setUsername(String username); - abstract Builder setPassword(String password); - abstract Builder setConnectionProperties(String connectionProperties); + abstract Builder setDriverClassName(ValueProvider driverClassName); + abstract Builder setUrl(ValueProvider url); + abstract Builder setUsername(ValueProvider username); + abstract Builder setPassword(ValueProvider password); + abstract Builder setConnectionProperties(ValueProvider connectionProperties); abstract Builder setDataSource(DataSource dataSource); abstract DataSourceConfiguration build(); } @@ -216,16 +243,34 @@ public static DataSourceConfiguration create(String driverClassName, String url) checkArgument(driverClassName != null, "driverClassName can not be null"); checkArgument(url != null, "url can not be null"); return new AutoValue_JdbcIO_DataSourceConfiguration.Builder() - .setDriverClassName(driverClassName) - .setUrl(url) + .setDriverClassName(ValueProvider.StaticValueProvider.of(driverClassName)) + .setUrl(ValueProvider.StaticValueProvider.of(url)) .build(); } + public static DataSourceConfiguration create(ValueProvider driverClassName, + ValueProvider url) { + checkArgument(driverClassName != null, "driverClassName can not be null"); + checkArgument(url != null, "url can not be null"); + return new AutoValue_JdbcIO_DataSourceConfiguration.Builder() + .setDriverClassName(driverClassName) + .setUrl(url) + .build(); + } + public DataSourceConfiguration withUsername(String username) { + return builder().setUsername(ValueProvider.StaticValueProvider.of(username)).build(); + } + + public DataSourceConfiguration withUsername(ValueProvider username) { return builder().setUsername(username).build(); } public DataSourceConfiguration withPassword(String password) { + return builder().setPassword(ValueProvider.StaticValueProvider.of(password)).build(); + } + + public DataSourceConfiguration withPassword(ValueProvider password) { return builder().setPassword(password).build(); } @@ -238,6 +283,17 @@ public DataSourceConfiguration withPassword(String password) { */ public DataSourceConfiguration withConnectionProperties(String connectionProperties) { checkArgument(connectionProperties != null, "connectionProperties can not be null"); + return builder() + .setConnectionProperties(ValueProvider.StaticValueProvider.of(connectionProperties)) + .build(); + } + + /** + * Same as {@link #withConnectionProperties(String)} but accepting a ValueProvider. + */ + public DataSourceConfiguration withConnectionProperties( + ValueProvider connectionProperties) { + checkArgument(connectionProperties != null, "connectionProperties can not be null"); return builder().setConnectionProperties(connectionProperties).build(); } @@ -256,12 +312,20 @@ DataSource buildDatasource() throws Exception{ return getDataSource(); } else { BasicDataSource basicDataSource = new BasicDataSource(); - basicDataSource.setDriverClassName(getDriverClassName()); - basicDataSource.setUrl(getUrl()); - basicDataSource.setUsername(getUsername()); - basicDataSource.setPassword(getPassword()); - if (getConnectionProperties() != null) { - basicDataSource.setConnectionProperties(getConnectionProperties()); + if (getDriverClassName() != null) { + basicDataSource.setDriverClassName(getDriverClassName().get()); + } + if (getUrl() != null) { + basicDataSource.setUrl(getUrl().get()); + } + if (getUsername() != null) { + basicDataSource.setUsername(getUsername().get()); + } + if (getPassword() != null) { + basicDataSource.setPassword(getPassword().get()); + } + if (getConnectionProperties() != null && getConnectionProperties().get() != null) { + basicDataSource.setConnectionProperties(getConnectionProperties().get()); } return basicDataSource; } @@ -506,6 +570,16 @@ public interface PreparedStatementSetter extends Serializable { void setParameters(T element, PreparedStatement preparedStatement) throws Exception; } + /** + * An interface used to control if we retry the statements when a {@link SQLException} occurs. + * If {@link RetryStrategy#apply(SQLException)} returns true, {@link Write} tries + * to replay the statements. + */ + @FunctionalInterface + public interface RetryStrategy extends Serializable { + boolean apply(SQLException sqlException); + } + /** A {@link PTransform} to write to a JDBC datasource. */ @AutoValue public abstract static class Write extends PTransform, PDone> { @@ -513,6 +587,7 @@ public abstract static class Write extends PTransform, PDone> @Nullable abstract String getStatement(); abstract long getBatchSize(); @Nullable abstract PreparedStatementSetter getPreparedStatementSetter(); + @Nullable abstract RetryStrategy getRetryStrategy(); abstract Builder toBuilder(); @@ -522,6 +597,7 @@ abstract static class Builder { abstract Builder setStatement(String statement); abstract Builder setBatchSize(long batchSize); abstract Builder setPreparedStatementSetter(PreparedStatementSetter setter); + abstract Builder setRetryStrategy(RetryStrategy deadlockPredicate); abstract Write build(); } @@ -540,13 +616,23 @@ public Write withPreparedStatementSetter(PreparedStatementSetter setter) { * Provide a maximum size in number of SQL statenebt for the batch. Default is 1000. * * @param batchSize maximum batch size in number of statements - * @return the {@link Write} with connection batch size set */ public Write withBatchSize(long batchSize) { checkArgument(batchSize > 0, "batchSize must be > 0, but was %d", batchSize); return toBuilder().setBatchSize(batchSize).build(); } + /** + * When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine + * if it will retry the statements. + * If {@link RetryStrategy#apply(SQLException)} returns {@code true}, + * then {@link Write} retries the statements. + */ + public Write withRetryStrategy(RetryStrategy retryStrategy) { + checkArgument(retryStrategy != null, "retryStrategy can not be null"); + return toBuilder().setRetryStrategy(retryStrategy).build(); + } + @Override public PDone expand(PCollection input) { checkArgument( @@ -563,10 +649,14 @@ private static class WriteFn extends DoFn { private final Write spec; + private static final int MAX_RETRIES = 5; + private static final FluentBackoff BUNDLE_WRITE_BACKOFF = + FluentBackoff.DEFAULT + .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5)); + private DataSource dataSource; private Connection connection; - private PreparedStatement preparedStatement; - private int batchCount; + private List records = new ArrayList<>(); public WriteFn(Write spec) { this.spec = spec; @@ -577,55 +667,78 @@ public void setup() throws Exception { dataSource = spec.getDataSourceConfiguration().buildDatasource(); connection = dataSource.getConnection(); connection.setAutoCommit(false); - preparedStatement = connection.prepareStatement(spec.getStatement()); - } - - @StartBundle - public void startBundle() { - batchCount = 0; } @ProcessElement public void processElement(ProcessContext context) throws Exception { T record = context.element(); - preparedStatement.clearParameters(); - spec.getPreparedStatementSetter().setParameters(record, preparedStatement); - preparedStatement.addBatch(); + records.add(record); - batchCount++; - - if (batchCount >= spec.getBatchSize()) { + if (records.size() >= spec.getBatchSize()) { executeBatch(); } } + private void processRecord(T record, PreparedStatement preparedStatement) { + try { + preparedStatement.clearParameters(); + spec.getPreparedStatementSetter().setParameters(record, preparedStatement); + preparedStatement.addBatch(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + @FinishBundle public void finishBundle() throws Exception { executeBatch(); } - private void executeBatch() throws SQLException { - if (batchCount > 0) { - preparedStatement.executeBatch(); - connection.commit(); - batchCount = 0; + private void executeBatch() throws SQLException, IOException, InterruptedException { + if (records.size() == 0) { + return; + } + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backoff = BUNDLE_WRITE_BACKOFF.backoff(); + while (true) { + try (PreparedStatement preparedStatement = + connection.prepareStatement(spec.getStatement())) { + try { + // add each record in the statement batch + for (T record : records) { + processRecord(record, preparedStatement); + } + // execute the batch + preparedStatement.executeBatch(); + // commit the changes + connection.commit(); + break; + } catch (SQLException exception) { + if (!spec.getRetryStrategy().apply(exception)) { + throw exception; + } + LOG.warn("Deadlock detected, retrying", exception); + // clean up the statement batch and the connection state + preparedStatement.clearBatch(); + connection.rollback(); + if (!BackOffUtils.next(sleeper, backoff)) { + // we tried the max number of times + throw exception; + } + } + } } + records.clear(); } @Teardown public void teardown() throws Exception { - try { - if (preparedStatement != null) { - preparedStatement.close(); - } - } finally { - if (connection != null) { - connection.close(); - } - if (dataSource instanceof AutoCloseable) { - ((AutoCloseable) dataSource).close(); - } + if (connection != null) { + connection.close(); + } + if (dataSource instanceof AutoCloseable) { + ((AutoCloseable) dataSource).close(); } } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index beb368592453..304c1c775290 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -39,6 +39,7 @@ import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.common.DatabaseTestHelper; import org.apache.beam.sdk.io.common.TestRow; +import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; @@ -47,6 +48,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.derby.drda.NetworkServerControl; import org.apache.derby.jdbc.ClientDataSource; +import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; @@ -59,8 +61,10 @@ * Test on the JdbcIO. */ public class JdbcIOTest implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(JdbcIOTest.class); public static final int EXPECTED_ROW_COUNT = 1000; + public static final String BACKOFF_TABLE = "UT_WRITE_BACKOFF"; private static NetworkServerControl derbyServer; private static ClientDataSource dataSource; @@ -71,6 +75,9 @@ public class JdbcIOTest implements Serializable { @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + @Rule + public final transient ExpectedLogs expectedLogs = ExpectedLogs.none(JdbcIO.class); + @BeforeClass public static void startDatabase() throws Exception { ServerSocket socket = new ServerSocket(0); @@ -79,6 +86,9 @@ public static void startDatabase() throws Exception { LOG.info("Starting Derby database on {}", port); + // by default, derby uses a lock timeout of 60 seconds. In order to speed up the test + // and detect the lock faster, we decrease this timeout + System.setProperty("derby.locks.waitTimeout", "2"); System.setProperty("derby.stream.error.file", "target/derby.log"); derbyServer = new NetworkServerControl(InetAddress.getByName("localhost"), port); @@ -149,11 +159,13 @@ public void testDataSourceConfigurationDriverAndUrl() throws Exception { @Test public void testDataSourceConfigurationUsernameAndPassword() throws Exception { + String username = "sa"; + String password = "sa"; JdbcIO.DataSourceConfiguration config = JdbcIO.DataSourceConfiguration.create( "org.apache.derby.jdbc.ClientDriver", "jdbc:derby://localhost:" + port + "/target/beam") - .withUsername("sa") - .withPassword("sa"); + .withUsername(username) + .withPassword(password); try (Connection conn = config.buildDatasource().getConnection()) { assertTrue(conn.isValid(0)); } @@ -161,11 +173,13 @@ public void testDataSourceConfigurationUsernameAndPassword() throws Exception { @Test public void testDataSourceConfigurationNullPassword() throws Exception { + String username = "sa"; + String password = null; JdbcIO.DataSourceConfiguration config = JdbcIO.DataSourceConfiguration.create( "org.apache.derby.jdbc.ClientDriver", "jdbc:derby://localhost:" + port + "/target/beam") - .withUsername("sa") - .withPassword(null); + .withUsername(username) + .withPassword(password); try (Connection conn = config.buildDatasource().getConnection()) { assertTrue(conn.isValid(0)); } @@ -173,11 +187,13 @@ public void testDataSourceConfigurationNullPassword() throws Exception { @Test public void testDataSourceConfigurationNullUsernameAndPassword() throws Exception { + String username = null; + String password = null; JdbcIO.DataSourceConfiguration config = JdbcIO.DataSourceConfiguration.create( "org.apache.derby.jdbc.ClientDriver", "jdbc:derby://localhost:" + port + "/target/beam") - .withUsername(null) - .withPassword(null); + .withUsername(username) + .withPassword(password); try (Connection conn = config.buildDatasource().getConnection()) { assertTrue(conn.isValid(0)); } @@ -222,25 +238,25 @@ public void testRead() throws Exception { pipeline.run(); } - @Test - public void testReadWithSingleStringParameter() throws Exception { + @Test + public void testReadWithSingleStringParameter() throws Exception { PCollection rows = - pipeline.apply( - JdbcIO.read() - .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) - .withQuery(String.format("select name,id from %s where name = ?", readTableName)) - .withStatementPreparator( - (preparedStatement) -> preparedStatement.setString(1, getNameForSeed(1))) - .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) - .withCoder(SerializableCoder.of(TestRow.class))); + pipeline.apply( + JdbcIO.read() + .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) + .withQuery(String.format("select name,id from %s where name = ?", readTableName)) + .withStatementPreparator( + (preparedStatement) -> preparedStatement.setString(1, getNameForSeed(1))) + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) + .withCoder(SerializableCoder.of(TestRow.class))); PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo(1L); - Iterable expectedValues = Collections.singletonList(TestRow.fromSeed(1)); - PAssert.that(rows).containsInAnyOrder(expectedValues); + Iterable expectedValues = Collections.singletonList(TestRow.fromSeed(1)); + PAssert.that(rows).containsInAnyOrder(expectedValues); - pipeline.run(); - } + pipeline.run(); + } @Test public void testWrite() throws Exception { @@ -275,7 +291,7 @@ public void testWrite() throws Exception { try (Connection connection = dataSource.getConnection()) { try (Statement statement = connection.createStatement()) { try (ResultSet resultSet = statement.executeQuery("select count(*) from " - + tableName)) { + + tableName)) { resultSet.next(); int count = resultSet.getInt(1); @@ -288,6 +304,84 @@ public void testWrite() throws Exception { } } + @Test + public void testWriteWithBackoff() throws Exception { + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE_BACKOFF"); + DatabaseTestHelper.createTable(dataSource, tableName); + + // lock table + Connection connection = dataSource.getConnection(); + Statement lockStatement = connection.createStatement(); + lockStatement.execute("ALTER TABLE " + tableName + " LOCKSIZE TABLE"); + lockStatement.execute("LOCK TABLE " + tableName + " IN EXCLUSIVE MODE"); + + // start a first transaction + connection.setAutoCommit(false); + PreparedStatement insertStatement = + connection.prepareStatement("insert into " + tableName + " values(?, ?)"); + insertStatement.setInt(1, 1); + insertStatement.setString(2, "TEST"); + insertStatement.execute(); + + // try to write to this table + pipeline + .apply(Create.of(Collections.singletonList(KV.of(1, "TEST")))) + .apply( + JdbcIO.>write() + .withDataSourceConfiguration( + JdbcIO.DataSourceConfiguration.create( + "org.apache.derby.jdbc.ClientDriver", + "jdbc:derby://localhost:" + port + "/target/beam")) + .withStatement(String.format("insert into %s values(?, ?)", tableName)) + .withRetryStrategy((JdbcIO.RetryStrategy) e -> { + return e.getSQLState().equals("XJ208"); // we fake a deadlock with a lock here + }) + .withPreparedStatementSetter( + (element, statement) -> { + statement.setInt(1, element.getKey()); + statement.setString(2, element.getValue()); + })); + + // starting a thread to perform the commit later, while the pipeline is running into the backoff + Thread commitThread = new Thread(() -> { + try { + Thread.sleep(10000); + connection.commit(); + } catch (Exception e) { + // nothing to do + } + }); + commitThread.start(); + pipeline.run(); + commitThread.join(); + + // we verify the the backoff has been called thanks to the log message + expectedLogs.verifyWarn("Deadlock detected, retrying"); + + try (Connection readConnection = dataSource.getConnection()) { + try (Statement statement = readConnection.createStatement()) { + try (ResultSet resultSet = statement.executeQuery("select count(*) from " + + tableName)) { + resultSet.next(); + int count = resultSet.getInt(1); + // here we have the record inserted by the first transaction (by hand), and a second one + // inserted by the pipeline + Assert.assertEquals(2, count); + } + } + } + + } + + @After + public void tearDown() { + try { + DatabaseTestHelper.deleteTable(dataSource, BACKOFF_TABLE); + } catch (Exception e) { + // nothing to do + } + } + @Test public void testWriteWithEmptyPCollection() throws Exception { pipeline diff --git a/sdks/java/io/jms/pom.xml b/sdks/java/io/jms/pom.xml index 0d1cc0b53ed0..71aaeb47908b 100644 --- a/sdks/java/io/jms/pom.xml +++ b/sdks/java/io/jms/pom.xml @@ -42,11 +42,6 @@ beam-sdks-java-core - - org.slf4j - slf4j-api - - joda-time joda-time @@ -112,14 +107,14 @@ org.hamcrest - hamcrest-all + hamcrest-core test - org.slf4j - slf4j-jdk14 + org.hamcrest + hamcrest-library test - + diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java index e167762bef55..ae88b41c1bab 100644 --- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.jms; import static com.google.common.base.Preconditions.checkArgument; + import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import java.io.IOException; @@ -55,8 +56,6 @@ import org.apache.beam.sdk.values.PDone; import org.joda.time.Duration; import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** @@ -120,8 +119,6 @@ @Experimental(Experimental.Kind.SOURCE_SINK) public class JmsIO { - private static final Logger LOG = LoggerFactory.getLogger(JmsIO.class); - public static Read read() { return new AutoValue_JmsIO_Read.Builder() .setMaxNumRecords(Long.MAX_VALUE) diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java index 95a143b86da1..fa6b916cd318 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java @@ -37,7 +37,6 @@ import javax.jms.QueueBrowser; import javax.jms.Session; import javax.jms.TextMessage; - import org.apache.activemq.ActiveMQConnectionFactory; import org.apache.activemq.broker.BrokerPlugin; import org.apache.activemq.broker.BrokerService; diff --git a/sdks/java/io/kafka/pom.xml b/sdks/java/io/kafka/pom.xml index b04f5bfd18a7..e8292c4ba980 100644 --- a/sdks/java/io/kafka/pom.xml +++ b/sdks/java/io/kafka/pom.xml @@ -53,6 +53,7 @@ org.apache.kafka kafka-clients + provided @@ -118,10 +119,14 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test - junit junit diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ConsumerSpEL.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ConsumerSpEL.java index 8cdad228f210..a3bd43925884 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ConsumerSpEL.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ConsumerSpEL.java @@ -21,12 +21,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.util.Collection; import java.util.Map; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; import org.apache.kafka.common.TopicPartition; import org.joda.time.Instant; import org.slf4j.Logger; @@ -54,29 +53,29 @@ class ConsumerSpEL { private Expression assignExpression = parser.parseExpression("#consumer.assign(#tp)"); - private Method timestampMethod; private boolean hasRecordTimestamp = false; - - private Method offsetGetterMethod; - private Method offsetsForTimesMethod; private boolean hasOffsetsForTimes = false; public ConsumerSpEL() { try { // It is supported by Kafka Client 0.10.0.0 onwards. - timestampMethod = ConsumerRecord.class.getMethod("timestamp", (Class[]) null); - hasRecordTimestamp = timestampMethod.getReturnType().equals(Long.TYPE); + hasRecordTimestamp = ConsumerRecord + .class + .getMethod("timestamp", (Class[]) null) + .getReturnType() + .equals(Long.TYPE); } catch (NoSuchMethodException | SecurityException e) { LOG.debug("Timestamp for Kafka message is not available."); } try { // It is supported by Kafka Client 0.10.1.0 onwards. - offsetGetterMethod = Class.forName("org.apache.kafka.clients.consumer.OffsetAndTimestamp") - .getMethod("offset", (Class[]) null); - offsetsForTimesMethod = Consumer.class.getMethod("offsetsForTimes", Map.class); - hasOffsetsForTimes = offsetsForTimesMethod.getReturnType().equals(Map.class); - } catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) { + hasOffsetsForTimes = Consumer + .class + .getMethod("offsetsForTimes", Map.class) + .getReturnType() + .equals(Map.class); + } catch (NoSuchMethodException | SecurityException e) { LOG.debug("OffsetsForTimes is not available."); } } @@ -97,15 +96,8 @@ public void evaluateAssign(Consumer consumer, Collection topicPa public long getRecordTimestamp(ConsumerRecord rawRecord) { long timestamp; - try { - //for Kafka 0.9, set to System.currentTimeMillis(); - //for kafka 0.10, when NO_TIMESTAMP also set to System.currentTimeMillis(); - if (!hasRecordTimestamp || (timestamp = (long) timestampMethod.invoke(rawRecord)) <= 0L) { - timestamp = System.currentTimeMillis(); - } - } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { - // Not expected. Method timestamp() is already checked. - throw new RuntimeException(e); + if (!hasRecordTimestamp || (timestamp = rawRecord.timestamp()) <= 0L) { + timestamp = System.currentTimeMillis(); } return timestamp; } @@ -125,23 +117,18 @@ public long offsetForTime(Consumer consumer, TopicPartition topicPartition checkArgument(hasOffsetsForTimes, "This Kafka Client must support Consumer.OffsetsForTimes()."); - Map timestampsToSearch = - ImmutableMap.of(topicPartition, time.getMillis()); - try { - Map offsetsByTimes = (Map) offsetsForTimesMethod.invoke(consumer, timestampsToSearch); - Object offsetAndTimestamp = Iterables.getOnlyElement(offsetsByTimes.values()); - - if (offsetAndTimestamp == null) { - throw new RuntimeException("There are no messages has a timestamp that is greater than or " - + "equals to the target time or the message format version in this partition is " - + "before 0.10.0, topicPartition is: " + topicPartition); - } else { - return (long) offsetGetterMethod.invoke(offsetAndTimestamp); - } - } catch (IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e); + // 'value' in the map returned by offsetFoTime() is null if there is no offset for the time. + OffsetAndTimestamp offsetAndTimestamp = Iterables.getOnlyElement( + consumer + .offsetsForTimes(ImmutableMap.of(topicPartition, time.getMillis())) + .values()); + + if (offsetAndTimestamp == null) { + throw new RuntimeException("There are no messages has a timestamp that is greater than or " + + "equals to the target time or the message format version in this partition is " + + "before 0.10.0, topicPartition is: " + topicPartition); + } else { + return offsetAndTimestamp.offset(); } - } - } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java index 0856c7c6ec89..791e594bb232 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java @@ -21,14 +21,13 @@ import java.io.IOException; import java.io.Serializable; import java.util.List; - import org.apache.avro.reflect.AvroIgnore; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.DefaultCoder; import org.apache.beam.sdk.io.UnboundedSource; /** - * Checkpoint for an unbounded KafkaIO.Read. Consists of Kafka topic name, partition id, + * Checkpoint for a {@link KafkaUnboundedReader}. Consists of Kafka topic name, partition id, * and the latest offset consumed so far. */ @DefaultCoder(AvroCoder.class) @@ -37,12 +36,12 @@ public class KafkaCheckpointMark implements UnboundedSource.CheckpointMark { private List partitions; @AvroIgnore - private KafkaIO.UnboundedKafkaReader reader; // Non-null when offsets need to be committed. + private KafkaUnboundedReader reader; // Non-null when offsets need to be committed. private KafkaCheckpointMark() {} // for Avro public KafkaCheckpointMark(List partitions, - KafkaIO.UnboundedKafkaReader reader) { + KafkaUnboundedReader reader) { this.partitions = partitions; this.reader = reader; } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java new file mode 100644 index 000000000000..7345a92f2015 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java @@ -0,0 +1,643 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.MoreObjects; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.cache.RemovalCause; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.io.kafka.KafkaIO.Write; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.metrics.SinkMetrics; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.joda.time.DateTimeUtils; +import org.joda.time.DateTimeZone; +import org.joda.time.format.DateTimeFormat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Exactly-once sink transform for Kafka. + * See {@link KafkaIO} for user visible documentation and example usage. + */ +class KafkaExactlyOnceSink extends PTransform>, PCollection> { + + // Dataflow ensures at-least once processing for side effects like sinks. In order to provide + // exactly-once semantics, a sink needs to be idempotent or it should avoid writing records + // that have already been written. This snk does the latter. All the the records are ordered + // across a fixed number of shards and records in each shard are written in order. It drops + // any records that are already written and buffers those arriving out of order. + // + // Exactly once sink involves two shuffles of the records: + // A : Assign a shard ---> B : Assign sequential ID ---> C : Write to Kafka in order + // + // Processing guarantees also require deterministic processing within user transforms. + // Here, that requires order of the records committed to Kafka by C should not be affected by + // restarts in C and its upstream stages. + // + // A : Assigns a random shard for message. Note that there are no ordering guarantees for + // writing user records to Kafka. User can still control partitioning among topic + // partitions as with regular sink (of course, there are no ordering guarantees in + // regular Kafka sink either). + // B : Assigns an id sequentially for each messages within a shard. + // C : Writes each shard to Kafka in sequential id order. In Dataflow, when C sees a record + // and id, it implies that record and the associated id are checkpointed to persistent + // storage and this record will always have same id, even in retries. + // Exactly-once semantics are achieved by writing records in the strict order of + // these check-pointed sequence ids. + // + // Parallelism for B and C is fixed to 'numShards', which defaults to number of partitions + // for the topic. A few reasons for that: + // - B & C implement their functionality using per-key state. Shard id makes it independent + // of cardinality of user key. + // - We create one producer per shard, and its 'transactional id' is based on shard id. This + // requires that number of shards to be finite. This also helps with batching. and avoids + // initializing producers and transactions. + // - Most importantly, each of sharded writers stores 'next message id' in partition + // metadata, which is committed atomically with Kafka transactions. This is critical + // to handle retries of C correctly. Initial testing showed number of shards could be + // larger than number of partitions for the topic. + // + // Number of shards can change across multiple runs of a pipeline (job upgrade in Dataflow). + // + + private static final Logger LOG = LoggerFactory.getLogger(KafkaExactlyOnceSink.class); + private static final String METRIC_NAMESPACE = "KafkaExactlyOnceSink"; + + private final Write spec; + + static void ensureEOSSupport() { + checkArgument( + ProducerSpEL.supportsTransactions(), "%s %s", + "This version of Kafka client does not support transactions required to support", + "exactly-once semantics. Please use Kafka client version 0.11 or newer."); + } + + KafkaExactlyOnceSink(Write spec) { + this.spec = spec; + } + + @Override + public PCollection expand(PCollection> input) { + + int numShards = spec.getNumShards(); + if (numShards <= 0) { + try (Consumer consumer = openConsumer(spec)) { + numShards = consumer.partitionsFor(spec.getTopic()).size(); + LOG.info("Using {} shards for exactly-once writer, matching number of partitions " + + "for topic '{}'", numShards, spec.getTopic()); + } + } + checkState(numShards > 0, "Could not set number of shards"); + + return input + .apply( + Window.>into(new GlobalWindows()) // Everything into global window. + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .discardingFiredPanes()) + .apply( + String.format("Shuffle across %d shards", numShards), + ParDo.of(new Reshard<>(numShards))) + .apply("Persist sharding", GroupByKey.create()) + .apply("Assign sequential ids", ParDo.of(new Sequencer<>())) + .apply("Persist ids", GroupByKey.create()) + .apply( + String.format("Write to Kafka topic '%s'", spec.getTopic()), + ParDo.of(new ExactlyOnceWriter<>(spec, input.getCoder()))); + } + + /** + * Shuffle messages assigning each randomly to a shard. + */ + private static class Reshard extends DoFn, KV>> { + + private final int numShards; + private transient int shardId; + + Reshard(int numShards) { + this.numShards = numShards; + } + + @Setup + public void setup() { + shardId = ThreadLocalRandom.current().nextInt(numShards); + } + + @ProcessElement + public void processElement(ProcessContext ctx) { + shardId = (shardId + 1) % numShards; // round-robin among shards. + ctx.output(KV.of(shardId, ctx.element())); + } + } + + private static class Sequencer + extends DoFn>>, KV>>> { + + private static final String NEXT_ID = "nextId"; + @StateId(NEXT_ID) + private final StateSpec> nextIdSpec = StateSpecs.value(); + + @ProcessElement + public void processElement(@StateId(NEXT_ID) ValueState nextIdState, ProcessContext ctx) { + long nextId = MoreObjects.firstNonNull(nextIdState.read(), 0L); + int shard = ctx.element().getKey(); + for (KV value : ctx.element().getValue()) { + ctx.output(KV.of(shard, KV.of(nextId, value))); + nextId++; + } + nextIdState.write(nextId); + } + } + + private static class ExactlyOnceWriter + extends DoFn>>>, Void> { + + private static final String NEXT_ID = "nextId"; + private static final String MIN_BUFFERED_ID = "minBufferedId"; + private static final String OUT_OF_ORDER_BUFFER = "outOfOrderBuffer"; + private static final String WRITER_ID = "writerId"; + + // Not sure of a good limit. This applies only for large bundles. + private static final int MAX_RECORDS_PER_TXN = 1000; + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + @StateId(NEXT_ID) + private final StateSpec> sequenceIdSpec = StateSpecs.value(); + @StateId(MIN_BUFFERED_ID) + private final StateSpec> minBufferedIdSpec = StateSpecs.value(); + @StateId(OUT_OF_ORDER_BUFFER) + private final StateSpec>>> outOfOrderBufferSpec; + // A random id assigned to each shard. Helps with detecting when multiple jobs are mistakenly + // started with same groupId used for storing state on Kafka side, including the case where + // a job is restarted with same groupId, but the metadata from previous run was not cleared. + // Better to be safe and error out with a clear message. + @StateId(WRITER_ID) + private final StateSpec> writerIdSpec = StateSpecs.value(); + + private final Write spec; + + // Metrics + private final Counter elementsWritten = SinkMetrics.elementsWritten(); + // Elements buffered due to out of order arrivals. + private final Counter elementsBuffered = Metrics.counter(METRIC_NAMESPACE, "elementsBuffered"); + private final Counter numTransactions = Metrics.counter(METRIC_NAMESPACE, "numTransactions"); + + ExactlyOnceWriter(Write spec, Coder> elemCoder) { + this.spec = spec; + this.outOfOrderBufferSpec = StateSpecs.bag(KvCoder.of(BigEndianLongCoder.of(), elemCoder)); + } + + @Setup + public void setup() { + // This is on the worker. Ensure the runtime version is till compatible. + KafkaExactlyOnceSink.ensureEOSSupport(); + } + + @ProcessElement + public void processElement(@StateId(NEXT_ID) ValueState nextIdState, + @StateId(MIN_BUFFERED_ID) ValueState minBufferedIdState, + @StateId(OUT_OF_ORDER_BUFFER) + BagState>> oooBufferState, + @StateId(WRITER_ID) ValueState writerIdState, + ProcessContext ctx) + throws IOException { + + int shard = ctx.element().getKey(); + + minBufferedIdState.readLater(); + long nextId = MoreObjects.firstNonNull(nextIdState.read(), 0L); + long minBufferedId = MoreObjects.firstNonNull(minBufferedIdState.read(), Long.MAX_VALUE); + + ShardWriterCache cache = + (ShardWriterCache) CACHE_BY_GROUP_ID.getUnchecked(spec.getSinkGroupId()); + ShardWriter writer = cache.removeIfPresent(shard); + if (writer == null) { + writer = initShardWriter(shard, writerIdState, nextId); + } + + long committedId = writer.committedId; + + if (committedId >= nextId) { + // This is a retry of an already committed batch. + LOG.info("{}: committed id {} is ahead of expected {}. {} records will be dropped " + + "(these are already written).", + shard, committedId, nextId - 1, committedId - nextId + 1); + nextId = committedId + 1; + } + + try { + writer.beginTxn(); + int txnSize = 0; + + // Iterate in recordId order. The input iterator could be mostly sorted. + // There might be out of order messages buffered in earlier iterations. These + // will get merged if and when minBufferedId matches nextId. + + Iterator>> iter = ctx.element().getValue().iterator(); + + while (iter.hasNext()) { + KV> kv = iter.next(); + long recordId = kv.getKey(); + + if (recordId < nextId) { + LOG.info("{}: dropping older record {}. Already committed till {}", + shard, recordId, committedId); + continue; + } + + if (recordId > nextId) { + // Out of order delivery. Should be pretty rare (what about in a batch pipeline?) + LOG.info("{}: Saving out of order record {}, next record id to be written is {}", + shard, recordId, nextId); + + // checkState(recordId - nextId < 10000, "records are way out of order"); + + oooBufferState.add(kv); + minBufferedId = Math.min(minBufferedId, recordId); + minBufferedIdState.write(minBufferedId); + elementsBuffered.inc(); + continue; + } + + // recordId and nextId match. Finally write the record. + + writer.sendRecord(kv.getValue(), elementsWritten); + nextId++; + + if (++txnSize >= MAX_RECORDS_PER_TXN) { + writer.commitTxn(recordId, numTransactions); + txnSize = 0; + writer.beginTxn(); + } + + if (minBufferedId == nextId) { + // One or more of the buffered records can be committed now. + // Read all of them in to memory and sort them. Reading into memory + // might be problematic in extreme cases. Might need to improve it in future. + + List>> buffered = Lists.newArrayList(oooBufferState.read()); + buffered.sort(new KV.OrderByKey<>()); + + LOG.info("{} : merging {} buffered records (min buffered id is {}).", + shard, buffered.size(), minBufferedId); + + oooBufferState.clear(); + minBufferedIdState.clear(); + minBufferedId = Long.MAX_VALUE; + + iter = + Iterators.mergeSorted( + ImmutableList.of(iter, buffered.iterator()), new KV.OrderByKey<>()); + } + } + + writer.commitTxn(nextId - 1, numTransactions); + nextIdState.write(nextId); + + } catch (ProducerSpEL.UnrecoverableProducerException e) { + // Producer JavaDoc says these are not recoverable errors and producer should be closed. + + // Close the producer and a new producer will be initialized in retry. + // It is possible that a rough worker keeps retrying and ends up fencing off + // active producers. How likely this might be or how well such a scenario is handled + // depends on the runner. For now we will leave it to upper layers, will need to revisit. + + LOG.warn("{} : closing producer {} after unrecoverable error. The work might have migrated." + + " Committed id {}, current id {}.", + writer.shard, writer.producerName, writer.committedId, nextId - 1, e); + + writer.producer.close(); + writer = null; // No need to cache it. + throw e; + } finally { + if (writer != null) { + cache.insert(shard, writer); + } + } + } + + private static class ShardMetadata { + + @JsonProperty("seq") + public final long sequenceId; + @JsonProperty("id") + public final String writerId; + + private ShardMetadata() { // for json deserializer + sequenceId = -1; + writerId = null; + } + + ShardMetadata(long sequenceId, String writerId) { + this.sequenceId = sequenceId; + this.writerId = writerId; + } + } + + /** + * A wrapper around Kafka producer. One for each of the shards. + */ + private static class ShardWriter { + + private final int shard; + private final String writerId; + private final Producer producer; + private final String producerName; + private final Write spec; + private long committedId; + + ShardWriter(int shard, + String writerId, + Producer producer, + String producerName, + Write spec, + long committedId) { + this.shard = shard; + this.writerId = writerId; + this.producer = producer; + this.producerName = producerName; + this.spec = spec; + this.committedId = committedId; + } + + void beginTxn() { + ProducerSpEL.beginTransaction(producer); + } + + void sendRecord(KV record, Counter sendCounter) { + try { + producer.send( + new ProducerRecord<>(spec.getTopic(), record.getKey(), record.getValue())); + sendCounter.inc(); + } catch (KafkaException e) { + ProducerSpEL.abortTransaction(producer); + throw e; + } + } + + void commitTxn(long lastRecordId, Counter numTransactions) throws IOException { + try { + // Store id in consumer group metadata for the partition. + // NOTE: Kafka keeps this metadata for 24 hours since the last update. This limits + // how long the pipeline could be down before resuming it. It does not look like + // this TTL can be adjusted (asked about it on Kafka users list). + ProducerSpEL.sendOffsetsToTransaction( + producer, + ImmutableMap.of(new TopicPartition(spec.getTopic(), shard), + new OffsetAndMetadata(0L, + JSON_MAPPER.writeValueAsString( + new ShardMetadata(lastRecordId, writerId)))), + spec.getSinkGroupId()); + ProducerSpEL.commitTransaction(producer); + + numTransactions.inc(); + LOG.debug("{} : committed {} records", shard, lastRecordId - committedId); + + committedId = lastRecordId; + } catch (KafkaException e) { + ProducerSpEL.abortTransaction(producer); + throw e; + } + } + } + + private ShardWriter initShardWriter(int shard, + ValueState writerIdState, + long nextId) throws IOException { + + String producerName = String.format("producer_%d_for_%s", shard, spec.getSinkGroupId()); + Producer producer = initializeExactlyOnceProducer(spec, producerName); + + // Fetch latest committed metadata for the partition (if any). Checks committed sequence ids. + try { + + String writerId = writerIdState.read(); + + OffsetAndMetadata committed; + + try (Consumer consumer = openConsumer(spec)) { + committed = consumer.committed(new TopicPartition(spec.getTopic(), shard)); + } + + long committedSeqId = -1; + + if (committed == null || committed.metadata() == null || committed.metadata().isEmpty()) { + checkState(nextId == 0 && writerId == null, + "State exists for shard %s (nextId %s, writerId '%s'), but there is no state " + + "stored with Kafka topic '%s' group id '%s'", + shard, nextId, writerId, spec.getTopic(), spec.getSinkGroupId()); + + writerId = String.format("%X - %s", + new Random().nextInt(Integer.MAX_VALUE), + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss") + .withZone(DateTimeZone.UTC) + .print(DateTimeUtils.currentTimeMillis())); + writerIdState.write(writerId); + LOG.info("Assigned writer id '{}' to shard {}", writerId, shard); + + } else { + ShardMetadata metadata = JSON_MAPPER.readValue(committed.metadata(), + ShardMetadata.class); + + checkNotNull(metadata.writerId); + + if (writerId == null) { + // a) This might be a restart of the job from scratch, in which case metatdata + // should be ignored and overwritten with new one. + // b) This job might be started with an incorrect group id which is an error. + // c) There is an extremely small chance that this is a retry of the first bundle + // where metatdate was committed to Kafka but the bundle results were not committed + // in Beam, in which case it should be treated as correct metadata. + // How can we tell these three cases apart? Be safe and throw an exception. + // + // We could let users explicitly an option to override the existing metadata. + // + throw new IllegalStateException(String.format( + "Kafka metadata exists for shard %s, but there is no stored state for it. " + + "This mostly indicates groupId '%s' is used else where or in earlier runs. " + + "Try another group id. Metadata for this shard on Kafka : '%s'", + shard, spec.getSinkGroupId(), committed.metadata())); + } + + checkState(writerId.equals(metadata.writerId), + "Writer ids don't match. This is mostly a unintended misuse of groupId('%s')." + + "Beam '%s', Kafka '%s'", + spec.getSinkGroupId(), writerId, metadata.writerId); + + committedSeqId = metadata.sequenceId; + + checkState(committedSeqId >= (nextId - 1), + "Committed sequence id can not be lower than %s, partition metadata : %s", + nextId - 1, committed.metadata()); + } + + LOG.info("{} : initialized producer {} with committed sequence id {}", + shard, producerName, committedSeqId); + + return new ShardWriter<>(shard, writerId, producer, producerName, spec, committedSeqId); + + } catch (Exception e) { + producer.close(); + throw e; + } + } + + /** + * A wrapper around guava cache to provide insert()/remove() semantics. A ShardWriter will be + * closed if it is stays in cache for more than 1 minute, i.e. not used inside + * KafkaExactlyOnceSink DoFn for a minute. + */ + private static class ShardWriterCache { + + static final ScheduledExecutorService SCHEDULED_CLEAN_UP_THREAD = + Executors.newSingleThreadScheduledExecutor(); + + static final int CLEAN_UP_CHECK_INTERVAL_MS = 10 * 1000; + static final int IDLE_TIMEOUT_MS = 60 * 1000; + + private final Cache> cache; + + ShardWriterCache() { + this.cache = + CacheBuilder.newBuilder() + .expireAfterWrite(IDLE_TIMEOUT_MS, TimeUnit.MILLISECONDS) + .>removalListener( + notification -> { + if (notification.getCause() != RemovalCause.EXPLICIT) { + ShardWriter writer = notification.getValue(); + LOG.info( + "{} : Closing idle shard writer {} after 1 minute of idle time.", + writer.shard, + writer.producerName); + writer.producer.close(); + } + }) + .build(); + + // run cache.cleanUp() every 10 seconds. + SCHEDULED_CLEAN_UP_THREAD.scheduleAtFixedRate( + cache::cleanUp, + CLEAN_UP_CHECK_INTERVAL_MS, + CLEAN_UP_CHECK_INTERVAL_MS, + TimeUnit.MILLISECONDS); + } + + ShardWriter removeIfPresent(int shard) { + return cache.asMap().remove(shard); + } + + void insert(int shard, ShardWriter writer) { + ShardWriter existing = cache.asMap().putIfAbsent(shard, writer); + checkState(existing == null, + "Unexpected multiple instances of writers for shard %s", shard); + } + } + + // One cache for each sink (usually there is only one sink per pipeline) + private static final LoadingCache> CACHE_BY_GROUP_ID = + CacheBuilder.newBuilder() + .build(new CacheLoader>() { + @Override + public ShardWriterCache load(String key) throws Exception { + return new ShardWriterCache<>(); + } + }); + } + + /** + * Opens a generic consumer that is mainly meant for metadata operations like fetching number of + * partitions for a topic rather than for fetching messages. + */ + private static Consumer openConsumer(Write spec) { + return spec.getConsumerFactoryFn().apply((ImmutableMap.of( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, spec + .getProducerConfig().get(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG), + ConsumerConfig.GROUP_ID_CONFIG, spec.getSinkGroupId(), + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class, + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class + ))); + } + + private static Producer initializeExactlyOnceProducer(Write spec, + String producerName) { + + Map producerConfig = new HashMap<>(spec.getProducerConfig()); + producerConfig.putAll(ImmutableMap.of( + ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, spec.getKeySerializer(), + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, spec.getValueSerializer(), + ProducerSpEL.ENABLE_IDEMPOTENCE_CONFIG, true, + ProducerSpEL.TRANSACTIONAL_ID_CONFIG, producerName)); + + Producer producer = + spec.getProducerFactoryFn() != null + ? spec.getProducerFactoryFn().apply((producerConfig)) + : new KafkaProducer<>(producerConfig); + + ProducerSpEL.initTransactions(producer); + return producer; + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 996a4604692b..bd8ac6443820 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -21,54 +21,23 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; -import com.google.common.base.MoreObjects; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import com.google.common.cache.RemovalCause; -import com.google.common.collect.ComparisonChain; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.io.Closeables; -import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Random; import java.util.Set; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.AvroCoder; -import org.apache.beam.sdk.coders.BigEndianLongCoder; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; @@ -77,62 +46,34 @@ import org.apache.beam.sdk.io.Read.Unbounded; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; -import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; -import org.apache.beam.sdk.io.kafka.KafkaCheckpointMark.PartitionMark; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Gauge; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.metrics.SinkMetrics; -import org.apache.beam.sdk.metrics.SourceMetrics; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.StateSpec; -import org.apache.beam.sdk.state.StateSpecs; -import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.windowing.AfterPane; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; -import org.apache.beam.sdk.transforms.windowing.Repeatedly; -import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; -import org.apache.kafka.clients.consumer.ConsumerRecord; -import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.KafkaConsumer; -import org.apache.kafka.clients.consumer.OffsetAndMetadata; -import org.apache.kafka.clients.producer.Callback; import org.apache.kafka.clients.producer.KafkaProducer; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; -import org.apache.kafka.clients.producer.RecordMetadata; -import org.apache.kafka.common.KafkaException; -import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; -import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.common.serialization.StringSerializer; import org.apache.kafka.common.utils.AppInfoParser; -import org.joda.time.DateTimeUtils; -import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.Instant; -import org.joda.time.format.DateTimeFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -203,7 +144,7 @@ * *

    Checkpointing is fully supported and each split can resume from previous checkpoint * (to the extent supported by runner). - * See {@link UnboundedKafkaSource#split(int, PipelineOptions)} for more details on + * See {@link KafkaUnboundedSource#split(int, PipelineOptions)} for more details on * splits and checkpoint support. * *

    When the pipeline starts for the first time, or without any checkpoint, the source starts @@ -384,7 +325,7 @@ public Read withBootstrapServers(String bootstrapServers) { /** * Sets the topic to read from. * - *

    See {@link UnboundedKafkaSource#split(int, PipelineOptions)} for description + *

    See {@link KafkaUnboundedSource#split(int, PipelineOptions)} for description * of how the partitions are distributed among the splits. */ public Read withTopic(String topic) { @@ -395,7 +336,7 @@ public Read withTopic(String topic) { * Sets a list of topics to read from. All the partitions from each * of the topics are read. * - *

    See {@link UnboundedKafkaSource#split(int, PipelineOptions)} for description + *

    See {@link KafkaUnboundedSource#split(int, PipelineOptions)} for description * of how the partitions are distributed among the splits. */ public Read withTopics(List topics) { @@ -408,7 +349,7 @@ public Read withTopics(List topics) { * Sets a list of partitions to read from. This allows reading only a subset * of partitions for one or more topics when (if ever) needed. * - *

    See {@link UnboundedKafkaSource#split(int, PipelineOptions)} for description + *

    See {@link KafkaUnboundedSource#split(int, PipelineOptions)} for description * of how the partitions are distributed among the splits. */ public Read withTopicPartitions(List topicPartitions) { @@ -595,8 +536,15 @@ public PCollection> expand(PBegin input) { "Either withTopic(), withTopics() or withTopicPartitions() is required"); checkArgument(getKeyDeserializer() != null, "withKeyDeserializer() is required"); checkArgument(getValueDeserializer() != null, "withValueDeserializer() is required"); + ConsumerSpEL consumerSpEL = new ConsumerSpEL(); + + if (!consumerSpEL.hasOffsetsForTimes()) { + LOG.warn("Kafka client version {} is too old. Versions before 0.10.1.0 are deprecated and " + + "may not be supported in next release of Apache Beam. " + + "Please upgrade your Kafka client version.", AppInfoParser.getVersion()); + } if (getStartReadTime() != null) { - checkArgument(new ConsumerSpEL().hasOffsetsForTimes(), + checkArgument(consumerSpEL.hasOffsetsForTimes(), "Consumer.offsetsForTimes is only supported by Kafka Client 0.10.1.0 onwards, " + "current version of Kafka Client is " + AppInfoParser.getVersion() + ". If you are building with maven, set \"kafka.clients.version\" " @@ -657,8 +605,7 @@ public PCollection> expand(PBegin input) { */ @VisibleForTesting UnboundedSource, KafkaCheckpointMark> makeSource() { - - return new UnboundedKafkaSource<>(this, -1); + return new KafkaUnboundedSource<>(this, -1); } // utility method to convert KafkRecord to user KV before applying user functions @@ -692,7 +639,7 @@ UnboundedSource, KafkaCheckpointMark> makeSource() { // takes many polls before a 1MB chunk from the server is fully read. In my testing // about half of the time select() inside kafka consumer waited for 20-30ms, though // the server had lots of data in tcp send buffers on its side. Compared to default, - // this setting increased throughput increased by many fold (3-4x). + // this setting increased throughput by many fold (3-4x). ConsumerConfig.RECEIVE_BUFFER_CONFIG, 512 * 1024, @@ -791,707 +738,6 @@ private static Map updateKafkaProperties( /** Static class, prevent instantiation. */ private KafkaIO() {} - private static class UnboundedKafkaSource - extends UnboundedSource, KafkaCheckpointMark> { - private Read spec; - private final int id; // split id, mainly for debugging - - public UnboundedKafkaSource(Read spec, int id) { - this.spec = spec; - this.id = id; - } - - /** - * The partitions are evenly distributed among the splits. The number of splits returned is - * {@code min(desiredNumSplits, totalNumPartitions)}, though better not to depend on the exact - * count. - * - *

    It is important to assign the partitions deterministically so that we can support - * resuming a split from last checkpoint. The Kafka partitions are sorted by - * {@code } and then assigned to splits in round-robin order. - */ - @Override - public List> split( - int desiredNumSplits, PipelineOptions options) throws Exception { - - List partitions = new ArrayList<>(spec.getTopicPartitions()); - - // (a) fetch partitions for each topic - // (b) sort by - // (c) round-robin assign the partitions to splits - - if (partitions.isEmpty()) { - try (Consumer consumer = - spec.getConsumerFactoryFn().apply(spec.getConsumerConfig())) { - for (String topic : spec.getTopics()) { - for (PartitionInfo p : consumer.partitionsFor(topic)) { - partitions.add(new TopicPartition(p.topic(), p.partition())); - } - } - } - } - - partitions.sort( - (tp1, tp2) -> - ComparisonChain.start() - .compare(tp1.topic(), tp2.topic()) - .compare(tp1.partition(), tp2.partition()) - .result()); - - checkArgument(desiredNumSplits > 0); - checkState(partitions.size() > 0, - "Could not find any partitions. Please check Kafka configuration and topic names"); - - int numSplits = Math.min(desiredNumSplits, partitions.size()); - List> assignments = new ArrayList<>(numSplits); - - for (int i = 0; i < numSplits; i++) { - assignments.add(new ArrayList<>()); - } - for (int i = 0; i < partitions.size(); i++) { - assignments.get(i % numSplits).add(partitions.get(i)); - } - - List> result = new ArrayList<>(numSplits); - - for (int i = 0; i < numSplits; i++) { - List assignedToSplit = assignments.get(i); - - LOG.info("Partitions assigned to split {} (total {}): {}", - i, assignedToSplit.size(), Joiner.on(",").join(assignedToSplit)); - - result.add( - new UnboundedKafkaSource<>( - spec.toBuilder() - .setTopics(Collections.emptyList()) - .setTopicPartitions(assignedToSplit) - .build(), - i)); - } - - return result; - } - - @Override - public UnboundedKafkaReader createReader(PipelineOptions options, - KafkaCheckpointMark checkpointMark) { - if (spec.getTopicPartitions().isEmpty()) { - LOG.warn("Looks like generateSplits() is not called. Generate single split."); - try { - return new UnboundedKafkaReader<>(split(1, options).get(0), checkpointMark); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - return new UnboundedKafkaReader<>(this, checkpointMark); - } - - @Override - public Coder getCheckpointMarkCoder() { - return AvroCoder.of(KafkaCheckpointMark.class); - } - - @Override - public boolean requiresDeduping() { - // Kafka records are ordered with in partitions. In addition checkpoint guarantees - // records are not consumed twice. - return false; - } - - @Override - public Coder> getOutputCoder() { - return KafkaRecordCoder.of(spec.getKeyCoder(), spec.getValueCoder()); - } - } - - @VisibleForTesting - static class UnboundedKafkaReader extends UnboundedReader> { - // package private, accessed in KafkaCheckpointMark.finalizeCheckpoint. - - @VisibleForTesting - static final String METRIC_NAMESPACE = "KafkaIOReader"; - @VisibleForTesting - static final String CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC = "checkpointMarkCommitsEnqueued"; - private static final String CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC = - "checkpointMarkCommitsSkipped"; - - private final UnboundedKafkaSource source; - private final String name; - private Consumer consumer; - private final List partitionStates; - private KafkaRecord curRecord; - private Instant curTimestamp; - private Iterator curBatch = Collections.emptyIterator(); - - private Deserializer keyDeserializerInstance = null; - private Deserializer valueDeserializerInstance = null; - - private final Counter elementsRead = SourceMetrics.elementsRead(); - private final Counter bytesRead = SourceMetrics.bytesRead(); - private final Counter elementsReadBySplit; - private final Counter bytesReadBySplit; - private final Gauge backlogBytesOfSplit; - private final Gauge backlogElementsOfSplit; - private final Counter checkpointMarkCommitsEnqueued = Metrics.counter( - METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC); - // Checkpoint marks skipped in favor of newer mark (only the latest needs to be committed). - private final Counter checkpointMarkCommitsSkipped = Metrics.counter( - METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC); - - /** - * The poll timeout while reading records from Kafka. - * If option to commit reader offsets in to Kafka in - * {@link KafkaCheckpointMark#finalizeCheckpoint()} is enabled, it would be delayed until - * this poll returns. It should be reasonably low as a result. - * At the same time it probably can't be very low like 10 millis, I am not sure how it affects - * when the latency is high. Probably good to experiment. Often multiple marks would be - * finalized in a batch, it it reduce finalization overhead to wait a short while and finalize - * only the last checkpoint mark. - */ - private static final Duration KAFKA_POLL_TIMEOUT = Duration.millis(1000); - private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT = Duration.millis(10); - private static final Duration RECORDS_ENQUEUE_POLL_TIMEOUT = Duration.millis(100); - - // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including - // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout - // like 100 milliseconds does not work well. This along with large receive buffer for - // consumer achieved best throughput in tests (see `defaultConsumerProperties`). - private final ExecutorService consumerPollThread = Executors.newSingleThreadExecutor(); - private final SynchronousQueue> availableRecordsQueue = - new SynchronousQueue<>(); - private AtomicReference finalizedCheckpointMark = new AtomicReference<>(); - private AtomicBoolean closed = new AtomicBoolean(false); - - // Backlog support : - // Kafka consumer does not have an API to fetch latest offset for topic. We need to seekToEnd() - // then look at position(). Use another consumer to do this so that the primary consumer does - // not need to be interrupted. The latest offsets are fetched periodically on a thread. This is - // still a bit of a hack, but so far there haven't been any issues reported by the users. - private Consumer offsetConsumer; - private final ScheduledExecutorService offsetFetcherThread = - Executors.newSingleThreadScheduledExecutor(); - private static final int OFFSET_UPDATE_INTERVAL_SECONDS = 5; - - private static final long UNINITIALIZED_OFFSET = -1; - - //Add SpEL instance to cover the interface difference of Kafka client - private transient ConsumerSpEL consumerSpEL; - - /** watermark before any records have been read. */ - private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; - - @Override - public String toString() { - return name; - } - - // Maintains approximate average over last 1000 elements - private static class MovingAvg { - private static final int MOVING_AVG_WINDOW = 1000; - private double avg = 0; - private long numUpdates = 0; - - void update(double quantity) { - numUpdates++; - avg += (quantity - avg) / Math.min(MOVING_AVG_WINDOW, numUpdates); - } - - double get() { - return avg; - } - } - - // maintains state of each assigned partition (buffered records, consumed offset, etc) - private static class PartitionState { - private final TopicPartition topicPartition; - private long nextOffset; - private long latestOffset; - private Iterator> recordIter = Collections.emptyIterator(); - - private MovingAvg avgRecordSize = new MovingAvg(); - private MovingAvg avgOffsetGap = new MovingAvg(); // > 0 only when log compaction is enabled. - - PartitionState(TopicPartition partition, long nextOffset) { - this.topicPartition = partition; - this.nextOffset = nextOffset; - this.latestOffset = UNINITIALIZED_OFFSET; - } - - // Update consumedOffset, avgRecordSize, and avgOffsetGap - void recordConsumed(long offset, int size, long offsetGap) { - nextOffset = offset + 1; - - // This is always updated from single thread. Probably not worth making atomic. - avgRecordSize.update(size); - avgOffsetGap.update(offsetGap); - } - - synchronized void setLatestOffset(long latestOffset) { - this.latestOffset = latestOffset; - } - - synchronized long approxBacklogInBytes() { - // Note that is an an estimate of uncompressed backlog. - long backlogMessageCount = backlogMessageCount(); - if (backlogMessageCount == UnboundedReader.BACKLOG_UNKNOWN) { - return UnboundedReader.BACKLOG_UNKNOWN; - } - return (long) (backlogMessageCount * avgRecordSize.get()); - } - - synchronized long backlogMessageCount() { - if (latestOffset < 0 || nextOffset < 0) { - return UnboundedReader.BACKLOG_UNKNOWN; - } - double remaining = (latestOffset - nextOffset) / (1 + avgOffsetGap.get()); - return Math.max(0, (long) Math.ceil(remaining)); - } - } - - public UnboundedKafkaReader( - UnboundedKafkaSource source, - @Nullable KafkaCheckpointMark checkpointMark) { - this.consumerSpEL = new ConsumerSpEL(); - this.source = source; - this.name = "Reader-" + source.id; - - List partitions = source.spec.getTopicPartitions(); - partitionStates = - ImmutableList.copyOf( - partitions - .stream() - .map(tp -> new PartitionState(tp, UNINITIALIZED_OFFSET)) - .collect(Collectors.toList())); - - if (checkpointMark != null) { - // a) verify that assigned and check-pointed partitions match exactly - // b) set consumed offsets - - checkState(checkpointMark.getPartitions().size() == partitions.size(), - "checkPointMark and assignedPartitions should match"); - - for (int i = 0; i < partitions.size(); i++) { - PartitionMark ckptMark = checkpointMark.getPartitions().get(i); - TopicPartition assigned = partitions.get(i); - TopicPartition partition = new TopicPartition(ckptMark.getTopic(), - ckptMark.getPartition()); - checkState(partition.equals(assigned), - "checkpointed partition %s and assigned partition %s don't match", - partition, assigned); - - partitionStates.get(i).nextOffset = ckptMark.getNextOffset(); - } - } - - String splitId = String.valueOf(source.id); - - elementsReadBySplit = SourceMetrics.elementsReadBySplit(splitId); - bytesReadBySplit = SourceMetrics.bytesReadBySplit(splitId); - backlogBytesOfSplit = SourceMetrics.backlogBytesOfSplit(splitId); - backlogElementsOfSplit = SourceMetrics.backlogElementsOfSplit(splitId); - } - - private void consumerPollLoop() { - // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. - - ConsumerRecords records = ConsumerRecords.empty(); - while (!closed.get()) { - try { - if (records.isEmpty()) { - records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); - } else if (availableRecordsQueue.offer(records, - RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), - TimeUnit.MILLISECONDS)) { - records = ConsumerRecords.empty(); - } - KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); - if (checkpointMark != null) { - commitCheckpointMark(checkpointMark); - } - } catch (InterruptedException e) { - LOG.warn("{}: consumer thread is interrupted", this, e); // not expected - break; - } catch (WakeupException e) { - break; - } - } - - LOG.info("{}: Returning from consumer pool loop", this); - } - - private void commitCheckpointMark(KafkaCheckpointMark checkpointMark) { - consumer.commitSync( - checkpointMark - .getPartitions() - .stream() - .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) - .collect(Collectors.toMap( - p -> new TopicPartition(p.getTopic(), p.getPartition()), - p -> new OffsetAndMetadata(p.getNextOffset()) - )) - ); - } - - /** - * Enqueue checkpoint mark to be committed to Kafka. This does not block until - * it is committed. There could be a delay of up to KAFKA_POLL_TIMEOUT (1 second). - * Any checkpoint mark enqueued earlier is dropped in favor of this checkpoint mark. - * Documentation for {@link CheckpointMark#finalizeCheckpoint()} says these are finalized - * in order. Only the latest offsets need to be committed. - */ - void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { - if (finalizedCheckpointMark.getAndSet(checkpointMark) != null) { - checkpointMarkCommitsSkipped.inc(); - } - checkpointMarkCommitsEnqueued.inc(); - } - - private void nextBatch() { - curBatch = Collections.emptyIterator(); - - ConsumerRecords records; - try { - // poll available records, wait (if necessary) up to the specified timeout. - records = availableRecordsQueue.poll(RECORDS_DEQUEUE_POLL_TIMEOUT.getMillis(), - TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.warn("{}: Unexpected", this, e); - return; - } - - if (records == null) { - return; - } - - List nonEmpty = new LinkedList<>(); - - for (PartitionState p : partitionStates) { - p.recordIter = records.records(p.topicPartition).iterator(); - if (p.recordIter.hasNext()) { - nonEmpty.add(p); - } - } - - // cycle through the partitions in order to interleave records from each. - curBatch = Iterators.cycle(nonEmpty); - } - - private void setupInitialOffset(PartitionState pState) { - Read spec = source.spec; - - if (pState.nextOffset != UNINITIALIZED_OFFSET) { - consumer.seek(pState.topicPartition, pState.nextOffset); - } else { - // nextOffset is unininitialized here, meaning start reading from latest record as of now - // ('latest' is the default, and is configurable) or 'look up offset by startReadTime. - // Remember the current position without waiting until the first record is read. This - // ensures checkpoint is accurate even if the reader is closed before reading any records. - Instant startReadTime = spec.getStartReadTime(); - if (startReadTime != null) { - pState.nextOffset = - consumerSpEL.offsetForTime(consumer, pState.topicPartition, spec.getStartReadTime()); - consumer.seek(pState.topicPartition, pState.nextOffset); - } else { - pState.nextOffset = consumer.position(pState.topicPartition); - } - } - } - - @Override - public boolean start() throws IOException { - final int defaultPartitionInitTimeout = 60 * 1000; - final int kafkaRequestTimeoutMultiple = 2; - - Read spec = source.spec; - consumer = spec.getConsumerFactoryFn().apply(spec.getConsumerConfig()); - consumerSpEL.evaluateAssign(consumer, spec.getTopicPartitions()); - - try { - keyDeserializerInstance = source.spec.getKeyDeserializer().newInstance(); - valueDeserializerInstance = source.spec.getValueDeserializer().newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new IOException("Could not instantiate deserializers", e); - } - - keyDeserializerInstance.configure(spec.getConsumerConfig(), true); - valueDeserializerInstance.configure(spec.getConsumerConfig(), false); - - // Seek to start offset for each partition. This is the first interaction with the server. - // Unfortunately it can block forever in case of network issues like incorrect ACLs. - // Initialize partition in a separate thread and cancel it if takes longer than a minute. - for (final PartitionState pState : partitionStates) { - Future future = consumerPollThread.submit(() -> setupInitialOffset(pState)); - - try { - // Timeout : 1 minute OR 2 * Kafka consumer request timeout if it is set. - Integer reqTimeout = (Integer) source.spec.getConsumerConfig().get( - ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG); - future.get(reqTimeout != null ? kafkaRequestTimeoutMultiple * reqTimeout - : defaultPartitionInitTimeout, - TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - consumer.wakeup(); // This unblocks consumer stuck on network I/O. - // Likely reason : Kafka servers are configured to advertise internal ips, but - // those ips are not accessible from workers outside. - String msg = String.format( - "%s: Timeout while initializing partition '%s'. " - + "Kafka client may not be able to connect to servers.", - this, pState.topicPartition); - LOG.error("{}", msg); - throw new IOException(msg); - } catch (Exception e) { - throw new IOException(e); - } - LOG.info("{}: reading from {} starting at offset {}", - name, pState.topicPartition, pState.nextOffset); - } - - // Start consumer read loop. - // Note that consumer is not thread safe, should not be accessed out side consumerPollLoop(). - consumerPollThread.submit(this::consumerPollLoop); - - // offsetConsumer setup : - - Object groupId = spec.getConsumerConfig().get(ConsumerConfig.GROUP_ID_CONFIG); - // override group_id and disable auto_commit so that it does not interfere with main consumer - String offsetGroupId = String.format("%s_offset_consumer_%d_%s", name, - (new Random()).nextInt(Integer.MAX_VALUE), (groupId == null ? "none" : groupId)); - Map offsetConsumerConfig = new HashMap<>(spec.getConsumerConfig()); - offsetConsumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId); - offsetConsumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); - - offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig); - consumerSpEL.evaluateAssign(offsetConsumer, spec.getTopicPartitions()); - - offsetFetcherThread.scheduleAtFixedRate( - this::updateLatestOffsets, 0, OFFSET_UPDATE_INTERVAL_SECONDS, TimeUnit.SECONDS); - - nextBatch(); - return advance(); - } - - @Override - public boolean advance() throws IOException { - /* Read first record (if any). we need to loop here because : - * - (a) some records initially need to be skipped if they are before consumedOffset - * - (b) if curBatch is empty, we want to fetch next batch and then advance. - * - (c) curBatch is an iterator of iterators. we interleave the records from each. - * curBatch.next() might return an empty iterator. - */ - while (true) { - if (curBatch.hasNext()) { - PartitionState pState = curBatch.next(); - - elementsRead.inc(); - elementsReadBySplit.inc(); - - if (!pState.recordIter.hasNext()) { // -- (c) - pState.recordIter = Collections.emptyIterator(); // drop ref - curBatch.remove(); - continue; - } - - ConsumerRecord rawRecord = pState.recordIter.next(); - long expected = pState.nextOffset; - long offset = rawRecord.offset(); - - if (offset < expected) { // -- (a) - // this can happen when compression is enabled in Kafka (seems to be fixed in 0.10) - // should we check if the offset is way off from consumedOffset (say > 1M)? - LOG.warn("{}: ignoring already consumed offset {} for {}", - this, offset, pState.topicPartition); - continue; - } - - long offsetGap = offset - expected; // could be > 0 when Kafka log compaction is enabled. - - if (curRecord == null) { - LOG.info("{}: first record offset {}", name, offset); - offsetGap = 0; - } - - // Apply user deserializers. User deserializers might throw, which will be propagated up - // and 'curRecord' remains unchanged. The runner should close this reader. - // TODO: write records that can't be deserialized to a "dead-letter" additional output. - KafkaRecord record = new KafkaRecord<>( - rawRecord.topic(), - rawRecord.partition(), - rawRecord.offset(), - consumerSpEL.getRecordTimestamp(rawRecord), - keyDeserializerInstance.deserialize(rawRecord.topic(), rawRecord.key()), - valueDeserializerInstance.deserialize(rawRecord.topic(), rawRecord.value())); - - curTimestamp = (source.spec.getTimestampFn() == null) - ? Instant.now() : source.spec.getTimestampFn().apply(record); - curRecord = record; - - int recordSize = (rawRecord.key() == null ? 0 : rawRecord.key().length) - + (rawRecord.value() == null ? 0 : rawRecord.value().length); - pState.recordConsumed(offset, recordSize, offsetGap); - bytesRead.inc(recordSize); - bytesReadBySplit.inc(recordSize); - return true; - - } else { // -- (b) - nextBatch(); - - if (!curBatch.hasNext()) { - return false; - } - } - } - } - - // update latest offset for each partition. - // called from offsetFetcher thread - private void updateLatestOffsets() { - for (PartitionState p : partitionStates) { - try { - // If "read_committed" is enabled in the config, this seeks to 'Last Stable Offset'. - // As a result uncommitted messages are not counted in backlog. It is correct since - // the reader can not read them anyway. - consumerSpEL.evaluateSeek2End(offsetConsumer, p.topicPartition); - long offset = offsetConsumer.position(p.topicPartition); - p.setLatestOffset(offset); - } catch (Exception e) { - // An exception is expected if we've closed the reader in another thread. Ignore and exit. - if (closed.get()) { - break; - } - LOG.warn("{}: exception while fetching latest offset for partition {}. will be retried.", - this, p.topicPartition, e); - p.setLatestOffset(UNINITIALIZED_OFFSET); // reset - } - - LOG.debug("{}: latest offset update for {} : {} (consumer offset {}, avg record size {})", - this, p.topicPartition, p.latestOffset, p.nextOffset, p.avgRecordSize); - } - - LOG.debug("{}: backlog {}", this, getSplitBacklogBytes()); - } - - private void reportBacklog() { - long splitBacklogBytes = getSplitBacklogBytes(); - if (splitBacklogBytes < 0) { - splitBacklogBytes = UnboundedReader.BACKLOG_UNKNOWN; - } - backlogBytesOfSplit.set(splitBacklogBytes); - long splitBacklogMessages = getSplitBacklogMessageCount(); - if (splitBacklogMessages < 0) { - splitBacklogMessages = UnboundedReader.BACKLOG_UNKNOWN; - } - backlogElementsOfSplit.set(splitBacklogMessages); - } - - @Override - public Instant getWatermark() { - if (curRecord == null) { - LOG.debug("{}: getWatermark() : no records have been read yet.", name); - return initialWatermark; - } - - return source.spec.getWatermarkFn() != null - ? source.spec.getWatermarkFn().apply(curRecord) : curTimestamp; - } - - @Override - public CheckpointMark getCheckpointMark() { - reportBacklog(); - return new KafkaCheckpointMark( - partitionStates.stream() - .map((p) -> new PartitionMark(p.topicPartition.topic(), - p.topicPartition.partition(), - p.nextOffset)) - .collect(Collectors.toList()), - source.spec.isCommitOffsetsInFinalizeEnabled() ? this : null - ); - } - - @Override - public UnboundedSource, ?> getCurrentSource() { - return source; - } - - @Override - public KafkaRecord getCurrent() throws NoSuchElementException { - // should we delay updating consumed offset till this point? Mostly not required. - return curRecord; - } - - @Override - public Instant getCurrentTimestamp() throws NoSuchElementException { - return curTimestamp; - } - - @Override - public long getSplitBacklogBytes() { - long backlogBytes = 0; - - for (PartitionState p : partitionStates) { - long pBacklog = p.approxBacklogInBytes(); - if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { - return UnboundedReader.BACKLOG_UNKNOWN; - } - backlogBytes += pBacklog; - } - - return backlogBytes; - } - - private long getSplitBacklogMessageCount() { - long backlogCount = 0; - - for (PartitionState p : partitionStates) { - long pBacklog = p.backlogMessageCount(); - if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { - return UnboundedReader.BACKLOG_UNKNOWN; - } - backlogCount += pBacklog; - } - - return backlogCount; - } - - @Override - public void close() throws IOException { - closed.set(true); - consumerPollThread.shutdown(); - offsetFetcherThread.shutdown(); - - boolean isShutdown = false; - - // Wait for threads to shutdown. Trying this as a loop to handle a tiny race where poll thread - // might block to enqueue right after availableRecordsQueue.poll() below. - while (!isShutdown) { - - if (consumer != null) { - consumer.wakeup(); - } - if (offsetConsumer != null) { - offsetConsumer.wakeup(); - } - availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. - try { - isShutdown = consumerPollThread.awaitTermination(10, TimeUnit.SECONDS) - && offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); // not expected - } - - if (!isShutdown) { - LOG.warn("An internal thread is taking a long time to shutdown. will retry."); - } - } - - Closeables.close(keyDeserializerInstance, true); - Closeables.close(valueDeserializerInstance, true); - - Closeables.close(offsetConsumer, true); - Closeables.close(consumer, true); - } - } //////////////////////// Sink Support \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ @@ -1624,7 +870,7 @@ public Write withProducerFactoryFn( * be written by the same job. */ public Write withEOS(int numShards, String sinkGroupId) { - EOSWrite.ensureEOSSupport(); + KafkaExactlyOnceSink.ensureEOSSupport(); checkArgument(numShards >= 1, "numShards should be >= 1"); checkArgument(sinkGroupId != null, "sinkGroupId is required for exactly-once sink"); return toBuilder() @@ -1669,7 +915,7 @@ public PDone expand(PCollection> input) { checkArgument(getValueSerializer() != null, "withValueSerializer() is required"); if (isEOS()) { - EOSWrite.ensureEOSSupport(); + KafkaExactlyOnceSink.ensureEOSSupport(); // TODO: Verify that the group_id does not have existing state stored on Kafka unless // this is an upgrade. This avoids issues with simple mistake of reusing group_id @@ -1677,7 +923,7 @@ public PDone expand(PCollection> input) { // transform initializes while processing the output. It might be better to // check here to catch common mistake. - input.apply(new EOSWrite<>(this)); + input.apply(new KafkaExactlyOnceSink<>(this)); } else { input.apply(ParDo.of(new KafkaWriter<>(this))); } @@ -1776,97 +1022,6 @@ public T decode(InputStream inStream) { } } - private static class KafkaWriter extends DoFn, Void> { - - @Setup - public void setup() { - if (spec.getProducerFactoryFn() != null) { - producer = spec.getProducerFactoryFn().apply(producerConfig); - } else { - producer = new KafkaProducer<>(producerConfig); - } - } - - @ProcessElement - public void processElement(ProcessContext ctx) throws Exception { - checkForFailures(); - - KV kv = ctx.element(); - producer.send( - new ProducerRecord<>(spec.getTopic(), kv.getKey(), kv.getValue()), new SendCallback()); - - elementsWritten.inc(); - } - - @FinishBundle - public void finishBundle() throws IOException { - producer.flush(); - checkForFailures(); - } - - @Teardown - public void teardown() { - producer.close(); - } - - /////////////////////////////////////////////////////////////////////////////////// - - private final Write spec; - private final Map producerConfig; - - private transient Producer producer = null; - //private transient Callback sendCallback = new SendCallback(); - // first exception and number of failures since last invocation of checkForFailures(): - private transient Exception sendException = null; - private transient long numSendFailures = 0; - - private final Counter elementsWritten = SinkMetrics.elementsWritten(); - - KafkaWriter(Write spec) { - this.spec = spec; - - this.producerConfig = new HashMap<>(spec.getProducerConfig()); - - this.producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, - spec.getKeySerializer()); - this.producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, - spec.getValueSerializer()); - } - - private synchronized void checkForFailures() throws IOException { - if (numSendFailures == 0) { - return; - } - - String msg = String.format( - "KafkaWriter : failed to send %d records (since last report)", numSendFailures); - - Exception e = sendException; - sendException = null; - numSendFailures = 0; - - LOG.warn(msg); - throw new IOException(msg, e); - } - - private class SendCallback implements Callback { - @Override - public void onCompletion(RecordMetadata metadata, Exception exception) { - if (exception == null) { - return; - } - - synchronized (KafkaWriter.this) { - if (sendException == null) { - sendException = exception; - } - numSendFailures++; - } - // don't log exception stacktrace here, exception will be propagated up. - LOG.warn("KafkaWriter send failed : '{}'", exception.getMessage()); - } - } - } /** * Attempt to infer a {@link Coder} by extracting the type of the deserialized-class from the @@ -1906,565 +1061,4 @@ static NullableCoder inferCoder( throw new RuntimeException(String.format( "Could not extract the Kafka Deserializer type from %s", deserializer)); } - - ////////////////////////////////// Exactly-Once Sink \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - - /** - * Exactly-once sink transform. - */ - private static class EOSWrite extends PTransform>, PCollection> { - // - // Dataflow ensures at-least once processing for side effects like sinks. In order to provide - // exactly-once semantics, a sink needs to be idempotent or it should avoid writing records - // that have already been written. This snk does the latter. All the the records are ordered - // across a fixed number of shards and records in each shard are written in order. It drops - // any records that are already written and buffers those arriving out of order. - // - // Exactly once sink involves two shuffles of the records: - // A : Assign a shard ---> B : Assign sequential ID ---> C : Write to Kafka in order - // - // Processing guarantees also require deterministic processing within user transforms. - // Here, that requires order of the records committed to Kafka by C should not be affected by - // restarts in C and its upstream stages. - // - // A : Assigns a random shard for message. Note that there are no ordering guarantees for - // writing user records to Kafka. User can still control partitioning among topic - // partitions as with regular sink (of course, there are no ordering guarantees in - // regular Kafka sink either). - // B : Assigns an id sequentially for each messages within a shard. - // C : Writes each shard to Kafka in sequential id order. In Dataflow, when C sees a record - // and id, it implies that record and the associated id are checkpointed to persistent - // storage and this record will always have same id, even in retries. - // Exactly-once semantics are achieved by writing records in the strict order of - // these check-pointed sequence ids. - // - // Parallelism for B and C is fixed to 'numShards', which defaults to number of partitions - // for the topic. A few reasons for that: - // - B & C implement their functionality using per-key state. Shard id makes it independent - // of cardinality of user key. - // - We create one producer per shard, and its 'transactional id' is based on shard id. This - // requires that number of shards to be finite. This also helps with batching. and avoids - // initializing producers and transactions. - // - Most importantly, each of sharded writers stores 'next message id' in partition - // metadata, which is committed atomically with Kafka transactions. This is critical - // to handle retries of C correctly. Initial testing showed number of shards could be - // larger than number of partitions for the topic. - // - // Number of shards can change across multiple runs of a pipeline (job upgrade in Dataflow). - // - - private final Write spec; - - static void ensureEOSSupport() { - checkArgument( - ProducerSpEL.supportsTransactions(), "%s %s", - "This version of Kafka client does not support transactions required to support", - "exactly-once semantics. Please use Kafka client version 0.11 or newer."); - } - - EOSWrite(Write spec) { - this.spec = spec; - } - - @Override - public PCollection expand(PCollection> input) { - - int numShards = spec.getNumShards(); - if (numShards <= 0) { - try (Consumer consumer = openConsumer(spec)) { - numShards = consumer.partitionsFor(spec.getTopic()).size(); - LOG.info("Using {} shards for exactly-once writer, matching number of partitions " - + "for topic '{}'", numShards, spec.getTopic()); - } - } - checkState(numShards > 0, "Could not set number of shards"); - - return input - .apply( - Window.>into(new GlobalWindows()) // Everything into global window. - .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) - .discardingFiredPanes()) - .apply( - String.format("Shuffle across %d shards", numShards), - ParDo.of(new EOSReshard<>(numShards))) - .apply("Persist sharding", GroupByKey.create()) - .apply("Assign sequential ids", ParDo.of(new EOSSequencer<>())) - .apply("Persist ids", GroupByKey.create()) - .apply( - String.format("Write to Kafka topic '%s'", spec.getTopic()), - ParDo.of(new KafkaEOWriter<>(spec, input.getCoder()))); - } - } - - /** - * Shuffle messages assigning each randomly to a shard. - */ - private static class EOSReshard extends DoFn, KV>> { - private final int numShards; - private transient int shardId; - - EOSReshard(int numShards) { - this.numShards = numShards; - } - - @Setup - public void setup() { - shardId = ThreadLocalRandom.current().nextInt(numShards); - } - - @ProcessElement - public void processElement(ProcessContext ctx) { - shardId = (shardId + 1) % numShards; // round-robin among shards. - ctx.output(KV.of(shardId, ctx.element())); - } - } - - private static class EOSSequencer - extends DoFn>>, KV>>> { - private static final String NEXT_ID = "nextId"; - @StateId(NEXT_ID) - private final StateSpec> nextIdSpec = StateSpecs.value(); - - @ProcessElement - public void processElement(@StateId(NEXT_ID) ValueState nextIdState, ProcessContext ctx) { - long nextId = MoreObjects.firstNonNull(nextIdState.read(), 0L); - int shard = ctx.element().getKey(); - for (KV value : ctx.element().getValue()) { - ctx.output(KV.of(shard, KV.of(nextId, value))); - nextId++; - } - nextIdState.write(nextId); - } - } - - private static class KafkaEOWriter - extends DoFn>>>, Void> { - - private static final String NEXT_ID = "nextId"; - private static final String MIN_BUFFERED_ID = "minBufferedId"; - private static final String OUT_OF_ORDER_BUFFER = "outOfOrderBuffer"; - private static final String WRITER_ID = "writerId"; - - private static final String METRIC_NAMESPACE = "KafkaEOSink"; - - // Not sure of a good limit. This applies only for large bundles. - private static final int MAX_RECORDS_PER_TXN = 1000; - private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); - - @StateId(NEXT_ID) - private final StateSpec> sequenceIdSpec = StateSpecs.value(); - @StateId(MIN_BUFFERED_ID) - private final StateSpec> minBufferedId = StateSpecs.value(); - @StateId(OUT_OF_ORDER_BUFFER) - private final StateSpec>>> outOfOrderBuffer; - // A random id assigned to each shard. Helps with detecting when multiple jobs are mistakenly - // started with same groupId used for storing state on Kafka side including the case where - // a job is restarted with same groupId, but the metadata from previous run is not removed. - // Better to be safe and error out with a clear message. - @StateId(WRITER_ID) - private final StateSpec> writerIdSpec = StateSpecs.value(); - - private final Write spec; - - // Metrics - private final Counter elementsWritten = SinkMetrics.elementsWritten(); - // Elements buffered due to out of order arrivals. - private final Counter elementsBuffered = Metrics.counter(METRIC_NAMESPACE, "elementsBuffered"); - private final Counter numTransactions = Metrics.counter(METRIC_NAMESPACE, "numTransactions"); - - KafkaEOWriter(Write spec, Coder> elemCoder) { - this.spec = spec; - this.outOfOrderBuffer = StateSpecs.bag(KvCoder.of(BigEndianLongCoder.of(), elemCoder)); - } - - @Setup - public void setup() { - // This is on the worker. Ensure the runtime version is till compatible. - EOSWrite.ensureEOSSupport(); - } - - @ProcessElement - public void processElement(@StateId(NEXT_ID) ValueState nextIdState, - @StateId(MIN_BUFFERED_ID) ValueState minBufferedIdState, - @StateId(OUT_OF_ORDER_BUFFER) - BagState>> oooBufferState, - @StateId(WRITER_ID) ValueState writerIdState, - ProcessContext ctx) - throws IOException { - - int shard = ctx.element().getKey(); - - minBufferedIdState.readLater(); - long nextId = MoreObjects.firstNonNull(nextIdState.read(), 0L); - long minBufferedId = MoreObjects.firstNonNull(minBufferedIdState.read(), Long.MAX_VALUE); - - ShardWriterCache cache = - (ShardWriterCache) CACHE_BY_GROUP_ID.getUnchecked(spec.getSinkGroupId()); - ShardWriter writer = cache.removeIfPresent(shard); - if (writer == null) { - writer = initShardWriter(shard, writerIdState, nextId); - } - - long committedId = writer.committedId; - - if (committedId >= nextId) { - // This is a retry of an already committed batch. - LOG.info("{}: committed id {} is ahead of expected {}. {} records will be dropped " - + "(these are already written).", - shard, committedId, nextId - 1, committedId - nextId + 1); - nextId = committedId + 1; - } - - try { - writer.beginTxn(); - int txnSize = 0; - - // Iterate in recordId order. The input iterator could be mostly sorted. - // There might be out of order messages buffered in earlier iterations. These - // will get merged if and when minBufferedId matches nextId. - - Iterator>> iter = ctx.element().getValue().iterator(); - - while (iter.hasNext()) { - KV> kv = iter.next(); - long recordId = kv.getKey(); - - if (recordId < nextId) { - LOG.info("{}: dropping older record {}. Already committed till {}", - shard, recordId, committedId); - continue; - } - - if (recordId > nextId) { - // Out of order delivery. Should be pretty rare (what about in a batch pipeline?) - - LOG.info("{}: Saving out of order record {}, next record id to be written is {}", - shard, recordId, nextId); - - // checkState(recordId - nextId < 10000, "records are way out of order"); - - oooBufferState.add(kv); - minBufferedId = Math.min(minBufferedId, recordId); - minBufferedIdState.write(minBufferedId); - elementsBuffered.inc(); - continue; - } - - // recordId and nextId match. Finally write record. - - writer.sendRecord(kv.getValue(), elementsWritten); - nextId++; - - if (++txnSize >= MAX_RECORDS_PER_TXN) { - writer.commitTxn(recordId, numTransactions); - txnSize = 0; - writer.beginTxn(); - } - - if (minBufferedId == nextId) { - // One or more of the buffered records can be committed now. - // Read all of them in to memory and sort them. Reading into memory - // might be problematic in extreme cases. Might need to improve it in future. - - List>> buffered = Lists.newArrayList(oooBufferState.read()); - buffered.sort(new KV.OrderByKey<>()); - - LOG.info("{} : merging {} buffered records (min buffered id is {}).", - shard, buffered.size(), minBufferedId); - - oooBufferState.clear(); - minBufferedIdState.clear(); - minBufferedId = Long.MAX_VALUE; - - iter = - Iterators.mergeSorted( - ImmutableList.of(iter, buffered.iterator()), new KV.OrderByKey<>()); - } - } - - writer.commitTxn(nextId - 1, numTransactions); - nextIdState.write(nextId); - - } catch (ProducerSpEL.UnrecoverableProducerException e) { - // Producer JavaDoc says these are not recoverable errors and producer should be closed. - - // Close the producer and a new producer will be initialized in retry. - // It is possible that a rough worker keeps retrying and ends up fencing off - // active producers. How likely this might be or how well such a scenario is handled - // depends on the runner. For now we will leave it to upper layers, will need to revisit. - - LOG.warn("{} : closing producer {} after unrecoverable error. The work might have migrated." - + " Committed id {}, current id {}.", - writer.shard, writer.producerName, writer.committedId, nextId - 1, e); - - writer.producer.close(); - writer = null; // No need to cache it. - throw e; - } finally { - if (writer != null) { - cache.insert(shard, writer); - } - } - } - - private static class ShardMetadata { - - @JsonProperty("seq") - public final long sequenceId; - @JsonProperty("id") - public final String writerId; - - private ShardMetadata() { // for json deserializer - sequenceId = -1; - writerId = null; - } - - ShardMetadata(long sequenceId, String writerId) { - this.sequenceId = sequenceId; - this.writerId = writerId; - } - } - - /** - * A wrapper around Kafka producer. One for each of the shards. - */ - private static class ShardWriter { - - private final int shard; - private final String writerId; - private final Producer producer; - private final String producerName; - private final Write spec; - private long committedId; - - ShardWriter(int shard, - String writerId, - Producer producer, - String producerName, - Write spec, - long committedId) { - this.shard = shard; - this.writerId = writerId; - this.producer = producer; - this.producerName = producerName; - this.spec = spec; - this.committedId = committedId; - } - - void beginTxn() { - ProducerSpEL.beginTransaction(producer); - } - - void sendRecord(KV record, Counter sendCounter) { - try { - producer.send( - new ProducerRecord<>(spec.getTopic(), record.getKey(), record.getValue())); - sendCounter.inc(); - } catch (KafkaException e) { - ProducerSpEL.abortTransaction(producer); - throw e; - } - } - - void commitTxn(long lastRecordId, Counter numTransactions) throws IOException { - try { - // Store id in consumer group metadata for the partition. - // NOTE: Kafka keeps this metadata for 24 hours since the last update. This limits - // how long the pipeline could be down before resuming it. It does not look like - // this TTL can be adjusted (asked about it on Kafka users list). - ProducerSpEL.sendOffsetsToTransaction( - producer, - ImmutableMap.of(new TopicPartition(spec.getTopic(), shard), - new OffsetAndMetadata(0L, - JSON_MAPPER.writeValueAsString( - new ShardMetadata(lastRecordId, writerId)))), - spec.getSinkGroupId()); - ProducerSpEL.commitTransaction(producer); - - numTransactions.inc(); - LOG.debug("{} : committed {} records", shard, lastRecordId - committedId); - - committedId = lastRecordId; - } catch (KafkaException e) { - ProducerSpEL.abortTransaction(producer); - throw e; - } - } - } - - private ShardWriter initShardWriter(int shard, - ValueState writerIdState, - long nextId) throws IOException { - - String producerName = String.format("producer_%d_for_%s", shard, spec.getSinkGroupId()); - Producer producer = initializeEosProducer(spec, producerName); - - // Fetch latest committed metadata for the partition (if any). Checks committed sequence ids. - try { - - String writerId = writerIdState.read(); - - OffsetAndMetadata committed; - - try (Consumer consumer = openConsumer(spec)) { - committed = consumer.committed(new TopicPartition(spec.getTopic(), shard)); - } - - long committedSeqId = -1; - - if (committed == null || committed.metadata() == null || committed.metadata().isEmpty()) { - checkState(nextId == 0 && writerId == null, - "State exists for shard %s (nextId %s, writerId '%s'), but there is no state " - + "stored with Kafka topic '%s' group id '%s'", - shard, nextId, writerId, spec.getTopic(), spec.getSinkGroupId()); - - writerId = String.format("%X - %s", - new Random().nextInt(Integer.MAX_VALUE), - DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss") - .withZone(DateTimeZone.UTC) - .print(DateTimeUtils.currentTimeMillis())); - writerIdState.write(writerId); - LOG.info("Assigned writer id '{}' to shard {}", writerId, shard); - - } else { - ShardMetadata metadata = JSON_MAPPER.readValue(committed.metadata(), - ShardMetadata.class); - - checkNotNull(metadata.writerId); - - if (writerId == null) { - // a) This might be a restart of the job from scratch, in which case metatdata - // should be ignored and overwritten with new one. - // b) This job might be started with an incorrect group id which is an error. - // c) There is an extremely small chance that this is a retry of the first bundle - // where metatdate was committed to Kafka but the bundle results were not committed - // in Beam, in which case it should be treated as correct metadata. - // How can we tell these three cases apart? Be safe and throw an exception. - // - // We could let users explicitly an option to override the existing metadata. - // - throw new IllegalStateException(String.format( - "Kafka metadata exists for shard %s, but there is no stored state for it. " - + "This mostly indicates groupId '%s' is used else where or in earlier runs. " - + "Try another group id. Metadata for this shard on Kafka : '%s'", - shard, spec.getSinkGroupId(), committed.metadata())); - } - - checkState(writerId.equals(metadata.writerId), - "Writer ids don't match. This is mostly a unintended misuse of groupId('%s')." - + "Beam '%s', Kafka '%s'", - spec.getSinkGroupId(), writerId, metadata.writerId); - - committedSeqId = metadata.sequenceId; - - checkState(committedSeqId >= (nextId - 1), - "Committed sequence id can not be lower than %s, partition metadata : %s", - nextId - 1, committed.metadata()); - } - - LOG.info("{} : initialized producer {} with committed sequence id {}", - shard, producerName, committedSeqId); - - return new ShardWriter<>(shard, writerId, producer, producerName, spec, committedSeqId); - - } catch (Exception e) { - producer.close(); - throw e; - } - } - - /** - * A wrapper around guava cache to provide insert()/remove() semantics. A ShardWriter will - * be closed if it is stays in cache for more than 1 minute, i.e. not used inside EOSWrite - * DoFn for a minute or more. - */ - private static class ShardWriterCache { - - static final ScheduledExecutorService SCHEDULED_CLEAN_UP_THREAD = - Executors.newSingleThreadScheduledExecutor(); - - static final int CLEAN_UP_CHECK_INTERVAL_MS = 10 * 1000; - static final int IDLE_TIMEOUT_MS = 60 * 1000; - - private final Cache> cache; - - ShardWriterCache() { - this.cache = - CacheBuilder.newBuilder() - .expireAfterWrite(IDLE_TIMEOUT_MS, TimeUnit.MILLISECONDS) - .>removalListener( - notification -> { - if (notification.getCause() != RemovalCause.EXPLICIT) { - ShardWriter writer = notification.getValue(); - LOG.info( - "{} : Closing idle shard writer {} after 1 minute of idle time.", - writer.shard, - writer.producerName); - writer.producer.close(); - } - }) - .build(); - - // run cache.cleanUp() every 10 seconds. - SCHEDULED_CLEAN_UP_THREAD.scheduleAtFixedRate( - cache::cleanUp, - CLEAN_UP_CHECK_INTERVAL_MS, - CLEAN_UP_CHECK_INTERVAL_MS, - TimeUnit.MILLISECONDS); - } - - ShardWriter removeIfPresent(int shard) { - return cache.asMap().remove(shard); - } - - void insert(int shard, ShardWriter writer) { - ShardWriter existing = cache.asMap().putIfAbsent(shard, writer); - checkState(existing == null, - "Unexpected multiple instances of writers for shard %s", shard); - } - } - - // One cache for each sink (usually there is only one sink per pipeline) - private static final LoadingCache> CACHE_BY_GROUP_ID = - CacheBuilder.newBuilder() - .build(new CacheLoader>() { - @Override - public ShardWriterCache load(String key) throws Exception { - return new ShardWriterCache<>(); - } - }); - } - - /** - * Opens a generic consumer that is mainly meant for metadata operations like fetching - * number of partitions for a topic rather than for fetching messages. - */ - private static Consumer openConsumer(Write spec) { - return spec.getConsumerFactoryFn().apply((ImmutableMap.of( - ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, spec - .getProducerConfig().get(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG), - ConsumerConfig.GROUP_ID_CONFIG, spec.getSinkGroupId(), - ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class, - ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class - ))); - } - - private static Producer initializeEosProducer(Write spec, - String producerName) { - - Map producerConfig = new HashMap<>(spec.getProducerConfig()); - producerConfig.putAll(ImmutableMap.of( - ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, spec.getKeySerializer(), - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, spec.getValueSerializer(), - ProducerSpEL.ENABLE_IDEMPOTENCE_CONFIG, true, - ProducerSpEL.TRANSACTIONAL_ID_CONFIG, producerName)); - - Producer producer = - spec.getProducerFactoryFn() != null - ? spec.getProducerFactoryFn().apply((producerConfig)) - : new KafkaProducer<>(producerConfig); - - ProducerSpEL.initTransactions(producer); - return producer; - } } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java index 9410445e2726..577fdee66efc 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java @@ -22,7 +22,6 @@ import java.io.OutputStream; import java.util.List; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.StructuredCoder; @@ -50,37 +49,22 @@ public KafkaRecordCoder(Coder keyCoder, Coder valueCoder) { } @Override - public void encode(KafkaRecord value, OutputStream outStream) - throws CoderException, IOException { - encode(value, outStream, Context.NESTED); + public void encode(KafkaRecord value, OutputStream outStream) throws IOException { + stringCoder.encode(value.getTopic(), outStream); + intCoder.encode(value.getPartition(), outStream); + longCoder.encode(value.getOffset(), outStream); + longCoder.encode(value.getTimestamp(), outStream); + kvCoder.encode(value.getKV(), outStream); } @Override - public void encode(KafkaRecord value, OutputStream outStream, Context context) - throws CoderException, IOException { - Context nested = context.nested(); - stringCoder.encode(value.getTopic(), outStream, nested); - intCoder.encode(value.getPartition(), outStream, nested); - longCoder.encode(value.getOffset(), outStream, nested); - longCoder.encode(value.getTimestamp(), outStream, nested); - kvCoder.encode(value.getKV(), outStream, context); - } - - @Override - public KafkaRecord decode(InputStream inStream) throws CoderException, IOException { - return decode(inStream, Context.NESTED); - } - - @Override - public KafkaRecord decode(InputStream inStream, Context context) - throws CoderException, IOException { - Context nested = context.nested(); + public KafkaRecord decode(InputStream inStream) throws IOException { return new KafkaRecord<>( - stringCoder.decode(inStream, nested), - intCoder.decode(inStream, nested), - longCoder.decode(inStream, nested), - longCoder.decode(inStream, nested), - kvCoder.decode(inStream, context)); + stringCoder.decode(inStream), + intCoder.decode(inStream), + longCoder.decode(inStream), + longCoder.decode(inStream), + kvCoder.decode(inStream)); } @Override diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java new file mode 100644 index 000000000000..e830b4ce06d3 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -0,0 +1,663 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; +import com.google.common.io.Closeables; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; +import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.io.kafka.KafkaCheckpointMark.PartitionMark; +import org.apache.beam.sdk.io.kafka.KafkaIO.Read; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Gauge; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.metrics.SourceMetrics; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.serialization.Deserializer; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An unbounded reader to read from Kafka. Each reader consumes messages from one or more Kafka + * partitions. See {@link KafkaIO} for user visible documentation and example usage. + */ +class KafkaUnboundedReader extends UnboundedReader> { + + ///////////////////// Reader API //////////////////////////////////////////////////////////// + @Override + public boolean start() throws IOException { + final int defaultPartitionInitTimeout = 60 * 1000; + final int kafkaRequestTimeoutMultiple = 2; + + Read spec = source.getSpec(); + consumer = spec.getConsumerFactoryFn().apply(spec.getConsumerConfig()); + consumerSpEL.evaluateAssign(consumer, spec.getTopicPartitions()); + + try { + keyDeserializerInstance = spec.getKeyDeserializer().newInstance(); + valueDeserializerInstance = spec.getValueDeserializer().newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new IOException("Could not instantiate deserializers", e); + } + + keyDeserializerInstance.configure(spec.getConsumerConfig(), true); + valueDeserializerInstance.configure(spec.getConsumerConfig(), false); + + // Seek to start offset for each partition. This is the first interaction with the server. + // Unfortunately it can block forever in case of network issues like incorrect ACLs. + // Initialize partition in a separate thread and cancel it if takes longer than a minute. + for (final PartitionState pState : partitionStates) { + Future future = consumerPollThread.submit(() -> setupInitialOffset(pState)); + + try { + // Timeout : 1 minute OR 2 * Kafka consumer request timeout if it is set. + Integer reqTimeout = (Integer) spec.getConsumerConfig().get( + ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG); + future.get(reqTimeout != null ? kafkaRequestTimeoutMultiple * reqTimeout + : defaultPartitionInitTimeout, + TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + consumer.wakeup(); // This unblocks consumer stuck on network I/O. + // Likely reason : Kafka servers are configured to advertise internal ips, but + // those ips are not accessible from workers outside. + String msg = String.format( + "%s: Timeout while initializing partition '%s'. " + + "Kafka client may not be able to connect to servers.", + this, pState.topicPartition); + LOG.error("{}", msg); + throw new IOException(msg); + } catch (Exception e) { + throw new IOException(e); + } + LOG.info("{}: reading from {} starting at offset {}", + name, pState.topicPartition, pState.nextOffset); + } + + // Start consumer read loop. + // Note that consumer is not thread safe, should not be accessed out side consumerPollLoop(). + consumerPollThread.submit(this::consumerPollLoop); + + // offsetConsumer setup : + + Object groupId = spec.getConsumerConfig().get(ConsumerConfig.GROUP_ID_CONFIG); + // override group_id and disable auto_commit so that it does not interfere with main consumer + String offsetGroupId = String.format("%s_offset_consumer_%d_%s", name, + (new Random()).nextInt(Integer.MAX_VALUE), (groupId == null ? "none" : groupId)); + Map offsetConsumerConfig = new HashMap<>(spec.getConsumerConfig()); + offsetConsumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId); + offsetConsumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + + offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig); + consumerSpEL.evaluateAssign(offsetConsumer, spec.getTopicPartitions()); + + offsetFetcherThread.scheduleAtFixedRate( + this::updateLatestOffsets, 0, OFFSET_UPDATE_INTERVAL_SECONDS, TimeUnit.SECONDS); + + nextBatch(); + return advance(); + } + + @Override + public boolean advance() throws IOException { + /* Read first record (if any). we need to loop here because : + * - (a) some records initially need to be skipped if they are before consumedOffset + * - (b) if curBatch is empty, we want to fetch next batch and then advance. + * - (c) curBatch is an iterator of iterators. we interleave the records from each. + * curBatch.next() might return an empty iterator. + */ + while (true) { + if (curBatch.hasNext()) { + PartitionState pState = curBatch.next(); + + elementsRead.inc(); + elementsReadBySplit.inc(); + + if (!pState.recordIter.hasNext()) { // -- (c) + pState.recordIter = Collections.emptyIterator(); // drop ref + curBatch.remove(); + continue; + } + + ConsumerRecord rawRecord = pState.recordIter.next(); + long expected = pState.nextOffset; + long offset = rawRecord.offset(); + + if (offset < expected) { // -- (a) + // this can happen when compression is enabled in Kafka (seems to be fixed in 0.10) + // should we check if the offset is way off from consumedOffset (say > 1M)? + LOG.warn("{}: ignoring already consumed offset {} for {}", + this, offset, pState.topicPartition); + continue; + } + + long offsetGap = offset - expected; // could be > 0 when Kafka log compaction is enabled. + + if (curRecord == null) { + LOG.info("{}: first record offset {}", name, offset); + offsetGap = 0; + } + + // Apply user deserializers. User deserializers might throw, which will be propagated up + // and 'curRecord' remains unchanged. The runner should close this reader. + // TODO: write records that can't be deserialized to a "dead-letter" additional output. + KafkaRecord record = new KafkaRecord<>( + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset(), + consumerSpEL.getRecordTimestamp(rawRecord), + keyDeserializerInstance.deserialize(rawRecord.topic(), rawRecord.key()), + valueDeserializerInstance.deserialize(rawRecord.topic(), rawRecord.value())); + + curTimestamp = (source.getSpec().getTimestampFn() == null) + ? Instant.now() : source.getSpec().getTimestampFn().apply(record); + curRecord = record; + + int recordSize = (rawRecord.key() == null ? 0 : rawRecord.key().length) + + (rawRecord.value() == null ? 0 : rawRecord.value().length); + pState.recordConsumed(offset, recordSize, offsetGap); + bytesRead.inc(recordSize); + bytesReadBySplit.inc(recordSize); + return true; + + } else { // -- (b) + nextBatch(); + + if (!curBatch.hasNext()) { + return false; + } + } + } + } + + @Override + public Instant getWatermark() { + if (curRecord == null) { + LOG.debug("{}: getWatermark() : no records have been read yet.", name); + return initialWatermark; + } + + return source.getSpec().getWatermarkFn() != null + ? source.getSpec().getWatermarkFn().apply(curRecord) : curTimestamp; + } + + @Override + public CheckpointMark getCheckpointMark() { + reportBacklog(); + return new KafkaCheckpointMark( + partitionStates.stream() + .map((p) -> new PartitionMark(p.topicPartition.topic(), + p.topicPartition.partition(), + p.nextOffset)) + .collect(Collectors.toList()), + source.getSpec().isCommitOffsetsInFinalizeEnabled() ? this : null + ); + } + + @Override + public UnboundedSource, ?> getCurrentSource() { + return source; + } + + @Override + public KafkaRecord getCurrent() throws NoSuchElementException { + // should we delay updating consumed offset till this point? Mostly not required. + return curRecord; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return curTimestamp; + } + + @Override + public long getSplitBacklogBytes() { + long backlogBytes = 0; + + for (PartitionState p : partitionStates) { + long pBacklog = p.approxBacklogInBytes(); + if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { + return UnboundedReader.BACKLOG_UNKNOWN; + } + backlogBytes += pBacklog; + } + + return backlogBytes; + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + private static final Logger LOG = LoggerFactory.getLogger(KafkaUnboundedSource.class); + + @VisibleForTesting + static final String METRIC_NAMESPACE = "KafkaIOReader"; + @VisibleForTesting + static final String CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC = "checkpointMarkCommitsEnqueued"; + private static final String CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC = + "checkpointMarkCommitsSkipped"; + + private final KafkaUnboundedSource source; + private final String name; + private Consumer consumer; + private final List partitionStates; + private KafkaRecord curRecord; + private Instant curTimestamp; + private Iterator curBatch = Collections.emptyIterator(); + + private Deserializer keyDeserializerInstance = null; + private Deserializer valueDeserializerInstance = null; + + private final Counter elementsRead = SourceMetrics.elementsRead(); + private final Counter bytesRead = SourceMetrics.bytesRead(); + private final Counter elementsReadBySplit; + private final Counter bytesReadBySplit; + private final Gauge backlogBytesOfSplit; + private final Gauge backlogElementsOfSplit; + private final Counter checkpointMarkCommitsEnqueued = Metrics.counter( + METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC); + // Checkpoint marks skipped in favor of newer mark (only the latest needs to be committed). + private final Counter checkpointMarkCommitsSkipped = Metrics.counter( + METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC); + + /** + * The poll timeout while reading records from Kafka. + * If option to commit reader offsets in to Kafka in + * {@link KafkaCheckpointMark#finalizeCheckpoint()} is enabled, it would be delayed until + * this poll returns. It should be reasonably low as a result. + * At the same time it probably can't be very low like 10 millis, I am not sure how it affects + * when the latency is high. Probably good to experiment. Often multiple marks would be + * finalized in a batch, it it reduce finalization overhead to wait a short while and finalize + * only the last checkpoint mark. + */ + private static final Duration KAFKA_POLL_TIMEOUT = Duration.millis(1000); + private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT = Duration.millis(10); + private static final Duration RECORDS_ENQUEUE_POLL_TIMEOUT = Duration.millis(100); + + // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including + // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout + // like 100 milliseconds does not work well. This along with large receive buffer for + // consumer achieved best throughput in tests (see `defaultConsumerProperties`). + private final ExecutorService consumerPollThread = Executors.newSingleThreadExecutor(); + private final SynchronousQueue> availableRecordsQueue = + new SynchronousQueue<>(); + private AtomicReference finalizedCheckpointMark = new AtomicReference<>(); + private AtomicBoolean closed = new AtomicBoolean(false); + + // Backlog support : + // Kafka consumer does not have an API to fetch latest offset for topic. We need to seekToEnd() + // then look at position(). Use another consumer to do this so that the primary consumer does + // not need to be interrupted. The latest offsets are fetched periodically on a thread. This is + // still a bit of a hack, but so far there haven't been any issues reported by the users. + private Consumer offsetConsumer; + private final ScheduledExecutorService offsetFetcherThread = + Executors.newSingleThreadScheduledExecutor(); + private static final int OFFSET_UPDATE_INTERVAL_SECONDS = 5; + + private static final long UNINITIALIZED_OFFSET = -1; + + //Add SpEL instance to cover the interface difference of Kafka client + private transient ConsumerSpEL consumerSpEL; + + /** watermark before any records have been read. */ + private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + + @Override + public String toString() { + return name; + } + + // Maintains approximate average over last 1000 elements + private static class MovingAvg { + private static final int MOVING_AVG_WINDOW = 1000; + private double avg = 0; + private long numUpdates = 0; + + void update(double quantity) { + numUpdates++; + avg += (quantity - avg) / Math.min(MOVING_AVG_WINDOW, numUpdates); + } + + double get() { + return avg; + } + } + + // maintains state of each assigned partition (buffered records, consumed offset, etc) + private static class PartitionState { + private final TopicPartition topicPartition; + private long nextOffset; + private long latestOffset; + private Iterator> recordIter = Collections.emptyIterator(); + + private MovingAvg avgRecordSize = new MovingAvg(); + private MovingAvg avgOffsetGap = new MovingAvg(); // > 0 only when log compaction is enabled. + + PartitionState(TopicPartition partition, long nextOffset) { + this.topicPartition = partition; + this.nextOffset = nextOffset; + this.latestOffset = UNINITIALIZED_OFFSET; + } + + // Update consumedOffset, avgRecordSize, and avgOffsetGap + void recordConsumed(long offset, int size, long offsetGap) { + nextOffset = offset + 1; + + // This is always updated from single thread. Probably not worth making atomic. + avgRecordSize.update(size); + avgOffsetGap.update(offsetGap); + } + + synchronized void setLatestOffset(long latestOffset) { + this.latestOffset = latestOffset; + } + + synchronized long approxBacklogInBytes() { + // Note that is an an estimate of uncompressed backlog. + long backlogMessageCount = backlogMessageCount(); + if (backlogMessageCount == UnboundedReader.BACKLOG_UNKNOWN) { + return UnboundedReader.BACKLOG_UNKNOWN; + } + return (long) (backlogMessageCount * avgRecordSize.get()); + } + + synchronized long backlogMessageCount() { + if (latestOffset < 0 || nextOffset < 0) { + return UnboundedReader.BACKLOG_UNKNOWN; + } + double remaining = (latestOffset - nextOffset) / (1 + avgOffsetGap.get()); + return Math.max(0, (long) Math.ceil(remaining)); + } + } + + KafkaUnboundedReader( + KafkaUnboundedSource source, + @Nullable KafkaCheckpointMark checkpointMark) { + this.consumerSpEL = new ConsumerSpEL(); + this.source = source; + this.name = "Reader-" + source.getId(); + + List partitions = source.getSpec().getTopicPartitions(); + partitionStates = + ImmutableList.copyOf( + partitions + .stream() + .map(tp -> new PartitionState(tp, UNINITIALIZED_OFFSET)) + .collect(Collectors.toList())); + + if (checkpointMark != null) { + // a) verify that assigned and check-pointed partitions match exactly + // b) set consumed offsets + + checkState(checkpointMark.getPartitions().size() == partitions.size(), + "checkPointMark and assignedPartitions should match"); + + for (int i = 0; i < partitions.size(); i++) { + PartitionMark ckptMark = checkpointMark.getPartitions().get(i); + TopicPartition assigned = partitions.get(i); + TopicPartition partition = new TopicPartition(ckptMark.getTopic(), + ckptMark.getPartition()); + checkState(partition.equals(assigned), + "checkpointed partition %s and assigned partition %s don't match", + partition, assigned); + + partitionStates.get(i).nextOffset = ckptMark.getNextOffset(); + } + } + + String splitId = String.valueOf(source.getId()); + + elementsReadBySplit = SourceMetrics.elementsReadBySplit(splitId); + bytesReadBySplit = SourceMetrics.bytesReadBySplit(splitId); + backlogBytesOfSplit = SourceMetrics.backlogBytesOfSplit(splitId); + backlogElementsOfSplit = SourceMetrics.backlogElementsOfSplit(splitId); + } + + private void consumerPollLoop() { + // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. + + ConsumerRecords records = ConsumerRecords.empty(); + while (!closed.get()) { + try { + if (records.isEmpty()) { + records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); + } else if (availableRecordsQueue.offer(records, + RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), + TimeUnit.MILLISECONDS)) { + records = ConsumerRecords.empty(); + } + KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); + if (checkpointMark != null) { + commitCheckpointMark(checkpointMark); + } + } catch (InterruptedException e) { + LOG.warn("{}: consumer thread is interrupted", this, e); // not expected + break; + } catch (WakeupException e) { + break; + } + } + + LOG.info("{}: Returning from consumer pool loop", this); + } + + private void commitCheckpointMark(KafkaCheckpointMark checkpointMark) { + consumer.commitSync( + checkpointMark + .getPartitions() + .stream() + .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) + .collect(Collectors.toMap( + p -> new TopicPartition(p.getTopic(), p.getPartition()), + p -> new OffsetAndMetadata(p.getNextOffset()) + )) + ); + } + + /** + * Enqueue checkpoint mark to be committed to Kafka. This does not block until + * it is committed. There could be a delay of up to KAFKA_POLL_TIMEOUT (1 second). + * Any checkpoint mark enqueued earlier is dropped in favor of this checkpoint mark. + * Documentation for {@link CheckpointMark#finalizeCheckpoint()} says these are finalized + * in order. Only the latest offsets need to be committed. + */ + void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { + if (finalizedCheckpointMark.getAndSet(checkpointMark) != null) { + checkpointMarkCommitsSkipped.inc(); + } + checkpointMarkCommitsEnqueued.inc(); + } + + private void nextBatch() { + curBatch = Collections.emptyIterator(); + + ConsumerRecords records; + try { + // poll available records, wait (if necessary) up to the specified timeout. + records = availableRecordsQueue.poll(RECORDS_DEQUEUE_POLL_TIMEOUT.getMillis(), + TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("{}: Unexpected", this, e); + return; + } + + if (records == null) { + return; + } + + List nonEmpty = new LinkedList<>(); + + for (PartitionState p : partitionStates) { + p.recordIter = records.records(p.topicPartition).iterator(); + if (p.recordIter.hasNext()) { + nonEmpty.add(p); + } + } + + // cycle through the partitions in order to interleave records from each. + curBatch = Iterators.cycle(nonEmpty); + } + + private void setupInitialOffset(PartitionState pState) { + Read spec = source.getSpec(); + + if (pState.nextOffset != UNINITIALIZED_OFFSET) { + consumer.seek(pState.topicPartition, pState.nextOffset); + } else { + // nextOffset is unininitialized here, meaning start reading from latest record as of now + // ('latest' is the default, and is configurable) or 'look up offset by startReadTime. + // Remember the current position without waiting until the first record is read. This + // ensures checkpoint is accurate even if the reader is closed before reading any records. + Instant startReadTime = spec.getStartReadTime(); + if (startReadTime != null) { + pState.nextOffset = + consumerSpEL.offsetForTime(consumer, pState.topicPartition, spec.getStartReadTime()); + consumer.seek(pState.topicPartition, pState.nextOffset); + } else { + pState.nextOffset = consumer.position(pState.topicPartition); + } + } + } + + // update latest offset for each partition. + // called from offsetFetcher thread + private void updateLatestOffsets() { + for (PartitionState p : partitionStates) { + try { + // If "read_committed" is enabled in the config, this seeks to 'Last Stable Offset'. + // As a result uncommitted messages are not counted in backlog. It is correct since + // the reader can not read them anyway. + consumerSpEL.evaluateSeek2End(offsetConsumer, p.topicPartition); + long offset = offsetConsumer.position(p.topicPartition); + p.setLatestOffset(offset); + } catch (Exception e) { + // An exception is expected if we've closed the reader in another thread. Ignore and exit. + if (closed.get()) { + break; + } + LOG.warn("{}: exception while fetching latest offset for partition {}. will be retried.", + this, p.topicPartition, e); + p.setLatestOffset(UNINITIALIZED_OFFSET); // reset + } + + LOG.debug("{}: latest offset update for {} : {} (consumer offset {}, avg record size {})", + this, p.topicPartition, p.latestOffset, p.nextOffset, p.avgRecordSize); + } + + LOG.debug("{}: backlog {}", this, getSplitBacklogBytes()); + } + + private void reportBacklog() { + long splitBacklogBytes = getSplitBacklogBytes(); + if (splitBacklogBytes < 0) { + splitBacklogBytes = UnboundedReader.BACKLOG_UNKNOWN; + } + backlogBytesOfSplit.set(splitBacklogBytes); + long splitBacklogMessages = getSplitBacklogMessageCount(); + if (splitBacklogMessages < 0) { + splitBacklogMessages = UnboundedReader.BACKLOG_UNKNOWN; + } + backlogElementsOfSplit.set(splitBacklogMessages); + } + + private long getSplitBacklogMessageCount() { + long backlogCount = 0; + + for (PartitionState p : partitionStates) { + long pBacklog = p.backlogMessageCount(); + if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { + return UnboundedReader.BACKLOG_UNKNOWN; + } + backlogCount += pBacklog; + } + + return backlogCount; + } + + @Override + public void close() throws IOException { + closed.set(true); + consumerPollThread.shutdown(); + offsetFetcherThread.shutdown(); + + boolean isShutdown = false; + + // Wait for threads to shutdown. Trying this as a loop to handle a tiny race where poll thread + // might block to enqueue right after availableRecordsQueue.poll() below. + while (!isShutdown) { + + if (consumer != null) { + consumer.wakeup(); + } + if (offsetConsumer != null) { + offsetConsumer.wakeup(); + } + availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. + try { + isShutdown = consumerPollThread.awaitTermination(10, TimeUnit.SECONDS) + && offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); // not expected + } + + if (!isShutdown) { + LOG.warn("An internal thread is taking a long time to shutdown. will retry."); + } + } + + Closeables.close(keyDeserializerInstance, true); + Closeables.close(valueDeserializerInstance, true); + + Closeables.close(offsetConsumer, true); + Closeables.close(consumer, true); + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java new file mode 100644 index 000000000000..e46d0a585c47 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedSource.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.base.Joiner; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import org.apache.beam.sdk.coders.AvroCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.kafka.KafkaIO.Read; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * An {@link UnboundedSource} to read from Kafka, used by {@link Read} transform in KafkaIO. + * See {@link KafkaIO} for user visible documentation and example usage. + */ +class KafkaUnboundedSource extends UnboundedSource, KafkaCheckpointMark> { + + /** + * The partitions are evenly distributed among the splits. The number of splits returned is + * {@code min(desiredNumSplits, totalNumPartitions)}, though better not to depend on the exact + * count. + * + *

    It is important to assign the partitions deterministically so that we can support + * resuming a split from last checkpoint. The Kafka partitions are sorted by + * {@code } and then assigned to splits in round-robin order. + */ + @Override + public List> split( + int desiredNumSplits, PipelineOptions options) throws Exception { + + List partitions = new ArrayList<>(spec.getTopicPartitions()); + + // (a) fetch partitions for each topic + // (b) sort by + // (c) round-robin assign the partitions to splits + + if (partitions.isEmpty()) { + try (Consumer consumer = + spec.getConsumerFactoryFn().apply(spec.getConsumerConfig())) { + for (String topic : spec.getTopics()) { + for (PartitionInfo p : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(p.topic(), p.partition())); + } + } + } + } + + partitions.sort( + Comparator.comparing(TopicPartition::topic) + .thenComparing(Comparator.comparingInt(TopicPartition::partition))); + + checkArgument(desiredNumSplits > 0); + checkState(partitions.size() > 0, + "Could not find any partitions. Please check Kafka configuration and topic names"); + + int numSplits = Math.min(desiredNumSplits, partitions.size()); + List> assignments = new ArrayList<>(numSplits); + + for (int i = 0; i < numSplits; i++) { + assignments.add(new ArrayList<>()); + } + for (int i = 0; i < partitions.size(); i++) { + assignments.get(i % numSplits).add(partitions.get(i)); + } + + List> result = new ArrayList<>(numSplits); + + for (int i = 0; i < numSplits; i++) { + List assignedToSplit = assignments.get(i); + + LOG.info("Partitions assigned to split {} (total {}): {}", + i, assignedToSplit.size(), Joiner.on(",").join(assignedToSplit)); + + result.add( + new KafkaUnboundedSource<>( + spec.toBuilder() + .setTopics(Collections.emptyList()) + .setTopicPartitions(assignedToSplit) + .build(), + i)); + } + + return result; + } + + @Override + public KafkaUnboundedReader createReader(PipelineOptions options, + KafkaCheckpointMark checkpointMark) { + if (spec.getTopicPartitions().isEmpty()) { + LOG.warn("Looks like generateSplits() is not called. Generate single split."); + try { + return new KafkaUnboundedReader<>(split(1, options).get(0), checkpointMark); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return new KafkaUnboundedReader<>(this, checkpointMark); + } + + @Override + public Coder getCheckpointMarkCoder() { + return AvroCoder.of(KafkaCheckpointMark.class); + } + + @Override + public boolean requiresDeduping() { + // Kafka records are ordered with in partitions. In addition checkpoint guarantees + // records are not consumed twice. + return false; + } + + @Override + public Coder> getOutputCoder() { + return KafkaRecordCoder.of(spec.getKeyCoder(), spec.getValueCoder()); + } + + ///////////////////////////////////////////////////////////////////////////////////////////// + + private static final Logger LOG = LoggerFactory.getLogger(KafkaUnboundedSource.class); + + private final Read spec; // Contains all the relevant configuratiton of the source. + private final int id; // split id, mainly for debugging + + public KafkaUnboundedSource(Read spec, int id) { + this.spec = spec; + this.id = id; + } + + Read getSpec() { + return spec; + } + + int getId() { + return id; + } +} + diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java new file mode 100644 index 000000000000..00b76e503b8d --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.io.kafka.KafkaIO.Write; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.SinkMetrics; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A DoFn to write to Kafka, used in KafkaIO Write transform. + * See {@link KafkaIO} for user visible documentation and example usage. + */ +class KafkaWriter extends DoFn, Void> { + + @Setup + public void setup() { + if (spec.getProducerFactoryFn() != null) { + producer = spec.getProducerFactoryFn().apply(producerConfig); + } else { + producer = new KafkaProducer<>(producerConfig); + } + } + + @ProcessElement + public void processElement(ProcessContext ctx) throws Exception { + checkForFailures(); + + KV kv = ctx.element(); + producer.send( + new ProducerRecord<>(spec.getTopic(), kv.getKey(), kv.getValue()), new SendCallback()); + + elementsWritten.inc(); + } + + @FinishBundle + public void finishBundle() throws IOException { + producer.flush(); + checkForFailures(); + } + + @Teardown + public void teardown() { + producer.close(); + } + + /////////////////////////////////////////////////////////////////////////////////// + + private static final Logger LOG = LoggerFactory.getLogger(KafkaWriter.class); + + private final Write spec; + private final Map producerConfig; + + private transient Producer producer = null; + // first exception and number of failures since last invocation of checkForFailures(): + private transient Exception sendException = null; + private transient long numSendFailures = 0; + + private final Counter elementsWritten = SinkMetrics.elementsWritten(); + + KafkaWriter(Write spec) { + this.spec = spec; + + this.producerConfig = new HashMap<>(spec.getProducerConfig()); + + this.producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, + spec.getKeySerializer()); + this.producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, + spec.getValueSerializer()); + } + + private synchronized void checkForFailures() throws IOException { + if (numSendFailures == 0) { + return; + } + + String msg = String.format( + "KafkaWriter : failed to send %d records (since last report)", numSendFailures); + + Exception e = sendException; + sendException = null; + numSendFailures = 0; + + LOG.warn(msg); + throw new IOException(msg, e); + } + + private class SendCallback implements Callback { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + if (exception == null) { + return; + } + + synchronized (KafkaWriter.this) { + if (sendException == null) { + sendException = exception; + } + numSendFailures++; + } + // don't log exception stacktrace here, exception will be propagated up. + LOG.warn("send failed : '{}'", exception.getMessage()); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ProducerSpEL.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ProducerSpEL.java index 08674e0f904e..f7ad7aacb550 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ProducerSpEL.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ProducerSpEL.java @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Map; - import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.common.TopicPartition; @@ -46,8 +45,8 @@ class ProducerSpEL { static final String ENABLE_IDEMPOTENCE_CONFIG = "enable.idempotence"; static final String TRANSACTIONAL_ID_CONFIG = "transactional.id"; - private static Class producerFencedExceptionClass; - private static Class outOfOrderSequenceExceptionClass; + private static Class producerFencedExceptionClass; + private static Class outOfOrderSequenceExceptionClass; static { try { @@ -90,13 +89,13 @@ private static void ensureTransactionsSupport() { "Please used version 0.11 or later."); } - private static Object invoke(Method method, Object obj, Object... args) { + private static void invoke(Method method, Object obj, Object... args) { try { - return method.invoke(obj, args); + method.invoke(obj, args); } catch (IllegalAccessException | InvocationTargetException e) { - return new RuntimeException(e); + throw new RuntimeException(e); } catch (ApiException e) { - Class eClass = e.getClass(); + Class eClass = e.getClass(); if (producerFencedExceptionClass.isAssignableFrom(eClass) || outOfOrderSequenceExceptionClass.isAssignableFrom(eClass) || AuthorizationException.class.isAssignableFrom(eClass)) { diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 08338d842bab..ebdd1da7028c 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -31,11 +31,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import java.io.IOException; -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.nio.ByteBuffer; +import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -49,6 +49,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; @@ -90,6 +91,7 @@ import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; import org.apache.kafka.clients.consumer.OffsetResetStrategy; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; @@ -119,8 +121,7 @@ /** * Tests of {@link KafkaIO}. - * Run with 'mvn test -Dkafka.clients.version=0.10.1.1', - * or 'mvn test -Dkafka.clients.version=0.9.0.1' for either Kafka client version. + * Run with 'mvn test -Dkafka.clients.version=0.10.1.1', to test with a specific Kafka version. */ @RunWith(JUnit4.class) public class KafkaIOTest { @@ -184,12 +185,8 @@ private static MockConsumer mkMockConsumer( final MockConsumer consumer = new MockConsumer(offsetResetStrategy) { - // override assign() in order to set offset limits & to save assigned partitions. - //remove keyword '@Override' here, it can work with Kafka client 0.9 and 0.10 as: - //1. SpEL can find this function, either input is List or Collection; - //2. List extends Collection, so super.assign() could find either assign(List) - // or assign(Collection). - public void assign(final List assigned) { + @Override + public void assign(final Collection assigned) { super.assign(assigned); assignedPartitions.set(ImmutableList.copyOf(assigned)); for (TopicPartition tp : assigned) { @@ -198,34 +195,21 @@ public void assign(final List assigned) { } } // Override offsetsForTimes() in order to look up the offsets by timestamp. - // Remove keyword '@Override' here, Kafka client 0.10.1.0 previous versions does not have - // this method. - // Should return Map, but 0.10.1.0 previous versions - // does not have the OffsetAndTimestamp class. So return a raw type and use reflection - // here. - @SuppressWarnings("unchecked") - public Map offsetsForTimes(Map timestampsToSearch) { - HashMap result = new HashMap<>(); - try { - Class cls = Class.forName("org.apache.kafka.clients.consumer.OffsetAndTimestamp"); - // OffsetAndTimestamp(long offset, long timestamp) - Constructor constructor = cls.getDeclaredConstructor(long.class, long.class); - - // In test scope, timestamp == offset. - for (Map.Entry entry : timestampsToSearch.entrySet()) { - long maxOffset = offsets[partitions.indexOf(entry.getKey())]; - Long offset = entry.getValue(); - if (offset >= maxOffset) { - offset = null; - } - result.put( - entry.getKey(), constructor.newInstance(entry.getValue(), offset)); - } - return result; - } catch (ClassNotFoundException | IllegalAccessException - | InstantiationException | NoSuchMethodException | InvocationTargetException e) { - throw new RuntimeException(e); - } + @Override + public Map offsetsForTimes( + Map timestampsToSearch) { + return timestampsToSearch + .entrySet() + .stream() + .map(e -> { + // In test scope, timestamp == offset. + long maxOffset = offsets[partitions.indexOf(e.getKey())]; + long offset = e.getValue(); + OffsetAndTimestamp value = (offset >= maxOffset) + ? null : new OffsetAndTimestamp(offset, offset); + return new SimpleEntry<>(e.getKey(), value); + }) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); } }; @@ -722,8 +706,8 @@ public void testUnboundedSourceMetrics() { MetricsFilter.builder() .addNameFilter( MetricNameFilter.named( - KafkaIO.UnboundedKafkaReader.METRIC_NAMESPACE, - KafkaIO.UnboundedKafkaReader.CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC)) + KafkaUnboundedReader.METRIC_NAMESPACE, + KafkaUnboundedReader.CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC)) .build()); assertThat(commitsEnqueuedMetrics.counters(), IsIterableWithSize.iterableWithSize(1)); @@ -794,16 +778,17 @@ public void testValuesSink() throws Exception { } @Test - public void testEOSink() { + public void testExactlyOnceSink() { // testSink() with EOS enabled. // This does not actually inject retries in a stage to test exactly-once-semantics. // It mainly exercises the code in normal flow without retries. // Ideally we should test EOS Sink by triggering replays of a messages between stages. // It is not feasible to test such retries with direct runner. When DoFnTester supports - // state, we can test KafkaEOWriter DoFn directly to ensure it handles retries correctly. + // state, we can test ExactlyOnceWriter DoFn directly to ensure it handles retries correctly. if (!ProducerSpEL.supportsTransactions()) { - LOG.warn("testEOSink() is disabled as Kafka client version does not support transactions."); + LOG.warn( + "testExactlyOnceSink() is disabled as Kafka client version does not support transactions."); return; } diff --git a/sdks/java/io/kinesis/build.gradle b/sdks/java/io/kinesis/build.gradle index 73824f612d20..f4123a831ce4 100644 --- a/sdks/java/io/kinesis/build.gradle +++ b/sdks/java/io/kinesis/build.gradle @@ -39,6 +39,7 @@ dependencies { shadow "com.amazonaws:aws-java-sdk-kinesis:$aws_version" shadow "com.amazonaws:aws-java-sdk-cloudwatch:$aws_version" shadow "com.amazonaws:amazon-kinesis-client:1.8.8" + shadow "com.amazonaws:amazon-kinesis-producer:0.12.8" shadow "commons-lang:commons-lang:2.6" testCompile project(path: ":runners:direct-java", configuration: "shadow") testCompile library.java.junit diff --git a/sdks/java/io/kinesis/pom.xml b/sdks/java/io/kinesis/pom.xml index ef47a7280d74..acc462a90dfa 100644 --- a/sdks/java/io/kinesis/pom.xml +++ b/sdks/java/io/kinesis/pom.xml @@ -31,14 +31,16 @@ + org.apache.maven.plugins - maven-surefire-plugin - - - 1 - false - + maven-shade-plugin + + + bundle-and-repackage + none + + @@ -46,6 +48,7 @@ 1.11.255 1.8.8 + 0.12.8 @@ -66,6 +69,12 @@ ${aws.version} + + com.amazonaws + aws-java-sdk-core + ${aws.version} + + com.amazonaws amazon-kinesis-client @@ -80,8 +89,9 @@ - org.slf4j - slf4j-api + com.amazonaws + amazon-kinesis-producer + ${amazon-kinesis-producer.version} @@ -100,12 +110,6 @@ 2.6 - - com.amazonaws - aws-java-sdk-core - ${aws.version} - - com.google.code.findbugs jsr305 @@ -117,6 +121,11 @@ provided + + org.slf4j + slf4j-api + + junit junit @@ -125,7 +134,7 @@ org.mockito - mockito-all + mockito-core test @@ -144,10 +153,14 @@ org.hamcrest - hamcrest-all + hamcrest-core + test + + + org.hamcrest + hamcrest-library test - org.apache.beam beam-runners-direct-java diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java index c82e4b14d7d1..fa3351ccf778 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java @@ -19,18 +19,20 @@ import com.amazonaws.services.cloudwatch.AmazonCloudWatch; import com.amazonaws.services.kinesis.AmazonKinesis; - +import com.amazonaws.services.kinesis.producer.IKinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; import java.io.Serializable; /** * Provides instances of AWS clients. * - *

    Please note, that any instance of {@link AWSClientsProvider} must be - * {@link Serializable} to ensure it can be sent to worker machines. + *

    Please note, that any instance of {@link AWSClientsProvider} must be {@link Serializable} to + * ensure it can be sent to worker machines. */ public interface AWSClientsProvider extends Serializable { - AmazonKinesis getKinesisClient(); AmazonCloudWatch getCloudWatchClient(); + + IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config); } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/BasicKinesisProvider.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/BasicKinesisProvider.java new file mode 100644 index 000000000000..247e9f10f7a1 --- /dev/null +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/BasicKinesisProvider.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.client.builder.AwsClientBuilder; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.cloudwatch.AmazonCloudWatch; +import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; +import com.amazonaws.services.kinesis.AmazonKinesis; +import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; +import com.amazonaws.services.kinesis.producer.IKinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; +import javax.annotation.Nullable; + +/** Basic implementation of {@link AWSClientsProvider} used by default in {@link KinesisIO}. */ +class BasicKinesisProvider implements AWSClientsProvider { + private final String accessKey; + private final String secretKey; + private final Regions region; + @Nullable private final String serviceEndpoint; + + BasicKinesisProvider( + String accessKey, String secretKey, Regions region, @Nullable String serviceEndpoint) { + checkArgument(accessKey != null, "accessKey can not be null"); + checkArgument(secretKey != null, "secretKey can not be null"); + checkArgument(region != null, "region can not be null"); + this.accessKey = accessKey; + this.secretKey = secretKey; + this.region = region; + this.serviceEndpoint = serviceEndpoint; + } + + private AWSCredentialsProvider getCredentialsProvider() { + return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)); + } + + @Override + public AmazonKinesis getKinesisClient() { + AmazonKinesisClientBuilder clientBuilder = + AmazonKinesisClientBuilder.standard().withCredentials(getCredentialsProvider()); + if (serviceEndpoint == null) { + clientBuilder.withRegion(region); + } else { + clientBuilder.withEndpointConfiguration( + new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); + } + return clientBuilder.build(); + } + + @Override + public AmazonCloudWatch getCloudWatchClient() { + AmazonCloudWatchClientBuilder clientBuilder = + AmazonCloudWatchClientBuilder.standard().withCredentials(getCredentialsProvider()); + if (serviceEndpoint == null) { + clientBuilder.withRegion(region); + } else { + clientBuilder.withEndpointConfiguration( + new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); + } + return clientBuilder.build(); + } + + @Override + public IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config) { + config.setRegion(region.getName()); + config.setCredentialsProvider(getCredentialsProvider()); + return new KinesisProducer(config); + } +} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java index bf229469f491..56af385268c6 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java @@ -20,10 +20,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.amazonaws.services.kinesis.model.Shard; - import java.util.Set; import java.util.stream.Collectors; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java index 60e695ac4ea5..0fe51406502b 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java @@ -19,33 +19,46 @@ import static com.google.common.base.Preconditions.checkArgument; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder; import com.amazonaws.regions.Regions; import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.amazonaws.services.kinesis.model.DescribeStreamResult; +import com.amazonaws.services.kinesis.producer.Attempt; +import com.amazonaws.services.kinesis.producer.IKinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; +import com.amazonaws.services.kinesis.producer.UserRecordFailedException; +import com.amazonaws.services.kinesis.producer.UserRecordResult; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.LinkedBlockingDeque; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.io.Read.Unbounded; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * {@link PTransform}s for reading from + * {@link PTransform}s for reading from and writing to * Kinesis streams. * + *

    Reading from Kinesis

    + * *

    Example usage: * *

    {@code
    @@ -111,11 +124,64 @@
      *  .apply( ... ) // other transformations
      * }
    * + *

    Writing to Kinesis

    + * + *

    Example usage: + * + *

    {@code
    + * PCollection data = ...;
    + *
    + * data.apply(KinesisIO.write()
    + *     .withStreamName("streamName")
    + *     .withPartitionKey("partitionKey")
    + *     .withAWSClientsProvider(AWS_KEY, AWS_SECRET, STREAM_REGION));
    + * }
    + * + *

    As a client, you need to provide at least 3 things: + *

      + *
    • name of the stream where you're going to write
    • + *
    • partition key (or implementation of {@link KinesisPartitioner}) that defines which + * partition will be used for writing
    • + *
    • data used to initialize {@link AmazonKinesis} and {@link AmazonCloudWatch} clients: + *
        + *
      • credentials (aws key, aws secret)
      • + *
      • region where the stream is located
      • + *
    • + *
    + * + *

    In case if you need to define more complicated logic for key partitioning then you can + * create your own implementation of {@link KinesisPartitioner} and set it by + * {@link KinesisIO.Write#withPartitioner(KinesisPartitioner)}

    + * + *

    Internally, {@link KinesisIO.Write} relies on Amazon Kinesis Producer Library (KPL). + * This library can be configured with a set of {@link Properties} if needed. + * + *

    Example usage of KPL configuration: + * + *

    {@code
    + * Properties properties = new Properties();
    + * properties.setProperty("KinesisEndpoint", "localhost");
    + * properties.setProperty("KinesisPort", "4567");
    + *
    + * PCollection data = ...;
    + *
    + * data.apply(KinesisIO.write()
    + *     .withStreamName("streamName")
    + *     .withPartitionKey("partitionKey")
    + *     .withAWSClientsProvider(AWS_KEY, AWS_SECRET, STREAM_REGION)
    + *     .withProducerProperties(properties));
    + * }
    + * + *

    For more information about configuratiom parameters, see the + * sample of configuration file. */ @Experimental(Experimental.Kind.SOURCE_SINK) public final class KinesisIO { + private static final Logger LOG = LoggerFactory.getLogger(KinesisIO.class); + private static final int DEFAULT_NUM_RETRIES = 6; + /** Returns a new {@link Read} transform for reading from Kinesis. */ public static Read read() { return new AutoValue_KinesisIO_Read.Builder() @@ -124,6 +190,11 @@ public static Read read() { .build(); } + /** A {@link PTransform} writing data to Kinesis. */ + public static Write write() { + return new AutoValue_KinesisIO_Write.Builder().setRetries(DEFAULT_NUM_RETRIES).build(); + } + /** Implementation of {@link #read}. */ @AutoValue public abstract static class Read extends PTransform> { @@ -250,11 +321,6 @@ public Read withUpToDateThreshold(Duration upToDateThreshold) { @Override public PCollection expand(PBegin input) { - checkArgument( - streamExists(getAWSClientsProvider().getKinesisClient(), getStreamName()), - "Stream %s does not exist", - getStreamName()); - Unbounded unbounded = org.apache.beam.sdk.io.Read.from( new KinesisSource( @@ -272,52 +338,321 @@ public PCollection expand(PBegin input) { return input.apply(transform); } + } + + /** Implementation of {@link #write}. */ + @AutoValue + public abstract static class Write extends PTransform, PDone> { + @Nullable + abstract String getStreamName(); + @Nullable + abstract String getPartitionKey(); + @Nullable + abstract KinesisPartitioner getPartitioner(); + @Nullable + abstract Properties getProducerProperties(); + @Nullable + abstract AWSClientsProvider getAWSClientsProvider(); + abstract int getRetries(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setStreamName(String streamName); + abstract Builder setPartitionKey(String partitionKey); + abstract Builder setPartitioner(KinesisPartitioner partitioner); + abstract Builder setProducerProperties(Properties properties); + abstract Builder setAWSClientsProvider(AWSClientsProvider clientProvider); + abstract Builder setRetries(int retries); + abstract Write build(); + } + + /** Specify Kinesis stream name which will be used for writing, this name is required. */ + public Write withStreamName(String streamName) { + return builder().setStreamName(streamName).build(); + } + + /** + * Specify default partition key. + * + *

    In case if you need to define more complicated logic for key partitioning then you can + * create your own implementation of {@link KinesisPartitioner} and specify it by + * {@link KinesisIO.Write#withPartitioner(KinesisPartitioner)} + * + *

    Using one of the methods {@link KinesisIO.Write#withPartitioner(KinesisPartitioner)} or + * {@link KinesisIO.Write#withPartitionKey(String)} is required but not both in the same time. + */ + public Write withPartitionKey(String partitionKey) { + return builder().setPartitionKey(partitionKey).build(); + } + + /** + * Allows to specify custom implementation of {@link KinesisPartitioner}. + * + *

    This method should be used to balance a distribution of new written records among all + * stream shards. + * + *

    Using one of the methods {@link KinesisIO.Write#withPartitioner(KinesisPartitioner)} or + * {@link KinesisIO.Write#withPartitionKey(String)} is required but not both in the same time. + */ + public Write withPartitioner(KinesisPartitioner partitioner) { + return builder().setPartitioner(partitioner).build(); + } + + /** + * Specify the configuration properties for Kinesis Producer Library (KPL). + * + *

    Example of creating new KPL configuration: + * + * {@code + * Properties properties = new Properties(); + * properties.setProperty("CollectionMaxCount", "1000"); + * properties.setProperty("ConnectTimeout", "10000");} + */ + public Write withProducerProperties(Properties properties) { + return builder().setProducerProperties(properties).build(); + } + + /** + * Allows to specify custom {@link AWSClientsProvider}. {@link AWSClientsProvider} creates new + * {@link IKinesisProducer} which is later used for writing to Kinesis. + * + *

    This method should be used if + * {@link Write#withAWSClientsProvider(String, String, Regions)} does not suit well. + */ + public Write withAWSClientsProvider(AWSClientsProvider awsClientsProvider) { + return builder().setAWSClientsProvider(awsClientsProvider).build(); + } + + /** + * Specify credential details and region to be used to write to Kinesis. If you need more + * sophisticated credential protocol, then you should look at {@link + * Write#withAWSClientsProvider(AWSClientsProvider)}. + */ + public Write withAWSClientsProvider(String awsAccessKey, String awsSecretKey, Regions region) { + return withAWSClientsProvider(awsAccessKey, awsSecretKey, region, null); + } + + /** + * Specify credential details and region to be used to write to Kinesis. If you need more + * sophisticated credential protocol, then you should look at {@link + * Write#withAWSClientsProvider(AWSClientsProvider)}. + * + *

    The {@code serviceEndpoint} sets an alternative service host. This is useful to execute + * the tests with Kinesis service emulator. + */ + public Write withAWSClientsProvider( + String awsAccessKey, String awsSecretKey, Regions region, String serviceEndpoint) { + return withAWSClientsProvider( + new BasicKinesisProvider(awsAccessKey, awsSecretKey, region, serviceEndpoint)); + } + + /** + * Specify the number of retries that will be used to flush the outstanding records in + * case if they were not flushed from the first time. Default number of retries is + * {@code DEFAULT_NUM_RETRIES = 10}. + * + *

    This is used for testing. + */ + @VisibleForTesting + Write withRetries(int retries) { + return builder().setRetries(retries).build(); + } + + @Override + public PDone expand(PCollection input) { + checkArgument(getStreamName() != null, "withStreamName() is required"); + checkArgument( + (getPartitionKey() != null) || (getPartitioner() != null), + "withPartitionKey() or withPartitioner() is required"); + checkArgument( + getPartitionKey() == null || (getPartitioner() == null), + "only one of either withPartitionKey() or withPartitioner() is possible"); + checkArgument(getAWSClientsProvider() != null, "withAWSClientsProvider() is required"); + input.apply(ParDo.of(new KinesisWriterFn(this))); + return PDone.in(input.getPipeline()); + } + + private static class KinesisWriterFn extends DoFn { - private static final class BasicKinesisProvider implements AWSClientsProvider { - private final String accessKey; - private final String secretKey; - private final Regions region; - @Nullable private final String serviceEndpoint; - - private BasicKinesisProvider( - String accessKey, String secretKey, Regions region, @Nullable String serviceEndpoint) { - checkArgument(accessKey != null, "accessKey can not be null"); - checkArgument(secretKey != null, "secretKey can not be null"); - checkArgument(region != null, "region can not be null"); - this.accessKey = accessKey; - this.secretKey = secretKey; - this.region = region; - this.serviceEndpoint = serviceEndpoint; + private static final int MAX_NUM_RECORDS = 100 * 1000; + private static final int MAX_NUM_FAILURES = 10; + + private final KinesisIO.Write spec; + private transient IKinesisProducer producer; + private transient KinesisPartitioner partitioner; + private transient LinkedBlockingDeque failures; + + public KinesisWriterFn(KinesisIO.Write spec) { + this.spec = spec; } - private AWSCredentialsProvider getCredentialsProvider() { - return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)); + @Setup + public void setup() throws Exception { + checkArgument( + streamExists(spec.getAWSClientsProvider().getKinesisClient(), spec.getStreamName()), + "Stream %s does not exist", spec.getStreamName()); + + // Init producer config + Properties props = spec.getProducerProperties(); + if (props == null) { + props = new Properties(); + } + KinesisProducerConfiguration config = KinesisProducerConfiguration.fromProperties(props); + // Fix to avoid the following message "WARNING: Exception during updateCredentials" during + // producer.destroy() call. More details can be found in this thread: + // https://github.com/awslabs/amazon-kinesis-producer/issues/10 + config.setCredentialsRefreshDelay(100); + + // Init Kinesis producer + producer = spec.getAWSClientsProvider().createKinesisProducer(config); + // Use custom partitioner if it exists + if (spec.getPartitioner() != null) { + partitioner = spec.getPartitioner(); + } + + /** Keep only the first {@link MAX_NUM_FAILURES} occurred exceptions */ + failures = new LinkedBlockingDeque<>(MAX_NUM_FAILURES); } - @Override - public AmazonKinesis getKinesisClient() { - AmazonKinesisClientBuilder clientBuilder = - AmazonKinesisClientBuilder.standard().withCredentials(getCredentialsProvider()); - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); + /** + * It adds a record asynchronously which then should be delivered by Kinesis producer in + * background (Kinesis producer forks native processes to do this job). + * + *

    The records can be batched and then they will be sent in one HTTP request. Amazon KPL + * supports two types of batching - aggregation and collection - and they can be configured by + * producer properties. + * + *

    More details can be found here: + * KPL Key Concepts and + * Configuring the KPL + */ + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + checkForFailures(); + + // Need to avoid keeping too many futures in producer's map to prevent OOM. + // In usual case, it should exit immediately. + flush(MAX_NUM_RECORDS); + + ByteBuffer data = ByteBuffer.wrap(c.element()); + String partitionKey = spec.getPartitionKey(); + String explicitHashKey = null; + + // Use custom partitioner + if (partitioner != null) { + partitionKey = partitioner.getPartitionKey(c.element()); + explicitHashKey = partitioner.getExplicitHashKey(c.element()); } - return clientBuilder.build(); + + ListenableFuture f = + producer.addUserRecord(spec.getStreamName(), partitionKey, explicitHashKey, data); + Futures.addCallback(f, new UserRecordResultFutureCallback()); } - @Override - public AmazonCloudWatch getCloudWatchClient() { - AmazonCloudWatchClientBuilder clientBuilder = - AmazonCloudWatchClientBuilder.standard().withCredentials(getCredentialsProvider()); - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); + @FinishBundle + public void finishBundle() throws Exception { + // Flush all outstanding records, blocking call + flushAll(); + + checkForFailures(); + } + + @Teardown + public void tearDown() throws Exception { + if (producer != null) { + producer.destroy(); + producer = null; + } + } + + /** + * Flush outstanding records until the total number will be less than required or + * the number of retries will be exhausted. The retry timeout starts from 1 second and it + * doubles on every iteration. + */ + private void flush(int numMax) throws InterruptedException, IOException { + int retries = spec.getRetries(); + int numOutstandingRecords = producer.getOutstandingRecordsCount(); + int retryTimeout = 1000; // initial timeout, 1 sec + + while (numOutstandingRecords > numMax && retries-- > 0) { + producer.flush(); + // wait until outstanding records will be flushed + Thread.sleep(retryTimeout); + numOutstandingRecords = producer.getOutstandingRecordsCount(); + retryTimeout *= 2; // exponential backoff + } + + if (numOutstandingRecords > numMax) { + String message = String.format( + "After [%d] retries, number of outstanding records [%d] is still greater than " + + "required [%d].", + spec.getRetries(), numOutstandingRecords, numMax); + LOG.error(message); + throw new IOException(message); + } + } + + private void flushAll() throws InterruptedException, IOException { + flush(0); + } + + /** + * If any write has asynchronously failed, fail the bundle with a useful error. + */ + private void checkForFailures() throws IOException { + // Note that this function is never called by multiple threads and is the only place that + // we remove from failures, so this code is safe. + if (failures.isEmpty()) { + return; + } + + StringBuilder logEntry = new StringBuilder(); + int i = 0; + while (!failures.isEmpty()) { + i++; + KinesisWriteException exc = failures.remove(); + + logEntry.append("\n").append(exc.getMessage()); + Throwable cause = exc.getCause(); + if (cause != null) { + logEntry.append(": ").append(cause.getMessage()); + + if (cause instanceof UserRecordFailedException) { + List attempts = ((UserRecordFailedException) cause).getResult() + .getAttempts(); + for (Attempt attempt : attempts) { + if (attempt.getErrorMessage() != null) { + logEntry.append("\n").append(attempt.getErrorMessage()); + } + } + } + } + } + failures.clear(); + + String message = + String.format( + "Some errors occurred writing to Kinesis. First %d errors: %s", + i, + logEntry.toString()); + throw new IOException(message); + } + + private class UserRecordResultFutureCallback implements FutureCallback { + + @Override public void onFailure(Throwable cause) { + failures.offer(new KinesisWriteException(cause)); + } + + @Override public void onSuccess(UserRecordResult result) { + if (!result.isSuccessful()) { + failures.offer(new KinesisWriteException("Put record was not successful.", + new UserRecordFailedException(result))); + } } - return clientBuilder.build(); } } } @@ -332,4 +667,17 @@ private static boolean streamExists(AmazonKinesis client, String streamName) { } return false; } + + /** + * An exception that puts information about the failed record. + */ + static class KinesisWriteException extends IOException { + KinesisWriteException(String message, Throwable cause) { + super(message, cause); + } + + KinesisWriteException(Throwable cause) { + super(cause); + } + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisPartitioner.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisPartitioner.java new file mode 100644 index 000000000000..9bd46eaef682 --- /dev/null +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisPartitioner.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import java.io.Serializable; + +/** Kinesis interface for custom partitioner. */ +public interface KinesisPartitioner extends Serializable { + String getPartitionKey(byte[] value); + + String getExplicitHashKey(byte[] value); +} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java index fb19a1ba59d9..be2346101c1b 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.NoSuchElementException; - import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.transforms.Min; import org.apache.beam.sdk.util.MovingFunction; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java index eca879101d8f..c096cee8026b 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java @@ -21,12 +21,10 @@ import static com.google.common.collect.Lists.partition; import com.google.common.collect.ImmutableList; - import java.io.IOException; import java.io.Serializable; import java.util.Iterator; import java.util.List; - import org.apache.beam.sdk.io.UnboundedSource; /** diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java index e980c7597ee9..06759a216bc1 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java @@ -21,10 +21,8 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; - import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; - import org.apache.commons.lang.builder.EqualsBuilder; import org.joda.time.Instant; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java index dcf564d3ec73..53b3f893cf57 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java @@ -21,7 +21,6 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; - import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java index b1a6c193af6f..88cfd472fc54 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java @@ -21,7 +21,6 @@ import static com.google.common.collect.Lists.newArrayList; import java.util.List; - import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java index 94e3b96cf0f1..5cba7320491a 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java @@ -26,9 +26,7 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kinesis.model.ShardIteratorType; - import java.io.Serializable; - import org.joda.time.Instant; /** diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java index 9fe990bf9fb6..67611a935a0f 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java @@ -21,7 +21,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; - import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -34,7 +33,6 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -108,7 +106,7 @@ void start() throws TransientKinesisException { } } - private void startReadingShards(Iterable shardRecordsIterators) { + void startReadingShards(Iterable shardRecordsIterators) { for (final ShardRecordsIterator recordsIterator : shardRecordsIterators) { numberOfRecordsInAQueueByShard.put(recordsIterator.getShardId(), new AtomicInteger()); executorService.submit(() -> readLoop(recordsIterator)); diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java index b59882377451..c70cde8578c2 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java @@ -22,12 +22,10 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import com.amazonaws.services.kinesis.model.Shard; - import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java index f9298fa54deb..66064bca3a23 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java @@ -21,10 +21,8 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.amazonaws.services.kinesis.model.ShardIteratorType; - import java.io.Serializable; import java.util.Objects; - import org.joda.time.Instant; /** diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinder.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinder.java index 2ddbe1192604..cd5905c46563 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinder.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinder.java @@ -20,13 +20,11 @@ import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.google.common.collect.Sets; - import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java index 39fe7b28b3ac..f4906bdad86e 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java @@ -77,6 +77,8 @@ import com.amazonaws.services.kinesis.model.StreamDescription; import com.amazonaws.services.kinesis.model.UpdateShardCountRequest; import com.amazonaws.services.kinesis.model.UpdateShardCountResult; +import com.amazonaws.services.kinesis.producer.IKinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; import com.amazonaws.services.kinesis.waiters.AmazonKinesisWaiters; import java.io.Serializable; import java.nio.ByteBuffer; @@ -127,6 +129,11 @@ public boolean equals(Object obj) { public int hashCode() { return reflectionHashCode(this); } + + @Override public String toString() { + return "TestData{" + "data='" + data + '\'' + ", arrivalTimestamp=" + arrivalTimestamp + + ", sequenceNumber='" + sequenceNumber + '\'' + '}'; + } } static class Provider implements AWSClientsProvider { @@ -153,6 +160,10 @@ public AmazonKinesis getKinesisClient() { public AmazonCloudWatch getCloudWatchClient() { return Mockito.mock(AmazonCloudWatch.class); } + + @Override public IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config) { + throw new RuntimeException("Not implemented"); + } } private final List> shardedData; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java index 0b16bb77ba0d..1a3955bfb3aa 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java @@ -18,9 +18,7 @@ package org.apache.beam.sdk.io.kinesis; import com.google.common.testing.EqualsTester; - import java.util.NoSuchElementException; - import org.junit.Test; /** diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java index c31cd6811f1c..f5731666489f 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java @@ -23,9 +23,7 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.amazonaws.services.kinesis.model.Shard; import com.google.common.collect.Sets; - import java.util.Set; - import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOIT.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOIT.java new file mode 100644 index 000000000000..6b8cbdfbeb2d --- /dev/null +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOIT.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import static com.google.common.collect.Lists.newArrayList; + +import com.amazonaws.regions.Regions; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import java.io.Serializable; +import java.util.List; +import java.util.Random; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +/** + * Integration test, that writes and reads data to and from real Kinesis. You need to provide all + * {@link KinesisTestOptions} in order to run this. + */ +public class KinesisIOIT implements Serializable { + public static final int NUM_RECORDS = 1000; + public static final int NUM_SHARDS = 2; + + @Rule public final transient TestPipeline p = TestPipeline.create(); + @Rule public final transient TestPipeline p2 = TestPipeline.create(); + + private static KinesisTestOptions options; + + @BeforeClass + public static void setup() { + PipelineOptionsFactory.register(KinesisTestOptions.class); + options = TestPipeline.testingPipelineOptions().as(KinesisTestOptions.class); + } + + @Test + public void testWriteThenRead() throws Exception { + Instant now = Instant.now(); + List inputData = prepareData(); + + // Write data into stream + p.apply(Create.of(inputData)) + .apply( + KinesisIO.write() + .withStreamName(options.getAwsKinesisStream()) + .withPartitioner(new RandomPartitioner()) + .withAWSClientsProvider( + options.getAwsAccessKey(), + options.getAwsSecretKey(), + Regions.fromName(options.getAwsKinesisRegion()))); + p.run().waitUntilFinish(); + + // Read new data from stream that was just written before + PCollection output = + p2.apply( + KinesisIO.read() + .withStreamName(options.getAwsKinesisStream()) + .withAWSClientsProvider( + options.getAwsAccessKey(), + options.getAwsSecretKey(), + Regions.fromName(options.getAwsKinesisRegion())) + .withMaxNumRecords(inputData.size()) + // to prevent endless running in case of error + .withMaxReadTime(Duration.standardMinutes(5)) + .withInitialPositionInStream(InitialPositionInStream.AT_TIMESTAMP) + .withInitialTimestampInStream(now)) + .apply( + ParDo.of( + new DoFn() { + + @ProcessElement + public void processElement(ProcessContext c) { + KinesisRecord record = c.element(); + byte[] data = record.getData().array(); + c.output(data); + } + })); + PAssert.that(output).containsInAnyOrder(inputData); + p2.run().waitUntilFinish(); + } + + private List prepareData() { + List data = newArrayList(); + for (int i = 0; i < NUM_RECORDS; i++) { + data.add(String.valueOf(i).getBytes()); + } + return data; + } + + private static final class RandomPartitioner implements KinesisPartitioner { + @Override + public String getPartitionKey(byte[] value) { + Random rand = new Random(); + int n = rand.nextInt(NUM_SHARDS) + 1; + return String.valueOf(n); + } + + @Override + public String getExplicitHashKey(byte[] value) { + return null; + } + } +} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java index 73554bb27eb1..42c4df4cb9fa 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java @@ -21,9 +21,7 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.google.common.collect.Iterables; - import java.util.List; - import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.DoFn; @@ -60,7 +58,7 @@ public void readsDataFromMockKinesis() { p.run(); } - private static class KinesisRecordToTestData extends + static class KinesisRecordToTestData extends DoFn { @ProcessElement diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockWriteTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockWriteTest.java new file mode 100644 index 000000000000..4227166d5eda --- /dev/null +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockWriteTest.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.amazonaws.http.SdkHttpMetadata; +import com.amazonaws.services.cloudwatch.AmazonCloudWatch; +import com.amazonaws.services.kinesis.AmazonKinesis; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import com.amazonaws.services.kinesis.model.DescribeStreamResult; +import com.amazonaws.services.kinesis.producer.IKinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; +import com.google.common.collect.Iterables; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KinesisIO.Write}. */ +@RunWith(JUnit4.class) +public class KinesisMockWriteTest { + private static final String STREAM = "BEAM"; + private static final String PARTITION_KEY = "partitionKey"; + + @Rule public final transient TestPipeline p = TestPipeline.create(); + @Rule public final transient TestPipeline p2 = TestPipeline.create(); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Before + public void beforeTest() { + KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); + kinesisService.init(STREAM, 1); + } + + @Test + public void testWriteBuildsCorrectly() { + Properties properties = new Properties(); + properties.setProperty("KinesisEndpoint", "localhost"); + properties.setProperty("KinesisPort", "4567"); + + KinesisIO.Write write = + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withPartitioner(new BasicKinesisPartitioner()) + .withAWSClientsProvider(new FakeKinesisProvider()) + .withProducerProperties(properties) + .withRetries(10); + + assertEquals(STREAM, write.getStreamName()); + assertEquals(PARTITION_KEY, write.getPartitionKey()); + assertEquals(properties, write.getProducerProperties()); + assertEquals(FakeKinesisProvider.class, write.getAWSClientsProvider().getClass()); + assertEquals(BasicKinesisPartitioner.class, write.getPartitioner().getClass()); + assertEquals(10, write.getRetries()); + + assertEquals("localhost", write.getProducerProperties().getProperty("KinesisEndpoint")); + assertEquals("4567", write.getProducerProperties().getProperty("KinesisPort")); + } + + @Test + public void testWriteValidationFailsMissingStreamName() { + KinesisIO.Write write = + KinesisIO.write() + .withPartitionKey(PARTITION_KEY) + .withAWSClientsProvider(new FakeKinesisProvider()); + + thrown.expect(IllegalArgumentException.class); + write.expand(null); + } + + @Test + public void testWriteValidationFailsMissingPartitioner() { + KinesisIO.Write write = + KinesisIO.write().withStreamName(STREAM).withAWSClientsProvider(new FakeKinesisProvider()); + + thrown.expect(IllegalArgumentException.class); + write.expand(null); + } + + @Test + public void testWriteValidationFailsPartitionerAndPartitioneKey() { + KinesisIO.Write write = + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withPartitioner(new BasicKinesisPartitioner()) + .withAWSClientsProvider(new FakeKinesisProvider()); + + thrown.expect(IllegalArgumentException.class); + write.expand(null); + } + + @Test + public void testWriteValidationFailsMissingAWSClientsProvider() { + KinesisIO.Write write = + KinesisIO.write().withPartitionKey(PARTITION_KEY).withStreamName(STREAM); + + thrown.expect(IllegalArgumentException.class); + write.expand(null); + } + + @Test + public void testNotExistedStream() { + Iterable data = Arrays.asList("1".getBytes()); + p.apply(Create.of(data)) + .apply( + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withAWSClientsProvider(new FakeKinesisProvider(false)) + ); + + thrown.expect(RuntimeException.class); + p.run().waitUntilFinish(); + } + + @Test + public void testSetInvalidProperty() { + Properties properties = new Properties(); + properties.setProperty("KinesisPort", "qwe"); + + Iterable data = Arrays.asList("1".getBytes()); + p.apply(Create.of(data)) + .apply( + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withAWSClientsProvider(new FakeKinesisProvider()) + .withProducerProperties(properties)); + + thrown.expect(RuntimeException.class); + p.run().waitUntilFinish(); + } + + @Test + public void testWrite() { + KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); + + Properties properties = new Properties(); + properties.setProperty("KinesisEndpoint", "localhost"); + properties.setProperty("KinesisPort", "4567"); + properties.setProperty("VerifyCertificate", "false"); + + Iterable data = Arrays.asList("1".getBytes(), "2".getBytes(), "3".getBytes()); + p.apply(Create.of(data)) + .apply( + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withAWSClientsProvider(new FakeKinesisProvider()) + .withProducerProperties(properties)); + p.run().waitUntilFinish(); + + assertEquals(3, kinesisService.getAddedRecords().get()); + } + + @Test + public void testWriteFailed() { + Iterable data = Arrays.asList("1".getBytes()); + p.apply(Create.of(data)) + .apply( + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withAWSClientsProvider(new FakeKinesisProvider().setFailedFlush(true)) + .withRetries(1)); + + thrown.expect(RuntimeException.class); + p.run().waitUntilFinish(); + } + + @Test + public void testWriteAndReadFromMockKinesis() { + KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); + + Iterable data = Arrays.asList("1".getBytes(), "2".getBytes()); + p.apply(Create.of(data)) + .apply( + KinesisIO.write() + .withStreamName(STREAM) + .withPartitionKey(PARTITION_KEY) + .withAWSClientsProvider(new FakeKinesisProvider())); + p.run().waitUntilFinish(); + assertEquals(2, kinesisService.getAddedRecords().get()); + + List> testData = kinesisService.getShardedData(); + + int noOfShards = 1; + int noOfEventsPerShard = 2; + PCollection result = + p2.apply( + KinesisIO.read() + .withStreamName(STREAM) + .withInitialPositionInStream(InitialPositionInStream.TRIM_HORIZON) + .withAWSClientsProvider(new AmazonKinesisMock.Provider(testData, 10)) + .withMaxNumRecords(noOfShards * noOfEventsPerShard)) + .apply(ParDo.of(new KinesisMockReadTest.KinesisRecordToTestData())); + PAssert.that(result).containsInAnyOrder(Iterables.concat(testData)); + p2.run().waitUntilFinish(); + } + + private static final class BasicKinesisPartitioner implements KinesisPartitioner { + @Override + public String getPartitionKey(byte[] value) { + return String.valueOf(value.length); + } + + @Override + public String getExplicitHashKey(byte[] value) { + return null; + } + } + + private static final class FakeKinesisProvider implements AWSClientsProvider { + private boolean isExistingStream = true; + private boolean isFailedFlush = false; + + public FakeKinesisProvider() { + } + + public FakeKinesisProvider(boolean isExistingStream) { + this.isExistingStream = isExistingStream; + } + + public FakeKinesisProvider setFailedFlush(boolean failedFlush) { + isFailedFlush = failedFlush; + return this; + } + + @Override + public AmazonKinesis getKinesisClient() { + return getMockedAmazonKinesisClient(); + } + + @Override + public AmazonCloudWatch getCloudWatchClient() { + throw new RuntimeException("Not implemented"); + } + + @Override + public IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config) { + return new KinesisProducerMock(config, isFailedFlush); + } + + private AmazonKinesis getMockedAmazonKinesisClient() { + int statusCode = isExistingStream ? 200 : 404; + SdkHttpMetadata httpMetadata = mock(SdkHttpMetadata.class); + when(httpMetadata.getHttpStatusCode()).thenReturn(statusCode); + + DescribeStreamResult streamResult = mock(DescribeStreamResult.class); + when(streamResult.getSdkHttpMetadata()).thenReturn(httpMetadata); + + AmazonKinesis client = mock(AmazonKinesis.class); + when(client.describeStream(any(String.class))).thenReturn(streamResult); + + return client; + } + } +} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisProducerMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisProducerMock.java new file mode 100644 index 000000000000..c3aa5a63222c --- /dev/null +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisProducerMock.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import static com.google.common.collect.Lists.newArrayList; + +import com.amazonaws.services.kinesis.producer.IKinesisProducer; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; +import com.amazonaws.services.kinesis.producer.Metric; +import com.amazonaws.services.kinesis.producer.UserRecord; +import com.amazonaws.services.kinesis.producer.UserRecordResult; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.ExecutionException; +import org.joda.time.DateTime; + +/** + * Simple mock implementation of {@link IKinesisProducer} for testing. + */ +public class KinesisProducerMock implements IKinesisProducer { + + private boolean isFailedFlush = false; + + private List addedRecords = newArrayList(); + + private KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); + + public KinesisProducerMock(){} + + public KinesisProducerMock(KinesisProducerConfiguration config, boolean isFailedFlush) { + this.isFailedFlush = isFailedFlush; + } + + @Override public ListenableFuture addUserRecord(String stream, + String partitionKey, ByteBuffer data) { + throw new RuntimeException("Not implemented"); + } + + @Override public ListenableFuture addUserRecord(UserRecord userRecord) { + throw new RuntimeException("Not implemented"); + } + + @Override public ListenableFuture addUserRecord(String stream, + String partitionKey, String explicitHashKey, ByteBuffer data) { + SettableFuture f = SettableFuture.create(); + if (kinesisService.getExistedStream().equals(stream)) { + addedRecords.add(new UserRecord(stream, partitionKey, explicitHashKey, data)); + } + return f; + } + + @Override + public int getOutstandingRecordsCount() { + return addedRecords.size(); + } + + @Override public List getMetrics(String metricName, int windowSeconds) + throws InterruptedException, ExecutionException { + throw new RuntimeException("Not implemented"); + } + + @Override public List getMetrics(String metricName) + throws InterruptedException, ExecutionException { + throw new RuntimeException("Not implemented"); + } + + @Override public List getMetrics() throws InterruptedException, ExecutionException { + throw new RuntimeException("Not implemented"); + } + + @Override public List getMetrics(int windowSeconds) + throws InterruptedException, ExecutionException { + throw new RuntimeException("Not implemented"); + } + + @Override public void destroy() { + } + + @Override public void flush(String stream) { + throw new RuntimeException("Not implemented"); + } + + @Override public void flush() { + if (isFailedFlush) { + // don't flush + return; + } + + DateTime arrival = DateTime.now(); + for (int i = 0; i < addedRecords.size(); i++) { + UserRecord record = addedRecords.get(i); + arrival = arrival.plusSeconds(1); + kinesisService.addShardedData(record.getData(), arrival); + addedRecords.remove(i); + } + } + + @Override public synchronized void flushSync() { + if (getOutstandingRecordsCount() > 0) { + flush(); + } + } +} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java index 1038a47bccb0..9fa0a8628fb1 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java @@ -21,10 +21,8 @@ import static org.assertj.core.api.Assertions.assertThat; import com.google.common.collect.Iterables; - import java.util.Iterator; import java.util.List; - import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java deleted file mode 100644 index 816af85dcaf0..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static com.google.common.base.Preconditions.checkNotNull; -import static java.util.concurrent.Executors.newSingleThreadExecutor; -import static org.assertj.core.api.Assertions.assertThat; - -import com.amazonaws.regions.Regions; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.commons.lang.RandomStringUtils; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; - -/** - * Integration test, that reads from the real Kinesis. You need to provide all {@link - * KinesisTestOptions} in order to run this. - */ -public class KinesisReaderIT { - - private static final long PIPELINE_STARTUP_TIME = TimeUnit.SECONDS.toMillis(10); - private static KinesisTestOptions options; - private ExecutorService singleThreadExecutor = newSingleThreadExecutor(); - - @Rule public final transient TestPipeline p = TestPipeline.create(); - - @BeforeClass - public static void setup() { - PipelineOptionsFactory.register(KinesisTestOptions.class); - options = TestPipeline.testingPipelineOptions().as(KinesisTestOptions.class); - } - - @Test - public void readsDataFromRealKinesisStream() - throws IOException, InterruptedException, ExecutionException { - List testData = prepareTestData(1000); - - KinesisIO.Read read = - KinesisIO.read() - .withStreamName(options.getAwsKinesisStream()) - .withInitialTimestampInStream(Instant.now()) - .withAWSClientsProvider( - options.getAwsAccessKey(), - options.getAwsSecretKey(), - Regions.fromName(options.getAwsKinesisRegion())) - .withMaxReadTime(Duration.standardMinutes(3)); - - Future future = runReadTest(read, testData); - KinesisUploader.uploadAll(testData, options); - future.get(); - } - - private static List prepareTestData(int count) { - List data = new ArrayList<>(); - for (int i = 0; i < count; ++i) { - data.add(RandomStringUtils.randomAlphabetic(32)); - } - return data; - } - - private Future runReadTest(KinesisIO.Read read, List testData) - throws InterruptedException { - PCollection result = p.apply(read).apply(ParDo.of(new RecordDataToString())); - PAssert.that(result).containsInAnyOrder(testData); - - Future future = - singleThreadExecutor.submit( - () -> { - PipelineResult result1 = p.run(); - PipelineResult.State state = result1.getState(); - while (state != PipelineResult.State.DONE && state != PipelineResult.State.FAILED) { - Thread.sleep(1000); - state = result1.getState(); - } - assertThat(state).isEqualTo(PipelineResult.State.DONE); - return null; - }); - Thread.sleep(PIPELINE_STARTUP_TIME); - return future; - } - - private static class RecordDataToString extends DoFn { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - checkNotNull(c.element(), "Null record given"); - c.output(new String(c.element().getData().array(), StandardCharsets.UTF_8)); - } - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java index c9f01bb11d3f..ff6232aee2c8 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.kinesis; import java.nio.ByteBuffer; - import org.apache.beam.sdk.testing.CoderProperties; import org.joda.time.Instant; import org.junit.Test; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisServiceMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisServiceMock.java new file mode 100644 index 000000000000..1ff0291d9b4f --- /dev/null +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisServiceMock.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import static com.google.common.collect.Lists.newArrayList; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.joda.time.DateTime; + +/** Simple mock implementation of Kinesis service for testing, singletone. */ +public class KinesisServiceMock { + private static KinesisServiceMock instance; + + // Mock stream where client is supposed to write + private String existedStream; + + private AtomicInteger addedRecords = new AtomicInteger(0); + private AtomicInteger seqNumber = new AtomicInteger(0); + private List> shardedData; + + private KinesisServiceMock() {} + + public static KinesisServiceMock getInstance() { + if (instance == null) { + synchronized (KinesisServiceMock.class) { + if (instance == null) { + instance = new KinesisServiceMock(); + } + } + } + return instance; + } + + public synchronized void init(String stream, int shardsNum) { + existedStream = stream; + addedRecords.set(0); + seqNumber.set(0); + shardedData = newArrayList(); + for (int i = 0; i < shardsNum; i++) { + List shardData = newArrayList(); + shardedData.add(shardData); + } + } + + public AtomicInteger getAddedRecords() { + return addedRecords; + } + + public String getExistedStream() { + return existedStream; + } + + public synchronized void addShardedData(ByteBuffer data, DateTime arrival) { + String dataString = StandardCharsets.UTF_8.decode(data).toString(); + + List shardData = shardedData.get(0); + + seqNumber.incrementAndGet(); + AmazonKinesisMock.TestData testData = + new AmazonKinesisMock.TestData( + dataString, arrival.toInstant(), Integer.toString(seqNumber.get())); + shardData.add(testData); + + addedRecords.incrementAndGet(); + } + + public synchronized List> getShardedData() { + return shardedData; + } +} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java index cb325620abfc..0df7bdcfc9af 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java @@ -20,10 +20,8 @@ import static org.mockito.BDDMockito.given; import com.google.common.collect.Lists; - import java.util.Collections; import java.util.List; - import org.assertj.core.api.Assertions; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java index d4784c48f11e..ec164d999d61 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java @@ -32,9 +32,7 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.model.ShardIteratorType; - import java.io.IOException; - import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Before; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java index 2c9181cb7f9e..ab5b89efad55 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java @@ -22,20 +22,18 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; - +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -50,6 +48,8 @@ @RunWith(MockitoJUnitRunner.class) public class ShardReadersPoolTest { + private static final int TIMEOUT_IN_MILLIS = (int) TimeUnit.SECONDS.toMillis(10); + @Mock private ShardRecordsIterator firstIterator, secondIterator, thirdIterator, fourthIterator; @Mock @@ -82,6 +82,11 @@ public void setUp() throws TransientKinesisException { doReturn(secondIterator).when(shardReadersPool).createShardIterator(kinesis, secondCheckpoint); } + @After + public void clean() { + shardReadersPool.stop(); + } + @Test public void shouldReturnAllRecords() throws TransientKinesisException, KinesisShardClosedException { @@ -156,7 +161,7 @@ public void shouldInterruptKinesisReadingAndStopShortly() Stopwatch stopwatch = Stopwatch.createStarted(); shardReadersPool.stop(); - assertThat(stopwatch.elapsed(TimeUnit.MILLISECONDS)).isLessThan(TimeUnit.SECONDS.toMillis(1)); + assertThat(stopwatch.elapsed(TimeUnit.MILLISECONDS)).isLessThan(TIMEOUT_IN_MILLIS); } @Test @@ -170,7 +175,7 @@ public void shouldInterruptPuttingRecordsToQueueAndStopShortly() Stopwatch stopwatch = Stopwatch.createStarted(); shardReadersPool.stop(); - assertThat(stopwatch.elapsed(TimeUnit.MILLISECONDS)).isLessThan(TimeUnit.SECONDS.toMillis(1)); + assertThat(stopwatch.elapsed(TimeUnit.MILLISECONDS)).isLessThan(TIMEOUT_IN_MILLIS); } @@ -199,10 +204,9 @@ public void shouldStopReadingShardAfterReceivingShardClosedException() throws Ex .thenReturn(Collections.emptyList()); shardReadersPool.start(); - Thread.sleep(200); - verify(firstIterator, times(1)).readNextBatch(); - verify(secondIterator, atLeast(2)).readNextBatch(); + verify(firstIterator, timeout(TIMEOUT_IN_MILLIS).times(1)).readNextBatch(); + verify(secondIterator, timeout(TIMEOUT_IN_MILLIS).atLeast(2)).readNextBatch(); } @Test @@ -213,10 +217,9 @@ public void shouldStartReadingSuccessiveShardsAfterReceivingShardClosedException .thenReturn(asList(thirdIterator, fourthIterator)); shardReadersPool.start(); - Thread.sleep(1500); - verify(thirdIterator, atLeast(2)).readNextBatch(); - verify(fourthIterator, atLeast(2)).readNextBatch(); + verify(thirdIterator, timeout(TIMEOUT_IN_MILLIS).atLeast(2)).readNextBatch(); + verify(fourthIterator, timeout(TIMEOUT_IN_MILLIS).atLeast(2)).readNextBatch(); } @Test @@ -226,9 +229,8 @@ public void shouldStopReadersPoolWhenLastShardReaderStopped() throws Exception { .thenReturn(Collections.emptyList()); shardReadersPool.start(); - Thread.sleep(200); - verify(firstIterator, times(1)).readNextBatch(); + verify(firstIterator, timeout(TIMEOUT_IN_MILLIS).times(1)).readNextBatch(); } @Test @@ -239,9 +241,8 @@ public void shouldStopReadersPoolAlsoWhenExceptionsOccurDuringStopping() throws .thenReturn(Collections.emptyList()); shardReadersPool.start(); - Thread.sleep(200); - verify(firstIterator, times(2)).readNextBatch(); + verify(firstIterator, timeout(TIMEOUT_IN_MILLIS).times(2)).readNextBatch(); } @Test @@ -260,11 +261,12 @@ public void shouldReturnAbsentOptionalWhenStartedWithNoIterators() throws Except @Test public void shouldForgetClosedShardIterator() throws Exception { when(firstIterator.readNextBatch()).thenThrow(KinesisShardClosedException.class); - when(firstIterator.findSuccessiveShardRecordIterators()) - .thenReturn(Collections.emptyList()); + List emptyList = Collections.emptyList(); + when(firstIterator.findSuccessiveShardRecordIterators()).thenReturn(emptyList); shardReadersPool.start(); - Thread.sleep(200); + verify(shardReadersPool).startReadingShards(Arrays.asList(firstIterator, secondIterator)); + verify(shardReadersPool, timeout(TIMEOUT_IN_MILLIS)).startReadingShards(emptyList); KinesisReaderCheckpoint checkpointMark = shardReadersPool.getCheckpointMark(); assertThat(checkpointMark.iterator()) diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java index 75c0ae018f15..43993f410fe4 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java @@ -39,9 +39,7 @@ import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.amazonaws.services.kinesis.model.StreamDescription; - import java.util.List; - import org.joda.time.Instant; import org.joda.time.Minutes; import org.junit.Test; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinderTest.java index 25e6711d95c1..ae96675caa28 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinderTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/StartingPointShardsFinderTest.java @@ -26,10 +26,8 @@ import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.google.common.collect.ImmutableList; - import java.util.Collections; import java.util.List; - import org.joda.time.Instant; import org.junit.Test; diff --git a/sdks/java/io/mongodb/pom.xml b/sdks/java/io/mongodb/pom.xml index 666b4b06150d..0fab9b4cc270 100644 --- a/sdks/java/io/mongodb/pom.xml +++ b/sdks/java/io/mongodb/pom.xml @@ -101,9 +101,14 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + diff --git a/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java b/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java index 880cdf0c25ba..07b4821b34b3 100644 --- a/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java +++ b/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java @@ -40,11 +40,11 @@ import java.net.ServerSocket; import java.util.ArrayList; import java.util.List; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.values.KV; @@ -272,12 +272,25 @@ public void testWrite() throws Exception { @Test public void testWriteEmptyCollection() throws Exception { - MongoDbIO.Write write = + final String emptyCollection = "empty"; + + final PCollection emptyInput = + pipeline.apply( + Create.empty( + SerializableCoder.of(Document.class))); + + emptyInput.apply( MongoDbIO.write() .withUri("mongodb://localhost:" + port) - .withDatabase("test") - .withCollection("empty"); - DoFnTester fnTester = DoFnTester.of(new MongoDbIO.Write.WriteFn(write)); - fnTester.processBundle(new ArrayList<>()); + .withDatabase(DATABASE) + .withCollection(emptyCollection)); + + pipeline.run(); + + final MongoClient client = new MongoClient("localhost", port); + final MongoDatabase database = client.getDatabase(DATABASE); + final MongoCollection collection = database.getCollection(emptyCollection); + + Assert.assertEquals(0, collection.count()); } } diff --git a/sdks/java/io/mqtt/pom.xml b/sdks/java/io/mqtt/pom.xml index 646dd50529fb..ea58135336c0 100644 --- a/sdks/java/io/mqtt/pom.xml +++ b/sdks/java/io/mqtt/pom.xml @@ -109,9 +109,14 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + org.slf4j slf4j-jdk14 diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java index 5d204b6a3efd..c23fea68ffb0 100644 --- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java @@ -30,7 +30,6 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; - import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; diff --git a/sdks/java/io/redis/pom.xml b/sdks/java/io/redis/pom.xml index 056668317fde..fdc177ff5417 100644 --- a/sdks/java/io/redis/pom.xml +++ b/sdks/java/io/redis/pom.xml @@ -76,9 +76,14 @@ org.hamcrest - hamcrest-all + hamcrest-core test + + org.hamcrest + hamcrest-library + test + com.github.kstyrc embedded-redis diff --git a/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisConnectionConfiguration.java b/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisConnectionConfiguration.java index efcc77b540ce..1a66c6d82748 100644 --- a/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisConnectionConfiguration.java +++ b/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisConnectionConfiguration.java @@ -20,13 +20,9 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; - import java.io.Serializable; - import javax.annotation.Nullable; - import org.apache.beam.sdk.transforms.display.DisplayData; - import redis.clients.jedis.Jedis; import redis.clients.jedis.Protocol; diff --git a/sdks/java/io/solr/pom.xml b/sdks/java/io/solr/pom.xml index 9967d1f587f5..b6027d597957 100644 --- a/sdks/java/io/solr/pom.xml +++ b/sdks/java/io/solr/pom.xml @@ -84,10 +84,10 @@ org.hamcrest - hamcrest-all + hamcrest-library test - + junit junit @@ -112,6 +112,13 @@ solr-test-framework 5.5.4 test + + + + jdk.tools + jdk.tools + + @@ -142,4 +149,4 @@ - \ No newline at end of file + diff --git a/sdks/java/io/solr/src/main/java/org/apache/beam/sdk/io/solr/AuthorizedSolrClient.java b/sdks/java/io/solr/src/main/java/org/apache/beam/sdk/io/solr/AuthorizedSolrClient.java index 44d7b88d9fde..2c89c870edb5 100644 --- a/sdks/java/io/solr/src/main/java/org/apache/beam/sdk/io/solr/AuthorizedSolrClient.java +++ b/sdks/java/io/solr/src/main/java/org/apache/beam/sdk/io/solr/AuthorizedSolrClient.java @@ -21,7 +21,6 @@ import java.io.Closeable; import java.io.IOException; - import org.apache.beam.sdk.io.solr.SolrIO.ConnectionConfiguration; import org.apache.solr.client.solrj.SolrClient; import org.apache.solr.client.solrj.SolrRequest; diff --git a/sdks/java/io/solr/src/test/java/org/apache/beam/sdk/io/solr/JavaBinCodecCoderTest.java b/sdks/java/io/solr/src/test/java/org/apache/beam/sdk/io/solr/JavaBinCodecCoderTest.java index 1fb435d1e20a..aad9a1db656c 100644 --- a/sdks/java/io/solr/src/test/java/org/apache/beam/sdk/io/solr/JavaBinCodecCoderTest.java +++ b/sdks/java/io/solr/src/test/java/org/apache/beam/sdk/io/solr/JavaBinCodecCoderTest.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.List; - import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.testing.CoderProperties; diff --git a/sdks/java/io/tika/pom.xml b/sdks/java/io/tika/pom.xml index 92173d17cc02..4aa481b9f875 100644 --- a/sdks/java/io/tika/pom.xml +++ b/sdks/java/io/tika/pom.xml @@ -72,6 +72,12 @@ test + + org.hamcrest + hamcrest-library + test + + org.apache.beam beam-sdks-java-core @@ -85,11 +91,6 @@ test - - org.hamcrest - hamcrest-all - test - org.apache.tika @@ -98,16 +99,4 @@ test - - - - - maven-compiler-plugin - - 1.8 - 1.8 - - - - diff --git a/sdks/java/io/xml/pom.xml b/sdks/java/io/xml/pom.xml index 99d74d4caaaf..f4783442d75e 100644 --- a/sdks/java/io/xml/pom.xml +++ b/sdks/java/io/xml/pom.xml @@ -103,11 +103,50 @@ hamcrest-core test - + org.hamcrest - hamcrest-all + hamcrest-library test + + + + + java-9 + + 9 + + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + + javax.activation + javax.activation-api + 1.2.0 + + + + diff --git a/sdks/java/java8tests/pom.xml b/sdks/java/java8tests/pom.xml deleted file mode 100644 index 195cb91cfddf..000000000000 --- a/sdks/java/java8tests/pom.xml +++ /dev/null @@ -1,83 +0,0 @@ - - - - - 4.0.0 - - - org.apache.beam - beam-sdks-java-parent - 2.4.0-SNAPSHOT - ../pom.xml - - - beam-sdks-java-java8tests - Apache Beam :: SDKs :: Java :: Java 8 Tests - Apache Beam Java SDK provides a simple, Java-based - interface for processing virtually any size data. - This artifact includes tests of the SDK from a Java 8 - user. - - - - - - org.jacoco - jacoco-maven-plugin - - - - - - - org.apache.beam - beam-sdks-java-core - test - - - - org.apache.beam - beam-runners-direct-java - test - - - - com.google.guava - guava - test - - - - joda-time - joda-time - test - - - - org.hamcrest - hamcrest-all - test - - - - junit - junit - test - - - diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryJava8Test.java deleted file mode 100644 index bc0c70bb1446..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryJava8Test.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.options; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.nullValue; -import static org.junit.Assert.assertThat; - -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 tests for {@link PipelineOptionsFactory}. - */ -@RunWith(JUnit4.class) -public class PipelineOptionsFactoryJava8Test { - @Rule public ExpectedException thrown = ExpectedException.none(); - - private interface OptionsWithDefaultMethod extends PipelineOptions { - default Number getValue() { - return 1024; - } - - void setValue(Number value); - } - - @Test - public void testDefaultMethodIgnoresDefaultImplementation() { - OptionsWithDefaultMethod optsWithDefault = - PipelineOptionsFactory.as(OptionsWithDefaultMethod.class); - assertThat(optsWithDefault.getValue(), nullValue()); - - optsWithDefault.setValue(12.25); - assertThat(optsWithDefault.getValue(), equalTo(12.25)); - } - - private interface ExtendedOptionsWithDefault extends OptionsWithDefaultMethod {} - - @Test - public void testDefaultMethodInExtendedClassIgnoresDefaultImplementation() { - OptionsWithDefaultMethod extendedOptsWithDefault = - PipelineOptionsFactory.as(ExtendedOptionsWithDefault.class); - assertThat(extendedOptsWithDefault.getValue(), nullValue()); - - extendedOptsWithDefault.setValue(Double.NEGATIVE_INFINITY); - assertThat(extendedOptsWithDefault.getValue(), equalTo(Double.NEGATIVE_INFINITY)); - } - - private interface Options extends PipelineOptions { - Number getValue(); - - void setValue(Number value); - } - - private interface SubtypeReturingOptions extends Options { - @Override - Integer getValue(); - void setValue(Integer value); - } - - @Test - public void testReturnTypeConflictThrows() throws Exception { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage( - "Method [getValue] has multiple definitions [public abstract java.lang.Integer " - + "org.apache.beam.sdk.options.PipelineOptionsFactoryJava8Test$" - + "SubtypeReturingOptions.getValue(), public abstract java.lang.Number " - + "org.apache.beam.sdk.options.PipelineOptionsFactoryJava8Test$Options" - + ".getValue()] with different return types for [" - + "org.apache.beam.sdk.options.PipelineOptionsFactoryJava8Test$" - + "SubtypeReturingOptions]."); - PipelineOptionsFactory.as(SubtypeReturingOptions.class); - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/CombineJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/CombineJava8Test.java deleted file mode 100644 index a0f7ce65f87a..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/CombineJava8Test.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.not; - -import com.google.common.collect.Iterables; -import java.io.Serializable; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.junit.Assume; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 Tests for {@link Combine}. - */ -@RunWith(JUnit4.class) -@SuppressWarnings("serial") -public class CombineJava8Test implements Serializable { - - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); - - @Rule - public transient ExpectedException thrown = ExpectedException.none(); - - /** - * Class for use in testing use of Java 8 method references. - */ - private static class Summer implements Serializable { - public int sum(Iterable integers) { - int sum = 0; - for (int i : integers) { - sum += i; - } - return sum; - } - } - - /** - * Tests creation of a global {@link Combine} via Java 8 lambda. - */ - @Test - public void testCombineGloballyLambda() { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3, 4)) - .apply(Combine.globally(integers -> { - int sum = 0; - for (int i : integers) { - sum += i; - } - return sum; - })); - - PAssert.that(output).containsInAnyOrder(10); - pipeline.run(); - } - - /** - * Tests creation of a global {@link Combine} via a Java 8 method reference. - */ - @Test - public void testCombineGloballyInstanceMethodReference() { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3, 4)) - .apply(Combine.globally(new Summer()::sum)); - - PAssert.that(output).containsInAnyOrder(10); - pipeline.run(); - } - - /** - * Tests creation of a per-key {@link Combine} via a Java 8 lambda. - */ - @Test - public void testCombinePerKeyLambda() { - - PCollection> output = pipeline - .apply(Create.of(KV.of("a", 1), KV.of("b", 2), KV.of("a", 3), KV.of("c", 4))) - .apply(Combine.perKey(integers -> { - int sum = 0; - for (int i : integers) { - sum += i; - } - return sum; - })); - - PAssert.that(output).containsInAnyOrder( - KV.of("a", 4), - KV.of("b", 2), - KV.of("c", 4)); - pipeline.run(); - } - - /** - * Tests creation of a per-key {@link Combine} via a Java 8 method reference. - */ - @Test - public void testCombinePerKeyInstanceMethodReference() { - - PCollection> output = pipeline - .apply(Create.of(KV.of("a", 1), KV.of("b", 2), KV.of("a", 3), KV.of("c", 4))) - .apply(Combine.perKey(new Summer()::sum)); - - PAssert.that(output).containsInAnyOrder( - KV.of("a", 4), - KV.of("b", 2), - KV.of("c", 4)); - pipeline.run(); - } - - /** - * Tests that we can serialize {@link Combine.CombineFn CombineFns} constructed from a lambda. - * Lambdas can be problematic because the {@link Class} object is synthetic and cannot be - * deserialized. - */ - @Test - public void testLambdaSerialization() { - SerializableFunction, Object> combiner = xs -> Iterables.getFirst(xs, 0); - - boolean lambdaClassSerializationThrows; - try { - SerializableUtils.clone(combiner.getClass()); - lambdaClassSerializationThrows = false; - } catch (IllegalArgumentException e) { - // Expected - lambdaClassSerializationThrows = true; - } - Assume.assumeTrue("Expected lambda class serialization to fail. " - + "If it's fixed, we can remove special behavior in Combine.", - lambdaClassSerializationThrows); - - - Combine.Globally combine = Combine.globally(combiner); - SerializableUtils.clone(combine); // should not throw. - } - - @Test - public void testLambdaDisplayData() { - Combine.Globally combine = Combine.globally(xs -> Iterables.getFirst(xs, 0)); - DisplayData displayData = DisplayData.from(combine); - assertThat(displayData.items(), not(empty())); - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/DistinctJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/DistinctJava8Test.java deleted file mode 100644 index 4b71a40946f3..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/DistinctJava8Test.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.hasItem; -import static org.hamcrest.Matchers.not; -import static org.junit.Assert.assertThat; - -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import java.util.HashSet; -import java.util.Set; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 tests for {@link Distinct}. - */ -@RunWith(JUnit4.class) -public class DistinctJava8Test { - - @Rule - public final transient TestPipeline p = TestPipeline.create(); - - @Rule - public ExpectedException thrown = ExpectedException.none(); - - @Test - public void withLambdaRepresentativeValuesFnAndTypeDescriptorShouldApplyFn() { - - Multimap predupedContents = HashMultimap.create(); - predupedContents.put(3, "foo"); - predupedContents.put(4, "foos"); - predupedContents.put(6, "barbaz"); - predupedContents.put(6, "bazbar"); - PCollection dupes = - p.apply(Create.of("foo", "foos", "barbaz", "barbaz", "bazbar", "foo")); - PCollection deduped = - dupes.apply( - Distinct.withRepresentativeValueFn(String::length) - .withRepresentativeType(TypeDescriptor.of(Integer.class))); - - PAssert.that(deduped).satisfies((Iterable strs) -> { - Set seenLengths = new HashSet<>(); - for (String s : strs) { - assertThat(predupedContents.values(), hasItem(s)); - assertThat(seenLengths, not(contains(s.length()))); - seenLengths.add(s.length()); - } - return null; - }); - - p.run(); - } - - @Test - public void withLambdaRepresentativeValuesFnNoTypeDescriptorShouldThrow() { - - Multimap predupedContents = HashMultimap.create(); - predupedContents.put(3, "foo"); - predupedContents.put(4, "foos"); - predupedContents.put(6, "barbaz"); - predupedContents.put(6, "bazbar"); - PCollection dupes = - p.apply(Create.of("foo", "foos", "barbaz", "barbaz", "bazbar", "foo")); - - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Unable to return a default Coder for RemoveRepresentativeDupes"); - - // Thrown when applying a transform to the internal WithKeys that withRepresentativeValueFn is - // implemented with - dupes.apply("RemoveRepresentativeDupes", Distinct.withRepresentativeValueFn(String::length)); - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/FilterJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/FilterJava8Test.java deleted file mode 100644 index b38250e6147a..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/FilterJava8Test.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import java.io.Serializable; -import org.apache.beam.sdk.coders.CannotProvideCoderException; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; -import org.apache.beam.sdk.values.PCollection; -import org.junit.Rule; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 Tests for {@link Filter}. - */ -@RunWith(JUnit4.class) -@SuppressWarnings("serial") -public class FilterJava8Test implements Serializable { - - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); - - @Rule - public transient ExpectedException thrown = ExpectedException.none(); - - @Test - @Category(ValidatesRunner.class) - public void testIdentityFilterByPredicate() { - - PCollection output = pipeline - .apply(Create.of(591, 11789, 1257, 24578, 24799, 307)) - .apply(Filter.by(i -> true)); - - PAssert.that(output).containsInAnyOrder(591, 11789, 1257, 24578, 24799, 307); - pipeline.run(); - } - - @Test - public void testNoFilterByPredicate() { - - PCollection output = pipeline - .apply(Create.of(1, 2, 4, 5)) - .apply(Filter.by(i -> false)); - - PAssert.that(output).empty(); - pipeline.run(); - } - - @Test - @Category(ValidatesRunner.class) - public void testFilterByPredicate() { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) - .apply(Filter.by(i -> i % 2 == 0)); - - PAssert.that(output).containsInAnyOrder(2, 4, 6); - pipeline.run(); - } - - /** - * Confirms that in Java 8 style, where a lambda results in a rawtype, the output type token is - * not useful. If this test ever fails there may be simplifications available to us. - */ - @Test - public void testFilterParDoOutputTypeDescriptorRaw() throws Exception { - - @SuppressWarnings({"unchecked", "rawtypes"}) - PCollection output = pipeline - .apply(Create.of("hello")) - .apply(Filter.by(s -> true)); - - thrown.expect(CannotProvideCoderException.class); - pipeline.getCoderRegistry().getCoder(output.getTypeDescriptor()); - } - - @Test - @Category(ValidatesRunner.class) - public void testFilterByMethodReference() { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3, 4, 5, 6, 7)) - .apply(Filter.by(new EvenFilter()::isEven)); - - PAssert.that(output).containsInAnyOrder(2, 4, 6); - pipeline.run(); - } - - private static class EvenFilter implements Serializable { - public boolean isEven(int i) { - return i % 2 == 0; - } - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsJava8Test.java deleted file mode 100644 index 501b0d1bd8aa..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsJava8Test.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import com.google.common.collect.ImmutableList; -import java.io.Serializable; -import java.util.List; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptors; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 Tests for {@link FlatMapElements}. - */ -@RunWith(JUnit4.class) -public class FlatMapElementsJava8Test implements Serializable { - - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); - - @Rule - public transient ExpectedException thrown = ExpectedException.none(); - - /** - * Basic test of {@link FlatMapElements} with a lambda (which is instantiated as a - * {@link SerializableFunction}). - */ - @Test - public void testFlatMapBasic() throws Exception { - PCollection output = pipeline - .apply(Create.of(1, 2, 3)) - .apply(FlatMapElements - // Note that the input type annotation is required. - .into(TypeDescriptors.integers()) - .via((Integer i) -> ImmutableList.of(i, -i))); - - PAssert.that(output).containsInAnyOrder(1, 3, -1, -3, 2, -2); - pipeline.run(); - } - - /** - * Basic test of {@link FlatMapElements} with a method reference. - */ - @Test - public void testFlatMapMethodReference() throws Exception { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3)) - .apply(FlatMapElements - // Note that the input type annotation is required. - .into(TypeDescriptors.integers()) - .via(new Negater()::numAndNegation)); - - PAssert.that(output).containsInAnyOrder(1, 3, -1, -3, 2, -2); - pipeline.run(); - } - - private static class Negater implements Serializable { - public List numAndNegation(int input) { - return ImmutableList.of(input, -input); - } - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/MapElementsJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/MapElementsJava8Test.java deleted file mode 100644 index dbd5ef3d209e..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/MapElementsJava8Test.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import java.io.Serializable; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptors; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 tests for {@link MapElements}. - */ -@RunWith(JUnit4.class) -public class MapElementsJava8Test implements Serializable { - - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); - - /** - * Basic test of {@link MapElements} with a lambda (which is instantiated as a {@link - * SerializableFunction}). - */ - @Test - public void testMapLambda() throws Exception { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3)) - .apply(MapElements - // Note that the type annotation is required. - .into(TypeDescriptors.integers()) - .via((Integer i) -> i * 2)); - - PAssert.that(output).containsInAnyOrder(6, 2, 4); - pipeline.run(); - } - - /** - * Basic test of {@link MapElements} with a lambda wrapped into a {@link SimpleFunction} to - * remember its type. - */ - @Test - public void testMapWrappedLambda() throws Exception { - - PCollection output = - pipeline - .apply(Create.of(1, 2, 3)) - .apply( - MapElements - .via(new SimpleFunction((Integer i) -> i * 2) {})); - - PAssert.that(output).containsInAnyOrder(6, 2, 4); - pipeline.run(); - } - - /** - * Basic test of {@link MapElements} with a method reference. - */ - @Test - public void testMapMethodReference() throws Exception { - - PCollection output = pipeline - .apply(Create.of(1, 2, 3)) - .apply(MapElements - // Note that the type annotation is required. - .into(TypeDescriptors.integers()) - .via(new Doubler()::doubleIt)); - - PAssert.that(output).containsInAnyOrder(6, 2, 4); - pipeline.run(); - } - - private static class Doubler implements Serializable { - public int doubleIt(int val) { - return val * 2; - } - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/PartitionJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/PartitionJava8Test.java deleted file mode 100644 index 94353a5bf7df..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/PartitionJava8Test.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import static org.junit.Assert.assertEquals; - -import java.io.Serializable; -import org.apache.beam.sdk.coders.CannotProvideCoderException; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.PCollectionList; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 Tests for {@link Filter}. - */ -@RunWith(JUnit4.class) -@SuppressWarnings("serial") -public class PartitionJava8Test implements Serializable { - - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); - - @Rule - public transient ExpectedException thrown = ExpectedException.none(); - - @Test - public void testModPartition() { - - - PCollectionList outputs = pipeline - .apply(Create.of(1, 2, 4, 5)) - .apply(Partition.of(3, (element, numPartitions) -> element % numPartitions)); - assertEquals(3, outputs.size()); - PAssert.that(outputs.get(0)).empty(); - PAssert.that(outputs.get(1)).containsInAnyOrder(1, 4); - PAssert.that(outputs.get(2)).containsInAnyOrder(2, 5); - pipeline.run(); - } - - /** - * Confirms that in Java 8 style, where a lambda results in a rawtype, the output type token is - * not useful. If this test ever fails there may be simplifications available to us. - */ - @Test - public void testPartitionFnOutputTypeDescriptorRaw() throws Exception { - - PCollectionList output = pipeline - .apply(Create.of("hello")) - .apply(Partition.of(1, (element, numPartitions) -> 0)); - - thrown.expect(CannotProvideCoderException.class); - pipeline.getCoderRegistry().getCoder(output.get(0).getTypeDescriptor()); - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionJava8Test.java deleted file mode 100644 index 327fa589536d..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/SimpleFunctionJava8Test.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import static org.hamcrest.Matchers.equalTo; -import static org.junit.Assert.assertThat; - -import java.io.Serializable; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.TypeDescriptors; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 tests for {@link SimpleFunction}. - */ -@RunWith(JUnit4.class) -public class SimpleFunctionJava8Test implements Serializable { - - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); - - /** - * Basic test of {@link MapElements} with a lambda (which is instantiated as a {@link - * SerializableFunction}). - */ - @Test - public void testGoodTypeForLambda() throws Exception { - SimpleFunction fn = new SimpleFunction(Object::toString) {}; - - assertThat(fn.getInputTypeDescriptor(), equalTo(TypeDescriptors.integers())); - assertThat(fn.getOutputTypeDescriptor(), equalTo(TypeDescriptors.strings())); - } - - /** - * Basic test of {@link MapElements} with a lambda wrapped into a {@link SimpleFunction} to - * remember its type. - */ - @Test - public void testGoodTypeForMethodRef() throws Exception { - SimpleFunction fn = - new SimpleFunction(SimpleFunctionJava8Test::toStringThisThing) {}; - - assertThat(fn.getInputTypeDescriptor(), equalTo(TypeDescriptors.integers())); - assertThat(fn.getOutputTypeDescriptor(), equalTo(TypeDescriptors.strings())); - } - - private static String toStringThisThing(Integer i) { - return i.toString(); - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/WithKeysJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/WithKeysJava8Test.java deleted file mode 100644 index 34e42aca2555..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/WithKeysJava8Test.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.junit.Rule; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 Tests for {@link WithKeys}. - */ -@RunWith(JUnit4.class) -public class WithKeysJava8Test { - - @Rule - public final transient TestPipeline p = TestPipeline.create(); - - @Rule - public ExpectedException thrown = ExpectedException.none(); - - @Test - @Category(ValidatesRunner.class) - public void withLambdaAndTypeDescriptorShouldSucceed() { - - - PCollection values = p.apply(Create.of("1234", "3210", "0", "-12")); - PCollection> kvs = values.apply( - WithKeys.of((SerializableFunction) Integer::valueOf) - .withKeyType(TypeDescriptor.of(Integer.class))); - - PAssert.that(kvs).containsInAnyOrder( - KV.of(1234, "1234"), KV.of(0, "0"), KV.of(-12, "-12"), KV.of(3210, "3210")); - - p.run(); - } - - @Test - public void withLambdaAndNoTypeDescriptorShouldThrow() { - - PCollection values = p.apply(Create.of("1234", "3210", "0", "-12")); - - values.apply("ApplyKeysWithWithKeys", WithKeys.of(Integer::valueOf)); - - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Unable to return a default Coder for ApplyKeysWithWithKeys"); - - p.run(); - } -} diff --git a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsJava8Test.java b/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsJava8Test.java deleted file mode 100644 index ee23d95dabf0..000000000000 --- a/sdks/java/java8tests/src/test/java/org/apache/beam/sdk/transforms/WithTimestampsJava8Test.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.transforms; - -import java.io.Serializable; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.joda.time.Instant; -import org.junit.Rule; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Java 8 tests for {@link WithTimestamps}. - */ -@RunWith(JUnit4.class) -public class WithTimestampsJava8Test implements Serializable { - - @Rule - public final transient TestPipeline p = TestPipeline.create(); - - @Test - @Category(ValidatesRunner.class) - public void withTimestampsLambdaShouldApplyTimestamps() { - - final String yearTwoThousand = "946684800000"; - PCollection timestamped = - p.apply(Create.of("1234", "0", Integer.toString(Integer.MAX_VALUE), yearTwoThousand)) - .apply(WithTimestamps.of((String input) -> new Instant(Long.valueOf(input)))); - - PCollection> timestampedVals = - timestamped.apply(ParDo.of(new DoFn>() { - @ProcessElement - public void processElement(ProcessContext c) - throws Exception { - c.output(KV.of(c.element(), c.timestamp())); - } - })); - - PAssert.that(timestamped) - .containsInAnyOrder(yearTwoThousand, "0", "1234", Integer.toString(Integer.MAX_VALUE)); - PAssert.that(timestampedVals) - .containsInAnyOrder( - KV.of("0", new Instant(0)), - KV.of("1234", new Instant(Long.valueOf("1234"))), - KV.of(Integer.toString(Integer.MAX_VALUE), new Instant(Integer.MAX_VALUE)), - KV.of(yearTwoThousand, new Instant(Long.valueOf(yearTwoThousand)))); - - p.run(); - } -} diff --git a/sdks/java/javadoc/pom.xml b/sdks/java/javadoc/pom.xml index ec9369db925b..17b11cd93b25 100644 --- a/sdks/java/javadoc/pom.xml +++ b/sdks/java/javadoc/pom.xml @@ -87,11 +87,26 @@ beam-sdks-java-core + + org.apache.beam + beam-sdks-java-extensions-google-cloud-platform-core + + org.apache.beam beam-sdks-java-extensions-join-library + + org.apache.beam + beam-sdks-java-extensions-json-jackson + + + + org.apache.beam + beam-sdks-java-extensions-protobuf + + org.apache.beam beam-sdks-java-extensions-sketching @@ -127,6 +142,11 @@ beam-sdks-java-io-cassandra + + org.apache.beam + beam-sdks-java-io-elasticsearch + + org.apache.beam beam-sdks-java-io-elasticsearch-tests-2 @@ -139,7 +159,8 @@ org.apache.beam - beam-sdks-java-io-elasticsearch + beam-sdks-java-io-elasticsearch-tests-common + tests @@ -157,6 +178,11 @@ beam-sdks-java-io-hadoop-file-system + + org.apache.beam + beam-sdks-java-io-hadoop-input-format + + org.apache.beam beam-sdks-java-io-hbase @@ -197,11 +223,26 @@ beam-sdks-java-io-mqtt + + org.apache.beam + beam-sdks-java-io-redis + + org.apache.beam beam-sdks-java-io-solr + + org.apache.beam + beam-sdks-java-io-tika + + + + org.apache.beam + beam-sdks-java-io-xml + + com.google.auto.service @@ -232,10 +273,16 @@ org.hamcrest - hamcrest-all + hamcrest-core compile + + org.hamcrest + hamcrest-library + compile + + junit junit diff --git a/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml b/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml index 8efac51398ad..e71461689835 100644 --- a/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml +++ b/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml @@ -243,6 +243,9 @@ spark-runner + + 4.0.43.Final + org.apache.beam @@ -274,6 +277,22 @@ ${jackson.version} runtime + + + org.apache.beam + beam-sdks-java-io-google-cloud-platform + ${beam.version} + + + io.grpc + grpc-netty + + + io.netty + netty-handler + + + @@ -381,7 +400,13 @@ which is used in the main code of DebuggingWordCount example. --> org.hamcrest - hamcrest-all + hamcrest-core + ${hamcrest.version} + + + + org.hamcrest + hamcrest-library ${hamcrest.version} @@ -401,7 +426,7 @@ org.mockito - mockito-all + mockito-core ${mockito.version} test diff --git a/sdks/java/nexmark/pom.xml b/sdks/java/nexmark/pom.xml index f9dfb7bfa921..77f88f2a1334 100644 --- a/sdks/java/nexmark/pom.xml +++ b/sdks/java/nexmark/pom.xml @@ -73,6 +73,10 @@ spark-runner + + + 4.0.43.Final + org.apache.beam @@ -114,6 +118,22 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + analyze-only + + + + org.hamcrest:hamcrest-library:jar:${hamcrest.version} + + + + + org.apache.maven.plugins maven-shade-plugin @@ -253,12 +273,13 @@ hamcrest-core compile - + org.hamcrest - hamcrest-all + hamcrest-library + compile - + org.apache.beam beam-runners-direct-java diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkLauncher.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkLauncher.java index 9b8d09c3fe71..d634260dc1d1 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkLauncher.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkLauncher.java @@ -82,6 +82,8 @@ import org.apache.beam.sdk.nexmark.queries.sql.SqlQuery1; import org.apache.beam.sdk.nexmark.queries.sql.SqlQuery2; import org.apache.beam.sdk.nexmark.queries.sql.SqlQuery3; +import org.apache.beam.sdk.nexmark.queries.sql.SqlQuery5; +import org.apache.beam.sdk.nexmark.queries.sql.SqlQuery7; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; @@ -1206,7 +1208,11 @@ private List createSqlQueries() { new NexmarkSqlQuery(configuration, new SqlQuery0()), new NexmarkSqlQuery(configuration, new SqlQuery1()), new NexmarkSqlQuery(configuration, new SqlQuery2(configuration.auctionSkip)), - new NexmarkSqlQuery(configuration, new SqlQuery3(configuration))); + new NexmarkSqlQuery(configuration, new SqlQuery3(configuration)), + null, + new NexmarkSqlQuery(configuration, new SqlQuery5(configuration)), + null, + new NexmarkSqlQuery(configuration, new SqlQuery7(configuration))); } private List createJavaQueries() { diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkOptions.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkOptions.java index b9c8861b680e..0c5c1c1c3685 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkOptions.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkOptions.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.nexmark; import javax.annotation.Nullable; - import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.options.ApplicationNameOptions; import org.apache.beam.sdk.options.Default; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java index fc0ab9f7fe52..3eb6f79c3ea3 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java @@ -282,11 +282,6 @@ public int stepLengthSec(int ratePeriodSec) { */ private static final boolean LOG_INFO = false; - /** - * Set to true to capture all error messages. The logging level flags don't currently work. - */ - private static final boolean LOG_ERROR = true; - /** * Set to true to log directly to stdout. If run using Google Dataflow, you can watch the results * in real-time with: tail -f /var/log/dataflow/streaming-harness/harness-stdout.log diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/Event.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/Event.java index 880cfe4d405c..a07cbb2ba266 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/Event.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/Event.java @@ -31,6 +31,7 @@ * {@link Bid}. */ public class Event implements KnownSize, Serializable { + private enum Tag { PERSON(0), AUCTION(1), @@ -42,6 +43,7 @@ private enum Tag { this.value = value; } } + private static final Coder INT_CODER = VarIntCoder.of(); public static final Coder CODER = diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/BeamRecordSize.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/RowSize.java similarity index 53% rename from sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/BeamRecordSize.java rename to sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/RowSize.java index e0a5f3c4c969..179fca03aa71 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/BeamRecordSize.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/RowSize.java @@ -22,78 +22,80 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.sql.Types; import java.util.Map; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.extensions.sql.RowSqlType; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoder; +import org.apache.beam.sdk.extensions.sql.SqlTypeCoders; import org.apache.beam.sdk.nexmark.model.KnownSize; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; /** - * {@link KnownSize} implementation to estimate the size of a {@link BeamRecord}, + * {@link KnownSize} implementation to estimate the size of a {@link Row}, * similar to Java model. NexmarkLauncher/Queries infrastructure expects the events to * be able to quickly provide the estimates of their sizes. * - *

    The {@link BeamRecord} size is calculated at creation time. + *

    The {@link Row} size is calculated at creation time. * - *

    Field sizes are sizes of Java types described in {@link BeamRecordSqlType}. Except strings, + *

    Field sizes are sizes of Java types described in {@link RowSqlType}. Except strings, * which are assumed to be taking 1-byte per character plus 1 byte size. */ -public class BeamRecordSize implements KnownSize { +public class RowSize implements KnownSize { private static final Coder LONG_CODER = VarLongCoder.of(); - public static final Coder CODER = new CustomCoder() { + public static final Coder CODER = new CustomCoder() { @Override - public void encode(BeamRecordSize beamRecordSize, OutputStream outStream) + public void encode(RowSize rowSize, OutputStream outStream) throws CoderException, IOException { - LONG_CODER.encode(beamRecordSize.sizeInBytes(), outStream); + LONG_CODER.encode(rowSize.sizeInBytes(), outStream); } @Override - public BeamRecordSize decode(InputStream inStream) throws CoderException, IOException { - return new BeamRecordSize(LONG_CODER.decode(inStream)); + public RowSize decode(InputStream inStream) throws CoderException, IOException { + return new RowSize(LONG_CODER.decode(inStream)); } }; - private static final Map ESTIMATED_FIELD_SIZES = - ImmutableMap.builder() - .put(Types.TINYINT, bytes(Byte.SIZE)) - .put(Types.SMALLINT, bytes(Short.SIZE)) - .put(Types.INTEGER, bytes(Integer.SIZE)) - .put(Types.BIGINT, bytes(Long.SIZE)) - .put(Types.FLOAT, bytes(Float.SIZE)) - .put(Types.DOUBLE, bytes(Double.SIZE)) - .put(Types.DECIMAL, 32) - .put(Types.BOOLEAN, 1) - .put(Types.TIME, bytes(Long.SIZE)) - .put(Types.DATE, bytes(Long.SIZE)) - .put(Types.TIMESTAMP, bytes(Long.SIZE)) + private static final Map ESTIMATED_FIELD_SIZES = + ImmutableMap.builder() + .put(SqlTypeCoders.TINYINT, bytes(Byte.SIZE)) + .put(SqlTypeCoders.SMALLINT, bytes(Short.SIZE)) + .put(SqlTypeCoders.INTEGER, bytes(Integer.SIZE)) + .put(SqlTypeCoders.BIGINT, bytes(Long.SIZE)) + .put(SqlTypeCoders.FLOAT, bytes(Float.SIZE)) + .put(SqlTypeCoders.DOUBLE, bytes(Double.SIZE)) + .put(SqlTypeCoders.DECIMAL, 32) + .put(SqlTypeCoders.BOOLEAN, 1) + .put(SqlTypeCoders.TIME, bytes(Long.SIZE)) + .put(SqlTypeCoders.DATE, bytes(Long.SIZE)) + .put(SqlTypeCoders.TIMESTAMP, bytes(Long.SIZE)) .build(); - public static ParDo.SingleOutput parDo() { - return ParDo.of(new DoFn() { + public static ParDo.SingleOutput parDo() { + return ParDo.of(new DoFn() { @ProcessElement public void processElement(ProcessContext c) { - c.output(BeamRecordSize.of(c.element())); + c.output(RowSize.of(c.element())); } }); } - public static BeamRecordSize of(BeamRecord beamRecord) { - return new BeamRecordSize(sizeInBytes(beamRecord)); + public static RowSize of(Row row) { + return new RowSize(sizeInBytes(row)); } - private static long sizeInBytes(BeamRecord beamRecord) { - BeamRecordSqlType recordType = (BeamRecordSqlType) beamRecord.getDataType(); + private static long sizeInBytes(Row row) { + RowType rowType = row.getRowType(); long size = 1; // nulls bitset - for (int fieldIndex = 0; fieldIndex < recordType.getFieldCount(); fieldIndex++) { - Integer fieldType = recordType.getFieldTypeByIndex(fieldIndex); + for (int fieldIndex = 0; fieldIndex < rowType.getFieldCount(); fieldIndex++) { + Coder fieldType = rowType.getFieldCoder(fieldIndex); Integer estimatedSize = ESTIMATED_FIELD_SIZES.get(fieldType); @@ -103,7 +105,7 @@ private static long sizeInBytes(BeamRecord beamRecord) { } if (isString(fieldType)) { - size += beamRecord.getString(fieldIndex).length() + 1; + size += row.getString(fieldIndex).length() + 1; continue; } @@ -115,7 +117,7 @@ private static long sizeInBytes(BeamRecord beamRecord) { private long sizeInBytes; - private BeamRecordSize(long sizeInBytes) { + private RowSize(long sizeInBytes) { this.sizeInBytes = sizeInBytes; } @@ -124,8 +126,9 @@ public long sizeInBytes() { return sizeInBytes; } - private static boolean isString(Integer fieldType) { - return fieldType == Types.CHAR || fieldType == Types.VARCHAR; + private static boolean isString(Coder fieldType) { + return SqlTypeCoders.CHAR.equals(fieldType) + || SqlTypeCoders.VARCHAR.equals(fieldType); } private static Integer bytes(int size) { diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/ToBeamRecord.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/ToRow.java similarity index 77% rename from sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/ToBeamRecord.java rename to sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/ToRow.java index 942bb50d8943..606fb02a3768 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/ToBeamRecord.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/ToRow.java @@ -25,22 +25,22 @@ import org.apache.beam.sdk.nexmark.model.sql.adapter.ModelFieldsAdapter; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; +import org.apache.beam.sdk.values.Row; /** - * Convert Java model object to BeamRecord. + * Convert Java model object to Row. */ -public class ToBeamRecord { +public class ToRow { - static final ToBeamRecord INSTANCE = new ToBeamRecord(ModelAdaptersMapping.ADAPTERS); + static final ToRow INSTANCE = new ToRow(ModelAdaptersMapping.ADAPTERS); private Map modelTypeAdapters; - private ToBeamRecord(Map modelTypeAdapters) { + private ToRow(Map modelTypeAdapters) { this.modelTypeAdapters = modelTypeAdapters; } - private BeamRecord toRecord(Event event) { + private Row toRow(Event event) { if (event == null) { return null; } @@ -55,7 +55,10 @@ private BeamRecord toRecord(Event event) { ModelFieldsAdapter adapter = modelTypeAdapters.get(modelClass); - return new BeamRecord(adapter.getRecordType(), adapter.getFieldsValues(model)); + return Row + .withRowType(adapter.getRowType()) + .addValues(adapter.getFieldsValues(model)) + .build(); } private KnownSize getModel(Event event) { @@ -70,12 +73,12 @@ private KnownSize getModel(Event event) { throw new IllegalStateException("Unsupported event type " + event); } - public static ParDo.SingleOutput parDo() { - return ParDo.of(new DoFn() { + public static ParDo.SingleOutput parDo() { + return ParDo.of(new DoFn() { @ProcessElement public void processElement(ProcessContext c) { - BeamRecord beamRecord = INSTANCE.toRecord(c.element()); - c.output(beamRecord); + Row row = INSTANCE.toRow(c.element()); + c.output(row); } }); } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMapping.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMapping.java index 177d5891bdf4..cf88dd36d452 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMapping.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMapping.java @@ -21,9 +21,10 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; +import java.util.Date; import java.util.List; import java.util.Map; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.nexmark.model.Auction; import org.apache.beam.sdk.nexmark.model.Bid; import org.apache.beam.sdk.nexmark.model.Person; @@ -42,52 +43,64 @@ public class ModelAdaptersMapping { private static ModelFieldsAdapter personAdapter() { return new ModelFieldsAdapter( - BeamRecordSqlType.builder() + RowSqlType.builder() .withBigIntField("id") .withVarcharField("name") .withVarcharField("emailAddress") .withVarcharField("creditCard") .withVarcharField("city") .withVarcharField("state") - .withBigIntField("dateTime") + .withTimestampField("dateTime") .withVarcharField("extra") .build()) { @Override public List getFieldsValues(Person p) { return Collections.unmodifiableList( Arrays.asList( - p.id, p.name, p.emailAddress, p.creditCard, p.city, p.state, p.dateTime, p.extra)); + p.id, + p.name, + p.emailAddress, + p.creditCard, + p.city, + p.state, + new Date(p.dateTime), + p.extra)); } }; } private static ModelFieldsAdapter bidAdapter() { return new ModelFieldsAdapter( - BeamRecordSqlType.builder() + RowSqlType.builder() .withBigIntField("auction") .withBigIntField("bidder") .withBigIntField("price") - .withBigIntField("dateTime") + .withTimestampField("dateTime") .withVarcharField("extra") .build()) { @Override public List getFieldsValues(Bid b) { return Collections.unmodifiableList( - Arrays.asList(b.auction, b.bidder, b.price, b.dateTime, b.extra)); + Arrays.asList( + b.auction, + b.bidder, + b.price, + new Date(b.dateTime), + b.extra)); } }; } private static ModelFieldsAdapter auctionAdapter() { return new ModelFieldsAdapter( - BeamRecordSqlType.builder() + RowSqlType.builder() .withBigIntField("id") .withVarcharField("itemName") .withVarcharField("description") .withBigIntField("initialBid") .withBigIntField("reserve") - .withBigIntField("dateTime") - .withBigIntField("expires") + .withTimestampField("dateTime") + .withTimestampField("expires") .withBigIntField("seller") .withBigIntField("category") .withVarcharField("extra") @@ -101,8 +114,8 @@ public List getFieldsValues(Auction a) { a.description, a.initialBid, a.reserve, - a.dateTime, - a.expires, + new Date(a.dateTime), + new Date(a.expires), a.seller, a.category, a.extra)); diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelFieldsAdapter.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelFieldsAdapter.java index cf43cc324895..c610380a1b7d 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelFieldsAdapter.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelFieldsAdapter.java @@ -19,23 +19,21 @@ package org.apache.beam.sdk.nexmark.model.sql.adapter; import java.util.List; - -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; -import org.apache.beam.sdk.values.BeamRecordType; +import org.apache.beam.sdk.values.RowType; /** - * Helper class to help map Java model fields to Beam SQL Record Type fields. + * Helper class to help map Java model fields to {@link RowType} fields. */ public abstract class ModelFieldsAdapter { - private BeamRecordSqlType recordType; + private RowType rowType; - ModelFieldsAdapter(BeamRecordSqlType recordType) { - this.recordType = recordType; + ModelFieldsAdapter(RowType rowType) { + this.rowType = rowType; } - public BeamRecordType getRecordType() { - return recordType; + public RowType getRowType() { + return rowType; } public abstract List getFieldsValues(T model); diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/package-info.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/package-info.java index b9554a8a92ea..6c79ad3edbb2 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/package-info.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/model/sql/adapter/package-info.java @@ -21,6 +21,6 @@ */ /** - * Model adapter which contains a mapping between specific Java model to a BeamRecord. + * Model adapter which contains a mapping between specific Java model to a Row. */ package org.apache.beam.sdk.nexmark.model.sql.adapter; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AbstractSimulator.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AbstractSimulator.java index 6f4ad568a7db..6bb4e978d6cb 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AbstractSimulator.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AbstractSimulator.java @@ -22,7 +22,6 @@ import java.util.Iterator; import java.util.List; import javax.annotation.Nullable; - import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.TimestampedValue; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AuctionOrBid.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AuctionOrBid.java new file mode 100644 index 000000000000..2c8b12fd2fc1 --- /dev/null +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/AuctionOrBid.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.nexmark.queries; + +import org.apache.beam.sdk.nexmark.model.Event; +import org.apache.beam.sdk.transforms.SerializableFunction; + +/** A predicate to filter for only auctions and bids. */ +public class AuctionOrBid implements SerializableFunction { + @Override + public Boolean apply(Event input) { + return input.bid != null || input.newAuction != null; + } +} diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query0Model.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query0Model.java index 0e73a21079d3..29c89d8fff1c 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query0Model.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query0Model.java @@ -19,7 +19,6 @@ import java.util.Collection; import java.util.Iterator; - import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.nexmark.model.Event; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java index 8d13a20bac53..24d4d044f4e7 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java @@ -62,7 +62,6 @@ */ public class Query10 extends NexmarkQuery { private static final Logger LOG = LoggerFactory.getLogger(Query10.class); - private static final int CHANNEL_BUFFER = 8 << 20; // 8MB private static final int NUM_SHARDS_PER_WORKER = 5; private static final Duration LATE_BATCHING_PERIOD = Duration.standardSeconds(10); diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query1Model.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query1Model.java index 76c182ade4a9..6d8f5b3c3c87 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query1Model.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query1Model.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.Collection; import java.util.Iterator; - import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.nexmark.model.Bid; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query2Model.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query2Model.java index 33a1f8d2ccd0..54ff947bea60 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query2Model.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query2Model.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.Collection; import java.util.Iterator; - import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.nexmark.model.AuctionPrice; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3Model.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3Model.java index 94f24cb54530..e05af4f1df2a 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3Model.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3Model.java @@ -19,13 +19,11 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; - import java.io.Serializable; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; - import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.nexmark.model.Auction; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query4.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query4.java index b59d173dd66b..d3b1e233b092 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query4.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query4.java @@ -27,6 +27,7 @@ import org.apache.beam.sdk.nexmark.model.Event; import org.apache.beam.sdk.nexmark.model.KnownSize; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.Mean; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.SlidingWindows; @@ -50,13 +51,14 @@ * } * *

    For extra spiciness our implementation differs slightly from the above: + * *

      - *
    • We select both the average winning price and the category. - *
    • We don't bother joining with a static category table, since it's contents are never used. - *
    • We only consider bids which are above the auction's reserve price. - *
    • We accept the highest-price, earliest valid bid as the winner. - *
    • We calculate the averages oven a sliding window of size {@code windowSizeSec} and - * period {@code windowPeriodSec}. + *
    • We select both the average winning price and the category. + *
    • We don't bother joining with a static category table, since it's contents are never used. + *
    • We only consider bids which are above the auction's reserve price. + *
    • We accept the highest-price, earliest valid bid as the winner. + *
    • We calculate the averages oven a sliding window of size {@code windowSizeSec} and period + * {@code windowPeriodSec}. *
    */ public class Query4 extends NexmarkQuery { @@ -70,12 +72,13 @@ public Query4(NexmarkConfiguration configuration) { private PCollection applyTyped(PCollection events) { PCollection winningBids = events + .apply(Filter.by(new AuctionOrBid())) // Find the winning bid for each closed auction. .apply(new WinningBids(name + ".WinningBids", configuration)); // Monitor winning bids - winningBids = winningBids.apply(name + ".WinningBidsMonitor", - winningBidsMonitor.getTransform()); + winningBids = + winningBids.apply(name + ".WinningBidsMonitor", winningBidsMonitor.getTransform()); return winningBids // Key the winning bid price by the auction category. diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query6.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query6.java index f7bb38656cf1..eeae79acc92a 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query6.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query6.java @@ -30,6 +30,7 @@ import org.apache.beam.sdk.nexmark.model.SellerPrice; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.AfterPane; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; @@ -40,8 +41,8 @@ import org.joda.time.Duration; /** - * Query 6, 'Average Selling Price by Seller'. Select the average selling price over the - * last 10 closed auctions by the same seller. In CQL syntax: + * Query 6, 'Average Selling Price by Seller'. Select the average selling price over the last 10 + * closed auctions by the same seller. In CQL syntax: * *
    {@code
      * SELECT Istream(AVG(Q.final), Q.seller)
    @@ -113,6 +114,7 @@ public Query6(NexmarkConfiguration configuration) {
     
       private PCollection applyTyped(PCollection events) {
         return events
    +        .apply(Filter.by(new AuctionOrBid()))
             // Find the winning bid for each closed auction.
             .apply(new WinningBids(name + ".WinningBids", configuration))
     
    diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query7Model.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query7Model.java
    index 401174697e9c..19aa23d4d97e 100644
    --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query7Model.java
    +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query7Model.java
    @@ -22,7 +22,6 @@
     import java.util.Collection;
     import java.util.Iterator;
     import java.util.List;
    -
     import org.apache.beam.sdk.nexmark.NexmarkConfiguration;
     import org.apache.beam.sdk.nexmark.NexmarkUtils;
     import org.apache.beam.sdk.nexmark.model.Bid;
    diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9.java
    index 5f11e4e6a106..af0f514a5c0a 100644
    --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9.java
    +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9.java
    @@ -22,11 +22,12 @@
     import org.apache.beam.sdk.nexmark.model.AuctionBid;
     import org.apache.beam.sdk.nexmark.model.Event;
     import org.apache.beam.sdk.nexmark.model.KnownSize;
    +import org.apache.beam.sdk.transforms.Filter;
     import org.apache.beam.sdk.values.PCollection;
     
     /**
    - * Query "9", 'Winning bids'. Select just the winning bids. Not in original NEXMark suite, but
    - * handy for testing. See {@link WinningBids} for the details.
    + * Query "9", 'Winning bids'. Select just the winning bids. Not in original NEXMark suite, but handy
    + * for testing. See {@link WinningBids} for the details.
      */
     public class Query9 extends NexmarkQuery {
       public Query9(NexmarkConfiguration configuration) {
    @@ -34,7 +35,9 @@ public Query9(NexmarkConfiguration configuration) {
       }
     
       private PCollection applyTyped(PCollection events) {
    -    return events.apply(new WinningBids(name, configuration));
    +    return events
    +        .apply(Filter.by(new AuctionOrBid()))
    +        .apply(new WinningBids(name, configuration));
       }
     
       @Override
    diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9Model.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9Model.java
    index 48d792ed5446..deb0096893fc 100644
    --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9Model.java
    +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query9Model.java
    @@ -20,7 +20,6 @@
     import java.io.Serializable;
     import java.util.Collection;
     import java.util.Iterator;
    -
     import org.apache.beam.sdk.nexmark.NexmarkConfiguration;
     import org.apache.beam.sdk.values.TimestampedValue;
     
    diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBids.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBids.java
    index 7ccdc951aa1e..fea096be18ae 100644
    --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBids.java
    +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBids.java
    @@ -69,12 +69,11 @@
      * GROUP BY A.id
      * }
    * - *

    We will also check that the winning bid is above the auction reserve. Note that - * we ignore the auction opening bid value since it has no impact on which bid eventually wins, - * if any. + *

    We will also check that the winning bid is above the auction reserve. Note that we ignore the + * auction opening bid value since it has no impact on which bid eventually wins, if any. * - *

    Our implementation will use a custom windowing function in order to bring bids and - * auctions together without requiring global state. + *

    Our implementation will use a custom windowing function in order to bring bids and auctions + * together without requiring global state. */ public class WinningBids extends PTransform, PCollection> { /** Windows for open auctions and bids. */ @@ -83,9 +82,9 @@ private static class AuctionOrBidWindow extends IntervalWindow { public final long auction; /** - * True if this window represents an actual auction, and thus has a start/end - * time matching that of the auction. False if this window represents a bid, and - * thus has an unbounded start/end time. + * True if this window represents an actual auction, and thus has a start/end time matching that + * of the auction. False if this window represents a bid, and thus has an unbounded start/end + * time. */ public final boolean isAuctionWindow; @@ -109,10 +108,9 @@ public static AuctionOrBidWindow forAuction(Instant timestamp, Auction auction) } /** - * Return a bid window for {@code bid}. It should later be merged into - * the corresponding auction window. However, it is possible this bid is for an already - * expired auction, or for an auction which the system has not yet seen. So we - * give the bid a bit of wiggle room in its interval. + * Return a bid window for {@code bid}. It should later be merged into the corresponding auction + * window. However, it is possible this bid is for an already expired auction, or for an auction + * which the system has not yet seen. So we give the bid a bit of wiggle room in its interval. */ public static AuctionOrBidWindow forBid( long expectedAuctionDurationMs, Instant timestamp, Bid bid) { @@ -137,11 +135,13 @@ public boolean isAuctionWindow() { @Override public String toString() { - return String.format("AuctionOrBidWindow{start:%s; end:%s; auction:%d; isAuctionWindow:%s}", + return String.format( + "AuctionOrBidWindow{start:%s; end:%s; auction:%d; isAuctionWindow:%s}", start(), end(), auction, isAuctionWindow); } - @Override public boolean equals(Object o) { + @Override + public boolean equals(Object o) { if (this == o) { return true; } @@ -155,14 +155,13 @@ public String toString() { return (isAuctionWindow == that.isAuctionWindow) && (auction == that.auction); } - @Override public int hashCode() { + @Override + public int hashCode() { return Objects.hash(super.hashCode(), isAuctionWindow, auction); } } - /** - * Encodes an {@link AuctionOrBidWindow} as an {@link IntervalWindow} and an auction id long. - */ + /** Encodes an {@link AuctionOrBidWindow} as an {@link IntervalWindow} and an auction id long. */ private static class AuctionOrBidWindowCoder extends CustomCoder { private static final AuctionOrBidWindowCoder INSTANCE = new AuctionOrBidWindowCoder(); private static final Coder SUPER_CODER = IntervalWindow.getCoder(); @@ -183,8 +182,7 @@ public void encode(AuctionOrBidWindow window, OutputStream outStream) } @Override - public AuctionOrBidWindow decode(InputStream inStream) - throws IOException, CoderException { + public AuctionOrBidWindow decode(InputStream inStream) throws IOException, CoderException { IntervalWindow superWindow = SUPER_CODER.decode(inStream); long auction = ID_CODER.decode(inStream); boolean isAuctionWindow = INT_CODER.decode(inStream) != 0; @@ -192,7 +190,8 @@ public AuctionOrBidWindow decode(InputStream inStream) superWindow.start(), superWindow.end(), auction, isAuctionWindow); } - @Override public void verifyDeterministic() throws NonDeterministicException {} + @Override + public void verifyDeterministic() throws NonDeterministicException {} @Override public Object structuralValue(AuctionOrBidWindow value) { @@ -214,16 +213,18 @@ public Collection assignWindows(AssignContext c) { Event event = c.element(); if (event.newAuction != null) { // Assign auctions to an auction window which expires at the auction's close. - return Collections - .singletonList(AuctionOrBidWindow.forAuction(c.timestamp(), event.newAuction)); + return Collections.singletonList( + AuctionOrBidWindow.forAuction(c.timestamp(), event.newAuction)); } else if (event.bid != null) { // Assign bids to a temporary bid window which will later be merged into the appropriate // auction window. return Collections.singletonList( AuctionOrBidWindow.forBid(expectedAuctionDurationMs, c.timestamp(), event.bid)); } else { - // Don't assign people to any window. They will thus be dropped. - return Collections.emptyList(); + throw new IllegalArgumentException( + String.format( + "%s can only assign windows to auctions and bids, but received %s", + getClass().getSimpleName(), c.element())); } } @@ -281,27 +282,25 @@ public WindowMappingFn getDefaultWindowMappingFn() { } /** - * Below we will GBK auctions and bids on their auction ids. Then we will reduce those - * per id to emit {@code (auction, winning bid)} pairs for auctions which have expired with at - * least one valid bid. We would like those output pairs to have a timestamp of the auction's - * expiry (since that's the earliest we know for sure we have the correct winner). We would - * also like to make that winning results are available to following stages at the auction's - * expiry. + * Below we will GBK auctions and bids on their auction ids. Then we will reduce those per id to + * emit {@code (auction, winning bid)} pairs for auctions which have expired with at least one + * valid bid. We would like those output pairs to have a timestamp of the auction's expiry + * (since that's the earliest we know for sure we have the correct winner). We would also like + * to make that winning results are available to following stages at the auction's expiry. * *

    Each result of the GBK will have a timestamp of the min of the result of this object's * assignOutputTime over all records which end up in one of its iterables. Thus we get the * desired behavior if we ignore each record's timestamp and always return the auction window's * 'maxTimestamp', which will correspond to the auction's expiry. * - *

    In contrast, if this object's assignOutputTime were to return 'inputTimestamp' - * (the usual implementation), then each GBK record will take as its timestamp the minimum of - * the timestamps of all bids and auctions within it, which will always be the auction's - * timestamp. An auction which expires well into the future would thus hold up the watermark - * of the GBK results until that auction expired. That in turn would hold up all winning pairs. + *

    In contrast, if this object's assignOutputTime were to return 'inputTimestamp' (the usual + * implementation), then each GBK record will take as its timestamp the minimum of the + * timestamps of all bids and auctions within it, which will always be the auction's timestamp. + * An auction which expires well into the future would thus hold up the watermark of the GBK + * results until that auction expired. That in turn would hold up all winning pairs. */ @Override - public Instant getOutputTime( - Instant inputTimestamp, AuctionOrBidWindow window) { + public Instant getOutputTime(Instant inputTimestamp, AuctionOrBidWindow window) { return window.maxTimestamp(); } } @@ -311,9 +310,10 @@ public Instant getOutputTime( public WinningBids(String name, NexmarkConfiguration configuration) { super(name); // What's the expected auction time (when the system is running at the lowest event rate). - long[] interEventDelayUs = configuration.rateShape.interEventDelayUs( - configuration.firstEventRate, configuration.nextEventRate, - configuration.rateUnit, configuration.numEventGenerators); + long[] interEventDelayUs = + configuration.rateShape.interEventDelayUs( + configuration.firstEventRate, configuration.nextEventRate, + configuration.rateUnit, configuration.numEventGenerators); long longestDelayUs = 0; for (long interEventDelayU : interEventDelayUs) { longestDelayUs = Math.max(longestDelayUs, interEventDelayU); @@ -321,7 +321,7 @@ public WinningBids(String name, NexmarkConfiguration configuration) { // Adjust for proportion of auction events amongst all events. longestDelayUs = (longestDelayUs * GeneratorConfig.PROPORTION_DENOMINATOR) - / GeneratorConfig.AUCTION_PROPORTION; + / GeneratorConfig.AUCTION_PROPORTION; // Adjust for number of in-flight auctions. longestDelayUs = longestDelayUs * configuration.numInFlightAuctions; long expectedAuctionDurationMs = (longestDelayUs + 999) / 1000; @@ -338,8 +338,9 @@ public PCollection expand(PCollection events) { // Key auctions by their id. PCollection> auctionsById = - events.apply(NexmarkQuery.JUST_NEW_AUCTIONS) - .apply("AuctionById:", NexmarkQuery.AUCTION_BY_ID); + events + .apply(NexmarkQuery.JUST_NEW_AUCTIONS) + .apply("AuctionById:", NexmarkQuery.AUCTION_BY_ID); // Key bids by their auction id. PCollection> bidsByAuctionId = @@ -403,7 +404,8 @@ public int hashCode() { return Objects.hash(auctionOrBidWindowFn); } - @Override public boolean equals(Object o) { + @Override + public boolean equals(Object o) { if (this == o) { return true; } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBidsSimulator.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBidsSimulator.java index 1cf3a54c435f..a42f3857ca2d 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBidsSimulator.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/WinningBidsSimulator.java @@ -25,7 +25,6 @@ import java.util.TreeMap; import java.util.TreeSet; import javax.annotation.Nullable; - import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.nexmark.model.Auction; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/NexmarkSqlQuery.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/NexmarkSqlQuery.java index 229ed425a693..434651938155 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/NexmarkSqlQuery.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/NexmarkSqlQuery.java @@ -22,11 +22,11 @@ import org.apache.beam.sdk.nexmark.NexmarkUtils; import org.apache.beam.sdk.nexmark.model.Event; import org.apache.beam.sdk.nexmark.model.KnownSize; -import org.apache.beam.sdk.nexmark.model.sql.BeamRecordSize; +import org.apache.beam.sdk.nexmark.model.sql.RowSize; import org.apache.beam.sdk.nexmark.queries.NexmarkQuery; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; /** * Executor for Nexmark queries. Allows to decouple from NexmarkQuery @@ -34,21 +34,21 @@ */ public class NexmarkSqlQuery extends NexmarkQuery { - private PTransform, PCollection> queryTransform; + private PTransform, PCollection> queryTransform; public NexmarkSqlQuery(NexmarkConfiguration configuration, - PTransform, PCollection> queryTransform) { + PTransform, PCollection> queryTransform) { super(configuration, queryTransform.getName()); this.queryTransform = queryTransform; } @Override protected PCollection applyPrim(PCollection events) { - PCollection queryResults = events.apply(queryTransform); + PCollection queryResults = events.apply(queryTransform); PCollection resultRecordsSizes = queryResults - .apply(BeamRecordSize.parDo()) - .setCoder(BeamRecordSize.CODER); + .apply(RowSize.parDo()) + .setCoder(RowSize.CODER); return NexmarkUtils.castToKnownSize(name, resultRecordsSizes); } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0.java index 17e5f0cefcf6..10f5cafca4e4 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0.java @@ -23,22 +23,23 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import org.apache.beam.sdk.coders.BeamRecordCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.BeamSql; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.nexmark.model.Bid; import org.apache.beam.sdk.nexmark.model.Event; -import org.apache.beam.sdk.nexmark.model.sql.ToBeamRecord; +import org.apache.beam.sdk.nexmark.model.sql.ToRow; import org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.Row; /** * Query 0: Pass events through unchanged. @@ -49,9 +50,9 @@ *

    {@link Bid} events are used here at the moment, ås they are most numerous * with default configuration. */ -public class SqlQuery0 extends PTransform, PCollection> { +public class SqlQuery0 extends PTransform, PCollection> { - private static final BeamSql.SimpleQueryTransform QUERY = + private static final PTransform> QUERY = BeamSql.query("SELECT * FROM PCOLLECTION"); public SqlQuery0() { @@ -59,23 +60,23 @@ public SqlQuery0() { } @Override - public PCollection expand(PCollection allEvents) { + public PCollection expand(PCollection allEvents) { - BeamRecordCoder bidRecordCoder = getBidRecordCoder(); + RowCoder bidRowCoder = getBidRowCoder(); - PCollection bidEventsRecords = allEvents + PCollection bidEventsRows = allEvents .apply(Filter.by(IS_BID)) - .apply(ToBeamRecord.parDo()) - .apply(getName() + ".Serialize", logBytesMetric(bidRecordCoder)) - .setCoder(bidRecordCoder); + .apply(ToRow.parDo()) + .apply(getName() + ".Serialize", logBytesMetric(bidRowCoder)) + .setCoder(bidRowCoder); - return bidEventsRecords.apply(QUERY).setCoder(bidRecordCoder); + return bidEventsRows.apply(QUERY).setCoder(bidRowCoder); } - private PTransform, PCollection> logBytesMetric( - final BeamRecordCoder coder) { + private PTransform, PCollection> logBytesMetric( + final RowCoder coder) { - return ParDo.of(new DoFn() { + return ParDo.of(new DoFn() { private final Counter bytesMetric = Metrics.counter(name , "bytes"); @ProcessElement @@ -85,13 +86,13 @@ public void processElement(ProcessContext c) throws CoderException, IOException byte[] byteArray = outStream.toByteArray(); bytesMetric.inc((long) byteArray.length); ByteArrayInputStream inStream = new ByteArrayInputStream(byteArray); - BeamRecord record = coder.decode(inStream, Coder.Context.OUTER); - c.output(record); + Row row = coder.decode(inStream, Coder.Context.OUTER); + c.output(row); } }); } - private BeamRecordCoder getBidRecordCoder() { - return ModelAdaptersMapping.ADAPTERS.get(Bid.class).getRecordType().getRecordCoder(); + private RowCoder getBidRowCoder() { + return ModelAdaptersMapping.ADAPTERS.get(Bid.class).getRowType().getRowCoder(); } } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1.java index 1c44558545b0..aa23c3b25f4e 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1.java @@ -20,16 +20,17 @@ import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; import static org.apache.beam.sdk.nexmark.queries.NexmarkQuery.IS_BID; -import org.apache.beam.sdk.coders.BeamRecordCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.BeamSql; import org.apache.beam.sdk.nexmark.model.Bid; import org.apache.beam.sdk.nexmark.model.Event; -import org.apache.beam.sdk.nexmark.model.sql.ToBeamRecord; +import org.apache.beam.sdk.nexmark.model.sql.ToRow; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.Row; /** * Query 1, 'Currency Conversion'. Convert each bid value from dollars to euros. @@ -43,11 +44,11 @@ *

    To make things more interesting, allow the 'currency conversion' to be arbitrarily * slowed down. */ -public class SqlQuery1 extends PTransform, PCollection> { +public class SqlQuery1 extends PTransform, PCollection> { - private static final BeamSql.SimpleQueryTransform QUERY = BeamSql + private static final PTransform> QUERY = BeamSql .query("SELECT auction, bidder, DolToEur(price) as price, dateTime, extra FROM PCOLLECTION") - .withUdf("DolToEur", new DolToEur()); + .registerUdf("DolToEur", new DolToEur()); /** * Dollar to Euro conversion. @@ -64,18 +65,18 @@ public SqlQuery1() { } @Override - public PCollection expand(PCollection allEvents) { - BeamRecordCoder bidRecordCoder = getBidRecordCoder(); + public PCollection expand(PCollection allEvents) { + RowCoder bidRecordCoder = getBidRowCoder(); - PCollection bidEventsRecords = allEvents + PCollection bidEventsRows = allEvents .apply(Filter.by(IS_BID)) - .apply(ToBeamRecord.parDo()) + .apply(ToRow.parDo()) .setCoder(bidRecordCoder); - return bidEventsRecords.apply(QUERY); + return bidEventsRows.apply(QUERY); } - private BeamRecordCoder getBidRecordCoder() { - return ADAPTERS.get(Bid.class).getRecordType().getRecordCoder(); + private RowCoder getBidRowCoder() { + return ADAPTERS.get(Bid.class).getRowType().getRowCoder(); } } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2.java index 959c71216aca..1700206555fe 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2.java @@ -20,15 +20,16 @@ import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; import static org.apache.beam.sdk.nexmark.queries.NexmarkQuery.IS_BID; -import org.apache.beam.sdk.coders.BeamRecordCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.BeamSql; import org.apache.beam.sdk.nexmark.model.Bid; import org.apache.beam.sdk.nexmark.model.Event; -import org.apache.beam.sdk.nexmark.model.sql.ToBeamRecord; +import org.apache.beam.sdk.nexmark.model.sql.ToRow; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.Row; /** * Query 2, 'Filtering. Find bids with specific auction ids and show their bid price. @@ -44,13 +45,13 @@ * arbitrary size. To make it more interesting we instead choose bids for every * {@code skipFactor}'th auction. */ -public class SqlQuery2 extends PTransform, PCollection> { +public class SqlQuery2 extends PTransform, PCollection> { private static final String QUERY_TEMPLATE = "SELECT auction, bidder, price, dateTime, extra FROM PCOLLECTION " + " WHERE MOD(auction, %d) = 0"; - private final BeamSql.SimpleQueryTransform query; + private final PTransform> query; public SqlQuery2(long skipFactor) { super("SqlQuery2"); @@ -60,18 +61,18 @@ public SqlQuery2(long skipFactor) { } @Override - public PCollection expand(PCollection allEvents) { - BeamRecordCoder bidRecordCoder = getBidRecordCoder(); + public PCollection expand(PCollection allEvents) { + RowCoder bidRecordCoder = getBidRowCoder(); - PCollection bidEventsRecords = allEvents + PCollection bidEventsRows = allEvents .apply(Filter.by(IS_BID)) - .apply(ToBeamRecord.parDo()) + .apply(ToRow.parDo()) .setCoder(bidRecordCoder); - return bidEventsRecords.apply(query); + return bidEventsRows.apply(query); } - private BeamRecordCoder getBidRecordCoder() { - return ADAPTERS.get(Bid.class).getRecordType().getRecordCoder(); + private RowCoder getBidRowCoder() { + return ADAPTERS.get(Bid.class).getRowType().getRowCoder(); } } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3.java index 9f5b29b33aa5..b233c65f79b9 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3.java @@ -19,23 +19,23 @@ import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; -import org.apache.beam.sdk.coders.BeamRecordCoder; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.sql.BeamSql; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.model.Auction; import org.apache.beam.sdk.nexmark.model.Event; import org.apache.beam.sdk.nexmark.model.Person; -import org.apache.beam.sdk.nexmark.model.sql.ToBeamRecord; +import org.apache.beam.sdk.nexmark.model.sql.ToRow; import org.apache.beam.sdk.nexmark.queries.Query3; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Duration; @@ -69,7 +69,7 @@ * *

    Correct join semantics implementation is tracked in BEAM-3190, BEAM-3191 */ -public class SqlQuery3 extends PTransform, PCollection> { +public class SqlQuery3 extends PTransform, PCollection> { private static final String QUERY_NAME = SqlQuery3.class.getSimpleName(); @@ -82,7 +82,7 @@ public class SqlQuery3 extends PTransform, PCollection expand(PCollection allEvents) { + public PCollection expand(PCollection allEvents) { PCollection windowed = fixedWindows(allEvents); - PCollection auctions = filter(windowed, e -> e.newAuction != null, Auction.class); - PCollection people = filter(windowed, e -> e.newPerson != null, Person.class); + PCollection auctions = filter(windowed, e -> e.newAuction != null, Auction.class); + PCollection people = filter(windowed, e -> e.newPerson != null, Person.class); PCollectionTuple inputStreams = createStreamsTuple(auctions, people); - return inputStreams.apply(BeamSql.queryMulti(QUERY_STRING)).setCoder(OUTPUT_RECORD_CODER); + return + inputStreams + .apply(BeamSql.query(QUERY_STRING)) + .setCoder(OUTPUT_RECORD_CODER); } private PCollection fixedWindows(PCollection events) { @@ -110,15 +113,15 @@ private PCollection fixedWindows(PCollection events) { } private PCollectionTuple createStreamsTuple( - PCollection auctions, - PCollection people) { + PCollection auctions, + PCollection people) { return PCollectionTuple .of(new TupleTag<>("Auction"), auctions) .and(new TupleTag<>("Person"), people); } - private PCollection filter( + private PCollection filter( PCollection allEvents, SerializableFunction filter, Class clazz) { @@ -127,22 +130,23 @@ private PCollection filter( return allEvents .apply(QUERY_NAME + ".Filter." + modelName, Filter.by(filter)) - .apply(QUERY_NAME + ".ToRecords." + modelName, ToBeamRecord.parDo()) + .apply(QUERY_NAME + ".ToRecords." + modelName, ToRow.parDo()) .setCoder(getRecordCoder(clazz)); } - private BeamRecordCoder getRecordCoder(Class modelClass) { - return ADAPTERS.get(modelClass).getRecordType().getRecordCoder(); + private RowCoder getRecordCoder(Class modelClass) { + return ADAPTERS.get(modelClass).getRowType().getRowCoder(); } - private static BeamRecordCoder createOutputRecordCoder() { - BeamRecordSqlType outputRecordType = BeamRecordSqlType.builder() - .withVarcharField("name") - .withVarcharField("city") - .withVarcharField("state") - .withBigIntField("id") - .build(); - - return outputRecordType.getRecordCoder(); + private static RowCoder createOutputRecordCoder() { + return + RowSqlType + .builder() + .withVarcharField("name") + .withVarcharField("city") + .withVarcharField("state") + .withBigIntField("id") + .build() + .getRowCoder(); } } diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery5.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery5.java new file mode 100644 index 000000000000..fad41f9e9016 --- /dev/null +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery5.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.nexmark.queries.sql; + +import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; +import static org.apache.beam.sdk.nexmark.queries.NexmarkQuery.IS_BID; + +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.extensions.sql.BeamSql; +import org.apache.beam.sdk.nexmark.NexmarkConfiguration; +import org.apache.beam.sdk.nexmark.model.Bid; +import org.apache.beam.sdk.nexmark.model.Event; +import org.apache.beam.sdk.nexmark.model.sql.ToRow; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; + +/** + * Query 5, 'Hot Items'. Which auctions have seen the most bids in the last hour (updated every + * minute). In CQL syntax: + * + *

    {@code
    + * SELECT Rstream(auction)
    + * FROM (SELECT B1.auction, count(*) AS num
    + *       FROM Bid [RANGE 60 MINUTE SLIDE 1 MINUTE] B1
    + *       GROUP BY B1.auction)
    + * WHERE num >= ALL (SELECT count(*)
    + *                   FROM Bid [RANGE 60 MINUTE SLIDE 1 MINUTE] B2
    + *                   GROUP BY B2.auction);
    + * }
    + * + *

    To make things a bit more dynamic and easier to test we use much shorter windows, and + * we'll also preserve the bid counts.

    + */ +public class SqlQuery5 extends PTransform, PCollection> { + + private static final String QUERY_TEMPLATE = "" + + " SELECT auction " + + " FROM (SELECT B1.auction, count(*) AS num, " + + " HOP_START(B1.dateTime, INTERVAL '%1$d' SECOND, " + + " INTERVAL '%2$d' SECOND) AS starttime " + + " FROM Bid B1 " + + " GROUP BY B1.auction, " + + " HOP(B1.dateTime, INTERVAL '%1$d' SECOND, " + + " INTERVAL '%2$d' SECOND)) B1 " + + " JOIN (SELECT max(B2.num) AS maxnum, B2.starttime " + + " FROM (SELECT count(*) AS num, " + + " HOP_START(B2.dateTime, INTERVAL '%1$d' SECOND, " + + " INTERVAL '%2$d' SECOND) AS starttime " + + " FROM Bid B2 " + + " GROUP BY B2.auction, " + + " HOP(B2.dateTime, INTERVAL '%1$d' SECOND, " + + " INTERVAL '%2$d' SECOND)) B2 " + + " GROUP BY B2.starttime) B2 " + + " ON B1.starttime = B2.starttime AND B1.num >= B2.maxnum "; + + private final PTransform> query; + + public SqlQuery5(NexmarkConfiguration configuration) { + super("SqlQuery5"); + + String queryString = String.format(QUERY_TEMPLATE, + configuration.windowPeriodSec, + configuration.windowSizeSec); + query = BeamSql.query(queryString); + } + + @Override + public PCollection expand(PCollection allEvents) { + RowCoder bidRecordCoder = getBidRowCoder(); + + PCollection bids = allEvents + .apply(Filter.by(IS_BID)) + .apply(ToRow.parDo()) + .setCoder(bidRecordCoder); + + return PCollectionTuple.of(new TupleTag<>("Bid"), bids) + .apply(query); + } + + private RowCoder getBidRowCoder() { + return ADAPTERS.get(Bid.class).getRowType().getRowCoder(); + } +} diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery7.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery7.java new file mode 100644 index 000000000000..33cb4ebf8934 --- /dev/null +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery7.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.nexmark.queries.sql; + +import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; +import static org.apache.beam.sdk.nexmark.queries.NexmarkQuery.IS_BID; + +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.extensions.sql.BeamSql; +import org.apache.beam.sdk.nexmark.NexmarkConfiguration; +import org.apache.beam.sdk.nexmark.model.Bid; +import org.apache.beam.sdk.nexmark.model.Event; +import org.apache.beam.sdk.nexmark.model.sql.ToRow; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; + +/** + * Query 7, 'Highest Bid'. Select the bids with the highest bid + * price in the last minute. In CQL syntax: + * + *
    + * SELECT Rstream(B.auction, B.price, B.bidder)
    + * FROM Bid [RANGE 1 MINUTE SLIDE 1 MINUTE] B
    + * WHERE B.price = (SELECT MAX(B1.price)
    + *                  FROM BID [RANGE 1 MINUTE SLIDE 1 MINUTE] B1);
    + * 
    + * + *

    We will use a shorter window to help make testing easier.

    + */ +public class SqlQuery7 extends PTransform, PCollection> { + + private static final String QUERY_TEMPLATE = "" + + " SELECT B.auction, B.price, B.bidder " + + " FROM (SELECT B.auction, B.price, B.bidder, " + + " TUMBLE_START(B.dateTime, INTERVAL '%1$d' SECOND) AS starttime " + + " FROM Bid B " + + " GROUP BY B.auction, B.price, B.bidder, " + + " TUMBLE(B.dateTime, INTERVAL '%1$d' SECOND)) B " + + " JOIN (SELECT MAX(B1.price) AS maxprice, " + + " TUMBLE_START(B1.dateTime, INTERVAL '%1$d' SECOND) AS starttime " + + " FROM Bid B1 " + + " GROUP BY TUMBLE(B1.dateTime, INTERVAL '%1$d' SECOND)) B1 " + + " ON B.starttime = B1.starttime AND B.price = B1.maxprice "; + + private final PTransform> query; + + public SqlQuery7(NexmarkConfiguration configuration) { + super("SqlQuery7"); + + String queryString = String.format(QUERY_TEMPLATE, + configuration.windowSizeSec); + query = BeamSql.query(queryString); + } + + @Override + public PCollection expand(PCollection allEvents) { + RowCoder bidRecordCoder = getBidRowCoder(); + + PCollection bids = allEvents + .apply(Filter.by(IS_BID)) + .apply(ToRow.parDo()) + .setCoder(bidRecordCoder); + + return PCollectionTuple.of(new TupleTag<>("Bid"), bids) + .apply(query); + } + + private RowCoder getBidRowCoder() { + return ADAPTERS.get(Bid.class).getRowType().getRowCoder(); + } +} diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java index f43486dc0ddd..741e3a882a70 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java @@ -23,9 +23,7 @@ import java.util.PriorityQueue; import java.util.Queue; import java.util.concurrent.ThreadLocalRandom; - import javax.annotation.Nullable; - import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.nexmark.NexmarkUtils; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorCheckpoint.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorCheckpoint.java index fa4173905e5f..dfaf113efbc4 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorCheckpoint.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorCheckpoint.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; - import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CustomCoder; diff --git a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorConfig.java b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorConfig.java index 7c862faa652f..135b97cf9468 100644 --- a/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorConfig.java +++ b/sdks/java/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/generator/GeneratorConfig.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; - import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.model.Event; import org.apache.beam.sdk.values.KV; diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/BeamRecordSizeTest.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/RowSizeTest.java similarity index 65% rename from sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/BeamRecordSizeTest.java rename to sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/RowSizeTest.java index 2c38b15c5ca8..54b88bbaf47e 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/BeamRecordSizeTest.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/RowSizeTest.java @@ -22,29 +22,28 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.math.BigDecimal; import java.util.Date; import java.util.GregorianCalendar; -import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; /** - * Unit tests for {@link BeamRecordSize}. + * Unit tests for {@link RowSize}. */ -public class BeamRecordSizeTest { +public class RowSizeTest { - private static final BeamRecordSqlType RECORD_TYPE = BeamRecordSqlType.builder() + private static final RowType ROW_TYPE = RowSqlType.builder() .withTinyIntField("f_tinyint") .withSmallIntField("f_smallint") .withIntegerField("f_int") @@ -60,53 +59,55 @@ public class BeamRecordSizeTest { .withVarcharField("f_varchar") .build(); - private static final List VALUES = - ImmutableList.of( - (byte) 1, - (short) 2, - (int) 3, - (long) 4, - (float) 5.12, - (double) 6.32, - new BigDecimal(7), - false, - new GregorianCalendar(2019, 03, 02), - new Date(10L), - new Date(11L), - "12", - "13"); + private static final long ROW_SIZE = 91L; - private static final long RECORD_SIZE = 91L; - - private static final BeamRecord RECORD = new BeamRecord(RECORD_TYPE, VALUES); + private static final Row ROW = + Row + .withRowType(ROW_TYPE) + .addValues( + (byte) 1, + (short) 2, + (int) 3, + (long) 4, + (float) 5.12, + (double) 6.32, + new BigDecimal(7), + false, + new GregorianCalendar(2019, 03, 02), + new Date(10L), + new Date(11L), + "12", + "13") + .build(); @Rule public TestPipeline testPipeline = TestPipeline.create(); @Rule public ExpectedException thrown = ExpectedException.none(); @Test public void testCalculatesCorrectSize() throws Exception { - assertEquals(RECORD_SIZE, BeamRecordSize.of(RECORD).sizeInBytes()); + assertEquals(ROW_SIZE, RowSize.of(ROW).sizeInBytes()); } @Test public void testParDoConvertsToRecordSize() throws Exception { - PCollection records = testPipeline.apply( - TestStream.create(RECORD_TYPE.getRecordCoder()) - .addElements(RECORD) + PCollection rows = testPipeline.apply( + TestStream + .create(ROW_TYPE.getRowCoder()) + .addElements(ROW) .advanceWatermarkToInfinity()); PAssert - .that(records) + .that(rows) .satisfies(new CorrectSize()); testPipeline.run(); } - static class CorrectSize implements SerializableFunction, Void> { + static class CorrectSize implements SerializableFunction, Void> { @Override - public Void apply(Iterable input) { - BeamRecordSize recordSize = BeamRecordSize.of(Iterables.getOnlyElement(input)); - assertThat(recordSize.sizeInBytes(), equalTo(RECORD_SIZE)); + public Void apply(Iterable input) { + RowSize recordSize = RowSize.of(Iterables.getOnlyElement(input)); + assertThat(recordSize.sizeInBytes(), equalTo(ROW_SIZE)); return null; } } diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/ToBeamRecordTest.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/ToRowTest.java similarity index 69% rename from sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/ToBeamRecordTest.java rename to sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/ToRowTest.java index c13138d85cfe..3a08698c849a 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/ToBeamRecordTest.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/ToRowTest.java @@ -18,24 +18,26 @@ package org.apache.beam.sdk.nexmark.model.sql; +import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; + import org.apache.beam.sdk.nexmark.model.Auction; import org.apache.beam.sdk.nexmark.model.Bid; import org.apache.beam.sdk.nexmark.model.Event; import org.apache.beam.sdk.nexmark.model.Person; -import org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping; +import org.apache.beam.sdk.nexmark.model.sql.adapter.ModelFieldsAdapter; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; /** - * Unit tests for {@link ToBeamRecord}. + * Unit tests for {@link ToRow}. */ -public class ToBeamRecordTest { +public class ToRowTest { private static final Person PERSON = new Person(3L, "name", "email", "cc", "city", "state", 329823L, "extra"); @@ -57,14 +59,11 @@ public void testConvertsBids() throws Exception { .addElements(new Event(BID)) .advanceWatermarkToInfinity()); - BeamRecord expectedBidRecord = - new BeamRecord( - ModelAdaptersMapping.ADAPTERS.get(Bid.class).getRecordType(), - ModelAdaptersMapping.ADAPTERS.get(Bid.class).getFieldsValues(BID)); + Row expectedBidRow = toRow(BID); PAssert - .that(bids.apply(ToBeamRecord.parDo())) - .containsInAnyOrder(expectedBidRecord); + .that(bids.apply(ToRow.parDo())) + .containsInAnyOrder(expectedBidRow); testPipeline.run(); } @@ -76,14 +75,11 @@ public void testConvertsPeople() throws Exception { .addElements(new Event(PERSON)) .advanceWatermarkToInfinity()); - BeamRecord expectedPersonRecord = - new BeamRecord( - ModelAdaptersMapping.ADAPTERS.get(Person.class).getRecordType(), - ModelAdaptersMapping.ADAPTERS.get(Person.class).getFieldsValues(PERSON)); + Row expectedPersonRow = toRow(PERSON); PAssert - .that(people.apply(ToBeamRecord.parDo())) - .containsInAnyOrder(expectedPersonRecord); + .that(people.apply(ToRow.parDo())) + .containsInAnyOrder(expectedPersonRow); testPipeline.run(); } @@ -95,15 +91,20 @@ public void testConvertsAuctions() throws Exception { .addElements(new Event(AUCTION)) .advanceWatermarkToInfinity()); - BeamRecord expectedAuctionRecord = - new BeamRecord( - ModelAdaptersMapping.ADAPTERS.get(Auction.class).getRecordType(), - ModelAdaptersMapping.ADAPTERS.get(Auction.class).getFieldsValues(AUCTION)); + Row expectedAuctionRow = toRow(AUCTION); PAssert - .that(auctions.apply(ToBeamRecord.parDo())) - .containsInAnyOrder(expectedAuctionRecord); + .that(auctions.apply(ToRow.parDo())) + .containsInAnyOrder(expectedAuctionRow); testPipeline.run(); } + + private static Row toRow(Object obj) { + ModelFieldsAdapter adapter = ADAPTERS.get(obj.getClass()); + return Row + .withRowType(adapter.getRowType()) + .addValues(adapter.getFieldsValues(obj)) + .build(); + } } diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMappingTest.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMappingTest.java index 1eccfffc804f..1506afa528b8 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMappingTest.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/model/sql/adapter/ModelAdaptersMappingTest.java @@ -22,11 +22,13 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import java.util.Date; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.nexmark.model.Auction; import org.apache.beam.sdk.nexmark.model.Bid; import org.apache.beam.sdk.nexmark.model.Person; +import org.apache.beam.sdk.values.RowType; import org.junit.Test; /** @@ -37,39 +39,39 @@ public class ModelAdaptersMappingTest { private static final Person PERSON = new Person(3L, "name", "email", "cc", "city", "state", 329823L, "extra"); - private static final BeamRecordSqlType PERSON_RECORD_TYPE = BeamRecordSqlType.builder() + private static final RowType PERSON_ROW_TYPE = RowSqlType.builder() .withBigIntField("id") .withVarcharField("name") .withVarcharField("emailAddress") .withVarcharField("creditCard") .withVarcharField("city") .withVarcharField("state") - .withBigIntField("dateTime") + .withTimestampField("dateTime") .withVarcharField("extra") .build(); private static final Bid BID = new Bid(5L, 3L, 123123L, 43234234L, "extra2"); - private static final BeamRecordSqlType BID_RECORD_TYPE = BeamRecordSqlType.builder() + private static final RowType BID_ROW_TYPE = RowSqlType.builder() .withBigIntField("auction") .withBigIntField("bidder") .withBigIntField("price") - .withBigIntField("dateTime") + .withTimestampField("dateTime") .withVarcharField("extra") .build(); private static final Auction AUCTION = new Auction(5L, "item", "desc", 342L, 321L, 3423342L, 2349234L, 3L, 1L, "extra3"); - private static final BeamRecordSqlType AUCTION_RECORD_TYPE = BeamRecordSqlType.builder() + private static final RowType AUCTION_ROW_TYPE = RowSqlType.builder() .withBigIntField("id") .withVarcharField("itemName") .withVarcharField("description") .withBigIntField("initialBid") .withBigIntField("reserve") - .withBigIntField("dateTime") - .withBigIntField("expires") + .withTimestampField("dateTime") + .withTimestampField("expires") .withBigIntField("seller") .withBigIntField("category") .withVarcharField("extra") @@ -90,30 +92,27 @@ public void hasAdaptersForSupportedModels() throws Exception { public void testBidAdapterRecordType() { ModelFieldsAdapter adapter = ModelAdaptersMapping.ADAPTERS.get(Bid.class); - BeamRecordSqlType bidRecordType = (BeamRecordSqlType) adapter.getRecordType(); + RowType bidRowType = adapter.getRowType(); - assertEquals(BID_RECORD_TYPE.getFieldNames(), bidRecordType.getFieldNames()); - assertEquals(BID_RECORD_TYPE.getFieldTypes(), bidRecordType.getFieldTypes()); + assertEquals(BID_ROW_TYPE, bidRowType); } @Test public void testPersonAdapterRecordType() { ModelFieldsAdapter adapter = ModelAdaptersMapping.ADAPTERS.get(Person.class); - BeamRecordSqlType personRecordType = (BeamRecordSqlType) adapter.getRecordType(); + RowType personRowType = adapter.getRowType(); - assertEquals(PERSON_RECORD_TYPE.getFieldNames(), personRecordType.getFieldNames()); - assertEquals(PERSON_RECORD_TYPE.getFieldTypes(), personRecordType.getFieldTypes()); + assertEquals(PERSON_ROW_TYPE, personRowType); } @Test public void testAuctionAdapterRecordType() { ModelFieldsAdapter adapter = ModelAdaptersMapping.ADAPTERS.get(Auction.class); - BeamRecordSqlType auctionRecordType = (BeamRecordSqlType) adapter.getRecordType(); + RowType auctionRowType = adapter.getRowType(); - assertEquals(AUCTION_RECORD_TYPE.getFieldNames(), auctionRecordType.getFieldNames()); - assertEquals(AUCTION_RECORD_TYPE.getFieldTypes(), auctionRecordType.getFieldTypes()); + assertEquals(AUCTION_ROW_TYPE, auctionRowType); } @Test @@ -126,7 +125,7 @@ public void testPersonAdapterGetsFieldValues() throws Exception { assertEquals(PERSON.creditCard, values.get(3)); assertEquals(PERSON.city, values.get(4)); assertEquals(PERSON.state, values.get(5)); - assertEquals(PERSON.dateTime, values.get(6)); + assertEquals(new Date(PERSON.dateTime), values.get(6)); assertEquals(PERSON.extra, values.get(7)); } @@ -137,7 +136,7 @@ public void testBidAdapterGetsFieldValues() throws Exception { assertEquals(BID.auction, values.get(0)); assertEquals(BID.bidder, values.get(1)); assertEquals(BID.price, values.get(2)); - assertEquals(BID.dateTime, values.get(3)); + assertEquals(new Date(BID.dateTime), values.get(3)); assertEquals(BID.extra, values.get(4)); } @@ -150,8 +149,8 @@ public void testAuctionAdapterGetsFieldValues() throws Exception { assertEquals(AUCTION.description, values.get(2)); assertEquals(AUCTION.initialBid, values.get(3)); assertEquals(AUCTION.reserve, values.get(4)); - assertEquals(AUCTION.dateTime, values.get(5)); - assertEquals(AUCTION.expires, values.get(6)); + assertEquals(new Date(AUCTION.dateTime), values.get(5)); + assertEquals(new Date(AUCTION.expires), values.get(6)); assertEquals(AUCTION.seller, values.get(7)); assertEquals(AUCTION.category, values.get(8)); assertEquals(AUCTION.extra, values.get(9)); diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0Test.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0Test.java index d5889d8903c6..66b0302c3df8 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0Test.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery0Test.java @@ -25,8 +25,8 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.Rule; import org.junit.Test; @@ -44,11 +44,17 @@ public class SqlQuery0Test { private static final ModelFieldsAdapter BID_ADAPTER = ModelAdaptersMapping.ADAPTERS.get(Bid.class); - private static final BeamRecord BID1_RECORD = - new BeamRecord(BID_ADAPTER.getRecordType(), BID_ADAPTER.getFieldsValues(BID1)); + private static final Row BID1_ROW = + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(BID1)) + .build(); - private static final BeamRecord BID2_RECORD = - new BeamRecord(BID_ADAPTER.getRecordType(), BID_ADAPTER.getFieldsValues(BID2)); + private static final Row BID2_ROW = + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(BID2)) + .build(); @Rule public TestPipeline testPipeline = TestPipeline.create(); @@ -63,7 +69,7 @@ public void testPassesBidsThrough() throws Exception { PAssert .that(bids.apply(new SqlQuery0())) - .containsInAnyOrder(BID1_RECORD, BID2_RECORD); + .containsInAnyOrder(BID1_ROW, BID2_ROW); testPipeline.run(); } diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1Test.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1Test.java index 2d90c9bf22d4..c85c3ece7bb2 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1Test.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery1Test.java @@ -27,8 +27,8 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.Rule; import org.junit.Test; @@ -52,11 +52,17 @@ public class SqlQuery1Test { private static final ModelFieldsAdapter BID_ADAPTER = ADAPTERS.get(Bid.class); - private static final BeamRecord BID1_EUR_RECORD = - new BeamRecord(BID_ADAPTER.getRecordType(), BID_ADAPTER.getFieldsValues(BID1_EUR)); + private static final Row BID1_EUR_ROW = + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(BID1_EUR)) + .build(); - private static final BeamRecord BID2_EUR_RECORD = - new BeamRecord(BID_ADAPTER.getRecordType(), BID_ADAPTER.getFieldsValues(BID2_EUR)); + private static final Row BID2_EUR_ROW = + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(BID2_EUR)) + .build(); @Rule public TestPipeline testPipeline = TestPipeline.create(); @@ -77,7 +83,7 @@ public void testConvertsPriceToEur() throws Exception { PAssert .that(bids.apply(new SqlQuery1())) - .containsInAnyOrder(BID1_EUR_RECORD, BID2_EUR_RECORD); + .containsInAnyOrder(BID1_EUR_ROW, BID2_EUR_ROW); testPipeline.run(); } diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2Test.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2Test.java index 2dc572093956..f837f87875e8 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2Test.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery2Test.java @@ -28,9 +28,9 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.Rule; import org.junit.Test; @@ -61,15 +61,15 @@ public class SqlQuery2Test { new Event(BIDS.get(6)), new Event(BIDS.get(7))); - private static final List BIDS_EVEN_RECORDS = ImmutableList.of( - newBidRecord(BIDS.get(1)), - newBidRecord(BIDS.get(3)), - newBidRecord(BIDS.get(5)), - newBidRecord(BIDS.get(7))); + private static final List BIDS_EVEN_ROWS = ImmutableList.of( + newBidRow(BIDS.get(1)), + newBidRow(BIDS.get(3)), + newBidRow(BIDS.get(5)), + newBidRow(BIDS.get(7))); - private static final List BIDS_EVERY_THIRD_RECORD = ImmutableList.of( - newBidRecord(BIDS.get(2)), - newBidRecord(BIDS.get(5))); + private static final List BIDS_EVERY_THIRD_ROW = ImmutableList.of( + newBidRow(BIDS.get(2)), + newBidRow(BIDS.get(5))); @Rule public TestPipeline testPipeline = TestPipeline.create(); @@ -83,7 +83,7 @@ public void testSkipsEverySecondElement() throws Exception { PAssert .that(bids.apply(new SqlQuery2(2))) - .containsInAnyOrder(BIDS_EVEN_RECORDS); + .containsInAnyOrder(BIDS_EVEN_ROWS); testPipeline.run(); } @@ -97,7 +97,7 @@ public void testSkipsEveryThirdElement() throws Exception { PAssert .that(bids.apply(new SqlQuery2(3))) - .containsInAnyOrder(BIDS_EVERY_THIRD_RECORD); + .containsInAnyOrder(BIDS_EVERY_THIRD_ROW); testPipeline.run(); } @@ -106,8 +106,11 @@ private static Bid newBid(long id) { return new Bid(id, 3L, 100L, 432342L + id, "extra_" + id); } - private static BeamRecord newBidRecord(Bid bid) { - return new BeamRecord(BID_ADAPTER.getRecordType(), BID_ADAPTER.getFieldsValues(bid)); + private static Row newBidRow(Bid bid) { + return + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(bid)) + .build(); } - } diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3Test.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3Test.java index a20be900f3bb..873cbb1de125 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3Test.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery3Test.java @@ -20,7 +20,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; -import org.apache.beam.sdk.extensions.sql.BeamRecordSqlType; +import org.apache.beam.sdk.extensions.sql.RowSqlType; import org.apache.beam.sdk.nexmark.NexmarkConfiguration; import org.apache.beam.sdk.nexmark.model.Auction; import org.apache.beam.sdk.nexmark.model.Event; @@ -28,9 +28,10 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.BeamRecord; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; import org.junit.Rule; import org.junit.Test; @@ -39,12 +40,14 @@ */ public class SqlQuery3Test { - private static final BeamRecordSqlType RESULT_RECORD_TYPE = BeamRecordSqlType.builder() - .withVarcharField("name") - .withVarcharField("city") - .withVarcharField("state") - .withBigIntField("id") - .build(); + private static final RowType RESULT_ROW_TYPE = + RowSqlType + .builder() + .withVarcharField("name") + .withVarcharField("city") + .withVarcharField("state") + .withBigIntField("id") + .build(); private static final List PEOPLE = ImmutableList.of( newPerson(0L, "WA"), @@ -83,11 +86,11 @@ public class SqlQuery3Test { new Event(AUCTIONS.get(8)), new Event(AUCTIONS.get(9))); - public static final List RESULTS = ImmutableList.of( - newResultRecord("name_1", "city_1", "CA", 1L), - newResultRecord("name_3", "city_3", "ID", 3L), - newResultRecord("name_1", "city_1", "CA", 6L), - newResultRecord("name_3", "city_3", "ID", 8L)); + public static final List RESULTS = ImmutableList.of( + newResultRow("name_1", "city_1", "CA", 1L), + newResultRow("name_3", "city_3", "ID", 3L), + newResultRow("name_1", "city_1", "CA", 6L), + newResultRow("name_3", "city_3", "ID", 8L)); @Rule public TestPipeline testPipeline = TestPipeline.create(); @@ -131,17 +134,20 @@ private static Auction newAuction(long id, long seller, long category) { "extra_" + id); } - private static BeamRecord newResultRecord( + private static Row newResultRow( String personName, String personCity, String personState, long auctionId) { - return new BeamRecord( - RESULT_RECORD_TYPE, - personName, - personCity, - personState, - auctionId); + return + Row + .withRowType(RESULT_ROW_TYPE) + .addValues( + personName, + personCity, + personState, + auctionId) + .build(); } } diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery5Test.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery5Test.java new file mode 100644 index 000000000000..1e7542fd3a21 --- /dev/null +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery5Test.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.nexmark.queries.sql; + +import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.apache.beam.sdk.extensions.sql.RowSqlType; +import org.apache.beam.sdk.nexmark.NexmarkConfiguration; +import org.apache.beam.sdk.nexmark.model.Bid; +import org.apache.beam.sdk.nexmark.model.Event; +import org.apache.beam.sdk.nexmark.model.sql.adapter.ModelFieldsAdapter; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.junit.Rule; +import org.junit.Test; + +/** + * Unit tests for {@link SqlQuery5}. + */ +public class SqlQuery5Test { + + private static final NexmarkConfiguration config = new NexmarkConfiguration(); + + private static final ModelFieldsAdapter BID_ADAPTER = ADAPTERS.get(Bid.class); + + private static final List BIDS = ImmutableList.of( + newBid(1L, 1L), + newBid(1L, 3L), + newBid(1L, 4L), + newBid(2L, 4L)); + + private static final List BIDS_EVENTS = ImmutableList.of( + new Event(BIDS.get(0)), + new Event(BIDS.get(1)), + new Event(BIDS.get(2)), + new Event(BIDS.get(3))); + + private static final RowType RESULT_ROW_TYPE = + RowSqlType + .builder() + .withBigIntField("auction") + .build(); + + public static final List RESULTS = ImmutableList.of( + newResultRow(1L), + newResultRow(1L), + newResultRow(1L), + newResultRow(1L), + newResultRow(1L), + newResultRow(2L)); + + @Rule public TestPipeline testPipeline = TestPipeline.create(); + + @Test + public void testBids() throws Exception { + assertEquals(Long.valueOf(config.windowSizeSec), + Long.valueOf(config.windowPeriodSec * 2)); + + PCollection bids = + PBegin + .in(testPipeline) + .apply(Create.of(BIDS_EVENTS).withCoder(Event.CODER)); + + PAssert + .that(bids.apply(new SqlQuery5(config))) + .containsInAnyOrder(RESULTS); + + testPipeline.run(); + } + + private static Bid newBid(long auction, long index) { + return new Bid(auction, + 3L, + 100L, + 432342L + index * config.windowPeriodSec * 1000, + "extra_" + auction); + } + + private static Row newBidRow(Bid bid) { + return + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(bid)) + .build(); + } + + private static Row newResultRow( + long auctionId) { + + return + Row + .withRowType(RESULT_ROW_TYPE) + .addValues( + auctionId) + .build(); + } +} diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery7Test.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery7Test.java new file mode 100644 index 000000000000..fb89e8d4f389 --- /dev/null +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/queries/sql/SqlQuery7Test.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.nexmark.queries.sql; + +import static org.apache.beam.sdk.nexmark.model.sql.adapter.ModelAdaptersMapping.ADAPTERS; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.apache.beam.sdk.extensions.sql.RowSqlType; +import org.apache.beam.sdk.nexmark.NexmarkConfiguration; +import org.apache.beam.sdk.nexmark.model.Bid; +import org.apache.beam.sdk.nexmark.model.Event; +import org.apache.beam.sdk.nexmark.model.sql.adapter.ModelFieldsAdapter; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowType; +import org.junit.Rule; +import org.junit.Test; + +/** + * Unit tests for {@link SqlQuery7}. + */ +public class SqlQuery7Test { + + private static final NexmarkConfiguration config = new NexmarkConfiguration(); + + private static final ModelFieldsAdapter BID_ADAPTER = ADAPTERS.get(Bid.class); + + private static final List BIDS = ImmutableList.of( + newBid(4L, 3L, 2L, 1L), + newBid(1L, 2L, 3L, 2L), + newBid(2L, 2L, 3L, 2L), + newBid(2L, 2L, 4L, 3L), + newBid(2L, 2L, 5L, 3L)); + + private static final List BIDS_EVENTS = ImmutableList.of( + new Event(BIDS.get(0)), + new Event(BIDS.get(1)), + new Event(BIDS.get(2)), + new Event(BIDS.get(3)), + new Event(BIDS.get(4))); + + private static final RowType RESULT_ROW_TYPE = + RowSqlType + .builder() + .withBigIntField("auction") + .withBigIntField("price") + .withBigIntField("bidder") + .build(); + + public static final List RESULTS = ImmutableList.of( + newResultRow(4L, 3L, 2L), + newResultRow(1L, 2L, 3L), + newResultRow(2L, 2L, 3L), + newResultRow(2L, 2L, 5L)); + + @Rule public TestPipeline testPipeline = TestPipeline.create(); + + @Test + public void testBids() throws Exception { + PCollection bids = + PBegin + .in(testPipeline) + .apply(Create.of(BIDS_EVENTS).withCoder(Event.CODER)); + + PAssert + .that(bids.apply(new SqlQuery7(config))) + .containsInAnyOrder(RESULTS); + + testPipeline.run(); + } + + private static Bid newBid(long auction, long bidder, long price, long index) { + return new Bid(auction, + bidder, + price, + 432342L + index * config.windowSizeSec * 1000, + "extra_" + auction); + } + + private static Row newBidRow(Bid bid) { + return + Row + .withRowType(BID_ADAPTER.getRowType()) + .addValues(BID_ADAPTER.getFieldsValues(bid)) + .build(); + } + + private static Row newResultRow( + long auctionId, long bidder, long price) { + + return + Row + .withRowType(RESULT_ROW_TYPE) + .addValues( + auctionId, price, bidder) + .build(); + } +} diff --git a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSourceTest.java b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSourceTest.java index 5c9bf5f9babc..5033f7f1b2de 100644 --- a/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSourceTest.java +++ b/sdks/java/nexmark/src/test/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSourceTest.java @@ -25,7 +25,6 @@ import java.util.HashSet; import java.util.Random; import java.util.Set; - import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.nexmark.NexmarkConfiguration; diff --git a/sdks/java/pom.xml b/sdks/java/pom.xml index e037e946b195..78b7c21a819e 100644 --- a/sdks/java/pom.xml +++ b/sdks/java/pom.xml @@ -40,7 +40,6 @@ container core io - java8tests maven-archetypes extensions fn-execution diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 172ee74d4c83..9432e5347899 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -52,6 +52,11 @@ from .slow_stream import get_varint_size # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports +try: + long # Python 2 +except NameError: + long = int # Python 3 + class CoderImpl(object): """For internal use only; no backwards-compatibility guarantees.""" diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 64902b592e31..f76625869879 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -25,11 +25,13 @@ import cPickle as pickle import google.protobuf +from google.protobuf import wrappers_pb2 from apache_beam.coders import coder_impl +from apache_beam.portability import common_urns +from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.utils import proto_utils -from apache_beam.utils import urns # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: @@ -263,8 +265,8 @@ def from_runner_api(cls, coder_proto, context): def to_runner_api_parameter(self, context): return ( - urns.PICKLED_CODER, - google.protobuf.wrappers_pb2.BytesValue(value=serialize_coder(self)), + python_urns.PICKLED_CODER, + wrappers_pb2.BytesValue(value=serialize_coder(self)), ()) @staticmethod @@ -284,7 +286,8 @@ def from_runner_api_parameter(unused_payload, components, unused_context): return cls() -@Coder.register_urn(urns.PICKLED_CODER, google.protobuf.wrappers_pb2.BytesValue) +@Coder.register_urn( + python_urns.PICKLED_CODER, google.protobuf.wrappers_pb2.BytesValue) def _pickle_from_runner_api_parameter(payload, components, context): return deserialize_coder(payload.value) @@ -363,7 +366,7 @@ def __hash__(self): return hash(type(self)) -Coder.register_structured_urn(urns.BYTES_CODER, BytesCoder) +Coder.register_structured_urn(common_urns.BYTES_CODER, BytesCoder) class VarIntCoder(FastCoder): @@ -382,7 +385,7 @@ def __hash__(self): return hash(type(self)) -Coder.register_structured_urn(urns.VAR_INT_CODER, VarIntCoder) +Coder.register_structured_urn(common_urns.VARINT_CODER, VarIntCoder) class FloatCoder(FastCoder): @@ -736,11 +739,11 @@ def __hash__(self): def to_runner_api_parameter(self, context): if self.is_kv_coder(): - return urns.KV_CODER, None, self.coders() + return common_urns.KV_CODER, None, self.coders() else: return super(TupleCoder, self).to_runner_api_parameter(context) - @Coder.register_urn(urns.KV_CODER, None) + @Coder.register_urn(common_urns.KV_CODER, None) def from_runner_api_parameter(unused_payload, components, unused_context): return TupleCoder(components) @@ -829,7 +832,7 @@ def __hash__(self): return hash((type(self), self._elem_coder)) -Coder.register_structured_urn(urns.ITERABLE_CODER, IterableCoder) +Coder.register_structured_urn(common_urns.ITERABLE_CODER, IterableCoder) class GlobalWindowCoder(SingletonCoder): @@ -845,7 +848,8 @@ def as_cloud_object(self): } -Coder.register_structured_urn(urns.GLOBAL_WINDOW_CODER, GlobalWindowCoder) +Coder.register_structured_urn( + common_urns.GLOBAL_WINDOW_CODER, GlobalWindowCoder) class IntervalWindowCoder(FastCoder): @@ -869,7 +873,8 @@ def __hash__(self): return hash(type(self)) -Coder.register_structured_urn(urns.INTERVAL_WINDOW_CODER, IntervalWindowCoder) +Coder.register_structured_urn( + common_urns.INTERVAL_WINDOW_CODER, IntervalWindowCoder) class WindowedValueCoder(FastCoder): @@ -928,7 +933,8 @@ def __hash__(self): (self.wrapped_value_coder, self.timestamp_coder, self.window_coder)) -Coder.register_structured_urn(urns.WINDOWED_VALUE_CODER, WindowedValueCoder) +Coder.register_structured_urn( + common_urns.WINDOWED_VALUE_CODER, WindowedValueCoder) class LengthPrefixCoder(FastCoder): @@ -972,4 +978,5 @@ def __hash__(self): return hash((type(self), self._value_coder)) -Coder.register_structured_urn(urns.LENGTH_PREFIX_CODER, LengthPrefixCoder) +Coder.register_structured_urn( + common_urns.LENGTH_PREFIX_CODER, LengthPrefixCoder) diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 82d296ddeb72..dd071d7a9331 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -64,8 +64,6 @@ def MakeXyzs(v): See apache_beam.typehints.decorators module for more details. """ -import warnings - from apache_beam.coders import coders from apache_beam.typehints import typehints @@ -123,14 +121,16 @@ def get_coder(self, typehint): # In some old code, None is used for Any. # TODO(robertwb): Clean this up. pass - elif typehint is object: + elif typehint is object or typehint == typehints.Any: # We explicitly want the fallback coder. pass elif isinstance(typehint, typehints.TypeVariable): # TODO(robertwb): Clean this up when type inference is fully enabled. pass else: - warnings.warn('Using fallback coder for typehint: %r.' % typehint) + # TODO(robertwb): Re-enable this warning when it's actionable. + # warnings.warn('Using fallback coder for typehint: %r.' % typehint) + pass coder = self._fallback_coder return coder.from_type_hint(typehint, self) diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi.py b/sdks/python/apache_beam/examples/complete/estimate_pi.py index d0a5fb74f3fc..982faaa0f475 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi.py @@ -55,7 +55,7 @@ def run_trials(runs): has same type for inputs and outputs (a requirement for combiner functions). """ inside_runs = 0 - for _ in xrange(runs): + for _ in range(runs): x = random.uniform(0, 1) y = random.uniform(0, 1) inside_runs += 1 if x * x + y * y <= 1.0 else 0 diff --git a/sdks/python/apache_beam/examples/complete/game/game_stats.py b/sdks/python/apache_beam/examples/complete/game/game_stats.py index d8c60dd67662..f9ccdc065e88 100644 --- a/sdks/python/apache_beam/examples/complete/game/game_stats.py +++ b/sdks/python/apache_beam/examples/complete/game/game_stats.py @@ -163,43 +163,6 @@ def process(self, team_score, window=beam.DoFn.WindowParam): } -class WriteToBigQuery(beam.PTransform): - """Generate, format, and write BigQuery table row information.""" - def __init__(self, table_name, dataset, schema): - """Initializes the transform. - Args: - table_name: Name of the BigQuery table to use. - dataset: Name of the dataset to use. - schema: Dictionary in the format {'column_name': 'bigquery_type'} - """ - super(WriteToBigQuery, self).__init__() - self.table_name = table_name - self.dataset = dataset - self.schema = schema - - def get_schema(self): - """Build the output table schema.""" - return ', '.join( - '%s:%s' % (col, self.schema[col]) for col in self.schema) - - def get_table(self, pipeline): - """Utility to construct an output table reference.""" - project = pipeline.options.view_as(GoogleCloudOptions).project - return '%s:%s.%s' % (project, self.dataset, self.table_name) - - def expand(self, pcoll): - table = self.get_table(pcoll.pipeline) - return ( - pcoll - | 'ConvertToRow' >> beam.Map( - lambda elem: {col: elem[col] for col in self.schema}) - | beam.io.Write(beam.io.BigQuerySink( - table, - schema=self.get_schema(), - create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, - write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND))) - - # [START abuse_detect] class CalculateSpammyUsers(beam.PTransform): """Filter out all but those users with a high clickrate, which we will @@ -280,7 +243,8 @@ def run(argv=None): options = PipelineOptions(pipeline_args) # We also require the --project option to access --dataset - if options.view_as(GoogleCloudOptions).project is None: + project = options.view_as(GoogleCloudOptions).project + if project is None: parser.print_usage() print(sys.argv[0] + ': error: argument --project is required') sys.exit(1) @@ -296,6 +260,8 @@ def run(argv=None): # Enforce that this pipeline is always run in streaming mode options.view_as(StandardOptions).streaming = True + table_spec_prefix = '{}:{}.{}'.format(project, args.dataset, args.table_name) + with beam.Pipeline(options=options) as p: # Read events from Pub/Sub using custom timestamps raw_events = ( @@ -332,6 +298,13 @@ def run(argv=None): # updates for late data. Uses the side input derived above --the set of # suspected robots-- to filter out scores from those users from the sum. # Write the results to BigQuery. + team_table_spec = table_spec_prefix + '_teams' + team_table_schema = ( + 'team:STRING, ' + 'total_score:INTEGER, ' + 'window_start:STRING, ' + 'processing_time: STRING') + (raw_events # pylint: disable=expression-not-assigned | 'WindowIntoFixedWindows' >> beam.WindowInto( beam.window.FixedWindows(fixed_window_duration)) @@ -344,19 +317,20 @@ def run(argv=None): | 'ExtractAndSumScore' >> ExtractAndSumScore('team') # [END filter_and_calc] | 'TeamScoresDict' >> beam.ParDo(TeamScoresDict()) - | 'WriteTeamScoreSums' >> WriteToBigQuery( - args.table_name + '_teams', args.dataset, { - 'team': 'STRING', - 'total_score': 'INTEGER', - 'window_start': 'STRING', - 'processing_time': 'STRING', - })) + | 'WriteTeamScoreSums' >> beam.io.WriteToBigQuery( + team_table_spec, + schema=team_table_schema, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND)) # [START session_calc] # Detect user sessions-- that is, a burst of activity separated by a gap # from further activity. Find and record the mean session lengths. # This information could help the game designers track the changing user # engagement as their set of game changes. + session_table_spec = table_spec_prefix + '_sessions' + session_table_schema = 'mean_duration:FLOAT' + (user_events # pylint: disable=expression-not-assigned | 'WindowIntoSessions' >> beam.WindowInto( beam.window.Sessions(session_gap), @@ -381,10 +355,11 @@ def run(argv=None): | beam.CombineGlobally(beam.combiners.MeanCombineFn()).without_defaults() | 'FormatAvgSessionLength' >> beam.Map( lambda elem: {'mean_duration': float(elem)}) - | 'WriteAvgSessionLength' >> WriteToBigQuery( - args.table_name + '_sessions', args.dataset, { - 'mean_duration': 'FLOAT', - })) + | 'WriteAvgSessionLength' >> beam.io.WriteToBigQuery( + session_table_spec, + schema=session_table_schema, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND)) # [END rewindow] diff --git a/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py b/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py index b286a6a5ddf4..6e826d45feed 100644 --- a/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py +++ b/sdks/python/apache_beam/examples/complete/game/hourly_team_score.py @@ -156,43 +156,6 @@ def process(self, team_score, window=beam.DoFn.WindowParam): } -class WriteToBigQuery(beam.PTransform): - """Generate, format, and write BigQuery table row information.""" - def __init__(self, table_name, dataset, schema): - """Initializes the transform. - Args: - table_name: Name of the BigQuery table to use. - dataset: Name of the dataset to use. - schema: Dictionary in the format {'column_name': 'bigquery_type'} - """ - super(WriteToBigQuery, self).__init__() - self.table_name = table_name - self.dataset = dataset - self.schema = schema - - def get_schema(self): - """Build the output table schema.""" - return ', '.join( - '%s:%s' % (col, self.schema[col]) for col in self.schema) - - def get_table(self, pipeline): - """Utility to construct an output table reference.""" - project = pipeline.options.view_as(GoogleCloudOptions).project - return '%s:%s.%s' % (project, self.dataset, self.table_name) - - def expand(self, pcoll): - table = self.get_table(pcoll.pipeline) - return ( - pcoll - | 'ConvertToRow' >> beam.Map( - lambda elem: {col: elem[col] for col in self.schema}) - | beam.io.Write(beam.io.BigQuerySink( - table, - schema=self.get_schema(), - create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, - write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND))) - - # [START main] class HourlyTeamScore(beam.PTransform): def __init__(self, start_min, stop_min, window_duration): @@ -278,7 +241,8 @@ def run(argv=None): options = PipelineOptions(pipeline_args) # We also require the --project option to access --dataset - if options.view_as(GoogleCloudOptions).project is None: + project = options.view_as(GoogleCloudOptions).project + if project is None: parser.print_usage() print(sys.argv[0] + ': error: argument --project is required') sys.exit(1) @@ -287,18 +251,23 @@ def run(argv=None): # workflow rely on global context (e.g., a module imported at module level). options.view_as(SetupOptions).save_main_session = True + table_spec = '{}:{}.{}'.format(project, args.dataset, args.table_name) + table_schema = ( + 'team:STRING, ' + 'total_score:INTEGER, ' + 'window_start:STRING') + with beam.Pipeline(options=options) as p: (p # pylint: disable=expression-not-assigned | 'ReadInputText' >> beam.io.ReadFromText(args.input) | 'HourlyTeamScore' >> HourlyTeamScore( args.start_min, args.stop_min, args.window_duration) | 'TeamScoresDict' >> beam.ParDo(TeamScoresDict()) - | 'WriteTeamScoreSums' >> WriteToBigQuery( - args.table_name, args.dataset, { - 'team': 'STRING', - 'total_score': 'INTEGER', - 'window_start': 'STRING', - })) + | 'WriteTeamScoreSums' >> beam.io.WriteToBigQuery( + table_spec, + schema=table_schema, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND)) # [END main] diff --git a/sdks/python/apache_beam/examples/complete/game/leader_board.py b/sdks/python/apache_beam/examples/complete/game/leader_board.py index e207f26712e3..0d1fce47663d 100644 --- a/sdks/python/apache_beam/examples/complete/game/leader_board.py +++ b/sdks/python/apache_beam/examples/complete/game/leader_board.py @@ -171,43 +171,6 @@ def process(self, team_score, window=beam.DoFn.WindowParam): } -class WriteToBigQuery(beam.PTransform): - """Generate, format, and write BigQuery table row information.""" - def __init__(self, table_name, dataset, schema): - """Initializes the transform. - Args: - table_name: Name of the BigQuery table to use. - dataset: Name of the dataset to use. - schema: Dictionary in the format {'column_name': 'bigquery_type'} - """ - super(WriteToBigQuery, self).__init__() - self.table_name = table_name - self.dataset = dataset - self.schema = schema - - def get_schema(self): - """Build the output table schema.""" - return ', '.join( - '%s:%s' % (col, self.schema[col]) for col in self.schema) - - def get_table(self, pipeline): - """Utility to construct an output table reference.""" - project = pipeline.options.view_as(GoogleCloudOptions).project - return '%s:%s.%s' % (project, self.dataset, self.table_name) - - def expand(self, pcoll): - table = self.get_table(pcoll.pipeline) - return ( - pcoll - | 'ConvertToRow' >> beam.Map( - lambda elem: {col: elem[col] for col in self.schema}) - | beam.io.Write(beam.io.BigQuerySink( - table, - schema=self.get_schema(), - create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, - write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND))) - - # [START window_and_trigger] class CalculateTeamScores(beam.PTransform): """Calculates scores for each team within the configured window duration. @@ -294,7 +257,8 @@ def run(argv=None): options = PipelineOptions(pipeline_args) # We also require the --project option to access --dataset - if options.view_as(GoogleCloudOptions).project is None: + project = options.view_as(GoogleCloudOptions).project + if project is None: parser.print_usage() print(sys.argv[0] + ': error: argument --project is required') sys.exit(1) @@ -306,6 +270,8 @@ def run(argv=None): # Enforce that this pipeline is always run in streaming mode options.view_as(StandardOptions).streaming = True + table_spec_prefix = '{}:{}.{}'.format(project, args.dataset, args.table_name) + with beam.Pipeline(options=options) as p: # Read game events from Pub/Sub using custom timestamps, which are extracted # from the pubsub data elements, and parse the data. @@ -316,32 +282,37 @@ def run(argv=None): | 'AddEventTimestamps' >> beam.Map( lambda elem: beam.window.TimestampedValue(elem, elem['timestamp']))) + team_table_spec = table_spec_prefix + '_teams' + team_table_schema = ( + 'team:STRING, ' + 'total_score:INTEGER, ' + 'window_start:STRING, ' + 'processing_time: STRING') + # Get team scores and write the results to BigQuery (events # pylint: disable=expression-not-assigned | 'CalculateTeamScores' >> CalculateTeamScores( args.team_window_duration, args.allowed_lateness) | 'TeamScoresDict' >> beam.ParDo(TeamScoresDict()) - | 'WriteTeamScoreSums' >> WriteToBigQuery( - args.table_name + '_teams', args.dataset, { - 'team': 'STRING', - 'total_score': 'INTEGER', - 'window_start': 'STRING', - 'processing_time': 'STRING', - })) - - def format_user_score_sums(user_score): - (user, score) = user_score - return {'user': user, 'total_score': score} + | 'WriteTeamScoreSums' >> beam.io.WriteToBigQuery( + team_table_spec, + schema=team_table_schema, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND)) + + user_table_spec = table_spec_prefix + '_users' + user_table_schema = 'user:STRING, total_score:INTEGER' # Get user scores and write the results to BigQuery (events # pylint: disable=expression-not-assigned | 'CalculateUserScores' >> CalculateUserScores(args.allowed_lateness) - | 'FormatUserScoreSums' >> beam.Map(format_user_score_sums) - | 'WriteUserScoreSums' >> WriteToBigQuery( - args.table_name + '_users', args.dataset, { - 'user': 'STRING', - 'total_score': 'INTEGER', - })) + | 'FormatUserScoreSums' >> beam.Map( + lambda elem: {'user': elem[0], 'total_score': elem[1]}) + | 'WriteUserScoreSums' >> beam.io.WriteToBigQuery( + user_table_spec, + schema=user_table_schema, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND)) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py b/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py index 3f3ef031cd73..1beffeef9b17 100644 --- a/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py +++ b/sdks/python/apache_beam/examples/complete/juliaset/juliaset/juliaset.py @@ -37,7 +37,7 @@ def get_julia_set_point_color(element, c, n, max_iterations): """Given an pixel, convert it into a point in our julia set.""" x, y = element z = from_pixel(x, y, n) - for i in xrange(max_iterations): + for i in range(max_iterations): if z.real * z.real + z.imag * z.imag > 2.0: break z = z * z + c diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py index e16ae7391b6e..91ddf51827b3 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py +++ b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py @@ -90,7 +90,7 @@ def run(argv=None): with beam.Pipeline(options=pipeline_options) as p: group_ids = [] - for i in xrange(0, int(known_args.num_groups)): + for i in range(0, int(known_args.num_groups)): group_ids.append('id' + str(i)) query_corpus = 'select UNIQUE(corpus) from publicdata:samples.shakespeare' diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index 6cc96efe79d9..b2c5bb926b34 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -711,6 +711,26 @@ def split(self, desired_bundle_size, start_position=None, # [END model_custom_source_new_source] +# We recommend users to start Source classes with an underscore to discourage +# using the Source class directly when a PTransform for the source is +# available. We simulate that here by simply extending the previous Source +# class. +class _CountingSource(CountingSource): + pass + + +# [START model_custom_source_new_ptransform] +class ReadFromCountingSource(PTransform): + + def __init__(self, count): + super(ReadFromCountingSource, self).__init__() + self._count = count + + def expand(self, pcoll): + return pcoll | iobase.Read(_CountingSource(self._count)) +# [END model_custom_source_new_ptransform] + + def model_custom_source(count): """Demonstrates creating a new custom source and using it in a pipeline. @@ -747,24 +767,6 @@ def model_custom_source(count): lines, equal_to( ['line ' + str(number) for number in range(0, count)])) - # We recommend users to start Source classes with an underscore to discourage - # using the Source class directly when a PTransform for the source is - # available. We simulate that here by simply extending the previous Source - # class. - class _CountingSource(CountingSource): - pass - - # [START model_custom_source_new_ptransform] - class ReadFromCountingSource(PTransform): - - def __init__(self, count, **kwargs): - super(ReadFromCountingSource, self).__init__(**kwargs) - self._count = count - - def expand(self, pcoll): - return pcoll | iobase.Read(_CountingSource(count)) - # [END model_custom_source_new_ptransform] - # [START model_custom_source_use_ptransform] p = beam.Pipeline(options=PipelineOptions()) numbers = p | 'ProduceNumbers' >> ReadFromCountingSource(count) @@ -775,27 +777,95 @@ def expand(self, pcoll): lines, equal_to( ['line ' + str(number) for number in range(0, count)])) - # Don't test runner api due to pickling errors. - p.run(test_runner_api=False).wait_until_finish() + p.run().wait_until_finish() + + +# Defining the new sink. +# +# Defines a new sink ``SimpleKVSink`` that demonstrates writing to a simple +# key-value based storage system which has following API. +# +# simplekv.connect(url) - +# connects to the storage system and returns an access token which can be +# used to perform further operations +# simplekv.open_table(access_token, table_name) - +# creates a table named 'table_name'. Returns a table object. +# simplekv.write_to_table(access_token, table, key, value) - +# writes a key-value pair to the given table. +# simplekv.rename_table(access_token, old_name, new_name) - +# renames the table named 'old_name' to 'new_name'. +# +# [START model_custom_sink_new_sink] +class SimpleKVSink(iobase.Sink): + + def __init__(self, simplekv, url, final_table_name): + self._simplekv = simplekv + self._url = url + self._final_table_name = final_table_name + + def initialize_write(self): + access_token = self._simplekv.connect(self._url) + return access_token + + def open_writer(self, access_token, uid): + table_name = 'table' + uid + return SimpleKVWriter(self._simplekv, access_token, table_name) + + def finalize_write(self, access_token, table_names): + for i, table_name in enumerate(table_names): + self._simplekv.rename_table( + access_token, table_name, self._final_table_name + str(i)) +# [END model_custom_sink_new_sink] + + +# Defining a writer for the new sink. +# [START model_custom_sink_new_writer] +class SimpleKVWriter(iobase.Writer): + + def __init__(self, simplekv, access_token, table_name): + self._simplekv = simplekv + self._access_token = access_token + self._table_name = table_name + self._table = self._simplekv.open_table(access_token, table_name) + + def write(self, record): + key, value = record + + self._simplekv.write_to_table(self._access_token, self._table, key, value) + + def close(self): + return self._table_name +# [END model_custom_sink_new_writer] + + +# [START model_custom_sink_new_ptransform] +class WriteToKVSink(PTransform): + + def __init__(self, simplekv, url, final_table_name, **kwargs): + self._simplekv = simplekv + super(WriteToKVSink, self).__init__(**kwargs) + self._url = url + self._final_table_name = final_table_name + + def expand(self, pcoll): + return pcoll | iobase.Write(_SimpleKVSink(self._simplekv, + self._url, + self._final_table_name)) +# [END model_custom_sink_new_ptransform] + + +# We recommend users to start Sink class names with an underscore to +# discourage using the Sink class directly when a PTransform for the sink is +# available. We simulate that here by simply extending the previous Sink +# class. +class _SimpleKVSink(SimpleKVSink): + pass def model_custom_sink(simplekv, KVs, final_table_name_no_ptransform, final_table_name_with_ptransform): """Demonstrates creating a new custom sink and using it in a pipeline. - Defines a new sink ``SimpleKVSink`` that demonstrates writing to a simple - key-value based storage system which has following API. - - simplekv.connect(url) - - connects to the storage system and returns an access token which can be - used to perform further operations - simplekv.open_table(access_token, table_name) - - creates a table named 'table_name'. Returns a table object. - simplekv.write_to_table(access_token, table, key, value) - - writes a key-value pair to the given table. - simplekv.rename_table(access_token, old_name, new_name) - - renames the table named 'old_name' to 'new_name'. - Uses the new sink in an example pipeline. Additionally demonstrates how a sink should be implemented using a @@ -824,51 +894,6 @@ def model_custom_sink(simplekv, KVs, final_table_name_no_ptransform, ``SimpleKVSink``. """ - import apache_beam as beam - from apache_beam.io import iobase - from apache_beam.transforms.core import PTransform - from apache_beam.options.pipeline_options import PipelineOptions - - # Defining the new sink. - # [START model_custom_sink_new_sink] - class SimpleKVSink(iobase.Sink): - - def __init__(self, url, final_table_name): - self._url = url - self._final_table_name = final_table_name - - def initialize_write(self): - access_token = simplekv.connect(self._url) - return access_token - - def open_writer(self, access_token, uid): - table_name = 'table' + uid - return SimpleKVWriter(access_token, table_name) - - def finalize_write(self, access_token, table_names): - for i, table_name in enumerate(table_names): - simplekv.rename_table( - access_token, table_name, self._final_table_name + str(i)) - # [END model_custom_sink_new_sink] - - # Defining a writer for the new sink. - # [START model_custom_sink_new_writer] - class SimpleKVWriter(iobase.Writer): - - def __init__(self, access_token, table_name): - self._access_token = access_token - self._table_name = table_name - self._table = simplekv.open_table(access_token, table_name) - - def write(self, record): - key, value = record - - simplekv.write_to_table(self._access_token, self._table, key, value) - - def close(self): - return self._table_name - # [END model_custom_sink_new_writer] - final_table_name = final_table_name_no_ptransform # Using the new sink in an example pipeline. @@ -877,36 +902,16 @@ def close(self): kvs = p | 'CreateKVs' >> beam.Create(KVs) kvs | 'WriteToSimpleKV' >> beam.io.Write( - SimpleKVSink('http://url_to_simple_kv/', final_table_name)) + SimpleKVSink(simplekv, 'http://url_to_simple_kv/', final_table_name)) # [END model_custom_sink_use_new_sink] - # We recommend users to start Sink class names with an underscore to - # discourage using the Sink class directly when a PTransform for the sink is - # available. We simulate that here by simply extending the previous Sink - # class. - class _SimpleKVSink(SimpleKVSink): - pass - - # [START model_custom_sink_new_ptransform] - class WriteToKVSink(PTransform): - - def __init__(self, url, final_table_name, **kwargs): - super(WriteToKVSink, self).__init__(**kwargs) - self._url = url - self._final_table_name = final_table_name - - def expand(self, pcoll): - return pcoll | iobase.Write(_SimpleKVSink(self._url, - self._final_table_name)) - # [END model_custom_sink_new_ptransform] - final_table_name = final_table_name_with_ptransform # [START model_custom_sink_use_ptransform] with beam.Pipeline(options=PipelineOptions()) as p: kvs = p | 'CreateKVs' >> beam.core.Create(KVs) kvs | 'WriteToSimpleKV' >> WriteToKVSink( - 'http://url_to_simple_kv/', final_table_name) + simplekv, 'http://url_to_simple_kv/', final_table_name) # [END model_custom_sink_use_ptransform] @@ -916,9 +921,6 @@ def filter_words(x): import re return re.findall(r'[A-Za-z\']+', x) - import apache_beam as beam - from apache_beam.options.pipeline_options import PipelineOptions - # [START model_textio_read] with beam.Pipeline(options=PipelineOptions()) as p: # [START model_pipelineio_read] @@ -989,46 +991,96 @@ def to_entity(content): # [END model_datastoreio_write] -def model_bigqueryio(): - """Using a Read and Write transform to read/write to BigQuery.""" - import apache_beam as beam - from apache_beam.options.pipeline_options import PipelineOptions - - # [START model_bigqueryio_read] - p = beam.Pipeline(options=PipelineOptions()) - weather_data = p | 'ReadWeatherStations' >> beam.io.Read( - beam.io.BigQuerySource( - 'clouddataflow-readonly:samples.weather_stations')) - # [END model_bigqueryio_read] - - # [START model_bigqueryio_query] - p = beam.Pipeline(options=PipelineOptions()) - weather_data = p | 'ReadYearAndTemp' >> beam.io.Read( - beam.io.BigQuerySource( - query='SELECT year, mean_temp FROM samples.weather_stations')) - # [END model_bigqueryio_query] - - # [START model_bigqueryio_query_standard_sql] - p = beam.Pipeline(options=PipelineOptions()) - weather_data = p | 'ReadYearAndTemp' >> beam.io.Read( - beam.io.BigQuerySource( - query='SELECT year, mean_temp FROM `samples.weather_stations`', +def model_bigqueryio(p, write_project='', write_dataset='', write_table=''): + """Using a Read and Write transform to read/write from/to BigQuery.""" + + # [START model_bigqueryio_table_spec] + # project-id:dataset_id.table_id + table_spec = 'clouddataflow-readonly:samples.weather_stations' + # [END model_bigqueryio_table_spec] + + # [START model_bigqueryio_table_spec_without_project] + # dataset_id.table_id + table_spec = 'samples.weather_stations' + # [END model_bigqueryio_table_spec_without_project] + + # [START model_bigqueryio_table_spec_object] + from apache_beam.io.gcp.internal.clients import bigquery + + table_spec = bigquery.TableReference( + projectId='clouddataflow-readonly', + datasetId='samples', + tableId='weather_stations') + # [END model_bigqueryio_table_spec_object] + + # [START model_bigqueryio_read_table] + max_temperatures = ( + p + | 'ReadTable' >> beam.io.Read(beam.io.BigQuerySource(table_spec)) + # Each row is a dictionary where the keys are the BigQuery columns + | beam.Map(lambda elem: elem['max_temperature'])) + # [END model_bigqueryio_read_table] + + # [START model_bigqueryio_read_query] + max_temperatures = ( + p + | 'QueryTable' >> beam.io.Read(beam.io.BigQuerySource( + query='SELECT max_temperature FROM '\ + '[clouddataflow-readonly:samples.weather_stations]')) + # Each row is a dictionary where the keys are the BigQuery columns + | beam.Map(lambda elem: elem['max_temperature'])) + # [END model_bigqueryio_read_query] + + # [START model_bigqueryio_read_query_std_sql] + max_temperatures = ( + p + | 'QueryTableStdSQL' >> beam.io.Read(beam.io.BigQuerySource( + query='SELECT max_temperature FROM '\ + '`clouddataflow-readonly.samples.weather_stations`', use_standard_sql=True)) - # [END model_bigqueryio_query_standard_sql] + # Each row is a dictionary where the keys are the BigQuery columns + | beam.Map(lambda elem: elem['max_temperature'])) + # [END model_bigqueryio_read_query_std_sql] # [START model_bigqueryio_schema] - schema = 'source:STRING, quote:STRING' + # column_name:BIGQUERY_TYPE, ... + table_schema = 'source:STRING, quote:STRING' # [END model_bigqueryio_schema] + # [START model_bigqueryio_schema_object] + from apache_beam.io.gcp.internal.clients import bigquery + + table_schema = bigquery.TableSchema() + + source_field = bigquery.TableFieldSchema() + source_field.name = 'source' + source_field.type = 'STRING' + source_field.mode = 'NULLABLE' + table_schema.fields.append(source_field) + + quote_field = bigquery.TableFieldSchema() + quote_field.name = 'quote' + quote_field.type = 'STRING' + quote_field.mode = 'REQUIRED' + table_schema.fields.append(quote_field) + # [END model_bigqueryio_schema_object] + + if write_project and write_dataset and write_table: + table_spec = '{}:{}.{}'.format(write_project, write_dataset, write_table) + + # [START model_bigqueryio_write_input] + quotes = p | beam.Create([ + {'source': 'Mahatma Ghandi', 'quote': 'My life is my message.'}, + {'source': 'Yoda', 'quote': "Do, or do not. There is no 'try'."}, + ]) + # [END model_bigqueryio_write_input] + # [START model_bigqueryio_write] - quotes = p | beam.Create( - [{'source': 'Mahatma Ghandi', 'quote': 'My life is my message.'}]) - quotes | 'Write' >> beam.io.Write( - beam.io.BigQuerySink( - 'my-project:output.output_table', - schema=schema, - write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE, - create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED)) + quotes | beam.io.WriteToBigQuery( + table_spec, + schema=table_schema, + write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED) # [END model_bigqueryio_write] @@ -1116,7 +1168,7 @@ def partition_fn(student, num_partitions): fortieth_percentile = by_decile[4] # [END model_multiple_pcollections_partition_40th] - ([by_decile[d] for d in xrange(10) if d != 4] + [fortieth_percentile] + ([by_decile[d] for d in range(10) if d != 4] + [fortieth_percentile] | beam.Flatten() | beam.io.WriteToText(output_path)) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index e731236c42e1..19d77d9c4a97 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -33,6 +33,7 @@ from apache_beam.examples.snippets import snippets from apache_beam.metrics import Metrics from apache_beam.metrics.metric import MetricsFilter +from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that @@ -308,10 +309,14 @@ def expand(self, pcoll): beam.typehints.Tuple[int, int]) def test_runtime_checks_off(self): + # We do not run the following pipeline, as it has incorrect type + # information, and may fail with obscure errors, depending on the runner + # implementation. + # pylint: disable=expression-not-assigned - with TestPipeline() as p: - # [START type_hints_runtime_off] - p | beam.Create(['a']) | beam.Map(lambda x: 3).with_output_types(str) + p = TestPipeline() + # [START type_hints_runtime_off] + p | beam.Create(['a']) | beam.Map(lambda x: 3).with_output_types(str) # [END type_hints_runtime_off] def test_runtime_checks_on(self): @@ -328,7 +333,7 @@ def test_deterministic_key(self): lines = (p | beam.Create( ['banana,fruit,3', 'kiwi,fruit,2', 'kiwi,fruit,2', 'zucchini,veg,3'])) - # For pickling + # For pickling. global Player # pylint: disable=global-variable-not-assigned # [START type_hints_deterministic_key] @@ -454,7 +459,7 @@ def setUp(self): def tearDown(self): beam.io.ReadFromText = self.old_read_from_text beam.io.WriteToText = self.old_write_to_text - # Cleanup all the temporary files created in the test + # Cleanup all the temporary files created in the test. map(os.remove, self.temp_files) def create_temp_file(self, contents=''): @@ -592,7 +597,7 @@ def test_model_textio_compressed(self): @unittest.skipIf(datastore_pb2 is None, 'GCP dependencies are not installed') def test_model_datastoreio(self): - # We cannot test datastoreio functionality in unit tests therefore we limit + # We cannot test DatastoreIO functionality in unit tests, therefore we limit # ourselves to making sure the pipeline containing Datastore read and write # transforms can be built. # TODO(vikasrk): Expore using Datastore Emulator. @@ -600,10 +605,26 @@ def test_model_datastoreio(self): @unittest.skipIf(base_api is None, 'GCP dependencies are not installed') def test_model_bigqueryio(self): - # We cannot test BigQueryIO functionality in unit tests therefore we limit + # We cannot test BigQueryIO functionality in unit tests, therefore we limit # ourselves to making sure the pipeline containing BigQuery sources and # sinks can be built. - snippets.model_bigqueryio() + # + # To run locally, set `run_locally` to `True`. You will also have to set + # `project`, `dataset` and `table` to the BigQuery table the test will write + # to. + run_locally = False + if run_locally: + project = 'my-project' + dataset = 'samples' # this must already exist + table = 'model_bigqueryio' # this will be created if needed + + options = PipelineOptions().view_as(GoogleCloudOptions) + options.project = project + with beam.Pipeline(options=options) as p: + snippets.model_bigqueryio(p, project, dataset, table) + else: + p = TestPipeline() + snippets.model_bigqueryio(p) def _run_test_pipeline_for_options(self, fn): temp_path = self.create_temp_file('aa\nbb\ncc') @@ -1027,7 +1048,7 @@ def test_composite(self): # [START model_composite_transform] class ComputeWordLengths(beam.PTransform): def expand(self, pcoll): - # transform logic goes here + # Transform logic goes here. return pcoll | beam.Map(lambda x: len(x)) # [END model_composite_transform] diff --git a/sdks/python/apache_beam/examples/wordcount_it_test.py b/sdks/python/apache_beam/examples/wordcount_it_test.py index 8532f49e186b..fe42a4fa5f01 100644 --- a/sdks/python/apache_beam/examples/wordcount_it_test.py +++ b/sdks/python/apache_beam/examples/wordcount_it_test.py @@ -65,7 +65,7 @@ def test_wordcount_it(self): # and start pipeline job by calling pipeline main function. wordcount.run(test_pipeline.get_full_options_as_args(**extra_opts)) - @attr('IT') + @attr('IT', 'ValidatesContainer') def test_wordcount_fnapi_it(self): test_pipeline = TestPipeline(is_integration_test=True) diff --git a/sdks/python/apache_beam/internal/gcp/auth.py b/sdks/python/apache_beam/internal/gcp/auth.py index 8478e1b475c0..4c61775a7a69 100644 --- a/sdks/python/apache_beam/internal/gcp/auth.py +++ b/sdks/python/apache_beam/internal/gcp/auth.py @@ -21,12 +21,13 @@ import json import logging import os -import urllib2 from oauth2client.client import GoogleCredentials from oauth2client.client import OAuth2Credentials from apache_beam.utils import retry +from six.moves.urllib.request import Request +from six.moves.urllib.request import urlopen # When we are running in GCE, we can authenticate with VM credentials. is_running_in_gce = False @@ -89,8 +90,8 @@ def _refresh(self, http_request): 'GCE_METADATA_ROOT', 'metadata.google.internal') token_url = ('http://{}/computeMetadata/v1/instance/service-accounts/' 'default/token').format(metadata_root) - req = urllib2.Request(token_url, headers={'Metadata-Flavor': 'Google'}) - token_data = json.loads(urllib2.urlopen(req).read()) + req = Request(token_url, headers={'Metadata-Flavor': 'Google'}) + token_data = json.loads(urlopen(req).read()) self.access_token = token_data['access_token'] self.token_expiry = (refresh_time + datetime.timedelta(seconds=token_data['expires_in'])) diff --git a/sdks/python/apache_beam/internal/gcp/json_value.py b/sdks/python/apache_beam/internal/gcp/json_value.py index 7a5dc543eefa..c4f3d7ba4da6 100644 --- a/sdks/python/apache_beam/internal/gcp/json_value.py +++ b/sdks/python/apache_beam/internal/gcp/json_value.py @@ -25,6 +25,8 @@ extra_types = None # pylint: enable=wrong-import-order, wrong-import-position +import six + from apache_beam.options.value_provider import ValueProvider _MAXINT64 = (1 << 63) - 1 @@ -47,7 +49,7 @@ def get_typed_value_descriptor(obj): ~exceptions.TypeError: if the Python object has a type that is not supported. """ - if isinstance(obj, basestring): + if isinstance(obj, six.string_types): type_name = 'Text' elif isinstance(obj, bool): type_name = 'Boolean' @@ -92,20 +94,18 @@ def to_json_value(obj, with_type=False): entries=[to_json_value(o, with_type=with_type) for o in obj])) elif isinstance(obj, dict): json_object = extra_types.JsonObject() - for k, v in obj.iteritems(): + for k, v in obj.items(): json_object.properties.append( extra_types.JsonObject.Property( key=k, value=to_json_value(v, with_type=with_type))) return extra_types.JsonValue(object_value=json_object) elif with_type: return to_json_value(get_typed_value_descriptor(obj), with_type=False) - elif isinstance(obj, basestring): + elif isinstance(obj, six.string_types): return extra_types.JsonValue(string_value=obj) elif isinstance(obj, bool): return extra_types.JsonValue(boolean_value=obj) - elif isinstance(obj, int): - return extra_types.JsonValue(integer_value=obj) - elif isinstance(obj, long): + elif isinstance(obj, six.integer_types): if _MININT64 <= obj <= _MAXINT64: return extra_types.JsonValue(integer_value=obj) else: diff --git a/sdks/python/apache_beam/internal/gcp/json_value_test.py b/sdks/python/apache_beam/internal/gcp/json_value_test.py index 14223f11c786..c22d067beea0 100644 --- a/sdks/python/apache_beam/internal/gcp/json_value_test.py +++ b/sdks/python/apache_beam/internal/gcp/json_value_test.py @@ -89,14 +89,14 @@ def test_none_from(self): def test_large_integer(self): num = 1 << 35 self.assertEquals(num, from_json_value(to_json_value(num))) - self.assertEquals(long(num), from_json_value(to_json_value(long(num)))) def test_long_value(self): - self.assertEquals(long(27), from_json_value(to_json_value(long(27)))) + num = 1 << 63 - 1 + self.assertEquals(num, from_json_value(to_json_value(num))) def test_too_long_value(self): with self.assertRaises(TypeError): - to_json_value(long(1 << 64)) + to_json_value(1 << 64) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/internal/util.py b/sdks/python/apache_beam/internal/util.py index e4f230b8eb15..e74dd4333ac6 100644 --- a/sdks/python/apache_beam/internal/util.py +++ b/sdks/python/apache_beam/internal/util.py @@ -79,7 +79,7 @@ def swapper(value): # by sorting the entries first. This will be important when putting back # PValues. new_kwargs = dict((k, swapper(v)) if isinstance(v, pvalue_classes) else (k, v) - for k, v in sorted(kwargs.iteritems())) + for k, v in sorted(kwargs.items())) return (new_args, new_kwargs, pvals) @@ -104,7 +104,7 @@ def insert_values_in_args(args, kwargs, values): for arg in args] new_kwargs = dict( (k, next(v_iter)) if isinstance(v, ArgumentPlaceholder) else (k, v) - for k, v in sorted(kwargs.iteritems())) + for k, v in sorted(kwargs.items())) return (new_args, new_kwargs) diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py index ba1a49503260..5b3bbf200a14 100644 --- a/sdks/python/apache_beam/io/filebasedsink.py +++ b/sdks/python/apache_beam/io/filebasedsink.py @@ -198,11 +198,11 @@ def finalize_write(self, init_result, writer_results): destination_files.append(final_name) source_file_batch = [source_files[i:i + chunk_size] - for i in xrange(0, len(source_files), - chunk_size)] + for i in range(0, len(source_files), + chunk_size)] destination_file_batch = [destination_files[i:i + chunk_size] - for i in xrange(0, len(destination_files), - chunk_size)] + for i in range(0, len(destination_files), + chunk_size)] logging.info( 'Starting finalize_write threads with num_shards: %d, ' diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 900825043b8e..a80896c78181 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -25,7 +25,8 @@ For an example implementation of :class:`FileBasedSource` see :class:`~apache_beam.io._AvroSource`. """ -import uuid + +from six import integer_types from apache_beam.internal import pickler from apache_beam.io import concat_source @@ -38,13 +39,10 @@ from apache_beam.options.value_provider import ValueProvider from apache_beam.options.value_provider import check_accessible from apache_beam.transforms.core import DoFn -from apache_beam.transforms.core import FlatMap -from apache_beam.transforms.core import GroupByKey -from apache_beam.transforms.core import Map from apache_beam.transforms.core import ParDo from apache_beam.transforms.core import PTransform from apache_beam.transforms.display import DisplayDataItem -from apache_beam.transforms.trigger import DefaultTrigger +from apache_beam.transforms.util import Reshuffle MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25 @@ -236,11 +234,11 @@ class _SingleFileSource(iobase.BoundedSource): def __init__(self, file_based_source, file_name, start_offset, stop_offset, min_bundle_size=0, splittable=True): - if not isinstance(start_offset, (int, long)): + if not isinstance(start_offset, integer_types): raise TypeError( 'start_offset must be a number. Received: %r' % start_offset) if stop_offset != range_trackers.OffsetRangeTracker.OFFSET_INFINITY: - if not isinstance(stop_offset, (int, long)): + if not isinstance(stop_offset, integer_types): raise TypeError( 'stop_offset must be a number. Received: %r' % stop_offset) if start_offset >= stop_offset: @@ -354,25 +352,6 @@ def process(self, element, *args, **kwargs): 0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)) -# Replace following with a generic reshard transform once -# https://issues.apache.org/jira/browse/BEAM-1872 is implemented. -class _Reshard(PTransform): - - def expand(self, pvalue): - keyed_pc = (pvalue - | 'AssignKey' >> Map(lambda x: (uuid.uuid4(), x))) - if keyed_pc.windowing.windowfn.is_merging(): - raise ValueError('Transform ReadAllFiles cannot be used in the presence ' - 'of merging windows') - if not isinstance(keyed_pc.windowing.triggerfn, DefaultTrigger): - raise ValueError('Transform ReadAllFiles cannot be used in the presence ' - 'of non-trivial triggers') - - return (keyed_pc | 'GroupByKey' >> GroupByKey() - # Using FlatMap below due to the possibility of key collisions. - | 'DropKey' >> FlatMap(lambda k_values: k_values[1])) - - class _ReadRange(DoFn): def __init__(self, source_from_file): @@ -432,5 +411,5 @@ def expand(self, pvalue): | 'ExpandIntoRanges' >> ParDo(_ExpandIntoRanges( self._splittable, self._compression_type, self._desired_bundle_size, self._min_bundle_size)) - | 'Reshard' >> _Reshard() + | 'Reshard' >> Reshuffle() | 'ReadRange' >> ParDo(_ReadRange(self._source_from_file))) diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index 0999510522f5..0110c3f683c8 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -330,7 +330,7 @@ def test_estimate_size_with_sampling_different_sizes(self): variance = 5 sizes = [] - for _ in xrange(num_files): + for _ in range(num_files): sizes.append(int(random.uniform(base_size - variance, base_size + variance))) pattern, _ = write_pattern(sizes) @@ -452,7 +452,7 @@ def test_read_file_gzip(self): def test_read_pattern_bzip2(self): _, lines = write_data(200) splits = [0, 34, 100, 140, 164, 188, 200] - chunks = [lines[splits[i-1]:splits[i]] for i in xrange(1, len(splits))] + chunks = [lines[splits[i-1]:splits[i]] for i in range(1, len(splits))] compressed_chunks = [] for c in chunks: compressobj = bz2.BZ2Compressor() @@ -470,7 +470,7 @@ def test_read_pattern_bzip2(self): def test_read_pattern_gzip(self): _, lines = write_data(200) splits = [0, 34, 100, 140, 164, 188, 200] - chunks = [lines[splits[i-1]:splits[i]] for i in xrange(1, len(splits))] + chunks = [lines[splits[i-1]:splits[i]] for i in range(1, len(splits))] compressed_chunks = [] for c in chunks: out = cStringIO.StringIO() @@ -517,7 +517,7 @@ def test_read_auto_single_file_gzip(self): def test_read_auto_pattern(self): _, lines = write_data(200) splits = [0, 34, 100, 140, 164, 188, 200] - chunks = [lines[splits[i - 1]:splits[i]] for i in xrange(1, len(splits))] + chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))] compressed_chunks = [] for c in chunks: out = cStringIO.StringIO() @@ -536,7 +536,7 @@ def test_read_auto_pattern(self): def test_read_auto_pattern_compressed_and_uncompressed(self): _, lines = write_data(200) splits = [0, 34, 100, 140, 164, 188, 200] - chunks = [lines[splits[i - 1]:splits[i]] for i in xrange(1, len(splits))] + chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))] chunks_to_write = [] for i, c in enumerate(chunks): if i%2 == 0: diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index 424462ab30c9..09739dc94454 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -26,6 +26,8 @@ import time import zlib +from six import integer_types + from apache_beam.utils.plugin import BeamPlugin logger = logging.getLogger(__name__) @@ -372,7 +374,7 @@ class FileMetadata(object): """ def __init__(self, path, size_in_bytes): assert isinstance(path, basestring) and path, "Path should be a string" - assert isinstance(size_in_bytes, (int, long)) and size_in_bytes >= 0, \ + assert isinstance(size_in_bytes, integer_types) and size_in_bytes >= 0, \ "Invalid value for size_in_bytes should %s (of type %s)" % ( size_in_bytes, type(size_in_bytes)) self.path = path diff --git a/sdks/python/apache_beam/io/filesystemio.py b/sdks/python/apache_beam/io/filesystemio.py new file mode 100644 index 000000000000..35e141bb7566 --- /dev/null +++ b/sdks/python/apache_beam/io/filesystemio.py @@ -0,0 +1,267 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utilities for ``FileSystem`` implementations.""" + +import abc +import io +import os + +__all__ = ['Downloader', 'Uploader', 'DownloaderStream', 'UploaderStream', + 'PipeStream'] + + +class Downloader(object): + """Download interface for a single file. + + Implementations should support random access reads. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractproperty + def size(self): + """Size of file to download.""" + + @abc.abstractmethod + def get_range(self, start, end): + """Retrieve a given byte range [start, end) from this download. + + Range must be in this form: + 0 <= start < end: Fetch the bytes from start to end. + + Args: + start: (int) Initial byte offset. + end: (int) Final byte offset, exclusive. + + Returns: + (string) A buffer containing the requested data. + """ + + +class Uploader(object): + """Upload interface for a single file.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def put(self, data): + """Write data to file sequentially. + + Args: + data: (memoryview) Data to write. + """ + + @abc.abstractmethod + def finish(self): + """Signal to upload any remaining data and close the file. + + File should be fully written upon return from this method. + + Raises: + Any error encountered during the upload. + """ + + +class DownloaderStream(io.RawIOBase): + """Provides a stream interface for Downloader objects.""" + + def __init__(self, downloader, mode='r'): + """Initializes the stream. + + Args: + downloader: (Downloader) Filesystem dependent implementation. + mode: (string) Python mode attribute for this stream. + """ + self._downloader = downloader + self.mode = mode + self._position = 0 + + def readinto(self, b): + """Read up to len(b) bytes into b. + + Returns number of bytes read (0 for EOF). + + Args: + b: (bytearray/memoryview) Buffer to read into. + """ + self._checkClosed() + if self._position >= self._downloader.size: + return 0 + + start = self._position + end = min(self._position + len(b), self._downloader.size) + data = self._downloader.get_range(start, end) + self._position += len(data) + b[:len(data)] = data + return len(data) + + def seek(self, offset, whence=os.SEEK_SET): + """Set the stream's current offset. + + Note if the new offset is out of bound, it is adjusted to either 0 or EOF. + + Args: + offset: seek offset as number. + whence: seek mode. Supported modes are os.SEEK_SET (absolute seek), + os.SEEK_CUR (seek relative to the current position), and os.SEEK_END + (seek relative to the end, offset should be negative). + + Raises: + ``ValueError``: When this stream is closed or if whence is invalid. + """ + self._checkClosed() + + if whence == os.SEEK_SET: + self._position = offset + elif whence == os.SEEK_CUR: + self._position += offset + elif whence == os.SEEK_END: + self._position = self._downloader.size + offset + else: + raise ValueError('Whence mode %r is invalid.' % whence) + + self._position = min(self._position, self._downloader.size) + self._position = max(self._position, 0) + return self._position + + def tell(self): + """Tell the stream's current offset. + + Returns: + current offset in reading this stream. + + Raises: + ``ValueError``: When this stream is closed. + """ + self._checkClosed() + return self._position + + def seekable(self): + return True + + def readable(self): + return True + + +class UploaderStream(io.RawIOBase): + """Provides a stream interface for Uploader objects.""" + + def __init__(self, uploader, mode='w'): + """Initializes the stream. + + Args: + uploader: (Uploader) Filesystem dependent implementation. + mode: (string) Python mode attribute for this stream. + """ + self._uploader = uploader + self.mode = mode + self._position = 0 + + def tell(self): + return self._position + + def write(self, b): + """Write bytes from b. + + Returns number of bytes written (<= len(b)). + + Args: + b: (memoryview) Buffer with data to write. + """ + self._checkClosed() + self._uploader.put(b) + + bytes_written = len(b) + self._position += bytes_written + return bytes_written + + def close(self): + """Complete the upload and close this stream. + + This method has no effect if the stream is already closed. + + Raises: + Any error encountered by the uploader. + """ + if not self.closed: + self._uploader.finish() + + super(UploaderStream, self).close() + + def writable(self): + return True + + +class PipeStream(object): + """A class that presents a pipe connection as a readable stream.""" + + def __init__(self, recv_pipe): + self.conn = recv_pipe + self.closed = False + self.position = 0 + self.remaining = '' + + def read(self, size): + """Read data from the wrapped pipe connection. + + Args: + size: Number of bytes to read. Actual number of bytes read is always + equal to size unless EOF is reached. + + Returns: + data read as str. + """ + data_list = [] + bytes_read = 0 + while bytes_read < size: + bytes_from_remaining = min(size - bytes_read, len(self.remaining)) + data_list.append(self.remaining[0:bytes_from_remaining]) + self.remaining = self.remaining[bytes_from_remaining:] + self.position += bytes_from_remaining + bytes_read += bytes_from_remaining + if not self.remaining: + try: + self.remaining = self.conn.recv_bytes() + except EOFError: + break + return ''.join(data_list) + + def tell(self): + """Tell the file's current offset. + + Returns: + current offset in reading this file. + + Raises: + ``ValueError``: When this stream is closed. + """ + self._check_open() + return self.position + + def seek(self, offset, whence=os.SEEK_SET): + # The apitools library used by the gcsio.Uploader class insists on seeking + # to the end of a stream to do a check before completing an upload, so we + # must have this no-op method here in that case. + if whence == os.SEEK_END and offset == 0: + return + elif whence == os.SEEK_SET and offset == self.position: + return + raise NotImplementedError + + def _check_open(self): + if self.closed: + raise IOError('Stream is closed.') diff --git a/sdks/python/apache_beam/io/filesystemio_test.py b/sdks/python/apache_beam/io/filesystemio_test.py new file mode 100644 index 000000000000..2f1de9dcedac --- /dev/null +++ b/sdks/python/apache_beam/io/filesystemio_test.py @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for filesystemio.""" + +import io +import multiprocessing +import os +import threading +import unittest + +from apache_beam.io import filesystemio + + +class FakeDownloader(filesystemio.Downloader): + + def __init__(self, data): + self._data = data + self.last_read_size = -1 + + @property + def size(self): + return len(self._data) + + def get_range(self, start, end): + self.last_read_size = end - start + return self._data[start:end] + + +class FakeUploader(filesystemio.Uploader): + + def __init__(self): + self.data = '' + self.last_write_size = -1 + self.finished = False + + def last_error(self): + return None + + def put(self, data): + assert not self.finished + self.data += data.tobytes() + self.last_write_size = len(data) + + def finish(self): + self.finished = True + + +class TestDownloaderStream(unittest.TestCase): + + def test_file_attributes(self): + downloader = FakeDownloader(data=None) + stream = filesystemio.DownloaderStream(downloader) + self.assertEqual(stream.mode, 'r') + self.assertTrue(stream.readable()) + self.assertFalse(stream.writable()) + self.assertTrue(stream.seekable()) + + def test_read_empty(self): + downloader = FakeDownloader(data='') + stream = filesystemio.DownloaderStream(downloader) + self.assertEqual(stream.read(), '') + + def test_read(self): + data = 'abcde' + downloader = FakeDownloader(data) + stream = filesystemio.DownloaderStream(downloader) + + # Read size is exactly what was passed to read() (unbuffered). + self.assertEqual(stream.read(1), data[0]) + self.assertEqual(downloader.last_read_size, 1) + self.assertEqual(stream.read(), data[1:]) + self.assertEqual(downloader.last_read_size, len(data) - 1) + + def test_read_buffered(self): + data = 'abcde' + downloader = FakeDownloader(data) + buffer_size = 2 + stream = io.BufferedReader(filesystemio.DownloaderStream(downloader), + buffer_size) + + # Verify that buffering works and is reading ahead. + self.assertEqual(stream.read(1), data[0]) + self.assertEqual(downloader.last_read_size, buffer_size) + self.assertEqual(stream.read(), data[1:]) + + +class TestUploaderStream(unittest.TestCase): + + def test_file_attributes(self): + uploader = FakeUploader() + stream = filesystemio.UploaderStream(uploader) + self.assertEqual(stream.mode, 'w') + self.assertFalse(stream.readable()) + self.assertTrue(stream.writable()) + self.assertFalse(stream.seekable()) + + def test_write_empty(self): + uploader = FakeUploader() + stream = filesystemio.UploaderStream(uploader) + data = '' + stream.write(memoryview(data)) + self.assertEqual(uploader.data, data) + + def test_write(self): + data = 'abcde' + uploader = FakeUploader() + stream = filesystemio.UploaderStream(uploader) + + # Unbuffered writes. + stream.write(memoryview(data[0])) + self.assertEqual(uploader.data[0], data[0]) + self.assertEqual(uploader.last_write_size, 1) + stream.write(memoryview(data[1:])) + self.assertEqual(uploader.data, data) + self.assertEqual(uploader.last_write_size, len(data) - 1) + + def test_write_buffered(self): + data = 'abcde' + uploader = FakeUploader() + buffer_size = 2 + stream = io.BufferedWriter(filesystemio.UploaderStream(uploader), + buffer_size) + + # Verify that buffering works: doesn't write to uploader until buffer is + # filled. + stream.write(data[0]) + self.assertEqual(-1, uploader.last_write_size) + stream.write(data[1:]) + stream.close() + self.assertEqual(data, uploader.data) + + +class TestPipeStream(unittest.TestCase): + + def _read_and_verify(self, stream, expected, buffer_size): + data_list = [] + bytes_read = 0 + seen_last_block = False + while True: + data = stream.read(buffer_size) + self.assertLessEqual(len(data), buffer_size) + if len(data) < buffer_size: + # Test the constraint that the pipe stream returns less than the buffer + # size only when at the end of the stream. + if data: + self.assertFalse(seen_last_block) + seen_last_block = True + if not data: + break + data_list.append(data) + bytes_read += len(data) + self.assertEqual(stream.tell(), bytes_read) + self.assertEqual(''.join(data_list), expected) + + def test_pipe_stream(self): + block_sizes = list(4**i for i in range(0, 12)) + data_blocks = list(os.urandom(size) for size in block_sizes) + expected = ''.join(data_blocks) + + buffer_sizes = [100001, 512 * 1024, 1024 * 1024] + + for buffer_size in buffer_sizes: + parent_conn, child_conn = multiprocessing.Pipe() + stream = filesystemio.PipeStream(child_conn) + child_thread = threading.Thread( + target=self._read_and_verify, args=(stream, expected, buffer_size)) + child_thread.start() + for data in data_blocks: + parent_conn.send_bytes(data) + parent_conn.close() + child_thread.join() diff --git a/sdks/python/apache_beam/io/filesystems.py b/sdks/python/apache_beam/io/filesystems.py index dad4e5f9f27d..6150907631da 100644 --- a/sdks/python/apache_beam/io/filesystems.py +++ b/sdks/python/apache_beam/io/filesystems.py @@ -51,7 +51,7 @@ def set_options(cls, pipeline_options): Args: pipeline_options: Instance of ``PipelineOptions``. """ - cls._options = pipeline_options + cls._pipeline_options = pipeline_options @staticmethod def get_scheme(path): diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 9afb75ec035d..3bdf2e64ca2e 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -23,6 +23,7 @@ import cStringIO import errno import fnmatch +import io import logging import multiprocessing import os @@ -33,6 +34,11 @@ import httplib2 +from apache_beam.io.filesystemio import Downloader +from apache_beam.io.filesystemio import DownloaderStream +from apache_beam.io.filesystemio import PipeStream +from apache_beam.io.filesystemio import Uploader +from apache_beam.io.filesystemio import UploaderStream from apache_beam.utils import retry __all__ = ['GcsIO'] @@ -188,11 +194,14 @@ def open(self, ~exceptions.ValueError: Invalid open file mode. """ if mode == 'r' or mode == 'rb': - return GcsBufferedReader(self.client, filename, mode=mode, + downloader = GcsDownloader(self.client, filename, + buffer_size=read_buffer_size) + return io.BufferedReader(DownloaderStream(downloader, mode=mode), buffer_size=read_buffer_size) elif mode == 'w' or mode == 'wb': - return GcsBufferedWriter(self.client, filename, mode=mode, - mime_type=mime_type) + uploader = GcsUploader(self.client, filename, mime_type) + return io.BufferedWriter(UploaderStream(uploader, mode=mode), + buffer_size=128 * 1024) else: raise ValueError('Invalid file open mode: %s.' % mode) @@ -456,364 +465,77 @@ def size_of_files_in_glob(self, pattern, limit=None): return file_sizes -# TODO: Consider using cStringIO instead of buffers and data_lists when reading. -class GcsBufferedReader(object): - """A class for reading Google Cloud Storage files.""" - - def __init__(self, - client, - path, - mode='r', - buffer_size=DEFAULT_READ_BUFFER_SIZE): - self.client = client - self.path = path - self.bucket, self.name = parse_gcs_path(path) - self.mode = mode - self.buffer_size = buffer_size +class GcsDownloader(Downloader): + def __init__(self, client, path, buffer_size): + self._client = client + self._path = path + self._bucket, self._name = parse_gcs_path(path) + self._buffer_size = buffer_size # Get object state. - self.get_request = (storage.StorageObjectsGetRequest( - bucket=self.bucket, object=self.name)) + self._get_request = (storage.StorageObjectsGetRequest( + bucket=self._bucket, object=self._name)) try: - metadata = self._get_object_metadata(self.get_request) + metadata = self._get_object_metadata(self._get_request) except HttpError as http_error: if http_error.status_code == 404: - raise IOError(errno.ENOENT, 'Not found: %s' % self.path) + raise IOError(errno.ENOENT, 'Not found: %s' % self._path) else: - logging.error('HTTP error while requesting file %s: %s', self.path, + logging.error('HTTP error while requesting file %s: %s', self._path, http_error) raise - self.size = metadata.size + self._size = metadata.size # Ensure read is from file of the correct generation. - self.get_request.generation = metadata.generation + self._get_request.generation = metadata.generation # Initialize read buffer state. - self.download_stream = cStringIO.StringIO() - self.downloader = transfer.Download( - self.download_stream, auto_transfer=False, chunksize=self.buffer_size) - self.client.objects.Get(self.get_request, download=self.downloader) - self.position = 0 - self.buffer = '' - self.buffer_start_position = 0 - self.closed = False + self._download_stream = cStringIO.StringIO() + self._downloader = transfer.Download( + self._download_stream, auto_transfer=False, chunksize=self._buffer_size) + self._client.objects.Get(self._get_request, download=self._downloader) @retry.with_exponential_backoff( retry_filter=retry.retry_on_server_errors_and_timeout_filter) def _get_object_metadata(self, get_request): - return self.client.objects.Get(get_request) - - def __iter__(self): - return self - - def __next__(self): - """Read one line delimited by '\\n' from the file. - """ - return next(self) - - def next(self): - """Read one line delimited by '\\n' from the file. - """ - line = self.readline() - if not line: - raise StopIteration - return line - - def read(self, size=-1): - """Read data from a GCS file. + return self._client.objects.Get(get_request) - Args: - size: Number of bytes to read. Actual number of bytes read is always - equal to size unless EOF is reached. If size is negative or - unspecified, read the entire file. - - Returns: - data read as str. - - Raises: - IOError: When this buffer is closed. - """ - return self._read_inner(size=size, readline=False) + @property + def size(self): + return self._size - def readline(self, size=-1): - """Read one line delimited by '\\n' from the file. + def get_range(self, start, end): + self._download_stream.truncate(0) + self._downloader.GetRange(start, end - 1) + return self._download_stream.getvalue() - Mimics behavior of the readline() method on standard file objects. - A trailing newline character is kept in the string. It may be absent when a - file ends with an incomplete line. If the size argument is non-negative, - it specifies the maximum string size (counting the newline) to return. - A negative size is the same as unspecified. Empty string is returned - only when EOF is encountered immediately. +class GcsUploader(Uploader): + def __init__(self, client, path, mime_type): + self._client = client + self._path = path + self._bucket, self._name = parse_gcs_path(path) + self._mime_type = mime_type - Args: - size: Maximum number of bytes to read. If not specified, readline stops - only on '\\n' or EOF. - - Returns: - The data read as a string. - - Raises: - IOError: When this buffer is closed. - """ - return self._read_inner(size=size, readline=True) - - def _read_inner(self, size=-1, readline=False): - """Shared implementation of read() and readline().""" - self._check_open() - if not self._remaining(): - return '' - - # Prepare to read. - data_list = [] - if size is None: - size = -1 - to_read = min(size, self._remaining()) - if to_read < 0: - to_read = self._remaining() - break_after = False - - while to_read > 0: - # If we have exhausted the buffer, get the next segment. - # TODO(ccy): We should consider prefetching the next block in another - # thread. - self._fetch_next_if_buffer_exhausted() - - # Determine number of bytes to read from buffer. - buffer_bytes_read = self.position - self.buffer_start_position - bytes_to_read_from_buffer = min( - len(self.buffer) - buffer_bytes_read, to_read) - - # If readline is set, we only want to read up to and including the next - # newline character. - if readline: - next_newline_position = self.buffer.find('\n', buffer_bytes_read, - len(self.buffer)) - if next_newline_position != -1: - bytes_to_read_from_buffer = ( - 1 + next_newline_position - buffer_bytes_read) - break_after = True - - # Read bytes. - data_list.append(self.buffer[buffer_bytes_read:buffer_bytes_read + - bytes_to_read_from_buffer]) - self.position += bytes_to_read_from_buffer - to_read -= bytes_to_read_from_buffer - - if break_after: - break - - return ''.join(data_list) - - def _fetch_next_if_buffer_exhausted(self): - if not self.buffer or ( - self.buffer_start_position + len(self.buffer) <= self.position): - bytes_to_request = min(self._remaining(), self.buffer_size) - self.buffer_start_position = self.position - try: - result = self._get_segment(self.position, bytes_to_request) - except Exception as e: # pylint: disable=broad-except - tb = traceback.format_exc() - logging.error( - ('Exception while fetching %d bytes from position %d of %s: ' - '%s\n%s'), - bytes_to_request, self.position, self.path, e, tb) - raise - - self.buffer = result - return - - def _remaining(self): - return self.size - self.position - - def close(self): - """Close the current GCS file.""" - self.closed = True - self.download_stream = None - self.downloader = None - self.buffer = None - - def _get_segment(self, start, size): - """Get the given segment of the current GCS file.""" - if size == 0: - return '' - # The objects self.downloader and self.download_stream may be recreated if - # this call times out; we save them locally to avoid any threading issues. - downloader = self.downloader - download_stream = self.download_stream - end = start + size - 1 - downloader.GetRange(start, end) - value = download_stream.getvalue() - # Clear the cStringIO object after we've read its contents. - download_stream.truncate(0) - assert len(value) == size - return value - - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - self.close() - - def seek(self, offset, whence=os.SEEK_SET): - """Set the file's current offset. - - Note if the new offset is out of bound, it is adjusted to either 0 or EOF. - - Args: - offset: seek offset as number. - whence: seek mode. Supported modes are os.SEEK_SET (absolute seek), - os.SEEK_CUR (seek relative to the current position), and os.SEEK_END - (seek relative to the end, offset should be negative). - - Raises: - IOError: When this buffer is closed. - ValueError: When whence is invalid. - """ - self._check_open() - - self.buffer = '' - self.buffer_start_position = -1 - - if whence == os.SEEK_SET: - self.position = offset - elif whence == os.SEEK_CUR: - self.position += offset - elif whence == os.SEEK_END: - self.position = self.size + offset - else: - raise ValueError('Whence mode %r is invalid.' % whence) - - self.position = min(self.position, self.size) - self.position = max(self.position, 0) - - def tell(self): - """Tell the file's current offset. - - Returns: - current offset in reading this file. - - Raises: - IOError: When this buffer is closed. - """ - self._check_open() - return self.position - - def _check_open(self): - if self.closed: - raise IOError('Buffer is closed.') - - def seekable(self): - return True - - def readable(self): - return True - - def writable(self): - return False - - -# TODO: Consider using cStringIO instead of buffers and data_lists when reading -# and writing. -class GcsBufferedWriter(object): - """A class for writing Google Cloud Storage files.""" - - class PipeStream(object): - """A class that presents a pipe connection as a readable stream.""" - - def __init__(self, recv_pipe): - self.conn = recv_pipe - self.closed = False - self.position = 0 - self.remaining = '' - - def read(self, size): - """Read data from the wrapped pipe connection. - - Args: - size: Number of bytes to read. Actual number of bytes read is always - equal to size unless EOF is reached. - - Returns: - data read as str. - """ - data_list = [] - bytes_read = 0 - while bytes_read < size: - bytes_from_remaining = min(size - bytes_read, len(self.remaining)) - data_list.append(self.remaining[0:bytes_from_remaining]) - self.remaining = self.remaining[bytes_from_remaining:] - self.position += bytes_from_remaining - bytes_read += bytes_from_remaining - if not self.remaining: - try: - self.remaining = self.conn.recv_bytes() - except EOFError: - break - return ''.join(data_list) - - def tell(self): - """Tell the file's current offset. - - Returns: - current offset in reading this file. - - Raises: - IOError: When this stream is closed. - """ - self._check_open() - return self.position - - def seek(self, offset, whence=os.SEEK_SET): - # The apitools.base.py.transfer.Upload class insists on seeking to the end - # of a stream to do a check before completing an upload, so we must have - # this no-op method here in that case. - if whence == os.SEEK_END and offset == 0: - return - elif whence == os.SEEK_SET and offset == self.position: - return - raise NotImplementedError - - def _check_open(self): - if self.closed: - raise IOError('Stream is closed.') - - def __init__(self, - client, - path, - mode='w', - mime_type='application/octet-stream'): - self.client = client - self.path = path - self.mode = mode - self.bucket, self.name = parse_gcs_path(path) - - self.closed = False - self.position = 0 - - # A small buffer to avoid CPU-heavy per-write pipe calls. - self.write_buffer = bytearray() - self.write_buffer_size = 128 * 1024 - - # Set up communication with uploading thread. + # Set up communication with child thread. parent_conn, child_conn = multiprocessing.Pipe() - self.child_conn = child_conn - self.conn = parent_conn + self._child_conn = child_conn + self._conn = parent_conn # Set up uploader. - self.insert_request = (storage.StorageObjectsInsertRequest( - bucket=self.bucket, name=self.name)) - self.upload = transfer.Upload( - GcsBufferedWriter.PipeStream(child_conn), - mime_type, + self._insert_request = (storage.StorageObjectsInsertRequest( + bucket=self._bucket, name=self._name)) + self._upload = transfer.Upload( + PipeStream(self._child_conn), + self._mime_type, chunksize=WRITE_CHUNK_SIZE) - self.upload.strategy = transfer.RESUMABLE_UPLOAD + self._upload.strategy = transfer.RESUMABLE_UPLOAD # Start uploading thread. - self.upload_thread = threading.Thread(target=self._start_upload) - self.upload_thread.daemon = True - self.upload_thread.last_error = None - self.upload_thread.start() + self._upload_thread = threading.Thread(target=self._start_upload) + self._upload_thread.daemon = True + self._upload_thread.last_error = None + self._upload_thread.start() # TODO(silviuc): Refactor so that retry logic can be applied. # There is retry logic in the underlying transfer library but we should make @@ -827,79 +549,27 @@ def _start_upload(self): # The uploader by default transfers data in chunks of 1024 * 1024 bytes at # a time, buffering writes until that size is reached. try: - self.client.objects.Insert(self.insert_request, upload=self.upload) + self._client.objects.Insert(self._insert_request, upload=self._upload) except Exception as e: # pylint: disable=broad-except logging.error('Error in _start_upload while inserting file %s: %s', - self.path, traceback.format_exc()) - self.upload_thread.last_error = e + self._path, traceback.format_exc()) + self._upload_thread.last_error = e finally: - self.child_conn.close() - - def write(self, data): - """Write data to a GCS file. - - Args: - data: data to write as str. + self._child_conn.close() - Raises: - IOError: When this buffer is closed. - """ - self._check_open() - if not data: - return - self.write_buffer.extend(data) - if len(self.write_buffer) > self.write_buffer_size: - self._flush_write_buffer() - self.position += len(data) - - def flush(self): - """Flushes any internal buffer to the underlying GCS file.""" - self._check_open() - self._flush_write_buffer() - - def tell(self): - """Return the total number of bytes passed to write() so far.""" - return self.position - - def close(self): - """Close the current GCS file.""" - if self.closed: - logging.warn('Channel for %s is not open.', self.path) - return - - self._flush_write_buffer() - self.closed = True - self.conn.close() - self.upload_thread.join() - # Check for exception since the last _flush_write_buffer() call. - if self.upload_thread.last_error: - raise self.upload_thread.last_error # pylint: disable=raising-bad-type - - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - self.close() - - def _check_open(self): - if self.closed: - raise IOError('Buffer is closed.') - - def seekable(self): - return False - - def readable(self): - return False - - def writable(self): - return True - - def _flush_write_buffer(self): + def put(self, data): try: - self.conn.send_bytes(buffer(self.write_buffer)) - self.write_buffer = bytearray() - except IOError: - if self.upload_thread.last_error: - raise self.upload_thread.last_error # pylint: disable=raising-bad-type - else: - raise + self._conn.send_bytes(data.tobytes()) + except EOFError: + if self._upload_thread.last_error is not None: + raise self._upload_thread.last_error # pylint: disable=raising-bad-type + raise + + def finish(self): + self._conn.close() + # TODO(udim): Add timeout=DEFAULT_HTTP_TIMEOUT_SECONDS * 2 and raise if + # isAlive is True. + self._upload_thread.join() + # Check for exception since the last put() call. + if self._upload_thread.last_error is not None: + raise self._upload_thread.last_error # pylint: disable=raising-bad-type diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 6994c523032a..63db8d56cb60 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -18,10 +18,8 @@ import errno import logging -import multiprocessing import os import random -import threading import unittest import httplib2 @@ -97,7 +95,9 @@ def Get(self, get_request, download=None): # pylint: disable=invalid-name stream = download.stream def get_range_callback(start, end): - assert start >= 0 and end >= start and end < len(f.contents) + if not (start >= 0 and end >= start and end < len(f.contents)): + raise ValueError( + 'start=%d end=%d len=%s' % (start, end, len(f.contents))) stream.write(f.contents[start:end + 1]) download.GetRange = get_range_callback @@ -769,48 +769,6 @@ def test_size_of_files_in_glob_limited(self): len(self.gcs.glob(file_pattern, limit)), expected_num_items) -@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') -class TestPipeStream(unittest.TestCase): - - def _read_and_verify(self, stream, expected, buffer_size): - data_list = [] - bytes_read = 0 - seen_last_block = False - while True: - data = stream.read(buffer_size) - self.assertLessEqual(len(data), buffer_size) - if len(data) < buffer_size: - # Test the constraint that the pipe stream returns less than the buffer - # size only when at the end of the stream. - if data: - self.assertFalse(seen_last_block) - seen_last_block = True - if not data: - break - data_list.append(data) - bytes_read += len(data) - self.assertEqual(stream.tell(), bytes_read) - self.assertEqual(''.join(data_list), expected) - - def test_pipe_stream(self): - block_sizes = list(4**i for i in range(0, 12)) - data_blocks = list(os.urandom(size) for size in block_sizes) - expected = ''.join(data_blocks) - - buffer_sizes = [100001, 512 * 1024, 1024 * 1024] - - for buffer_size in buffer_sizes: - parent_conn, child_conn = multiprocessing.Pipe() - stream = gcsio.GcsBufferedWriter.PipeStream(child_conn) - child_thread = threading.Thread( - target=self._read_and_verify, args=(stream, expected, buffer_size)) - child_thread.start() - for data in data_blocks: - parent_conn.send_bytes(data) - parent_conn.close() - child_thread.join() - - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 0c4ba02db87a..8bd9fa4f41aa 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -22,10 +22,13 @@ import hamcrest as hc +import apache_beam as beam from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub from apache_beam.io.gcp.pubsub import WriteStringsToPubSub from apache_beam.io.gcp.pubsub import _PubSubPayloadSink from apache_beam.io.gcp.pubsub import _PubSubPayloadSource +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.direct.direct_runner import _get_transform_overrides from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher @@ -40,28 +43,51 @@ @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') -class TestReadStringsFromPubSub(unittest.TestCase): +class TestReadStringsFromPubSubOverride(unittest.TestCase): def test_expand_with_topic(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', - None, 'a_label') - # Ensure that the output type is str + p.options.view_as(StandardOptions).streaming = True + pcoll = (p + | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', + None, 'a_label') + | beam.Map(lambda x: x)) + # Ensure that the output type is str. self.assertEqual(unicode, pcoll.element_type) + # Apply the necessary PTransformOverrides. + overrides = _get_transform_overrides(p.options) + p.replace_all(overrides) + + # Note that the direct output of ReadStringsFromPubSub will be replaced + # by a PTransformOverride, so we use a no-op Map. + read_transform = pcoll.producer.inputs[0].producer.transform + # Ensure that the properties passed through correctly - source = pcoll.producer.transform._source + source = read_transform._source self.assertEqual('a_topic', source.topic_name) self.assertEqual('a_label', source.id_label) def test_expand_with_subscription(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub( - None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label') + p.options.view_as(StandardOptions).streaming = True + pcoll = (p + | ReadStringsFromPubSub( + None, 'projects/fakeprj/subscriptions/a_subscription', + 'a_label') + | beam.Map(lambda x: x)) # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) + # Apply the necessary PTransformOverrides. + overrides = _get_transform_overrides(p.options) + p.replace_all(overrides) + + # Note that the direct output of ReadStringsFromPubSub will be replaced + # by a PTransformOverride, so we use a no-op Map. + read_transform = pcoll.producer.inputs[0].producer.transform + # Ensure that the properties passed through correctly - source = pcoll.producer.transform._source + source = read_transform._source self.assertEqual('a_subscription', source.subscription_name) self.assertEqual('a_label', source.id_label) @@ -80,12 +106,22 @@ def test_expand_with_both_topic_and_subscription(self): class TestWriteStringsToPubSub(unittest.TestCase): def test_expand(self): p = TestPipeline() - pdone = (p + p.options.view_as(StandardOptions).streaming = True + pcoll = (p | ReadStringsFromPubSub('projects/fakeprj/topics/baz') - | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')) + | WriteStringsToPubSub('projects/fakeprj/topics/a_topic') + | beam.Map(lambda x: x)) + + # Apply the necessary PTransformOverrides. + overrides = _get_transform_overrides(p.options) + p.replace_all(overrides) + + # Note that the direct output of ReadStringsFromPubSub will be replaced + # by a PTransformOverride, so we use a no-op Map. + write_transform = pcoll.producer.inputs[0].producer.transform # Ensure that the properties passed through correctly - self.assertEqual('a_topic', pdone.producer.transform.dofn.topic_name) + self.assertEqual('a_topic', write_transform.dofn.topic_name) @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') diff --git a/sdks/python/apache_beam/io/hadoopfilesystem.py b/sdks/python/apache_beam/io/hadoopfilesystem.py index a761068f413b..054d56df6431 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem.py @@ -20,28 +20,72 @@ from __future__ import absolute_import +import io import logging import posixpath import re -from hdfs3 import HDFileSystem +import hdfs +from apache_beam.io import filesystemio from apache_beam.io.filesystem import BeamIOError from apache_beam.io.filesystem import CompressedFile from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.filesystem import FileMetadata from apache_beam.io.filesystem import FileSystem from apache_beam.io.filesystem import MatchResult +from apache_beam.options.pipeline_options import HadoopFileSystemOptions __all__ = ['HadoopFileSystem'] _HDFS_PREFIX = 'hdfs:/' _URL_RE = re.compile(r'^' + _HDFS_PREFIX + r'(/.*)') _COPY_BUFFER_SIZE = 2 ** 16 +_DEFAULT_BUFFER_SIZE = 20 * 1024 * 1024 +# WebHDFS FileStatus property constants. +_FILE_STATUS_NAME = 'name' +_FILE_STATUS_PATH_SUFFIX = 'pathSuffix' +_FILE_STATUS_TYPE = 'type' +_FILE_STATUS_TYPE_DIRECTORY = 'DIRECTORY' +_FILE_STATUS_TYPE_FILE = 'FILE' +_FILE_STATUS_SIZE = 'size' -# TODO(udim): Add @retry.with_exponential_backoff to some functions, like in -# gcsio.py. + +class HdfsDownloader(filesystemio.Downloader): + + def __init__(self, hdfs_client, path): + self._hdfs_client = hdfs_client + self._path = path + self._size = self._hdfs_client.status(path)[_FILE_STATUS_SIZE] + + @property + def size(self): + return self._size + + def get_range(self, start, end): + with self._hdfs_client.read( + self._path, offset=start, length=end - start + 1) as reader: + return reader.read() + + +class HdfsUploader(filesystemio.Uploader): + + def __init__(self, hdfs_client, path): + self._hdfs_client = hdfs_client + if self._hdfs_client.status(path, strict=False) is not None: + raise BeamIOError('Path already exists: %s' % path) + + self._handle_context = self._hdfs_client.write(path) + self._handle = self._handle_context.__enter__() + + def put(self, data): + self._handle.write(data) + + def finish(self): + self._handle.__exit__(None, None, None) + self._handle = None + self._handle_context = None class HadoopFileSystem(FileSystem): @@ -49,16 +93,31 @@ class HadoopFileSystem(FileSystem): URL arguments to methods expect strings starting with ``hdfs://``. - Uses client library :class:`hdfs3.core.HDFileSystem`. + Experimental; TODO(BEAM-3600): Writes are experimental until file rename + retries are better handled. """ def __init__(self, pipeline_options): """Initializes a connection to HDFS. - Connection configuration is done using :doc:`hdfs`. + Connection configuration is done by passing pipeline options. + See :class:`~apache_beam.options.pipeline_options.HadoopFileSystemOptions`. """ super(HadoopFileSystem, self).__init__(pipeline_options) - self._hdfs_client = HDFileSystem() + + if pipeline_options is None: + raise ValueError('pipeline_options is not set') + hdfs_options = pipeline_options.view_as(HadoopFileSystemOptions) + if hdfs_options.hdfs_host is None: + raise ValueError('hdfs_host is not set') + if hdfs_options.hdfs_port is None: + raise ValueError('hdfs_port is not set') + if hdfs_options.hdfs_user is None: + raise ValueError('hdfs_user is not set') + self._hdfs_client = hdfs.InsecureClient( + 'http://%s:%s' % ( + hdfs_options.hdfs_host, str(hdfs_options.hdfs_port)), + user=hdfs_options.hdfs_user) @classmethod def scheme(cls): @@ -108,7 +167,7 @@ def split(self, url): def mkdirs(self, url): path = self._parse_url(url) if self._exists(path): - raise IOError('Path already exists: %s' % path) + raise BeamIOError('Path already exists: %s' % path) return self._mkdirs(path) def _mkdirs(self, path): @@ -123,12 +182,17 @@ def match(self, url_patterns, limits=None): 'Patterns and limits should be equal in length: %d != %d' % ( len(url_patterns), len(limits))) - # TODO(udim): Update client to allow batched results. def _match(path_pattern, limit): """Find all matching paths to the pattern provided.""" - file_infos = self._hdfs_client.ls(path_pattern, detail=True)[:limit] - metadata_list = [FileMetadata(file_info['name'], file_info['size']) - for file_info in file_infos] + fs = self._hdfs_client.status(path_pattern, strict=False) + if fs and fs[_FILE_STATUS_TYPE] == _FILE_STATUS_TYPE_FILE: + file_statuses = [(fs[_FILE_STATUS_PATH_SUFFIX], fs)][:limit] + else: + file_statuses = self._hdfs_client.list(path_pattern, + status=True)[:limit] + metadata_list = [FileMetadata(file_status[1][_FILE_STATUS_NAME], + file_status[1][_FILE_STATUS_SIZE]) + for file_status in file_statuses] return MatchResult(path_pattern, metadata_list) exceptions = {} @@ -144,46 +208,55 @@ def _match(path_pattern, limit): raise BeamIOError('Match operation failed', exceptions) return result - def _open_hdfs(self, path, mode, mime_type, compression_type): + @staticmethod + def _add_compression(stream, path, mime_type, compression_type): if mime_type != 'application/octet-stream': logging.warning('Mime types are not supported. Got non-default mime_type:' ' %s', mime_type) if compression_type == CompressionTypes.AUTO: compression_type = CompressionTypes.detect_compression_type(path) - res = self._hdfs_client.open(path, mode) if compression_type != CompressionTypes.UNCOMPRESSED: - res = CompressedFile(res) - return res + return CompressedFile(stream) + + return stream def create(self, url, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): """ Returns: - *hdfs3.core.HDFile*: An Python File-like object. + A Python File-like object. """ path = self._parse_url(url) return self._create(path, mime_type, compression_type) def _create(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): - return self._open_hdfs(path, 'wb', mime_type, compression_type) + stream = io.BufferedWriter( + filesystemio.UploaderStream( + HdfsUploader(self._hdfs_client, path)), + buffer_size=_DEFAULT_BUFFER_SIZE) + return self._add_compression(stream, path, mime_type, compression_type) def open(self, url, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): """ Returns: - *hdfs3.core.HDFile*: An Python File-like object. + A Python File-like object. """ path = self._parse_url(url) return self._open(path, mime_type, compression_type) def _open(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): - return self._open_hdfs(path, 'rb', mime_type, compression_type) + stream = io.BufferedReader( + filesystemio.DownloaderStream( + HdfsDownloader(self._hdfs_client, path)), + buffer_size=_DEFAULT_BUFFER_SIZE) + return self._add_compression(stream, path, mime_type, compression_type) def copy(self, source_file_names, destination_file_names): """ - Will overwrite files and directories in destination_file_names. + It is an error if any file to copy already exists at the destination. Raises ``BeamIOError`` if any error occurred. @@ -208,7 +281,8 @@ def _copy_file(source, destination): def _copy_path(source, destination): """Recursively copy the file tree from the source to the destination.""" - if not self._hdfs_client.isdir(source): + if self._hdfs_client.status( + source)[_FILE_STATUS_TYPE] != _FILE_STATUS_TYPE_DIRECTORY: _copy_file(source, destination) return @@ -243,9 +317,11 @@ def rename(self, source_file_names, destination_file_names): try: rel_source = self._parse_url(source) rel_destination = self._parse_url(destination) - if not self._hdfs_client.mv(rel_source, rel_destination): + try: + self._hdfs_client.rename(rel_source, rel_destination) + except hdfs.HdfsError as e: raise BeamIOError( - 'libhdfs error in renaming %s to %s' % (source, destination)) + 'libhdfs error in renaming %s to %s' % (source, destination), e) except Exception as e: # pylint: disable=broad-except exceptions[(source, destination)] = e @@ -270,14 +346,14 @@ def _exists(self, path): Args: path: String in the form /... """ - return self._hdfs_client.exists(path) + return self._hdfs_client.status(path, strict=False) is not None def delete(self, urls): exceptions = {} for url in urls: try: path = self._parse_url(url) - self._hdfs_client.rm(path, recursive=True) + self._hdfs_client.delete(path, recursive=True) except Exception as e: # pylint: disable=broad-except exceptions[url] = e diff --git a/sdks/python/apache_beam/io/hadoopfilesystem_test.py b/sdks/python/apache_beam/io/hadoopfilesystem_test.py index 8a1c0f1b3d29..2ba1da26468b 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem_test.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem_test.py @@ -19,23 +19,26 @@ from __future__ import absolute_import +import io import posixpath -import StringIO import unittest -from apache_beam.io import hadoopfilesystem +from apache_beam.io import hadoopfilesystem as hdfs from apache_beam.io.filesystem import BeamIOError +from apache_beam.options.pipeline_options import HadoopFileSystemOptions from apache_beam.options.pipeline_options import PipelineOptions -class FakeFile(StringIO.StringIO): +class FakeFile(io.BytesIO): """File object for FakeHdfs""" - def __init__(self, path, mode): - StringIO.StringIO.__init__(self) + def __init__(self, path, mode='', type='FILE'): + io.BytesIO.__init__(self) + self.stat = { 'path': path, 'mode': mode, + 'type': type, } self.saved_data = None @@ -44,7 +47,7 @@ def __eq__(self, other): def close(self): self.saved_data = self.getvalue() - StringIO.StringIO.close(self) + io.BytesIO.close(self) def __enter__(self): return self @@ -60,74 +63,87 @@ def size(self): return len(self.saved_data) return len(self.getvalue()) + def get_file_status(self): + """Returns a partial WebHDFS FileStatus object.""" + return { + hdfs._FILE_STATUS_NAME: self.stat['path'], + hdfs._FILE_STATUS_PATH_SUFFIX: posixpath.basename(self.stat['path']), + hdfs._FILE_STATUS_SIZE: self.size, + hdfs._FILE_STATUS_TYPE: self.stat['type'], + } + class FakeHdfsError(Exception): """Generic error for FakeHdfs methods.""" class FakeHdfs(object): - """Fake implementation of hdfs3.HadoopFileSystem.""" + """Fake implementation of ``hdfs.Client``.""" def __init__(self): self.files = {} - def open(self, path, mode='rb'): - if mode == 'rb' and not self.exists(path): + def write(self, path): + if self.status(path, strict=False) is not None: + raise FakeHdfsError('Path already exists: %s' % path) + + new_file = FakeFile(path, 'wb') + self.files[path] = new_file + return new_file + + def read(self, path, offset=0, length=None): + old_file = self.files.get(path, None) + if old_file is None: raise FakeHdfsError('Path not found: %s' % path) + if old_file.stat['type'] == 'DIRECTORY': + raise FakeHdfsError('Cannot open a directory: %s' % path) + if not old_file.closed: + raise FakeHdfsError('File already opened: %s' % path) + + # old_file is closed and can't be operated upon. Return a copy instead. + new_file = FakeFile(path, 'rb') + if old_file.saved_data: + new_file.write(old_file.saved_data) + new_file.seek(0) + return new_file + + def list(self, path, status=False): + if not status: + raise ValueError('status must be True') + fs = self.status(path, strict=False) + if (fs is not None and + fs[hdfs._FILE_STATUS_TYPE] == hdfs._FILE_STATUS_TYPE_FILE): + raise ValueError('list must be called on a directory, got file: %s', path) - if mode in ['rb', 'wb']: - new_file = FakeFile(path, mode) - # Required to support read and write operations with CompressedFile. - new_file.mode = 'rw' - - if mode == 'rb': - old_file = self.files.get(path, None) - if old_file is not None: - if old_file.stat['mode'] == 'dir': - raise FakeHdfsError('Cannot open a directory: %s' % path) - if old_file.saved_data: - old_file = self.files[path] - new_file.write(old_file.saved_data) - new_file.seek(0) - - self.files[path] = new_file - return new_file - else: - raise FakeHdfsError('Unknown mode: %s' % mode) - - def ls(self, path, detail=False): result = [] for file in self.files.itervalues(): if file.stat['path'].startswith(path): - result.append({ - 'name': file.stat['path'], - 'size': file.size, - }) + fs = file.get_file_status() + result.append((fs[hdfs._FILE_STATUS_PATH_SUFFIX], fs)) return result def makedirs(self, path): - self.files[path] = FakeFile(path, 'dir') - - def exists(self, path): - return path in self.files - - def rm(self, path, recursive=True): + self.files[path] = FakeFile(path, type='DIRECTORY') + + def status(self, path, strict=True): + f = self.files.get(path) + if f is None: + if strict: + raise FakeHdfsError('Path not found: %s' % path) + else: + return f + return f.get_file_status() + + def delete(self, path, recursive=True): if not recursive: raise FakeHdfsError('Non-recursive mode not implemented') - if not self.exists(path): - raise FakeHdfsError('Path not found: %s' % path) + _ = self.status(path) for filepath in self.files.keys(): # pylint: disable=consider-iterating-dictionary if filepath.startswith(path): del self.files[filepath] - def isdir(self, path): - if not self.exists(path): - raise FakeHdfsError('Path not found: %s' % path) - - return self.files[path].stat['mode'] == 'dir' - def walk(self, path): paths = [path] while paths: @@ -139,7 +155,7 @@ def walk(self, path): continue short_path = posixpath.relpath(full_path, path) if '/' not in short_path: - if self.isdir(full_path): + if self.status(full_path)[hdfs._FILE_STATUS_TYPE] == 'DIRECTORY': if short_path != '.': dirs.append(short_path) else: @@ -148,8 +164,8 @@ def walk(self, path): yield path, dirs, files paths = [posixpath.join(path, dir) for dir in dirs] - def mv(self, path1, path2): - if not self.exists(path1): + def rename(self, path1, path2): + if self.status(path1, strict=False) is None: raise FakeHdfsError('Path1 not found: %s' % path1) for fullpath in self.files.keys(): # pylint: disable=consider-iterating-dictionary @@ -159,16 +175,20 @@ def mv(self, path1, path2): f.stat['path'] = newpath self.files[newpath] = f - return True - class HadoopFileSystemTest(unittest.TestCase): def setUp(self): self._fake_hdfs = FakeHdfs() - hadoopfilesystem.HDFileSystem = lambda *args, **kwargs: self._fake_hdfs + hdfs.hdfs.InsecureClient = ( + lambda *args, **kwargs: self._fake_hdfs) pipeline_options = PipelineOptions() - self.fs = hadoopfilesystem.HadoopFileSystem(pipeline_options) + hdfs_options = pipeline_options.view_as(HadoopFileSystemOptions) + hdfs_options.hdfs_host = '' + hdfs_options.hdfs_port = 0 + hdfs_options.hdfs_user = '' + + self.fs = hdfs.HadoopFileSystem(pipeline_options) self.tmpdir = 'hdfs://test_dir' for filename in ['old_file1', 'old_file2']: @@ -177,7 +197,7 @@ def setUp(self): def test_scheme(self): self.assertEqual(self.fs.scheme(), 'hdfs') - self.assertEqual(hadoopfilesystem.HadoopFileSystem.scheme(), 'hdfs') + self.assertEqual(hdfs.HadoopFileSystem.scheme(), 'hdfs') def test_url_join(self): self.assertEqual('hdfs://tmp/path/to/file', @@ -303,7 +323,9 @@ def test_create_write_read_compressed(self): def test_open(self): url = self.fs.join(self.tmpdir, 'old_file1') handle = self.fs.open(url) - self.assertEqual(handle, self._fake_hdfs.files[self.fs._parse_url(url)]) + expected_data = '' + data = handle.read() + self.assertEqual(data, expected_data) def test_open_bad_path(self): with self.assertRaises(FakeHdfsError): @@ -326,15 +348,16 @@ def test_copy_file(self): self.assertTrue(self._cmpfiles(url1, url2)) self.assertTrue(self._cmpfiles(url1, url3)) - def test_copy_file_overwrite(self): + def test_copy_file_overwrite_error(self): url1 = self.fs.join(self.tmpdir, 'new_file1') url2 = self.fs.join(self.tmpdir, 'new_file2') with self.fs.create(url1) as f1: f1.write('Hello') with self.fs.create(url2) as f2: f2.write('nope') - self.fs.copy([url1], [url2]) - self.assertTrue(self._cmpfiles(url1, url2)) + with self.assertRaisesRegexp( + BeamIOError, r'already exists.*%s' % posixpath.basename(url2)): + self.fs.copy([url1], [url2]) def test_copy_file_error(self): url1 = self.fs.join(self.tmpdir, 'new_file1') @@ -366,7 +389,7 @@ def test_copy_directory(self): self.fs.copy([url_t1], [url_t2]) self.assertTrue(self._cmpfiles(url1, url2)) - def test_copy_directory_overwrite(self): + def test_copy_directory_overwrite_error(self): url_t1 = self.fs.join(self.tmpdir, 't1') url_t1_inner = self.fs.join(self.tmpdir, 't1/inner') url_t2 = self.fs.join(self.tmpdir, 't2') @@ -379,7 +402,7 @@ def test_copy_directory_overwrite(self): url1 = self.fs.join(url_t1, 'f1') url1_inner = self.fs.join(url_t1_inner, 'f2') url2 = self.fs.join(url_t2, 'f1') - url2_inner = self.fs.join(url_t2_inner, 'f2') + unused_url2_inner = self.fs.join(url_t2_inner, 'f2') url3_inner = self.fs.join(url_t2_inner, 'f3') for url in [url1, url1_inner, url3_inner]: with self.fs.create(url) as f: @@ -387,10 +410,8 @@ def test_copy_directory_overwrite(self): with self.fs.create(url2) as f: f.write('nope') - self.fs.copy([url_t1], [url_t2]) - self.assertTrue(self._cmpfiles(url1, url2)) - self.assertTrue(self._cmpfiles(url1_inner, url2_inner)) - self.assertTrue(self.fs.exists(url3_inner)) + with self.assertRaisesRegexp(BeamIOError, r'already exists'): + self.fs.copy([url_t1], [url_t2]) def test_rename_file(self): url1 = self.fs.join(self.tmpdir, 'f1') diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index fc7a2f3f7b14..eb79f4d3c661 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -36,6 +36,8 @@ from apache_beam import coders from apache_beam import pvalue +from apache_beam.portability import common_urns +from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsIter from apache_beam.pvalue import AsSingleton @@ -75,7 +77,7 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn): """Base class for all sources that can be passed to beam.io.Read(...). """ - urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_SOURCE) + urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE) class BoundedSource(SourceBase): @@ -832,7 +834,7 @@ def display_data(self): 'source_dd': self.source} def to_runner_api_parameter(self, context): - return (urns.READ_TRANSFORM, + return (common_urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload( source=self.source.to_runner_api(context), is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED @@ -845,7 +847,7 @@ def from_runner_api_parameter(parameter, context): ptransform.PTransform.register_urn( - urns.READ_TRANSFORM, + common_urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload, Read.from_runner_api_parameter) diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py index 7106aef057dd..2da8736b1141 100644 --- a/sdks/python/apache_beam/io/range_trackers.py +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -22,6 +22,8 @@ import math import threading +from six import integer_types + from apache_beam.io import iobase __all__ = ['OffsetRangeTracker', 'LexicographicKeyRangeTracker', @@ -45,9 +47,9 @@ def __init__(self, start, end): raise ValueError('Start offset must not be \'None\'') if end is None: raise ValueError('End offset must not be \'None\'') - assert isinstance(start, (int, long)) + assert isinstance(start, integer_types) if end != self.OFFSET_INFINITY: - assert isinstance(end, (int, long)) + assert isinstance(end, integer_types) assert start <= end @@ -121,7 +123,7 @@ def set_current_position(self, record_start): self._last_record_start = record_start def try_split(self, split_offset): - assert isinstance(split_offset, (int, long)) + assert isinstance(split_offset, integer_types) with self._lock: if self._stop_offset == OffsetRangeTracker.OFFSET_INFINITY: logging.debug('refusing to split %r at %d: stop position unspecified', diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py index 3e926634c85f..6b8f7c7139c9 100644 --- a/sdks/python/apache_beam/io/range_trackers_test.py +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -22,6 +22,8 @@ import math import unittest +from six import integer_types + from apache_beam.io import range_trackers @@ -99,7 +101,8 @@ def test_get_position_for_fraction_dense(self): tracker = range_trackers.OffsetRangeTracker(3, 6) # Position must be an integer type. - self.assertTrue(isinstance(tracker.position_at_fraction(0.0), (int, long))) + self.assertTrue(isinstance(tracker.position_at_fraction(0.0), + integer_types)) # [3, 3) represents 0.0 of [3, 6) self.assertEqual(3, tracker.position_at_fraction(0.0)) # [3, 4) represents up to 1/3 of [3, 6) diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 4a4bd3a1ae68..bfe4b9f30c74 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -23,6 +23,8 @@ import logging from functools import partial +from six import integer_types + from apache_beam.coders import coders from apache_beam.io import filebasedsink from apache_beam.io import filebasedsource @@ -72,7 +74,7 @@ def position(self): @position.setter def position(self, value): - assert isinstance(value, (int, long)) + assert isinstance(value, integer_types) if value > len(self._data): raise ValueError('Cannot set position to %d since it\'s larger than ' 'size of data %d.', value, len(self._data)) diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py index 5af07164b669..0290bfaf2702 100644 --- a/sdks/python/apache_beam/io/tfrecordio.py +++ b/sdks/python/apache_beam/io/tfrecordio.py @@ -20,12 +20,14 @@ import logging import struct +from functools import partial import crcmod from apache_beam import coders from apache_beam.io import filebasedsink -from apache_beam.io import filebasedsource +from apache_beam.io.filebasedsource import FileBasedSource +from apache_beam.io.filebasedsource import ReadAllFiles from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.iobase import Read from apache_beam.io.iobase import Write @@ -140,7 +142,7 @@ def read_record(cls, file_handle): return data -class _TFRecordSource(filebasedsource.FileBasedSource): +class _TFRecordSource(FileBasedSource): """A File source for reading files of TFRecords. For detailed TFRecords format description see: @@ -178,6 +180,47 @@ def read_records(self, file_name, offset_range_tracker): yield self._coder.decode(record) +def _create_tfrcordio_source( + file_pattern=None, coder=None, compression_type=None): + # We intentionally disable validation for ReadAll pattern so that reading does + # not fail for globs (elements) that are empty. + return _TFRecordSource(file_pattern, coder, compression_type, + validate=False) + + +class ReadAllFromTFRecord(PTransform): + """A ``PTransform`` for reading a ``PCollection`` of TFRecord files.""" + + def __init__( + self, + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO, + **kwargs): + """Initialize the ``ReadAllFromTFRecord`` transform. + + Args: + coder: Coder used to decode each record. + compression_type: Used to handle compressed input files. Default value + is CompressionTypes.AUTO, in which case the file_path's extension will + be used to detect the compression. + **kwargs: optional args dictionary. These are passed through to parent + constructor. + """ + super(ReadAllFromTFRecord, self).__init__(**kwargs) + source_from_file = partial( + _create_tfrcordio_source, compression_type=compression_type, + coder=coder) + # Desired and min bundle sizes do not matter since TFRecord files are + # unsplittable. + self._read_all_files = ReadAllFiles( + splittable=False, compression_type=compression_type, + desired_bundle_size=0, min_bundle_size=0, + source_from_file=source_from_file) + + def expand(self, pvalue): + return pvalue | 'ReadAllFiles' >> self._read_all_files + + class ReadFromTFRecord(PTransform): """Transform for reading TFRecord sources.""" diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py index fcafb712530c..c540cfae0502 100644 --- a/sdks/python/apache_beam/io/tfrecordio_test.py +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -24,21 +24,21 @@ import pickle import random import re -import shutil -import tempfile import unittest import crcmod import apache_beam as beam +from apache_beam import Create from apache_beam import coders from apache_beam.io.filesystem import CompressionTypes +from apache_beam.io.tfrecordio import ReadAllFromTFRecord from apache_beam.io.tfrecordio import ReadFromTFRecord from apache_beam.io.tfrecordio import WriteToTFRecord from apache_beam.io.tfrecordio import _TFRecordSink -from apache_beam.io.tfrecordio import _TFRecordSource from apache_beam.io.tfrecordio import _TFRecordUtil from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_utils import TempDir from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -63,6 +63,18 @@ FOO_BAR_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=' +def _write_file(path, base64_records): + record = binascii.a2b_base64(base64_records) + with open(path, 'wb') as f: + f.write(record) + + +def _write_file_gzip(path, base64_records): + record = binascii.a2b_base64(base64_records) + with gzip.GzipFile(path, 'wb') as f: + f.write(record) + + class TestTFRecordUtil(unittest.TestCase): def setUp(self): @@ -138,29 +150,7 @@ def test_compatibility_read_write(self): self.assertEqual(record, actual) -class _TestCaseWithTempDirCleanUp(unittest.TestCase): - """Base class for TestCases that deals with TempDir clean-up. - - Inherited test cases will call self._new_tempdir() to start a temporary dir - which will be deleted at the end of the tests (when tearDown() is called). - """ - - def setUp(self): - self._tempdirs = [] - - def tearDown(self): - for path in self._tempdirs: - if os.path.exists(path): - shutil.rmtree(path) - self._tempdirs = [] - - def _new_tempdir(self): - result = tempfile.mkdtemp() - self._tempdirs.append(result) - return result - - -class TestTFRecordSink(_TestCaseWithTempDirCleanUp): +class TestTFRecordSink(unittest.TestCase): def _write_lines(self, sink, path, lines): f = sink.open(path) @@ -169,240 +159,322 @@ def _write_lines(self, sink, path, lines): sink.close(f) def test_write_record_single(self): - path = os.path.join(self._new_tempdir(), 'result') - record = binascii.a2b_base64(FOO_RECORD_BASE64) - sink = _TFRecordSink( - path, - coder=coders.BytesCoder(), - file_name_suffix='', - num_shards=0, - shard_name_template=None, - compression_type=CompressionTypes.UNCOMPRESSED) - self._write_lines(sink, path, ['foo']) - - with open(path, 'r') as f: - self.assertEqual(f.read(), record) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + record = binascii.a2b_base64(FOO_RECORD_BASE64) + sink = _TFRecordSink( + path, + coder=coders.BytesCoder(), + file_name_suffix='', + num_shards=0, + shard_name_template=None, + compression_type=CompressionTypes.UNCOMPRESSED) + self._write_lines(sink, path, ['foo']) + + with open(path, 'r') as f: + self.assertEqual(f.read(), record) def test_write_record_multiple(self): - path = os.path.join(self._new_tempdir(), 'result') - record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) - sink = _TFRecordSink( - path, - coder=coders.BytesCoder(), - file_name_suffix='', - num_shards=0, - shard_name_template=None, - compression_type=CompressionTypes.UNCOMPRESSED) - self._write_lines(sink, path, ['foo', 'bar']) - - with open(path, 'r') as f: - self.assertEqual(f.read(), record) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) + sink = _TFRecordSink( + path, + coder=coders.BytesCoder(), + file_name_suffix='', + num_shards=0, + shard_name_template=None, + compression_type=CompressionTypes.UNCOMPRESSED) + self._write_lines(sink, path, ['foo', 'bar']) + + with open(path, 'r') as f: + self.assertEqual(f.read(), record) @unittest.skipIf(tf is None, 'tensorflow not installed.') class TestWriteToTFRecord(TestTFRecordSink): def test_write_record_gzip(self): - file_path_prefix = os.path.join(self._new_tempdir(), 'result') - with TestPipeline() as p: - input_data = ['foo', 'bar'] - _ = p | beam.Create(input_data) | WriteToTFRecord( - file_path_prefix, compression_type=CompressionTypes.GZIP) - - actual = [] - file_name = glob.glob(file_path_prefix + '-*')[0] - for r in tf.python_io.tf_record_iterator( - file_name, options=tf.python_io.TFRecordOptions( - tf.python_io.TFRecordCompressionType.GZIP)): - actual.append(r) - self.assertEqual(actual, input_data) + with TempDir() as temp_dir: + file_path_prefix = temp_dir.create_temp_file('result') + with TestPipeline() as p: + input_data = ['foo', 'bar'] + _ = p | beam.Create(input_data) | WriteToTFRecord( + file_path_prefix, compression_type=CompressionTypes.GZIP) + + actual = [] + file_name = glob.glob(file_path_prefix + '-*')[0] + for r in tf.python_io.tf_record_iterator( + file_name, options=tf.python_io.TFRecordOptions( + tf.python_io.TFRecordCompressionType.GZIP)): + actual.append(r) + self.assertEqual(actual, input_data) def test_write_record_auto(self): - file_path_prefix = os.path.join(self._new_tempdir(), 'result') - with TestPipeline() as p: - input_data = ['foo', 'bar'] - _ = p | beam.Create(input_data) | WriteToTFRecord( - file_path_prefix, file_name_suffix='.gz') + with TempDir() as temp_dir: + file_path_prefix = temp_dir.create_temp_file('result') + with TestPipeline() as p: + input_data = ['foo', 'bar'] + _ = p | beam.Create(input_data) | WriteToTFRecord( + file_path_prefix, file_name_suffix='.gz') - actual = [] - file_name = glob.glob(file_path_prefix + '-*.gz')[0] - for r in tf.python_io.tf_record_iterator( - file_name, options=tf.python_io.TFRecordOptions( - tf.python_io.TFRecordCompressionType.GZIP)): - actual.append(r) - self.assertEqual(actual, input_data) + actual = [] + file_name = glob.glob(file_path_prefix + '-*.gz')[0] + for r in tf.python_io.tf_record_iterator( + file_name, options=tf.python_io.TFRecordOptions( + tf.python_io.TFRecordCompressionType.GZIP)): + actual.append(r) + self.assertEqual(actual, input_data) -class TestTFRecordSource(_TestCaseWithTempDirCleanUp): - - def _write_file(self, path, base64_records): - record = binascii.a2b_base64(base64_records) - with open(path, 'wb') as f: - f.write(record) - - def _write_file_gzip(self, path, base64_records): - record = binascii.a2b_base64(base64_records) - with gzip.GzipFile(path, 'wb') as f: - f.write(record) +class TestReadFromTFRecord(unittest.TestCase): def test_process_single(self): - path = os.path.join(self._new_tempdir(), 'result') - self._write_file(path, FOO_RECORD_BASE64) - with TestPipeline() as p: - result = (p - | beam.io.Read( - _TFRecordSource( - path, - coder=coders.BytesCoder(), - compression_type=CompressionTypes.AUTO, - validate=True))) - assert_that(result, equal_to(['foo'])) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file(path, FOO_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | ReadFromTFRecord( + path, + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO, + validate=True)) + assert_that(result, equal_to(['foo'])) def test_process_multiple(self): - path = os.path.join(self._new_tempdir(), 'result') - self._write_file(path, FOO_BAR_RECORD_BASE64) - with TestPipeline() as p: - result = (p - | beam.io.Read( - _TFRecordSource( - path, - coder=coders.BytesCoder(), - compression_type=CompressionTypes.AUTO, - validate=True))) - assert_that(result, equal_to(['foo', 'bar'])) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | ReadFromTFRecord( + path, + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO, + validate=True)) + assert_that(result, equal_to(['foo', 'bar'])) def test_process_gzip(self): - path = os.path.join(self._new_tempdir(), 'result') - self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) - with TestPipeline() as p: - result = (p - | beam.io.Read( - _TFRecordSource( - path, - coder=coders.BytesCoder(), - compression_type=CompressionTypes.GZIP, - validate=True))) - assert_that(result, equal_to(['foo', 'bar'])) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | ReadFromTFRecord( + path, + coder=coders.BytesCoder(), + compression_type=CompressionTypes.GZIP, + validate=True)) + assert_that(result, equal_to(['foo', 'bar'])) def test_process_auto(self): - path = os.path.join(self._new_tempdir(), 'result.gz') - self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) - with TestPipeline() as p: - result = (p - | beam.io.Read( - _TFRecordSource( - path, - coder=coders.BytesCoder(), - compression_type=CompressionTypes.AUTO, - validate=True))) - assert_that(result, equal_to(['foo', 'bar'])) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result.gz') + _write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | ReadFromTFRecord( + path, + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO, + validate=True)) + assert_that(result, equal_to(['foo', 'bar'])) + + def test_process_gzip(self): + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | ReadFromTFRecord( + path, compression_type=CompressionTypes.GZIP)) + assert_that(result, equal_to(['foo', 'bar'])) + + def test_process_gzip_auto(self): + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result.gz') + _write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | ReadFromTFRecord( + path, compression_type=CompressionTypes.AUTO)) + assert_that(result, equal_to(['foo', 'bar'])) + +class TestReadAllFromTFRecord(unittest.TestCase): -class TestReadFromTFRecordSource(TestTFRecordSource): + def _write_glob(self, temp_dir, suffix): + for _ in range(3): + path = temp_dir.create_temp_file(suffix) + _write_file(path, FOO_BAR_RECORD_BASE64) + + def test_process_single(self): + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file(path, FOO_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | Create([path]) + | ReadAllFromTFRecord( + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO)) + assert_that(result, equal_to(['foo'])) + + def test_process_multiple(self): + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | Create([path]) + | ReadAllFromTFRecord( + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO)) + assert_that(result, equal_to(['foo', 'bar'])) + + def test_process_glob(self): + with TempDir() as temp_dir: + self._write_glob(temp_dir, 'result') + glob = temp_dir.get_path() + os.path.sep + '*result' + with TestPipeline() as p: + result = (p + | Create([glob]) + | ReadAllFromTFRecord( + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO)) + assert_that(result, equal_to(['foo', 'bar'] * 3)) + + def test_process_multiple_globs(self): + with TempDir() as temp_dir: + globs = [] + for i in range(3): + suffix = 'result' + str(i) + self._write_glob(temp_dir, suffix) + globs.append(temp_dir.get_path() + os.path.sep + '*' + suffix) + + with TestPipeline() as p: + result = (p + | Create(globs) + | ReadAllFromTFRecord( + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO)) + assert_that(result, equal_to(['foo', 'bar'] * 9)) def test_process_gzip(self): - path = os.path.join(self._new_tempdir(), 'result') - self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) - with TestPipeline() as p: - result = (p - | ReadFromTFRecord( - path, compression_type=CompressionTypes.GZIP)) - assert_that(result, equal_to(['foo', 'bar'])) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + _write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | Create([path]) + | ReadAllFromTFRecord( + coder=coders.BytesCoder(), + compression_type=CompressionTypes.GZIP)) + assert_that(result, equal_to(['foo', 'bar'])) - def test_process_gzip_auto(self): - path = os.path.join(self._new_tempdir(), 'result.gz') - self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) - with TestPipeline() as p: - result = (p - | ReadFromTFRecord( - path, compression_type=CompressionTypes.AUTO)) - assert_that(result, equal_to(['foo', 'bar'])) + def test_process_auto(self): + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result.gz') + _write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with TestPipeline() as p: + result = (p + | Create([path]) + | ReadAllFromTFRecord( + coder=coders.BytesCoder(), + compression_type=CompressionTypes.AUTO)) + assert_that(result, equal_to(['foo', 'bar'])) -class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp): +class TestEnd2EndWriteAndRead(unittest.TestCase): def create_inputs(self): - input_array = [[random.random() - 0.5 for _ in xrange(15)] - for _ in xrange(12)] + input_array = [[random.random() - 0.5 for _ in range(15)] + for _ in range(12)] memfile = cStringIO.StringIO() pickle.dump(input_array, memfile) return memfile.getvalue() def test_end2end(self): - file_path_prefix = os.path.join(self._new_tempdir(), 'result') + with TempDir() as temp_dir: + file_path_prefix = temp_dir.create_temp_file('result') - # Generate a TFRecord file. - with TestPipeline() as p: - expected_data = [self.create_inputs() for _ in range(0, 10)] - _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix) + # Generate a TFRecord file. + with TestPipeline() as p: + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix) - # Read the file back and compare. - with TestPipeline() as p: - actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') - assert_that(actual_data, equal_to(expected_data)) + # Read the file back and compare. + with TestPipeline() as p: + actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') + assert_that(actual_data, equal_to(expected_data)) def test_end2end_auto_compression(self): - file_path_prefix = os.path.join(self._new_tempdir(), 'result') + with TempDir() as temp_dir: + file_path_prefix = temp_dir.create_temp_file('result') - # Generate a TFRecord file. - with TestPipeline() as p: - expected_data = [self.create_inputs() for _ in range(0, 10)] - _ = p | beam.Create(expected_data) | WriteToTFRecord( - file_path_prefix, file_name_suffix='.gz') + # Generate a TFRecord file. + with TestPipeline() as p: + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord( + file_path_prefix, file_name_suffix='.gz') - # Read the file back and compare. - with TestPipeline() as p: - actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') - assert_that(actual_data, equal_to(expected_data)) + # Read the file back and compare. + with TestPipeline() as p: + actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') + assert_that(actual_data, equal_to(expected_data)) def test_end2end_auto_compression_unsharded(self): - file_path_prefix = os.path.join(self._new_tempdir(), 'result') + with TempDir() as temp_dir: + file_path_prefix = temp_dir.create_temp_file('result') - # Generate a TFRecord file. - with TestPipeline() as p: - expected_data = [self.create_inputs() for _ in range(0, 10)] - _ = p | beam.Create(expected_data) | WriteToTFRecord( - file_path_prefix + '.gz', shard_name_template='') + # Generate a TFRecord file. + with TestPipeline() as p: + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord( + file_path_prefix + '.gz', shard_name_template='') - # Read the file back and compare. - with TestPipeline() as p: - actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz') - assert_that(actual_data, equal_to(expected_data)) + # Read the file back and compare. + with TestPipeline() as p: + actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz') + assert_that(actual_data, equal_to(expected_data)) @unittest.skipIf(tf is None, 'tensorflow not installed.') def test_end2end_example_proto(self): - file_path_prefix = os.path.join(self._new_tempdir(), 'result') + with TempDir() as temp_dir: + file_path_prefix = temp_dir.create_temp_file('result') - example = tf.train.Example() - example.features.feature['int'].int64_list.value.extend(range(3)) - example.features.feature['bytes'].bytes_list.value.extend( - [b'foo', b'bar']) + example = tf.train.Example() + example.features.feature['int'].int64_list.value.extend(range(3)) + example.features.feature['bytes'].bytes_list.value.extend( + [b'foo', b'bar']) - with TestPipeline() as p: - _ = p | beam.Create([example]) | WriteToTFRecord( - file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__)) + with TestPipeline() as p: + _ = p | beam.Create([example]) | WriteToTFRecord( + file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__)) - # Read the file back and compare. - with TestPipeline() as p: - actual_data = (p | ReadFromTFRecord( - file_path_prefix + '-*', - coder=beam.coders.ProtoCoder(example.__class__))) - assert_that(actual_data, equal_to([example])) + # Read the file back and compare. + with TestPipeline() as p: + actual_data = (p | ReadFromTFRecord( + file_path_prefix + '-*', + coder=beam.coders.ProtoCoder(example.__class__))) + assert_that(actual_data, equal_to([example])) def test_end2end_read_write_read(self): - path = os.path.join(self._new_tempdir(), 'result') - with TestPipeline() as p: - # Initial read to validate the pipeline doesn't fail before the file is - # created. - _ = p | ReadFromTFRecord(path + '-*', validate=False) - expected_data = [self.create_inputs() for _ in range(0, 10)] - _ = p | beam.Create(expected_data) | WriteToTFRecord( - path, file_name_suffix='.gz') - - # Read the file back and compare. - with TestPipeline() as p: - actual_data = p | ReadFromTFRecord(path+'-*', validate=True) - assert_that(actual_data, equal_to(expected_data)) + with TempDir() as temp_dir: + path = temp_dir.create_temp_file('result') + with TestPipeline() as p: + # Initial read to validate the pipeline doesn't fail before the file is + # created. + _ = p | ReadFromTFRecord(path + '-*', validate=False) + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord( + path, file_name_suffix='.gz') + + # Read the file back and compare. + with TestPipeline() as p: + actual_data = p | ReadFromTFRecord(path+'-*', validate=True) + assert_that(actual_data, equal_to(expected_data)) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/vcfio.py b/sdks/python/apache_beam/io/vcfio.py index a45861ef82ad..a0206d450762 100644 --- a/sdks/python/apache_beam/io/vcfio.py +++ b/sdks/python/apache_beam/io/vcfio.py @@ -35,6 +35,12 @@ from apache_beam.io.textio import _TextSource as TextSource from apache_beam.transforms import PTransform +try: + long # Python 2 +except NameError: + long = int # Python 3 + + __all__ = ['ReadFromVcf', 'Variant', 'VariantCall', 'VariantInfo', 'MalformedVcfRecord'] diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 2b5a4e4094d5..302d79ab5127 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -23,14 +23,19 @@ Cells depend on a 'dirty-bit' in the CellCommitState class that tracks whether a cell's updates have been committed. """ +from __future__ import division import threading +import time + +from google.protobuf import timestamp_pb2 from apache_beam.metrics.metricbase import Counter from apache_beam.metrics.metricbase import Distribution +from apache_beam.metrics.metricbase import Gauge from apache_beam.portability.api import beam_fn_api_pb2 -__all__ = ['DistributionResult'] +__all__ = ['DistributionResult', 'GaugeResult'] class CellCommitState(object): @@ -137,7 +142,7 @@ class CounterCell(Counter, MetricCell): """ def __init__(self, *args): super(CounterCell, self).__init__(*args) - self.value = 0 + self.value = CounterAggregator.identity_element() def combine(self, other): result = CounterCell() @@ -167,7 +172,7 @@ class DistributionCell(Distribution, MetricCell): """ def __init__(self, *args): super(DistributionCell, self).__init__(*args) - self.data = DistributionData(0, 0, None, None) + self.data = DistributionAggregator.identity_element() def combine(self, other): result = DistributionCell() @@ -195,14 +200,53 @@ def get_cumulative(self): return self.data.get_cumulative() -class DistributionResult(object): - """The result of a Distribution metric. +class GaugeCell(Gauge, MetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value and delta for a gauge metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. """ + def __init__(self, *args): + super(GaugeCell, self).__init__(*args) + self.data = GaugeAggregator.identity_element() + + def combine(self, other): + result = GaugeCell() + result.data = self.data.combine(other.data) + return result + + def set(self, value): + value = int(value) + with self._lock: + self.commit.after_modification() + # Set the value directly without checking timestamp, because + # this value is naturally the latest value. + self.data.value = value + self.data.timestamp = time.time() + + def get_cumulative(self): + with self._lock: + return self.data.get_cumulative() + + +class DistributionResult(object): + """The result of a Distribution metric.""" def __init__(self, data): self.data = data def __eq__(self, other): - return self.data == other.data + if isinstance(other, DistributionResult): + return self.data == other.data + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) def __repr__(self): return ''.format( @@ -235,7 +279,89 @@ def mean(self): """ if self.data.count == 0: return None - return float(self.data.sum)/self.data.count + return self.data.sum / self.data.count + + +class GaugeResult(object): + def __init__(self, data): + self.data = data + + def __eq__(self, other): + if isinstance(other, GaugeResult): + return self.data == other.data + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return ''.format( + self.value, + self.timestamp) + + @property + def value(self): + return self.data.value + + @property + def timestamp(self): + return self.data.timestamp + + +class GaugeData(object): + """For internal use only; no backwards-compatibility guarantees. + + The data structure that holds data about a gauge metric. + + Gauge metrics are restricted to integers only. + + This object is not thread safe, so it's not supposed to be modified + by other than the GaugeCell that contains it. + """ + def __init__(self, value, timestamp=None): + self.value = value + self.timestamp = timestamp if timestamp is not None else 0 + + def __eq__(self, other): + return self.value == other.value and self.timestamp == other.timestamp + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return ''.format( + self.value, + self.timestamp) + + def get_cumulative(self): + return GaugeData(self.value, timestamp=self.timestamp) + + def combine(self, other): + if other is None: + return self + + if other.timestamp > self.timestamp: + return other + else: + return self + + @staticmethod + def singleton(value, timestamp=None): + return GaugeData(value, timestamp=timestamp) + + def to_runner_api(self): + seconds = int(self.timestamp) + nanos = int((self.timestamp - seconds) * 10**9) + gauge_timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + return beam_fn_api_pb2.Metrics.User.GaugeData( + value=self.value, timestamp=gauge_timestamp) + + @staticmethod + def from_runner_api(proto): + gauge_timestamp = (proto.timestamp.seconds + + float(proto.timestamp.nanos) / 10**9) + return GaugeData(proto.value, timestamp=gauge_timestamp) class DistributionData(object): @@ -260,7 +386,7 @@ def __eq__(self, other): self.min == other.min and self.max == other.max) - def __neq__(self, other): + def __ne__(self, other): return not self.__eq__(other) def __repr__(self): @@ -304,7 +430,13 @@ class MetricAggregator(object): """For internal use only; no backwards-compatibility guarantees. Base interface for aggregating metric data during pipeline execution.""" - def zero(self): + + def identity_element(self): + """Returns the identical element of an Aggregation. + + For the identity element, it must hold that + Aggregator.combine(any_element, identity_element) == any_element. + """ raise NotImplementedError def combine(self, updates): @@ -321,7 +453,8 @@ class CounterAggregator(MetricAggregator): Values aggregated should be ``int`` objects. """ - def zero(self): + @staticmethod + def identity_element(): return 0 def combine(self, x, y): @@ -338,7 +471,8 @@ class DistributionAggregator(MetricAggregator): Values aggregated should be ``DistributionData`` objects. """ - def zero(self): + @staticmethod + def identity_element(): return DistributionData(0, 0, None, None) def combine(self, x, y): @@ -346,3 +480,22 @@ def combine(self, x, y): def result(self, x): return DistributionResult(x.get_cumulative()) + + +class GaugeAggregator(MetricAggregator): + """For internal use only; no backwards-compatibility guarantees. + + Aggregator for Gauge metric data during pipeline execution. + + Values aggregated should be ``GaugeData`` objects. + """ + @staticmethod + def identity_element(): + return GaugeData(None, timestamp=0) + + def combine(self, x, y): + result = x.combine(y) + return result + + def result(self, x): + return GaugeResult(x.get_cumulative()) diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index c0664ab2bb0c..14e7e537ea1c 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -22,6 +22,8 @@ from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import DistributionData +from apache_beam.metrics.cells import GaugeCell +from apache_beam.metrics.cells import GaugeData class TestCounterCell(unittest.TestCase): @@ -46,7 +48,9 @@ def test_parallel_access(self): for t in threads: t.join() - total = (self.NUM_ITERATIONS * (self.NUM_ITERATIONS-1)/2 * self.NUM_THREADS) + total = (self.NUM_ITERATIONS + * (self.NUM_ITERATIONS - 1) // 2 + * self.NUM_THREADS) self.assertEqual(c.get_cumulative(), total) def test_basic_operations(self): @@ -86,7 +90,9 @@ def test_parallel_access(self): for t in threads: t.join() - total = (self.NUM_ITERATIONS * (self.NUM_ITERATIONS-1)/2 * self.NUM_THREADS) + total = (self.NUM_ITERATIONS + * (self.NUM_ITERATIONS - 1) // 2 + * self.NUM_THREADS) count = (self.NUM_ITERATIONS * self.NUM_THREADS) @@ -117,6 +123,33 @@ def test_integer_only(self): DistributionData(9, 3, 3, 3)) +class TestGaugeCell(unittest.TestCase): + def test_basic_operations(self): + g = GaugeCell() + g.set(10) + self.assertEqual(g.get_cumulative().value, GaugeData(10).value) + + g.set(2) + self.assertEqual(g.get_cumulative().value, 2) + + def test_integer_only(self): + g = GaugeCell() + g.set(3.3) + self.assertEqual(g.get_cumulative().value, 3) + + def test_combine_appropriately(self): + g1 = GaugeCell() + g1.set(3) + + g2 = GaugeCell() + g2.set(1) + + # THe second Gauge, with value 1 was the most recent, so it should be + # the final result. + result = g2.combine(g1) + self.assertEqual(result.data.value, 1) + + class TestCellCommitState(unittest.TestCase): def test_basic_path(self): ds = CellCommitState() diff --git a/sdks/python/apache_beam/metrics/execution.pxd b/sdks/python/apache_beam/metrics/execution.pxd index d89004f783ce..af0c30c137ef 100644 --- a/sdks/python/apache_beam/metrics/execution.pxd +++ b/sdks/python/apache_beam/metrics/execution.pxd @@ -22,6 +22,7 @@ cdef class MetricsContainer(object): cdef object step_name cdef public object counters cdef public object distributions + cdef public object gauges cdef class ScopedMetricsContainer(object): diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index 959424160bc2..f6c790de5d4b 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -34,7 +34,7 @@ from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell -from apache_beam.metrics.metricbase import MetricName +from apache_beam.metrics.cells import GaugeCell from apache_beam.portability.api import beam_fn_api_pb2 @@ -65,14 +65,6 @@ def __repr__(self): def __hash__(self): return hash((self.step, self.metric)) - def to_runner_api(self): - return beam_fn_api_pb2.Metrics.User.MetricKey( - step=self.step, namespace=self.metric.namespace, name=self.metric.name) - - @staticmethod - def from_runner_api(proto): - return MetricKey(proto.step, MetricName(proto.namespace, proto.name)) - class MetricResult(object): """Keeps track of the status of a metric within a single bundle. @@ -160,6 +152,7 @@ def __init__(self, step_name): self.step_name = step_name self.counters = defaultdict(lambda: CounterCell()) self.distributions = defaultdict(lambda: DistributionCell()) + self.gauges = defaultdict(lambda: GaugeCell()) def get_counter(self, metric_name): return self.counters[metric_name] @@ -167,6 +160,9 @@ def get_counter(self, metric_name): def get_distribution(self, metric_name): return self.distributions[metric_name] + def get_gauge(self, metric_name): + return self.gauges[metric_name] + def _get_updates(self, filter=None): """Return cumulative values of metrics filtered according to a lambda. @@ -184,7 +180,11 @@ def _get_updates(self, filter=None): for k, v in self.distributions.items() if filter(v)} - return MetricUpdates(counters, distributions) + gauges = {MetricKey(self.step_name, k): v.get_cumulative() + for k, v in self.gauges.items() + if filter(v)} + + return MetricUpdates(counters, distributions, gauges) def get_updates(self): """Return cumulative values of metrics that changed since the last commit. @@ -205,16 +205,19 @@ def get_cumulative(self): def to_runner_api(self): return ( [beam_fn_api_pb2.Metrics.User( - key=beam_fn_api_pb2.Metrics.User.MetricKey( - step=self.step_name, namespace=k.namespace, name=k.name), + metric_name=k.to_runner_api(), counter_data=beam_fn_api_pb2.Metrics.User.CounterData( value=v.get_cumulative())) for k, v in self.counters.items()] + [beam_fn_api_pb2.Metrics.User( - key=beam_fn_api_pb2.Metrics.User.MetricKey( - step=self.step_name, namespace=k.namespace, name=k.name), + metric_name=k.to_runner_api(), distribution_data=v.get_cumulative().to_runner_api()) - for k, v in self.distributions.items()]) + for k, v in self.distributions.items()] + + [beam_fn_api_pb2.Metrics.User( + metric_name=k.to_runner_api(), + gauge_data=v.get_cumulative().to_runner_api()) + for k, v in self.gauges.items()] + ) class ScopedMetricsContainer(object): @@ -243,12 +246,14 @@ class MetricUpdates(object): For Distribution metrics, it is DistributionData, and for Counter metrics, it's an int. """ - def __init__(self, counters=None, distributions=None): + def __init__(self, counters=None, distributions=None, gauges=None): """Create a MetricUpdates object. Args: counters: Dictionary of MetricKey:MetricUpdate updates. distributions: Dictionary of MetricKey:MetricUpdate objects. + gauges: Dictionary of MetricKey:MetricUpdate objects. """ self.counters = counters or {} self.distributions = distributions or {} + self.gauges = gauges or {} diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py index 855f54c84026..2367e35df4dc 100644 --- a/sdks/python/apache_beam/metrics/execution_test.py +++ b/sdks/python/apache_beam/metrics/execution_test.py @@ -46,14 +46,14 @@ def test_scoped_container(self): counter = Metrics.counter('ns', 'name') counter.inc(3) self.assertEqual( - c2.get_cumulative().counters.items(), + list(c2.get_cumulative().counters.items()), [(MetricKey('myinternalstep', MetricName('ns', 'name')), 3)]) self.assertEqual(c1, MetricsEnvironment.current_container()) counter = Metrics.counter('ns', 'name') counter.inc(4) self.assertEqual( - c1.get_cumulative().counters.items(), + list(c1.get_cumulative().counters.items()), [(MetricKey('mystep', MetricName('ns', 'name')), 6)]) def test_add_to_counter(self): @@ -72,8 +72,11 @@ def test_get_cumulative_or_updates(self): counter = mc.get_counter(MetricName('namespace', 'name{}'.format(i))) distribution = mc.get_distribution( MetricName('namespace', 'name{}'.format(i))) + gauge = mc.get_gauge(MetricName('namespace', 'name{}'.format(i))) + counter.inc(i) distribution.update(i) + gauge.set(i) if i % 2 == 0: # Some are left to be DIRTY (i.e. not yet committed). # Some are left to be CLEAN (i.e. already committed). @@ -82,25 +85,37 @@ def test_get_cumulative_or_updates(self): # Assert: Counter/Distribution is DIRTY or COMMITTING (not CLEAN) self.assertEqual(distribution.commit.before_commit(), True) self.assertEqual(counter.commit.before_commit(), True) + self.assertEqual(gauge.commit.before_commit(), True) distribution.commit.after_commit() counter.commit.after_commit() + gauge.commit.after_commit() # Assert: Counter/Distribution has been committed, therefore it's CLEAN self.assertEqual(counter.commit.state, CellCommitState.CLEAN) self.assertEqual(distribution.commit.state, CellCommitState.CLEAN) + self.assertEqual(gauge.commit.state, CellCommitState.CLEAN) clean_values.append(i) # Retrieve NON-COMMITTED updates. logical = mc.get_updates() self.assertEqual(len(logical.counters), 5) self.assertEqual(len(logical.distributions), 5) + self.assertEqual(len(logical.gauges), 5) + + self.assertEqual(set(dirty_values), + set([v.value for _, v in logical.gauges.items()])) self.assertEqual(set(dirty_values), set([v for _, v in logical.counters.items()])) + # Retrieve ALL updates. cumulative = mc.get_cumulative() self.assertEqual(len(cumulative.counters), 10) self.assertEqual(len(cumulative.distributions), 10) + self.assertEqual(len(cumulative.gauges), 10) + self.assertEqual(set(dirty_values + clean_values), set([v for _, v in cumulative.counters.items()])) + self.assertEqual(set(dirty_values + clean_values), + set([v.value for _, v in cumulative.gauges.items()])) class TestMetricsEnvironment(unittest.TestCase): @@ -115,11 +130,11 @@ def test_uses_right_container(self): MetricsEnvironment.unset_current_container() self.assertEqual( - c1.get_cumulative().counters.items(), + list(c1.get_cumulative().counters.items()), [(MetricKey('step1', MetricName('ns', 'name')), 1)]) self.assertEqual( - c2.get_cumulative().counters.items(), + list(c2.get_cumulative().counters.items()), [(MetricKey('step2', MetricName('ns', 'name')), 3)]) def test_no_container(self): diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index 8fbf9804ac86..99c435c609ac 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -29,6 +29,7 @@ from apache_beam.metrics.execution import MetricsEnvironment from apache_beam.metrics.metricbase import Counter from apache_beam.metrics.metricbase import Distribution +from apache_beam.metrics.metricbase import Gauge from apache_beam.metrics.metricbase import MetricName __all__ = ['Metrics', 'MetricsFilter'] @@ -75,6 +76,22 @@ def distribution(namespace, name): namespace = Metrics.get_namespace(namespace) return Metrics.DelegatingDistribution(MetricName(namespace, name)) + @staticmethod + def gauge(namespace, name): + """Obtains or creates a Gauge metric. + + Gauge metrics are restricted to integer-only values. + + Args: + namespace: A class or string that gives the namespace to a metric + name: A string that gives a unique name to a metric + + Returns: + A Distribution object. + """ + namespace = Metrics.get_namespace(namespace) + return Metrics.DelegatingGauge(MetricName(namespace, name)) + class DelegatingCounter(Counter): def __init__(self, metric_name): self.metric_name = metric_name @@ -93,9 +110,17 @@ def update(self, value): if container is not None: container.get_distribution(self.metric_name).update(value) + class DelegatingGauge(Gauge): + def __init__(self, metric_name): + self.metric_name = metric_name -class MetricResults(object): + def set(self, value): + container = MetricsEnvironment.current_container() + if container is not None: + container.get_gauge(self.metric_name).set(value) + +class MetricResults(object): @staticmethod def _matches_name(filter, metric_key): if not filter.names and not filter.namespaces: diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py index ef98b2d655e3..eaad1574c73a 100644 --- a/sdks/python/apache_beam/metrics/metric_test.py +++ b/sdks/python/apache_beam/metrics/metric_test.py @@ -118,18 +118,23 @@ def test_create_counter_distribution(self): MetricsEnvironment.set_current_container(MetricsContainer('mystep')) counter_ns = 'aCounterNamespace' distro_ns = 'aDistributionNamespace' + gauge_ns = 'aGaugeNamespace' name = 'a_name' counter = Metrics.counter(counter_ns, name) distro = Metrics.distribution(distro_ns, name) + gauge = Metrics.gauge(gauge_ns, name) counter.inc(10) counter.dec(3) distro.update(10) distro.update(2) + gauge.set(10) self.assertTrue(isinstance(counter, Metrics.DelegatingCounter)) self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution)) + self.assertTrue(isinstance(gauge, Metrics.DelegatingGauge)) del distro del counter + del gauge container = MetricsEnvironment.current_container() self.assertEqual( @@ -138,6 +143,9 @@ def test_create_counter_distribution(self): self.assertEqual( container.distributions[MetricName(distro_ns, name)].get_cumulative(), DistributionData(12, 2, 2, 10)) + self.assertEqual( + container.gauges[MetricName(gauge_ns, name)].get_cumulative().value, + 10) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/metricbase.py b/sdks/python/apache_beam/metrics/metricbase.py index 9b1918907f6b..2ef9fe6d5350 100644 --- a/sdks/python/apache_beam/metrics/metricbase.py +++ b/sdks/python/apache_beam/metrics/metricbase.py @@ -27,10 +27,14 @@ decremented during pipeline execution. - Distribution - Distribution Metric interface. Allows statistics about the distribution of a variable to be collected during pipeline execution. +- Gauge - Gauge Metric interface. Allows to track the latest value of a + variable during pipeline execution. - MetricName - Namespace and name used to refer to a Metric. """ -__all__ = ['Metric', 'Counter', 'Distribution', 'MetricName'] +from apache_beam.portability.api import beam_fn_api_pb2 + +__all__ = ['Metric', 'Counter', 'Distribution', 'Gauge', 'MetricName'] class MetricName(object): @@ -65,6 +69,14 @@ def __str__(self): def __hash__(self): return hash((self.namespace, self.name)) + def to_runner_api(self): + return beam_fn_api_pb2.Metrics.User.MetricName( + namespace=self.namespace, name=self.name) + + @staticmethod + def from_runner_api(proto): + return MetricName(proto.namespace, proto.name) + class Metric(object): """Base interface of a metric object.""" @@ -82,7 +94,20 @@ def dec(self, n=1): class Distribution(Metric): - """Distribution Metric interface. Allows statistics about the - distribution of a variable to be collected during pipeline execution.""" + """Distribution Metric interface. + + Allows statistics about the distribution of a variable to be collected during + pipeline execution.""" + def update(self, value): raise NotImplementedError + + +class Gauge(Metric): + """Gauge Metric interface. + + Allows tracking of the latest value of a variable during pipeline + execution.""" + + def set(self, value): + raise NotImplementedError diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index aaac9a4fd996..7a2cd4bf1e40 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -30,6 +30,7 @@ 'TypeOptions', 'DirectOptions', 'GoogleCloudOptions', + 'HadoopFileSystemOptions', 'WorkerOptions', 'DebugOptions', 'ProfilingOptions', @@ -392,6 +393,33 @@ def validate(self, validator): return errors +class HadoopFileSystemOptions(PipelineOptions): + """``HadoopFileSystem`` connection options.""" + + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument( + '--hdfs_host', + default=None, + help= + ('Hostname or address of the HDFS namenode.')) + parser.add_argument( + '--hdfs_port', + default=None, + help= + ('Port of the HDFS namenode.')) + parser.add_argument( + '--hdfs_user', + default=None, + help= + ('HDFS username to use.')) + + def validate(self, validator): + errors = [] + errors.extend(validator.validate_optional_argument_positive(self, 'port')) + return errors + + # Command line options controlling the worker pool configuration. # TODO(silviuc): Update description when autoscaling options are in. class WorkerOptions(PipelineOptions): diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py index 291440d662fc..09b4a4479354 100644 --- a/sdks/python/apache_beam/options/value_provider.py +++ b/sdks/python/apache_beam/options/value_provider.py @@ -78,17 +78,22 @@ def __init__(self, option_name, value_type, default_value): def is_accessible(self): return RuntimeValueProvider.runtime_options is not None + @classmethod + def get_value(cls, option_name, value_type, default_value): + candidate = RuntimeValueProvider.runtime_options.get(option_name) + if candidate: + return value_type(candidate) + else: + return default_value + def get(self): if RuntimeValueProvider.runtime_options is None: raise error.RuntimeValueProviderError( '%s.get() not called from a runtime context' % self) - candidate = RuntimeValueProvider.runtime_options.get(self.option_name) - if candidate: - value = self.value_type(candidate) - else: - value = self.default_value - return value + return RuntimeValueProvider.get_value(self.option_name, + self.value_type, + self.default_value) @classmethod def set_runtime_options(cls, pipeline_options): diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 4ac5ea86bf28..71d97ba5d21f 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -61,13 +61,14 @@ from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator +from apache_beam.portability import common_urns from apache_beam.pvalue import PCollection +from apache_beam.pvalue import PDone from apache_beam.runners import PipelineRunner from apache_beam.runners import create_runner from apache_beam.transforms import ptransform from apache_beam.typehints import TypeCheckError from apache_beam.typehints import typehints -from apache_beam.utils import urns from apache_beam.utils.annotations import deprecated __all__ = ['Pipeline', 'PTransformOverride'] @@ -180,7 +181,6 @@ def _remove_labels_recursively(self, applied_transform): def _replace(self, override): assert isinstance(override, PTransformOverride) - matcher = override.get_matcher() output_map = {} output_replacements = {} @@ -193,10 +193,12 @@ def __init__(self, pipeline): self.pipeline = pipeline def _replace_if_needed(self, original_transform_node): - if matcher(original_transform_node): + if override.matches(original_transform_node): assert isinstance(original_transform_node, AppliedPTransform) replacement_transform = override.get_replacement_transform( original_transform_node.transform) + if replacement_transform is original_transform_node.transform: + return replacement_transform_node = AppliedPTransform( original_transform_node.parent, replacement_transform, @@ -227,6 +229,10 @@ def _replace_if_needed(self, original_transform_node): 'have a single input. Tried to replace input of ' 'AppliedPTransform %r that has %d inputs', original_transform_node, len(inputs)) + elif len(inputs) == 1: + input_node = inputs[0] + elif len(inputs) == 0: + input_node = pvalue.PBegin(self) # We have to add the new AppliedTransform to the stack before expand() # and pop it out later to make sure that parts get added correctly. @@ -239,16 +245,23 @@ def _replace_if_needed(self, original_transform_node): # with labels of the children of the original. self.pipeline._remove_labels_recursively(original_transform_node) - new_output = replacement_transform.expand(inputs[0]) + new_output = replacement_transform.expand(input_node) + + new_output.element_type = None + self.pipeline._infer_result_type(replacement_transform, inputs, + new_output) + replacement_transform_node.add_output(new_output) + if not new_output.producer: + new_output.producer = replacement_transform_node # We only support replacing transforms with a single output with # another transform that produces a single output. # TODO: Support replacing PTransforms with multiple outputs. if (len(original_transform_node.outputs) > 1 or - not isinstance( - original_transform_node.outputs[None], PCollection) or - not isinstance(new_output, PCollection)): + not isinstance(original_transform_node.outputs[None], + (PCollection, PDone)) or + not isinstance(new_output, (PCollection, PDone))): raise NotImplementedError( 'PTransform overriding is only supported for PTransforms that ' 'have a single output. Tried to replace output of ' @@ -314,11 +327,10 @@ def visit_transform(self, transform_node): transform.inputs = input_replacements[transform] def _check_replacement(self, override): - matcher = override.get_matcher() class ReplacementValidator(PipelineVisitor): def visit_transform(self, transform_node): - if matcher(transform_node): + if override.matches(transform_node): raise RuntimeError('Transform node %r was not replaced as expected.', transform_node) @@ -477,29 +489,8 @@ def apply(self, transform, pvalueish=None, label=None): # being the real producer of the result. if result.producer is None: result.producer = current - # TODO(robertwb): Multi-input, multi-output inference. - # TODO(robertwb): Ideally we'd do intersection here. - if (type_options is not None and type_options.pipeline_type_check - and isinstance(result, pvalue.PCollection) - and not result.element_type): - input_element_type = ( - inputs[0].element_type - if len(inputs) == 1 - else typehints.Any) - type_hints = transform.get_type_hints() - declared_output_type = type_hints.simple_output_type(transform.label) - if declared_output_type: - input_types = type_hints.input_types - if input_types and input_types[0]: - declared_input_type = input_types[0][0] - result.element_type = typehints.bind_type_variables( - declared_output_type, - typehints.match_type_variables(declared_input_type, - input_element_type)) - else: - result.element_type = declared_output_type - else: - result.element_type = transform.infer_output_type(input_element_type) + + self._infer_result_type(transform, inputs, result) assert isinstance(result.producer.inputs, tuple) current.add_output(result) @@ -516,6 +507,33 @@ def apply(self, transform, pvalueish=None, label=None): self.transforms_stack.pop() return pvalueish_result + def _infer_result_type(self, transform, inputs, result_pcollection): + # TODO(robertwb): Multi-input, multi-output inference. + # TODO(robertwb): Ideally we'd do intersection here. + type_options = self._options.view_as(TypeOptions) + if (type_options is not None and type_options.pipeline_type_check + and isinstance(result_pcollection, pvalue.PCollection) + and not result_pcollection.element_type): + input_element_type = ( + inputs[0].element_type + if len(inputs) == 1 + else typehints.Any) + type_hints = transform.get_type_hints() + declared_output_type = type_hints.simple_output_type(transform.label) + if declared_output_type: + input_types = type_hints.input_types + if input_types and input_types[0]: + declared_input_type = input_types[0][0] + result_pcollection.element_type = typehints.bind_type_variables( + declared_output_type, + typehints.match_type_variables(declared_input_type, + input_element_type)) + else: + result_pcollection.element_type = declared_output_type + else: + result_pcollection.element_type = transform.infer_output_type( + input_element_type) + def __reduce__(self): # Some transforms contain a reference to their enclosing pipeline, # which in turn reference all other transforms (resulting in quadratic @@ -828,7 +846,7 @@ def is_side_input(tag): None if tag == 'None' else tag: context.pcollections.get_by_id(id) for tag, id in proto.outputs.items()} # This annotation is expected by some runners. - if proto.spec.urn == urns.PARDO_TRANSFORM: + if proto.spec.urn == common_urns.PARDO_TRANSFORM: result.transform.output_tags = set(proto.outputs.keys()).difference( {'None'}) if not result.parts: @@ -852,12 +870,14 @@ class PTransformOverride(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def get_matcher(self): - """Gives a matcher that will be used to to perform this override. + def matches(self, applied_ptransform): + """Determines whether the given AppliedPTransform matches. + + Args: + applied_ptransform: AppliedPTransform to be matched. Returns: - a callable that takes an AppliedPTransform as a parameter and returns a - boolean as a result. + a bool indicating whether the given AppliedPTransform is a match. """ raise NotImplementedError @@ -867,6 +887,7 @@ def get_replacement_transform(self, ptransform): Args: ptransform: PTransform to be replaced. + Returns: A PTransform that will be the replacement for the PTransform given as an argument. diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 34ec48ecfd64..c3dd2296f206 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -17,6 +17,7 @@ """Unit tests for the Pipeline class.""" +import copy import logging import platform import unittest @@ -25,6 +26,7 @@ import mock import apache_beam as beam +from apache_beam import typehints from apache_beam.io import Read from apache_beam.metrics import Metrics from apache_beam.pipeline import Pipeline @@ -94,6 +96,13 @@ def expand(self, input): return input | 'Inner' >> beam.Map(lambda a: a * 3) +class ToStringParDo(beam.PTransform): + def expand(self, input): + # We use copy.copy() here to make sure the typehint mechanism doesn't + # automatically infer that the output type is str. + return input | 'Inner' >> beam.Map(lambda a: copy.copy(str(a))) + + class PipelineTest(unittest.TestCase): @staticmethod @@ -161,8 +170,7 @@ def test_create_singleton_pcollection(self): # TODO(BEAM-1555): Test is failing on the service, with FakeSource. # @attr('ValidatesRunner') def test_metrics_in_fake_source(self): - # FakeSource mock requires DirectRunner. - pipeline = TestPipeline(runner='DirectRunner') + pipeline = TestPipeline() pcoll = pipeline | Read(FakeSource([1, 2, 3, 4, 5, 6])) assert_that(pcoll, equal_to([1, 2, 3, 4, 5, 6])) res = pipeline.run() @@ -173,8 +181,7 @@ def test_metrics_in_fake_source(self): self.assertEqual(outputs_counter.committed, 6) def test_fake_read(self): - # FakeSource mock requires DirectRunner. - pipeline = TestPipeline(runner='DirectRunner') + pipeline = TestPipeline() pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3])) assert_that(pcoll, equal_to([1, 2, 3])) pipeline.run() @@ -214,7 +221,7 @@ def test_reuse_custom_transform_instance(self): with self.assertRaises(RuntimeError) as cm: pipeline.apply(transform, pcoll2) self.assertEqual( - cm.exception.message, + cm.exception.args[0], 'Transform "CustomTransform" does not have a stable unique label. ' 'This will prevent updating of pipelines. ' 'To apply a transform with a specified label write ' @@ -270,7 +277,9 @@ def check_memory(value, memory_threshold): num_elements = 10 num_maps = 100 - pipeline = TestPipeline() + # TODO(robertwb): reduce memory usage of FnApiRunner so that this test + # passes. + pipeline = TestPipeline(runner='BundleBasedDirectRunner') # Consumed memory should not be proportional to the number of maps. memory_threshold = ( @@ -310,29 +319,57 @@ def raise_exception(exn): 'apache_beam.runners.direct.direct_runner._get_transform_overrides') def test_ptransform_overrides(self, file_system_override_mock): - def my_par_do_matcher(applied_ptransform): - return isinstance(applied_ptransform.transform, DoubleParDo) - class MyParDoOverride(PTransformOverride): - def get_matcher(self): - return my_par_do_matcher + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, DoubleParDo) def get_replacement_transform(self, ptransform): if isinstance(ptransform, DoubleParDo): return TripleParDo() raise ValueError('Unsupported type of transform: %r', ptransform) - def get_overrides(): + def get_overrides(unused_pipeline_options): return [MyParDoOverride()] file_system_override_mock.side_effect = get_overrides # Specify DirectRunner as it's the one patched above. - with Pipeline(runner='DirectRunner') as p: + with Pipeline(runner='BundleBasedDirectRunner') as p: pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo() assert_that(pcoll, equal_to([3, 6, 9])) + def test_ptransform_override_type_hints(self): + + class NoTypeHintOverride(PTransformOverride): + + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, DoubleParDo) + + def get_replacement_transform(self, ptransform): + return ToStringParDo() + + class WithTypeHintOverride(PTransformOverride): + + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, DoubleParDo) + + def get_replacement_transform(self, ptransform): + return (ToStringParDo() + .with_input_types(int) + .with_output_types(str)) + + for override, expected_type in [(NoTypeHintOverride(), typehints.Any), + (WithTypeHintOverride(), str)]: + p = TestPipeline() + pcoll = (p + | beam.Create([1, 2, 3]) + | 'Operate' >> DoubleParDo() + | 'NoOp' >> beam.Map(lambda x: x)) + + p.replace_all([override]) + self.assertEquals(pcoll.producer.inputs[0].element_type, expected_type) + class DoFnTest(unittest.TestCase): @@ -408,6 +445,12 @@ def process(self, element, timestamp=DoFn.TimestampParam): assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP])) pipeline.run() + def test_timestamp_param_map(self): + with TestPipeline() as p: + assert_that( + p | Create([1, 2]) | beam.Map(lambda _, t=DoFn.TimestampParam: t), + equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP])) + class Bacon(PipelineOptions): @@ -529,7 +572,9 @@ def expand(self, p): class DirectRunnerRetryTests(unittest.TestCase): def test_retry_fork_graph(self): - p = beam.Pipeline(runner='DirectRunner') + # TODO(BEAM-3642): The FnApiRunner currently does not currently support + # retries. + p = beam.Pipeline(runner='BundleBasedDirectRunner') # TODO(mariagh): Remove the use of globals from the test. global count_b, count_c # pylint: disable=global-variable-undefined diff --git a/sdks/python/apache_beam/portability/python_urns.py b/sdks/python/apache_beam/portability/python_urns.py new file mode 100644 index 000000000000..a284b5fe66c0 --- /dev/null +++ b/sdks/python/apache_beam/portability/python_urns.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Enumeration of URNs specific to the Python SDK. + +For internal use only; no backwards-compatibility guarantees.""" + +PICKLED_CODER = "beam:coder:pickled_python:v1" +PICKLED_COMBINE_FN = "beam:combinefn:pickled_python:v1" +PICKLED_DOFN = "beam:dofn:pickled_python:v1" +PICKLED_DOFN_INFO = "beam:dofn:pickled_python_info:v1" +PICKLED_SOURCE = "beam:source:pickled_python:v1" +PICKLED_TRANSFORM = "beam:ptransform:pickled_python:v1" +PICKLED_WINDOW_MAPPING_FN = "beam:window_mapping_fn:pickled_python:v1" +PICKLED_WINDOWFN = "beam:windowfn:pickled_python:v1" +PICKLED_VIEWFN = "beam:view_fn:pickled_python_data:v1" diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 3987fd5dfff3..2aca33e667f6 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -31,8 +31,9 @@ from apache_beam import coders from apache_beam import typehints from apache_beam.internal import pickler +from apache_beam.portability import common_urns +from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 -from apache_beam.utils import urns __all__ = [ 'PCollection', @@ -301,7 +302,7 @@ def _side_input_data(self): view_options = self._view_options() from_runtime_iterable = type(self)._from_runtime_iterable return SideInputData( - urns.ITERABLE_ACCESS, + common_urns.ITERABLE_SIDE_INPUT, self._window_mapping_fn, lambda iterable: from_runtime_iterable(iterable, view_options), self._input_element_coder()) @@ -354,17 +355,18 @@ def to_runner_api(self, unused_context): urn=self.access_pattern), view_fn=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.PICKLED_PYTHON_VIEWFN, + urn=python_urns.PICKLED_VIEWFN, payload=pickler.dumps((self.view_fn, self.coder)))), window_mapping_fn=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.PICKLED_WINDOW_MAPPING_FN, + urn=python_urns.PICKLED_WINDOW_MAPPING_FN, payload=pickler.dumps(self.window_mapping_fn)))) @staticmethod def from_runner_api(proto, unused_context): - assert proto.view_fn.spec.urn == urns.PICKLED_PYTHON_VIEWFN - assert proto.window_mapping_fn.spec.urn == urns.PICKLED_WINDOW_MAPPING_FN + assert proto.view_fn.spec.urn == python_urns.PICKLED_VIEWFN + assert (proto.window_mapping_fn.spec.urn == + python_urns.PICKLED_WINDOW_MAPPING_FN) return SideInputData( proto.access_pattern.urn, pickler.loads(proto.window_mapping_fn.spec.payload), @@ -442,7 +444,7 @@ def _from_runtime_iterable(it, options): def _side_input_data(self): return SideInputData( - urns.ITERABLE_ACCESS, + common_urns.ITERABLE_SIDE_INPUT, self._window_mapping_fn, lambda iterable: iterable, self._input_element_coder()) @@ -473,7 +475,7 @@ def _from_runtime_iterable(it, options): def _side_input_data(self): return SideInputData( - urns.ITERABLE_ACCESS, + common_urns.ITERABLE_SIDE_INPUT, self._window_mapping_fn, list, self._input_element_coder()) @@ -501,7 +503,7 @@ def _from_runtime_iterable(it, options): def _side_input_data(self): return SideInputData( - urns.ITERABLE_ACCESS, + common_urns.ITERABLE_SIDE_INPUT, self._window_mapping_fn, dict, self._input_element_coder()) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py index c7eb88ef578d..2e0bc8209ecf 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py @@ -170,12 +170,8 @@ def _get_metric_value(self, metric): lambda x: x.key == 'min').value.integer_value dist_max = _get_match(metric.distribution.object_value.properties, lambda x: x.key == 'max').value.integer_value - dist_mean = _get_match(metric.distribution.object_value.properties, - lambda x: x.key == 'mean').value.integer_value - # Calculating dist_sum with a hack, as distribution sum is not yet - # available in the Dataflow API. - # TODO(pabloem) Switch to "sum" field once it's available in the API - dist_sum = dist_count * dist_mean + dist_sum = _get_match(metric.distribution.object_value.properties, + lambda x: x.key == 'sum').value.integer_value return DistributionResult( DistributionData( dist_sum, dist_count, dist_min, dist_max)) @@ -209,4 +205,4 @@ def query(self, filter=None): 'distributions': [elm for elm in metric_results if self.matches(filter, elm.key) and DataflowMetrics._is_distribution(elm)], - 'gauges': []} # Gauges are not currently supported by dataflow + 'gauges': []} # TODO(pabloem): Add Gauge support for dataflow. diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 66e62ed59992..bfec89310e9e 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -37,6 +37,7 @@ from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import TestOptions +from apache_beam.portability import common_urns from apache_beam.pvalue import AsSideInput from apache_beam.runners.dataflow.dataflow_metrics import DataflowMetrics from apache_beam.runners.dataflow.internal import names @@ -280,7 +281,8 @@ def run_pipeline(self, pipeline): 'please install apache_beam[gcp]') # Snapshot the pipeline in a portable proto before mutating it - proto_pipeline = pipeline.to_runner_api() + proto_pipeline, self.proto_context = pipeline.to_runner_api( + return_context=True) # Performing configured PTransform overrides. pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES) @@ -575,8 +577,17 @@ def run_ParDo(self, transform_node): if transform_node.side_inputs else ''), transform_node, transform_node.transform.output_tags) - fn_data = self._pardo_fn_data(transform_node, lookup_label) - step.add_property(PropertyNames.SERIALIZED_FN, pickler.dumps(fn_data)) + # Import here to avoid adding the dependency for local running scenarios. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam.runners.dataflow.internal import apiclient + transform_proto = self.proto_context.transforms.get_proto(transform_node) + if (apiclient._use_fnapi(transform_node.inputs[0].pipeline._options) + and transform_proto.spec.urn == common_urns.PARDO_TRANSFORM): + serialized_data = self.proto_context.transforms.get_id(transform_node) + else: + serialized_data = pickler.dumps( + self._pardo_fn_data(transform_node, lookup_label)) + step.add_property(PropertyNames.SERIALIZED_FN, serialized_data) step.add_property( PropertyNames.PARALLEL_INPUT, {'@type': 'OutputReference', diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index 2d529e11d2c6..b5300a4a9f64 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -273,8 +273,10 @@ def test_group_by_key_input_visitor_with_valid_inputs(self): pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: + applied = AppliedPTransform(None, transform, "label", [pcoll]) + applied.outputs[None] = PCollection(None) DataflowRunner.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, transform, "label", [pcoll])) + applied) self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any, typehints.Any]) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/dependency_test.py b/sdks/python/apache_beam/runners/dataflow/internal/dependency_test.py index 68e5d8c20fe5..41afe0a8c5b1 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/dependency_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/dependency_test.py @@ -75,7 +75,7 @@ def test_no_staging_location(self): with self.assertRaises(RuntimeError) as cm: dependency.stage_job_resources(PipelineOptions()) self.assertEqual('The --staging_location option must be specified.', - cm.exception.message) + cm.exception.args[0]) def test_no_temp_location(self): staging_dir = self.make_temp_dir() @@ -87,7 +87,7 @@ def test_no_temp_location(self): with self.assertRaises(RuntimeError) as cm: dependency.stage_job_resources(options) self.assertEqual('The --temp_location option must be specified.', - cm.exception.message) + cm.exception.args[0]) def test_no_main_session(self): staging_dir = self.make_temp_dir() @@ -161,7 +161,7 @@ def test_requirements_file_not_present(self): dependency.stage_job_resources( options, populate_requirements_cache=self.populate_requirements_cache) self.assertEqual( - cm.exception.message, + cm.exception.args[0], 'The file %s cannot be found. It was specified in the ' '--requirements_file command line option.' % 'nosuchfile') @@ -229,7 +229,7 @@ def test_setup_file_not_present(self): with self.assertRaises(RuntimeError) as cm: dependency.stage_job_resources(options) self.assertEqual( - cm.exception.message, + cm.exception.args[0], 'The file %s cannot be found. It was specified in the ' '--setup_file command line option.' % 'nosuchfile') @@ -248,7 +248,7 @@ def test_setup_file_not_named_setup_dot_py(self): with self.assertRaises(RuntimeError) as cm: dependency.stage_job_resources(options) self.assertTrue( - cm.exception.message.startswith( + cm.exception.args[0].startswith( 'The --setup_file option expects the full path to a file named ' 'setup.py instead of ')) @@ -338,7 +338,7 @@ def test_sdk_location_local_not_present(self): 'The file "%s" cannot be found. Its ' 'location was specified by the --sdk_location command-line option.' % sdk_location, - cm.exception.message) + cm.exception.args[0]) def test_sdk_location_gcs(self): staging_dir = self.make_temp_dir() @@ -415,7 +415,7 @@ def test_with_extra_packages_missing_files(self): dependency.stage_job_resources(options) self.assertEqual( - cm.exception.message, + cm.exception.args[0], 'The file %s cannot be found. It was specified in the ' '--extra_packages command line option.' % 'nosuchfile.tar.gz') @@ -432,7 +432,7 @@ def test_with_extra_packages_invalid_file_name(self): os.path.join(source_dir, 'abc.tgz')] dependency.stage_job_resources(options) self.assertEqual( - cm.exception.message, + cm.exception.args[0], 'The --extra_package option expects a full path ending with ' '".tar", ".tar.gz", ".whl" or ".zip" ' 'instead of %s' % os.path.join(source_dir, 'abc.tgz')) diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py index ce7725728741..01fd35f9cf95 100644 --- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py +++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py @@ -182,8 +182,7 @@ def __exit__(self, *unused_args): def Write(self, value): self.written_values.append(value) - # Records in-memory writes, only works on Direct runner. - p = TestPipeline(runner='DirectRunner') + p = TestPipeline() sink = FakeSink() p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned p.run() diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py index 680a4b7de5c2..0ce212fa31bd 100644 --- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py +++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py @@ -24,11 +24,7 @@ class CreatePTransformOverride(PTransformOverride): """A ``PTransformOverride`` for ``Create`` in streaming mode.""" - def get_matcher(self): - return self.is_streaming_create - - @staticmethod - def is_streaming_create(applied_ptransform): + def matches(self, applied_ptransform): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam import Create diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index aa35fb7ccbb1..67f5780005fc 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -25,6 +25,7 @@ from apache_beam.metrics.cells import CounterAggregator from apache_beam.metrics.cells import DistributionAggregator +from apache_beam.metrics.cells import GaugeAggregator from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metric import MetricResults @@ -36,6 +37,8 @@ def __init__(self): lambda: DirectMetric(CounterAggregator())) self._distributions = defaultdict( lambda: DirectMetric(DistributionAggregator())) + self._gauges = defaultdict( + lambda: DirectMetric(GaugeAggregator())) def _apply_operation(self, bundle, updates, op): for k, v in updates.counters.items(): @@ -44,6 +47,9 @@ def _apply_operation(self, bundle, updates, op): for k, v in updates.distributions.items(): op(self._distributions[k], bundle, v) + for k, v in updates.gauges.items(): + op(self._gauges[k], bundle, v) + def commit_logical(self, bundle, updates): op = lambda obj, bundle, update: obj.commit_logical(bundle, update) self._apply_operation(bundle, updates, op) @@ -67,9 +73,15 @@ def query(self, filter=None): v.extract_latest_attempted()) for k, v in self._distributions.items() if self.matches(filter, k)] + gauges = [MetricResult(MetricKey(k.step, k.metric), + v.extract_committed(), + v.extract_latest_attempted()) + for k, v in self._gauges.items() + if self.matches(filter, k)] return {'counters': counters, - 'distributions': distributions} + 'distributions': distributions, + 'gauges': gauges} class DirectMetric(object): @@ -81,10 +93,10 @@ class DirectMetric(object): def __init__(self, aggregator): self.aggregator = aggregator self._attempted_lock = threading.Lock() - self.finished_attempted = aggregator.zero() + self.finished_attempted = aggregator.identity_element() self.inflight_attempted = {} self._committed_lock = threading.Lock() - self.finished_committed = aggregator.zero() + self.finished_committed = aggregator.identity_element() def commit_logical(self, bundle, update): with self._committed_lock: diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 7f3200ea5f6d..d06a9a8df410 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -23,16 +23,18 @@ from __future__ import absolute_import +import itertools import logging from google.protobuf import wrappers_pb2 import apache_beam as beam +from apache_beam import coders from apache_beam import typehints +from apache_beam.internal.util import ArgumentPlaceholder from apache_beam.metrics.execution import MetricsEnvironment from apache_beam.options.pipeline_options import DirectOptions from apache_beam.options.pipeline_options import StandardOptions -from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.pvalue import PCollection from apache_beam.runners.direct.bundle_factory import BundleFactory @@ -41,11 +43,90 @@ from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState +from apache_beam.transforms.core import CombinePerKey +from apache_beam.transforms.core import CombineValuesDoFn +from apache_beam.transforms.core import ParDo from apache_beam.transforms.core import _GroupAlsoByWindow +from apache_beam.transforms.core import _GroupAlsoByWindowDoFn from apache_beam.transforms.core import _GroupByKeyOnly from apache_beam.transforms.ptransform import PTransform -__all__ = ['DirectRunner'] +# Note that the BundleBasedDirectRunner and SwitchingDirectRunner names are +# experimental and have no backwards compatibility guarantees. +__all__ = ['BundleBasedDirectRunner', + 'DirectRunner', + 'SwitchingDirectRunner'] + + +class SwitchingDirectRunner(PipelineRunner): + """Executes a single pipeline on the local machine. + + This implementation switches between using the FnApiRunner (which has + high throughput for batch jobs) and using the BundleBasedDirectRunner, + which supports streaming execution and certain primitives not yet + implemented in the FnApiRunner. + """ + + def run_pipeline(self, pipeline): + use_fnapi_runner = True + + # Streaming mode is not yet supported on the FnApiRunner. + if pipeline.options.view_as(StandardOptions).streaming: + use_fnapi_runner = False + + from apache_beam.pipeline import PipelineVisitor + from apache_beam.runners.common import DoFnSignature + from apache_beam.runners.dataflow.native_io.iobase import NativeSource + from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite + from apache_beam.testing.test_stream import TestStream + + class _FnApiRunnerSupportVisitor(PipelineVisitor): + """Visitor determining if a Pipeline can be run on the FnApiRunner.""" + + def __init__(self): + self.supported_by_fnapi_runner = True + + def visit_transform(self, applied_ptransform): + transform = applied_ptransform.transform + # The FnApiRunner does not support streaming execution. + if isinstance(transform, TestStream): + self.supported_by_fnapi_runner = False + # The FnApiRunner does not support reads from NativeSources. + if (isinstance(transform, beam.io.Read) and + isinstance(transform.source, NativeSource)): + self.supported_by_fnapi_runner = False + # The FnApiRunner does not support the use of _NativeWrites. + if isinstance(transform, _NativeWrite): + self.supported_by_fnapi_runner = False + if isinstance(transform, beam.ParDo): + dofn = transform.dofn + # The FnApiRunner does not support execution of SplittableDoFns. + if DoFnSignature(dofn).is_splittable_dofn(): + self.supported_by_fnapi_runner = False + # The FnApiRunner does not support execution of CombineFns with + # deferred side inputs. + if isinstance(dofn, CombineValuesDoFn): + args, kwargs = transform.raw_side_inputs + args_to_check = itertools.chain(args, + kwargs.values()) + if any(isinstance(arg, ArgumentPlaceholder) + for arg in args_to_check): + self.supported_by_fnapi_runner = False + + # Check whether all transforms used in the pipeline are supported by the + # FnApiRunner. + visitor = _FnApiRunnerSupportVisitor() + pipeline.visit(visitor) + if not visitor.supported_by_fnapi_runner: + use_fnapi_runner = False + + if use_fnapi_runner: + from apache_beam.runners.portability.fn_api_runner import FnApiRunner + runner = FnApiRunner() + else: + runner = BundleBasedDirectRunner() + + return runner.run_pipeline(pipeline) # Type variables. @@ -87,7 +168,23 @@ def from_runner_api_parameter(payload, context): context.windowing_strategies.get_by_id(payload.value)) -def _get_transform_overrides(): +class _DirectReadStringsFromPubSub(PTransform): + def __init__(self, source): + self._source = source + + def _infer_output_coder(self, unused_input_type=None, + unused_input_coder=None): + return coders.StrUtf8Coder() + + def get_windowing(self, inputs): + return beam.Windowing(beam.window.GlobalWindows()) + + def expand(self, pvalue): + # This is handled as a native transform. + return PCollection(self.pipeline) + + +def _get_transform_overrides(pipeline_options): # A list of PTransformOverride objects to be applied before running a pipeline # using DirectRunner. # Currently this only works for overrides where the input and output types do @@ -95,131 +192,170 @@ def _get_transform_overrides(): # For internal use only; no backwards-compatibility guarantees. # Importing following locally to avoid a circular dependency. + from apache_beam.pipeline import PTransformOverride from apache_beam.runners.sdf_common import SplittableParDoOverride + from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey from apache_beam.runners.direct.sdf_direct_runner import ProcessKeyedElementsViaKeyedWorkItemsOverride - return [SplittableParDoOverride(), - ProcessKeyedElementsViaKeyedWorkItemsOverride()] + class CombinePerKeyOverride(PTransformOverride): + def matches(self, applied_ptransform): + if isinstance(applied_ptransform.transform, CombinePerKey): + return True -class DirectRunner(PipelineRunner): - """Executes a single pipeline on the local machine.""" + def get_replacement_transform(self, transform): + # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems + # with resolving imports when they are at top. + # pylint: disable=wrong-import-position + try: + return LiftedCombinePerKey(transform.fn, transform.args, + transform.kwargs) + except NotImplementedError: + return transform + + class StreamingGroupByKeyOverride(PTransformOverride): + def matches(self, applied_ptransform): + # Note: we match the exact class, since we replace it with a subclass. + return applied_ptransform.transform.__class__ == _GroupByKeyOnly + + def get_replacement_transform(self, transform): + # Use specialized streaming implementation. + transform = _StreamingGroupByKeyOnly() + return transform + + class StreamingGroupAlsoByWindowOverride(PTransformOverride): + def matches(self, applied_ptransform): + # Note: we match the exact class, since we replace it with a subclass. + transform = applied_ptransform.transform + return (isinstance(applied_ptransform.transform, ParDo) and + isinstance(transform.dofn, _GroupAlsoByWindowDoFn) and + transform.__class__ != _StreamingGroupAlsoByWindow) + + def get_replacement_transform(self, transform): + # Use specialized streaming implementation. + transform = _StreamingGroupAlsoByWindow(transform.dofn.windowing) + return transform + + overrides = [SplittableParDoOverride(), + ProcessKeyedElementsViaKeyedWorkItemsOverride(), + CombinePerKeyOverride()] + + # Add streaming overrides, if necessary. + if pipeline_options.view_as(StandardOptions).streaming: + overrides.append(StreamingGroupByKeyOverride()) + overrides.append(StreamingGroupAlsoByWindowOverride()) + + # Add PubSub overrides, if PubSub is available. + try: + from apache_beam.io.gcp import pubsub as unused_pubsub + overrides += _get_pubsub_transform_overrides(pipeline_options) + except ImportError: + pass + + return overrides + + +def _get_pubsub_transform_overrides(pipeline_options): + from google.cloud import pubsub + from apache_beam.io.gcp import pubsub as beam_pubsub + from apache_beam.pipeline import PTransformOverride + + class ReadStringsFromPubSubOverride(PTransformOverride): + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, + beam_pubsub.ReadStringsFromPubSub) + + def get_replacement_transform(self, transform): + if not pipeline_options.view_as(StandardOptions).streaming: + raise Exception('PubSub I/O is only available in streaming mode ' + '(use the --streaming flag).') + return _DirectReadStringsFromPubSub(transform._source) + + class WriteStringsToPubSubOverride(PTransformOverride): + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, + beam_pubsub.WriteStringsToPubSub) + + def get_replacement_transform(self, transform): + if not pipeline_options.view_as(StandardOptions).streaming: + raise Exception('PubSub I/O is only available in streaming mode ' + '(use the --streaming flag).') + + class _DirectWriteToPubSub(beam.DoFn): + _topic = None + + def __init__(self, project, topic_name): + self.project = project + self.topic_name = topic_name + + def start_bundle(self): + if self._topic is None: + self._topic = pubsub.Client(project=self.project).topic( + self.topic_name) + self._buffer = [] - def __init__(self): - self._use_test_clock = False # use RealClock() in production - self._ptransform_overrides = _get_transform_overrides() + def process(self, elem): + self._buffer.append(elem.encode('utf-8')) + if len(self._buffer) >= 100: + self._flush() - def apply_CombinePerKey(self, transform, pcoll): - if pcoll.pipeline._options.view_as(TypeOptions).runtime_type_check: - # TODO(robertwb): This can be reenabled once expansion happens after run. - return transform.expand(pcoll) - # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems - # with resolving imports when they are at top. - # pylint: disable=wrong-import-position - from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey - try: - return pcoll | LiftedCombinePerKey( - transform.fn, transform.args, transform.kwargs) - except NotImplementedError: - return transform.expand(pcoll) - - def apply_TestStream(self, transform, pcoll): - self._use_test_clock = True # use TestClock() for testing - return transform.expand(pcoll) - - def apply__GroupByKeyOnly(self, transform, pcoll): - if (transform.__class__ == _GroupByKeyOnly and - pcoll.pipeline._options.view_as(StandardOptions).streaming): - # Use specialized streaming implementation, if requested. - type_hints = transform.get_type_hints() - return pcoll | (_StreamingGroupByKeyOnly() - .with_input_types(*type_hints.input_types[0]) - .with_output_types(*type_hints.output_types[0])) - return transform.expand(pcoll) - - def apply__GroupAlsoByWindow(self, transform, pcoll): - if (transform.__class__ == _GroupAlsoByWindow and - pcoll.pipeline._options.view_as(StandardOptions).streaming): - # Use specialized streaming implementation, if requested. - type_hints = transform.get_type_hints() - return pcoll | (_StreamingGroupAlsoByWindow(transform.windowing) - .with_input_types(*type_hints.input_types[0]) - .with_output_types(*type_hints.output_types[0])) - return transform.expand(pcoll) - - def apply_ReadStringsFromPubSub(self, transform, pcoll): - try: - from google.cloud import pubsub as unused_pubsub - except ImportError: - raise ImportError('Google Cloud PubSub not available, please install ' - 'apache_beam[gcp]') - # Execute this as a native transform. - output = PCollection(pcoll.pipeline) - output.element_type = unicode - return output - - def apply_WriteStringsToPubSub(self, transform, pcoll): - try: - from google.cloud import pubsub - except ImportError: - raise ImportError('Google Cloud PubSub not available, please install ' - 'apache_beam[gcp]') - project = transform._sink.project - topic_name = transform._sink.topic_name - - class DirectWriteToPubSub(beam.DoFn): - _topic = None - - def __init__(self, project, topic_name): - self.project = project - self.topic_name = topic_name - - def start_bundle(self): - if self._topic is None: - self._topic = pubsub.Client(project=self.project).topic( - self.topic_name) - self._buffer = [] - - def process(self, elem): - self._buffer.append(elem.encode('utf-8')) - if len(self._buffer) >= 100: + def finish_bundle(self): self._flush() - def finish_bundle(self): - self._flush() + def _flush(self): + if self._buffer: + with self._topic.batch() as batch: + for datum in self._buffer: + batch.publish(datum) + self._buffer = [] + + project = transform._sink.project + topic_name = transform._sink.topic_name + return beam.ParDo(_DirectWriteToPubSub(project, topic_name)) + + return [ReadStringsFromPubSubOverride(), WriteStringsToPubSubOverride()] - def _flush(self): - if self._buffer: - with self._topic.batch() as batch: - for datum in self._buffer: - batch.publish(datum) - self._buffer = [] - output = pcoll | beam.ParDo(DirectWriteToPubSub(project, topic_name)) - output.element_type = unicode - return output +class BundleBasedDirectRunner(PipelineRunner): + """Executes a single pipeline on the local machine.""" def run_pipeline(self, pipeline): """Execute the entire pipeline and returns an DirectPipelineResult.""" - # Performing configured PTransform overrides. - pipeline.replace_all(self._ptransform_overrides) - # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems # with resolving imports when they are at top. # pylint: disable=wrong-import-position + from apache_beam.pipeline import PipelineVisitor from apache_beam.runners.direct.consumer_tracking_pipeline_visitor import \ ConsumerTrackingPipelineVisitor from apache_beam.runners.direct.evaluation_context import EvaluationContext from apache_beam.runners.direct.executor import Executor from apache_beam.runners.direct.transform_evaluator import \ TransformEvaluatorRegistry + from apache_beam.testing.test_stream import TestStream + + # Performing configured PTransform overrides. + pipeline.replace_all(_get_transform_overrides(pipeline.options)) + + # If the TestStream I/O is used, use a mock test clock. + class _TestStreamUsageVisitor(PipelineVisitor): + """Visitor determining whether a Pipeline uses a TestStream.""" + + def __init__(self): + self.uses_test_stream = False + + def visit_transform(self, applied_ptransform): + if isinstance(applied_ptransform.transform, TestStream): + self.uses_test_stream = True + + visitor = _TestStreamUsageVisitor() + pipeline.visit(visitor) + clock = TestClock() if visitor.uses_test_stream else RealClock() MetricsEnvironment.set_metrics_supported(True) logging.info('Running pipeline with DirectRunner.') self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor() pipeline.visit(self.consumer_tracking_visitor) - clock = TestClock() if self._use_test_clock else RealClock() evaluation_context = EvaluationContext( pipeline._options, BundleFactory(stacked=pipeline._options.view_as(DirectOptions) @@ -244,6 +380,10 @@ def run_pipeline(self, pipeline): return result +# Use the SwitchingDirectRunner as the default. +DirectRunner = SwitchingDirectRunner + + class DirectPipelineResult(PipelineResult): """A DirectPipelineResult provides access to info about a pipeline.""" diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py index 1b51d05aae0d..231cca72476a 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py @@ -27,14 +27,16 @@ class DirectPipelineResultTest(unittest.TestCase): def test_waiting_on_result_stops_executor_threads(self): pre_test_threads = set(t.ident for t in threading.enumerate()) - pipeline = test_pipeline.TestPipeline(runner='DirectRunner') - _ = (pipeline | beam.Create([{'foo': 'bar'}])) - result = pipeline.run() - result.wait_until_finish() - - post_test_threads = set(t.ident for t in threading.enumerate()) - new_threads = post_test_threads - pre_test_threads - self.assertEqual(len(new_threads), 0) + for runner in ['DirectRunner', 'BundleBasedDirectRunner', + 'SwitchingDirectRunner']: + pipeline = test_pipeline.TestPipeline(runner=runner) + _ = (pipeline | beam.Create([{'foo': 'bar'}])) + result = pipeline.run() + result.wait_until_finish() + + post_test_threads = set(t.ident for t in threading.enumerate()) + new_threads = post_test_threads - pre_test_threads + self.assertEqual(len(new_threads), 0) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index d4d9cb5ca637..34c12c345a0e 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -539,7 +539,8 @@ def call(self): def _should_shutdown(self): """Checks whether the pipeline is completed and should be shut down. - If there is anything in the queue of tasks to do, do not shut down. + If there is anything in the queue of tasks to do or + if there are any realtime timers set, do not shut down. Otherwise, check if all the transforms' watermarks are complete. If they are not, the pipeline is not progressing (stall detected). @@ -554,6 +555,12 @@ def _should_shutdown(self): if self._is_executing(): # There are some bundles still in progress. return False + + watermark_manager = self._executor.evaluation_context._watermark_manager + _, any_unfired_realtime_timers = watermark_manager.extract_all_timers() + if any_unfired_realtime_timers: + return False + else: if self._executor.evaluation_context.is_done(): self._executor.visible_updates.offer( @@ -594,13 +601,7 @@ def _is_executing(self): """Checks whether the job is still executing. Returns: - True if there are any timers set or if there is at least - one non-blocked TransformExecutor active.""" - - watermark_manager = self._executor.evaluation_context._watermark_manager - _, any_unfired_realtime_timers = watermark_manager.extract_all_timers() - if any_unfired_realtime_timers: - return True + True if there is at least one non-blocked TransformExecutor active.""" executors = self._executor.transform_executor_services.executors if not executors: diff --git a/sdks/python/apache_beam/runners/direct/helper_transforms.py b/sdks/python/apache_beam/runners/direct/helper_transforms.py index 26b0701bd02b..0c1da0351264 100644 --- a/sdks/python/apache_beam/runners/direct/helper_transforms.py +++ b/sdks/python/apache_beam/runners/direct/helper_transforms.py @@ -21,6 +21,7 @@ import apache_beam as beam from apache_beam import typehints from apache_beam.internal.util import ArgumentPlaceholder +from apache_beam.transforms.combiners import _CurriedFn from apache_beam.utils.windowed_value import WindowedValue @@ -28,8 +29,13 @@ class LiftedCombinePerKey(beam.PTransform): """An implementation of CombinePerKey that does mapper-side pre-combining. """ def __init__(self, combine_fn, args, kwargs): + args_to_check = itertools.chain(args, kwargs.values()) + if isinstance(combine_fn, _CurriedFn): + args_to_check = itertools.chain(args_to_check, + combine_fn.args, + combine_fn.kwargs.values()) if any(isinstance(arg, ArgumentPlaceholder) - for arg in itertools.chain(args, kwargs.values())): + for arg in args_to_check): # This isn't implemented in dataflow either... raise NotImplementedError('Deferred CombineFn side inputs.') self._combine_fn = beam.transforms.combiners.curry_combine_fn( diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index ddbe9649b424..aa247aa4118b 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -44,12 +44,9 @@ class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride): """A transform override for ProcessElements transform.""" - def get_matcher(self): - def _matcher(applied_ptransform): - return isinstance( - applied_ptransform.transform, ProcessKeyedElements) - - return _matcher + def matches(self, applied_ptransform): + return isinstance( + applied_ptransform.transform, ProcessKeyedElements) def get_replacement_transform(self, ptransform): return ProcessKeyedElementsViaKeyedWorkItems(ptransform) diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py index c1df7da52c48..7ab6dde93979 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py @@ -145,7 +145,7 @@ def run_sdf_read_pipeline( assert len(expected_data) > 0 - with TestPipeline(runner='DirectRunner') as p: + with TestPipeline() as p: pc1 = (p | 'Create1' >> beam.Create(file_names) | 'SDF' >> beam.ParDo(ReadFiles(resume_count))) @@ -205,7 +205,7 @@ def test_sdf_with_resume_multiple_elements(self): resume_count) def test_sdf_with_windowed_timestamped_input(self): - with TestPipeline(runner='DirectRunner') as p: + with TestPipeline() as p: result = (p | beam.Create([1, 3, 5, 10]) | beam.FlatMap(lambda t: [TimestampedValue(('A', t), t), @@ -221,7 +221,7 @@ def test_sdf_with_windowed_timestamped_input(self): assert_that(result, equal_to(expected_result)) def test_sdf_with_side_inputs(self): - with TestPipeline(runner='DirectRunner') as p: + with TestPipeline() as p: result = (p | 'create_main' >> beam.Create(['1', '3', '5']) | beam.ParDo(ExpandStrings(), side=['1', '3'])) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 67dd382fa52b..cbd2c9fbc254 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -26,11 +26,13 @@ import apache_beam.io as io from apache_beam import coders from apache_beam import pvalue +from apache_beam import typehints from apache_beam.internal import pickler from apache_beam.runners import common from apache_beam.runners.common import DoFnRunner from apache_beam.runners.common import DoFnState from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access +from apache_beam.runners.direct.direct_runner import _DirectReadStringsFromPubSub from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow from apache_beam.runners.direct.direct_runner import _StreamingGroupByKeyOnly from apache_beam.runners.direct.sdf_direct_runner import ProcessElements @@ -67,7 +69,7 @@ def __init__(self, evaluation_context): self._evaluation_context = evaluation_context self._evaluators = { io.Read: _BoundedReadEvaluator, - io.ReadStringsFromPubSub: _PubSubReadEvaluator, + _DirectReadStringsFromPubSub: _PubSubReadEvaluator, core.Flatten: _FlattenEvaluator, core.ParDo: _ParDoEvaluator, core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, @@ -340,8 +342,8 @@ def process_element(self, element): assert event.new_watermark >= self.watermark self.watermark = event.new_watermark elif isinstance(event, ProcessingTimeEvent): - # TODO(ccy): advance processing time in the context's mock clock. - pass + self._evaluation_context._watermark_manager._clock.advance_time( + event.advance_by) else: raise ValueError('Invalid TestStream event: %s.' % event) @@ -684,9 +686,10 @@ def start_bundle(self): self.output_pcollection = list(self._outputs)[0] # The input type of a GroupByKey will be KV[Any, Any] or more specific. - kv_type_hint = ( - self._applied_ptransform.transform.get_type_hints().input_types[0]) - self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) + kv_type_hint = self._applied_ptransform.inputs[0].element_type + key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint + else typehints.Any) + self.key_coder = coders.registry.get_coder(key_type_hint) def process_element(self, element): if (isinstance(element, WindowedValue) @@ -733,15 +736,17 @@ def start_bundle(self): self.output_pcollection = list(self._outputs)[0] self.step_context = self._execution_context.get_step_context() self.driver = create_trigger_driver( - self._applied_ptransform.transform.windowing) + self._applied_ptransform.transform.windowing, + clock=self._evaluation_context._watermark_manager._clock) self.gabw_items = [] self.keyed_holds = {} - # The input type of a GroupAlsoByWindow will be KV[Any, Iter[Any]] or more - # specific. - kv_type_hint = ( - self._applied_ptransform.transform.get_type_hints().input_types[0]) - self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) + # The input type (which is the same as the output type) of a + # GroupAlsoByWindow will be KV[Any, Iter[Any]] or more specific. + kv_type_hint = self._applied_ptransform.outputs[None].element_type + key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint + else typehints.Any) + self.key_coder = coders.registry.get_coder(key_type_hint) def process_element(self, element): kwi = element.value diff --git a/sdks/python/apache_beam/runners/direct/util.py b/sdks/python/apache_beam/runners/direct/util.py index 96a6ee2cba01..797a7432644f 100644 --- a/sdks/python/apache_beam/runners/direct/util.py +++ b/sdks/python/apache_beam/runners/direct/util.py @@ -61,6 +61,11 @@ def __init__(self, encoded_key, window, name, time_domain, timestamp): self.time_domain = time_domain self.timestamp = timestamp + def __repr__(self): + return 'TimerFiring(%r, %r, %s, %s)' % (self.encoded_key, + self.name, self.time_domain, + self.timestamp) + class KeyedWorkItem(object): """A keyed item that can either be a timer firing or a list of elements.""" diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 35060999b723..dd8e0518acd0 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -60,6 +60,9 @@ def get_id(self, obj, label=None): self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context) return self._obj_to_id[obj] + def get_proto(self, obj, label=None): + return self._id_to_proto[self.get_id(obj, label)] + def get_by_id(self, id): if id not in self._id_to_obj: self._id_to_obj[id] = self._obj_type.from_runner_api( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 341f06f9a6aa..7e89d9aa7575 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -21,6 +21,7 @@ import copy import logging import Queue as queue +import re import threading import time from concurrent import futures @@ -35,6 +36,7 @@ from apache_beam.coders.coder_impl import create_OutputStream from apache_beam.internal import pickler from apache_beam.metrics.execution import MetricsEnvironment +from apache_beam.portability import common_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.portability.api import beam_runner_api_pb2 @@ -46,7 +48,6 @@ from apache_beam.transforms import trigger from apache_beam.transforms.window import GlobalWindows from apache_beam.utils import proto_utils -from apache_beam.utils import urns # This module is experimental. No backwards-compatibility guarantees. @@ -143,7 +144,7 @@ class _WindowGroupingBuffer(object): def __init__(self, side_input_data): # Here's where we would use a different type of partitioning # (e.g. also by key) for a different access pattern. - assert side_input_data.access_pattern == urns.ITERABLE_ACCESS + assert side_input_data.access_pattern == common_urns.ITERABLE_SIDE_INPUT self._windowed_value_coder = side_input_data.coder self._window_coder = side_input_data.coder.window_coder self._value_coder = side_input_data.coder.wrapped_value_coder @@ -251,12 +252,12 @@ def fuse(self, other): union(self.must_follow, other.must_follow)) def is_flatten(self): - return any(transform.spec.urn == urns.FLATTEN_TRANSFORM + return any(transform.spec.urn == common_urns.FLATTEN_TRANSFORM for transform in self.transforms) def side_inputs(self): for transform in self.transforms: - if transform.spec.urn == urns.PARDO_TRANSFORM: + if transform.spec.urn == common_urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for side_input in payload.side_inputs: @@ -264,7 +265,7 @@ def side_inputs(self): def has_as_main_input(self, pcoll): for transform in self.transforms: - if transform.spec.urn == urns.PARDO_TRANSFORM: + if transform.spec.urn == common_urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs @@ -311,14 +312,14 @@ def windowed_coder_id(coder_id): proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.WINDOWED_VALUE_CODER)), + urn=common_urns.WINDOWED_VALUE_CODER)), component_coder_ids=[coder_id, window_coder_id]) return add_or_get_coder_id(proto) for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] - if transform.spec.urn == urns.COMBINE_PER_KEY_TRANSFORM: + if transform.spec.urn == common_urns.COMBINE_PER_KEY_TRANSFORM: combine_payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.CombinePayload) @@ -338,14 +339,14 @@ def windowed_coder_id(coder_id): key_accumulator_coder = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.KV_CODER)), + urn=common_urns.KV_CODER)), component_coder_ids=[key_coder_id, accumulator_coder_id]) key_accumulator_coder_id = add_or_get_coder_id(key_accumulator_coder) accumulator_iter_coder = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.ITERABLE_CODER)), + urn=common_urns.ITERABLE_CODER)), component_coder_ids=[accumulator_coder_id]) accumulator_iter_coder_id = add_or_get_coder_id( accumulator_iter_coder) @@ -353,7 +354,7 @@ def windowed_coder_id(coder_id): key_accumulator_iter_coder = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.KV_CODER)), + urn=common_urns.KV_CODER)), component_coder_ids=[key_coder_id, accumulator_iter_coder_id]) key_accumulator_iter_coder_id = add_or_get_coder_id( key_accumulator_iter_coder) @@ -397,7 +398,7 @@ def make_stage(base_stage, transform): beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Precombine', spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.PRECOMBINE_TRANSFORM, + urn=common_urns.COMBINE_PGBKCV_TRANSFORM, payload=transform.spec.payload), inputs=transform.inputs, outputs={'out': precombined_pcoll_id})) @@ -407,7 +408,7 @@ def make_stage(base_stage, transform): beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Group', spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.GROUP_BY_KEY_TRANSFORM), + urn=common_urns.GROUP_BY_KEY_TRANSFORM), inputs={'in': precombined_pcoll_id}, outputs={'out': grouped_pcoll_id})) @@ -416,7 +417,7 @@ def make_stage(base_stage, transform): beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/Merge', spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.MERGE_ACCUMULATORS_TRANSFORM, + urn=common_urns.COMBINE_MERGE_ACCUMULATORS_TRANSFORM, payload=transform.spec.payload), inputs={'in': grouped_pcoll_id}, outputs={'out': merged_pcoll_id})) @@ -426,7 +427,7 @@ def make_stage(base_stage, transform): beam_runner_api_pb2.PTransform( unique_name=transform.unique_name + '/ExtractOutputs', spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.EXTRACT_OUTPUTS_TRANSFORM, + urn=common_urns.COMBINE_EXTRACT_OUTPUTS_TRANSFORM, payload=transform.spec.payload), inputs={'in': merged_pcoll_id}, outputs=transform.outputs)) @@ -437,12 +438,13 @@ def make_stage(base_stage, transform): def expand_gbk(stages): """Transforms each GBK into a write followed by a read. """ - good_coder_urns = set(beam.coders.Coder._known_urns.keys()) - set([ - urns.PICKLED_CODER]) + good_coder_urns = set( + value for key, value in common_urns.__dict__.items() + if re.match('[A-Z][A-Z_]*$', key)) coders = pipeline_components.coders for coder_id, coder_proto in coders.items(): - if coder_proto.spec.spec.urn == urns.BYTES_CODER: + if coder_proto.spec.spec.urn == common_urns.BYTES_CODER: bytes_coder_id = coder_id break else: @@ -456,7 +458,7 @@ def wrap_unknown_coders(coder_id, with_bytes): if (coder_id, with_bytes) not in coder_substitutions: wrapped_coder_id = None coder_proto = coders[coder_id] - if coder_proto.spec.spec.urn == urns.LENGTH_PREFIX_CODER: + if coder_proto.spec.spec.urn == common_urns.LENGTH_PREFIX_CODER: coder_substitutions[coder_id, with_bytes] = ( bytes_coder_id if with_bytes else coder_id) elif coder_proto.spec.spec.urn in good_coder_urns: @@ -483,7 +485,7 @@ def wrap_unknown_coders(coder_id, with_bytes): len_prefix_coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.LENGTH_PREFIX_CODER)), + urn=common_urns.LENGTH_PREFIX_CODER)), component_coder_ids=[coder_id]) coders[wrapped_coder_id].CopyFrom(len_prefix_coder_proto) coder_substitutions[coder_id, with_bytes] = wrapped_coder_id @@ -500,7 +502,7 @@ def fix_pcoll_coder(pcoll): for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] - if transform.spec.urn == urns.GROUP_BY_KEY_TRANSFORM: + if transform.spec.urn == common_urns.GROUP_BY_KEY_TRANSFORM: for pcoll_id in transform.inputs.values(): fix_pcoll_coder(pipeline_components.pcollections[pcoll_id]) for pcoll_id in transform.outputs.values(): @@ -547,7 +549,7 @@ def sink_flattens(stages): for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] - if transform.spec.urn == urns.FLATTEN_TRANSFORM: + if transform.spec.urn == common_urns.FLATTEN_TRANSFORM: # This is used later to correlate the read and writes. param = str("materialize:%s" % transform.unique_name) output_pcoll_id, = transform.outputs.values() @@ -773,7 +775,8 @@ def process(stage): coders.populate_map(pipeline_components.coders) known_composites = set( - [urns.GROUP_BY_KEY_TRANSFORM, urns.COMBINE_PER_KEY_TRANSFORM]) + [common_urns.GROUP_BY_KEY_TRANSFORM, + common_urns.COMBINE_PER_KEY_TRANSFORM]) def leaf_transforms(root_ids): for root_id in root_ids: @@ -851,7 +854,7 @@ def extract_endpoints(stage): transform.spec.payload = data_operation_spec.SerializeToString() else: transform.spec.payload = "" - elif transform.spec.urn == urns.PARDO_TRANSFORM: + elif transform.spec.urn == common_urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): @@ -1178,16 +1181,25 @@ class FnApiMetrics(metrics.metric.MetricResults): def __init__(self, step_metrics): self._counters = {} self._distributions = {} + self._gauges = {} for step_metric in step_metrics.values(): - for proto in step_metric.user: - key = metrics.execution.MetricKey.from_runner_api(proto.key) - if proto.HasField('counter_data'): - self._counters[key] = proto.counter_data.value - elif proto.HasField('distribution_data'): - self._distributions[ - key] = metrics.cells.DistributionResult( - metrics.cells.DistributionData.from_runner_api( - proto.distribution_data)) + for ptransform_id, ptransform in step_metric.ptransforms.items(): + for proto in ptransform.user: + key = metrics.execution.MetricKey( + ptransform_id, + metrics.metricbase.MetricName.from_runner_api(proto.metric_name)) + if proto.HasField('counter_data'): + self._counters[key] = proto.counter_data.value + elif proto.HasField('distribution_data'): + self._distributions[ + key] = metrics.cells.DistributionResult( + metrics.cells.DistributionData.from_runner_api( + proto.distribution_data)) + elif proto.HasField('gauge_data'): + self._gauges[ + key] = metrics.cells.GaugeResult( + metrics.cells.GaugeData.from_runner_api( + proto.gauge_data)) def query(self, filter=None): counters = [metrics.execution.MetricResult(k, v, v) @@ -1196,9 +1208,13 @@ def query(self, filter=None): distributions = [metrics.execution.MetricResult(k, v, v) for k, v in self._distributions.items() if self.matches(filter, k)] + gauges = [metrics.execution.MetricResult(k, v, v) + for k, v in self._gauges.items() + if self.matches(filter, k)] return {'counters': counters, - 'distributions': distributions} + 'distributions': distributions, + 'gauges': gauges} class RunnerResult(runner.PipelineResult): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 3b5d5de88387..e7b865cb631c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -125,11 +125,14 @@ def test_metrics(self): counter = beam.metrics.Metrics.counter('ns', 'counter') distribution = beam.metrics.Metrics.distribution('ns', 'distribution') + gauge = beam.metrics.Metrics.gauge('ns', 'gauge') + pcoll = p | beam.Create(['a', 'zzz']) # pylint: disable=expression-not-assigned pcoll | 'count1' >> beam.FlatMap(lambda x: counter.inc()) pcoll | 'count2' >> beam.FlatMap(lambda x: counter.inc(len(x))) pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x))) + pcoll | 'gauge' >> beam.FlatMap(lambda x: gauge.set(len(x))) res = p.run() res.wait_until_finish() @@ -141,9 +144,12 @@ def test_metrics(self): self.assertEqual(c2.committed, 4) dist, = res.metrics().query(beam.metrics.MetricsFilter().with_step('dist'))[ 'distributions'] + gaug, = res.metrics().query( + beam.metrics.MetricsFilter().with_step('gauge'))['gauges'] self.assertEqual( dist.committed.data, beam.metrics.cells.DistributionData(4, 2, 1, 3)) self.assertEqual(dist.committed.mean, 2.0) + self.assertEqual(gaug.committed.value, 3) def test_progress_metrics(self): p = self.create_pipeline() @@ -153,7 +159,7 @@ def test_progress_metrics(self): self.skipTest('Progress metrics not supported.') _ = (p - | beam.Create([0, 0, 0, 2.1e-3 * DEFAULT_SAMPLING_PERIOD_MS]) + | beam.Create([0, 0, 0, 5e-3 * DEFAULT_SAMPLING_PERIOD_MS]) | beam.Map(time.sleep) | beam.Map(lambda x: ('key', x)) | beam.GroupByKey() @@ -180,7 +186,7 @@ def test_progress_metrics(self): pregbk_metrics.ptransforms['Map(sleep)'] .processed_elements.measured.output_element_counts['None']) self.assertLessEqual( - 2e-3 * DEFAULT_SAMPLING_PERIOD_MS, + 4e-3 * DEFAULT_SAMPLING_PERIOD_MS, pregbk_metrics.ptransforms['Map(sleep)'] .processed_elements.measured.total_time_spent) self.assertEqual( diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 1d0f700c66c3..22288a301896 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -45,7 +45,8 @@ def _get_runner_map(runner_names, module_path): 'python_rpc_direct_runner.') _KNOWN_PYTHON_RPC_DIRECT_RUNNER = ('PythonRPCDirectRunner',) -_KNOWN_DIRECT_RUNNERS = ('DirectRunner',) +_KNOWN_DIRECT_RUNNERS = ('DirectRunner', 'BundleBasedDirectRunner', + 'SwitchingDirectRunner') _KNOWN_DATAFLOW_RUNNERS = ('DataflowRunner',) _KNOWN_TEST_RUNNERS = ('TestDataflowRunner',) @@ -117,13 +118,18 @@ class PipelineRunner(object): """ def run(self, transform, options=None): - """Run the given transform with this runner. + """Run the given transform or callable with this runner. """ # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import PTransform + from apache_beam.pvalue import PBegin from apache_beam.pipeline import Pipeline p = Pipeline(runner=self, options=options) - p | transform + if isinstance(transform, PTransform): + p | transform + else: + transform(PBegin(p)) return p.run() def run_pipeline(self, pipeline): diff --git a/sdks/python/apache_beam/runners/runner_test.py b/sdks/python/apache_beam/runners/runner_test.py index aa615cfdd196..e3962f89b488 100644 --- a/sdks/python/apache_beam/runners/runner_test.py +++ b/sdks/python/apache_beam/runners/runner_test.py @@ -80,6 +80,8 @@ def finish_bundle(self): count.inc() def process(self, element): + gauge = Metrics.gauge(self.__class__, 'latest_element') + gauge.set(element) count = Metrics.counter(self.__class__, 'elements') count.inc() distro = Metrics.distribution(self.__class__, 'element_dist') @@ -110,6 +112,7 @@ def process(self, element): MetricResult( MetricKey('Do', MetricName(namespace, 'finished_bundles')), 1, 1))) + hc.assert_that( metrics['distributions'], hc.contains_inanyorder( @@ -118,6 +121,13 @@ def process(self, element): DistributionResult(DistributionData(15, 5, 1, 5)), DistributionResult(DistributionData(15, 5, 1, 5))))) + gauge_result = metrics['gauges'][0] + hc.assert_that( + gauge_result.key, + hc.equal_to(MetricKey('Do', MetricName(namespace, 'latest_element')))) + hc.assert_that(gauge_result.committed.value, hc.equal_to(5)) + hc.assert_that(gauge_result.attempted.value, hc.equal_to(5)) + def test_run_api(self): my_metric = Metrics.counter('namespace', 'my_metric') runner = DirectRunner() @@ -128,6 +138,20 @@ def test_run_api(self): my_metric_value = result.metrics().query()['counters'][0].committed self.assertEqual(my_metric_value, 111) + def test_run_api_with_callable(self): + my_metric = Metrics.counter('namespace', 'my_metric') + + def fn(start): + return (start + | beam.Create([1, 10, 100]) + | beam.Map(lambda x: my_metric.inc(x))) + runner = DirectRunner() + result = runner.run(fn) + result.wait_until_finish() + # Use counters to assert the pipeline actually ran. + my_metric_value = result.metrics().query()['counters'][0].committed + self.assertEqual(my_metric_value, 111) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/sdf_common.py b/sdks/python/apache_beam/runners/sdf_common.py index a7d80ac8b180..a3e141891236 100644 --- a/sdks/python/apache_beam/runners/sdf_common.py +++ b/sdks/python/apache_beam/runners/sdf_common.py @@ -37,15 +37,12 @@ class SplittableParDoOverride(PTransformOverride): SDF specific logic. """ - def get_matcher(self): - def _matcher(applied_ptransform): - assert isinstance(applied_ptransform, AppliedPTransform) - transform = applied_ptransform.transform - if isinstance(transform, ParDo): - signature = DoFnSignature(transform.fn) - return signature.is_splittable_dofn() - - return _matcher + def matches(self, applied_ptransform): + assert isinstance(applied_ptransform, AppliedPTransform) + transform = applied_ptransform.transform + if isinstance(transform, ParDo): + signature = DoFnSignature(transform.fn) + return signature.is_splittable_dofn() def get_replacement_transform(self, ptransform): assert isinstance(ptransform, ParDo) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 3c14a6f6781d..55fa6cbe82a1 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -26,13 +26,13 @@ import json import logging -from google.protobuf import wrappers_pb2 - import apache_beam as beam from apache_beam.coders import WindowedValueCoder from apache_beam.coders import coder_impl from apache_beam.internal import pickler from apache_beam.io import iobase +from apache_beam.portability import common_urns +from apache_beam.portability import python_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import pipeline_context @@ -42,7 +42,6 @@ from apache_beam.transforms import sideinputs from apache_beam.utils import counters from apache_beam.utils import proto_utils -from apache_beam.utils import urns # This module is experimental. No backwards-compatibility guarantees. @@ -50,9 +49,7 @@ DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1' DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1' IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1' -PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1' -PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1' -# TODO(vikasrk): Fix this once runner sends appropriate python urns. +# TODO(vikasrk): Fix this once runner sends appropriate common_urns. OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN = 'urn:beam:dofn:javasdk:0.1' OLD_DATAFLOW_RUNNER_HARNESS_READ_URN = 'urn:org.apache.beam:source:java:0.1' @@ -199,7 +196,7 @@ def create_execution_tree(self, descriptor): self.state_sampler, self.state_handler) def is_side_input(transform_proto, tag): - if transform_proto.spec.urn == urns.PARDO_TRANSFORM: + if transform_proto.spec.urn == common_urns.PARDO_TRANSFORM: return tag in proto_utils.parse_Bytes( transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload).side_inputs @@ -274,10 +271,7 @@ def metrics(self): ptransforms={ transform_id: self._fix_output_tags(transform_id, op.progress_metrics()) - for transform_id, op in self.ops.items()}, - user=sum( - [op.metrics_container.to_runner_api() for op in self.ops.values()], - [])) + for transform_id, op in self.ops.items()}) def _fix_output_tags(self, transform_id, metrics): # Outputs are still referred to by index, not by name, in many Operations. @@ -417,7 +411,7 @@ def create(factory, transform_id, transform_proto, parameter, consumers): @BeamTransformFactory.register_urn( - urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) + common_urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) def create(factory, transform_id, transform_proto, parameter, consumers): source = iobase.SourceBase.from_runner_api(parameter.source, factory.context) spec = operation_specs.WorkerRead( @@ -440,9 +434,9 @@ def create(factory, transform_id, transform_proto, serialized_fn, consumers): @BeamTransformFactory.register_urn( - urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload) + common_urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload) def create(factory, transform_id, transform_proto, parameter, consumers): - assert parameter.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO + assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO serialized_fn = parameter.do_fn.spec.payload return _create_pardo_operation( factory, transform_id, transform_proto, consumers, @@ -513,18 +507,7 @@ def _create_simple_pardo_operation( @BeamTransformFactory.register_urn( - urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue) -def create(factory, transform_id, transform_proto, parameter, consumers): - # Perhaps this hack can go away once all apply overloads are gone. - from apache_beam.transforms.core import _GroupAlsoByWindowDoFn - return _create_simple_pardo_operation( - factory, transform_id, transform_proto, consumers, - _GroupAlsoByWindowDoFn( - factory.context.windowing_strategies.get_by_id(parameter.value))) - - -@BeamTransformFactory.register_urn( - urns.WINDOW_INTO_TRANSFORM, beam_runner_api_pb2.WindowingStrategy) + common_urns.WINDOW_INTO_TRANSFORM, beam_runner_api_pb2.WindowingStrategy) def create(factory, transform_id, transform_proto, parameter, consumers): class WindowIntoDoFn(beam.DoFn): def __init__(self, windowing): @@ -557,7 +540,7 @@ def create(factory, transform_id, transform_proto, unused_parameter, consumers): @BeamTransformFactory.register_urn( - urns.PRECOMBINE_TRANSFORM, beam_runner_api_pb2.CombinePayload) + common_urns.COMBINE_PGBKCV_TRANSFORM, beam_runner_api_pb2.CombinePayload) def create(factory, transform_id, transform_proto, payload, consumers): # TODO: Combine side inputs. serialized_combine_fn = pickler.dumps( @@ -577,14 +560,16 @@ def create(factory, transform_id, transform_proto, payload, consumers): @BeamTransformFactory.register_urn( - urns.MERGE_ACCUMULATORS_TRANSFORM, beam_runner_api_pb2.CombinePayload) + common_urns.COMBINE_MERGE_ACCUMULATORS_TRANSFORM, + beam_runner_api_pb2.CombinePayload) def create(factory, transform_id, transform_proto, payload, consumers): return _create_combine_phase_operation( factory, transform_proto, payload, consumers, 'merge') @BeamTransformFactory.register_urn( - urns.EXTRACT_OUTPUTS_TRANSFORM, beam_runner_api_pb2.CombinePayload) + common_urns.COMBINE_EXTRACT_OUTPUTS_TRANSFORM, + beam_runner_api_pb2.CombinePayload) def create(factory, transform_id, transform_proto, payload, consumers): return _create_combine_phase_operation( factory, transform_proto, payload, consumers, 'extract') @@ -609,7 +594,7 @@ def _create_combine_phase_operation( consumers) -@BeamTransformFactory.register_urn(urns.FLATTEN_TRANSFORM, None) +@BeamTransformFactory.register_urn(common_urns.FLATTEN_TRANSFORM, None) def create(factory, transform_id, transform_proto, unused_parameter, consumers): return factory.augment_oldstyle_op( operations.create_operation( diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index f18ab3e4df44..2e4f2d6f69a7 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -182,7 +182,7 @@ def input_elements(self, instruction_id, expected_targets): data = received.get(timeout=1) except queue.Empty: if self._exc_info: - raise exc_info[0], exc_info[1], exc_info[2] + raise self.exc_info[0], self.exc_info[1], self.exc_info[2] else: if not data.data and data.target in expected_targets: done_targets.append(data.target) diff --git a/sdks/python/apache_beam/runners/worker/opcounters_test.py b/sdks/python/apache_beam/runners/worker/opcounters_test.py index 3a922c8db28f..41c80e87c4f7 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters_test.py +++ b/sdks/python/apache_beam/runners/worker/opcounters_test.py @@ -166,17 +166,17 @@ def test_should_sample(self): total_runs = 10 * len(buckets) # Fill the buckets. - for _ in xrange(total_runs): + for _ in range(total_runs): opcounts = OperationCounters(CounterFactory(), 'some-name', coders.PickleCoder(), 0) - for i in xrange(len(buckets)): + for i in range(len(buckets)): if opcounts.should_sample(): buckets[i] += 1 # Look at the buckets to see if they are likely. - for i in xrange(10): + for i in range(10): self.assertEqual(total_runs, buckets[i]) - for i in xrange(10, len(buckets)): + for i in range(10, len(buckets)): self.assertTrue(buckets[i] > 7 * total_runs / i, 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % ( i, buckets[i], diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 8a072e83d4ac..11ff909f3e9f 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -178,7 +178,8 @@ def progress_metrics(self): {'ONLY_OUTPUT': self.receivers[0].opcounter .element_counter.value()} if len(self.receivers) == 1 - else None)))) + else None))), + user=self.metrics_container.to_runner_api()) def __str__(self): """Generates a useful string for this object. @@ -288,8 +289,8 @@ def _read_side_inputs(self, tags_and_types): assert self.side_input_maps is None # Get experiments active in the worker to check for side input metrics exp. - experiments = set(RuntimeValueProvider( - 'experiments', str, '').get().split(',')) + experiments = set( + RuntimeValueProvider.get_value('experiments', str, '').split(',')) # We will read the side inputs in the order prescribed by the # tags_and_types argument because this is exactly the order needed to diff --git a/sdks/python/apache_beam/runners/worker/sideinputs.py b/sdks/python/apache_beam/runners/worker/sideinputs.py index 8b7e14ee8d16..cc405e0e4771 100644 --- a/sdks/python/apache_beam/runners/worker/sideinputs.py +++ b/sdks/python/apache_beam/runners/worker/sideinputs.py @@ -106,7 +106,7 @@ def _start_reader_threads(self): def _reader_thread(self): # pylint: disable=too-many-nested-blocks experiments = set( - RuntimeValueProvider('experiments', str, '').get().split(',')) + RuntimeValueProvider.get_value('experiments', str, '').split(',')) try: while True: try: diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py index 63dc6f899bfc..8b2216951dd3 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_test.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py @@ -74,6 +74,8 @@ def test_basic_sampler(self): self.assertIn(counter.name, expected_counter_values) expected_value = expected_counter_values[counter.name] actual_value = counter.value() + deviation = float(abs(actual_value - expected_value)) / expected_value + logging.info('Sampling deviation from expectation: %f', deviation) self.assertGreater(actual_value, expected_value * 0.75) self.assertLess(actual_value, expected_value * 1.25) diff --git a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py index 3b02431212de..bc8789f5b423 100644 --- a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py +++ b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py @@ -127,7 +127,7 @@ def test_file_checksum_matchcer_invalid_sleep_time(self): verifiers.FileChecksumMatcher('file_path', 'expected_checksum', 'invalid_sleep_time') - self.assertEqual(cm.exception.message, + self.assertEqual(cm.exception.args[0], 'Sleep seconds, if received, must be int. ' 'But received: \'invalid_sleep_time\', ' '') diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index 0f6691f3d644..a3f2413f1679 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -29,6 +29,7 @@ from apache_beam.testing.test_stream import WatermarkEvent from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import trigger from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp @@ -117,7 +118,7 @@ def process(self, element=beam.DoFn.ElementParam, ('last', timestamp.Timestamp(310)),])) p.run() - def test_gbk_execution(self): + def test_gbk_execution_no_triggers(self): test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a', 'b', 'c']) @@ -129,6 +130,14 @@ def test_gbk_execution(self): .add_elements([TimestampedValue('late', 12)]) .add_elements([TimestampedValue('last', 310)])) + # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. + global result # pylint: disable=global-variable-undefined + result = [] + + def fired_elements(elem): + result.append(elem) + return elem + options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) @@ -136,7 +145,8 @@ def test_gbk_execution(self): | test_stream | beam.WindowInto(FixedWindows(15)) | beam.Map(lambda x: ('k', x)) - | beam.GroupByKey()) + | beam.GroupByKey() + | beam.Map(fired_elements)) # TODO(BEAM-2519): timestamp assignment for elements from a GBK should # respect the TimestampCombiner. The test below should also verify the # timestamps of the outputted elements once this is implemented. @@ -146,6 +156,94 @@ def test_gbk_execution(self): ('k', ['late']), ('k', ['last'])])) p.run() + # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. + self.assertEqual([ + ('k', ['a', 'b', 'c']), + ('k', ['d', 'e']), + ('k', ['late']), + ('k', ['last'])], result) + + def test_gbk_execution_after_watermark_trigger(self): + test_stream = (TestStream() + .advance_watermark_to(10) + .add_elements(['a']) + .advance_watermark_to(20)) + + # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. + global result # pylint: disable=global-variable-undefined + result = [] + + def fired_elements(elem): + result.append(elem) + return elem + + options = PipelineOptions() + options.view_as(StandardOptions).streaming = True + p = TestPipeline(options=options) + records = (p # pylint: disable=unused-variable + | test_stream + | beam.WindowInto( + FixedWindows(15), + trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + | beam.Map(lambda x: ('k', x)) + | beam.GroupByKey() + | beam.Map(fired_elements)) + # TODO(BEAM-2519): timestamp assignment for elements from a GBK should + # respect the TimestampCombiner. The test below should also verify the + # timestamps of the outputted elements once this is implemented. + + # TODO(BEAM-3377): Reinstate after assert_that in streaming is fixed. + # assert_that(records, equal_to([ + # ('k', ['a']), ('k', [])])) + + p.run() + # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. + self.assertEqual([('k', ['a']), ('k', [])], result) + + def test_gbk_execution_after_processing_trigger_fired(self): + """Advance TestClock to (X + delta) and see the pipeline does finish.""" + # TODO(mariagh): Add test_gbk_execution_after_processing_trigger_unfired + # Advance TestClock to (X + delta) and see the pipeline does finish + # Possibly to the framework trigger_transcripts.yaml + + test_stream = (TestStream() + .advance_watermark_to(10) + .add_elements(['a']) + .advance_processing_time(5.1)) + + # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. + global result # pylint: disable=global-variable-undefined + result = [] + + def fired_elements(elem): + result.append(elem) + return elem + + options = PipelineOptions() + options.view_as(StandardOptions).streaming = True + p = TestPipeline(options=options) + records = (p + | test_stream + | beam.WindowInto( + beam.window.FixedWindows(15), + trigger=trigger.AfterProcessingTime(5), + accumulation_mode=trigger.AccumulationMode.DISCARDING + ) + | beam.Map(lambda x: ('k', x)) + | beam.GroupByKey() + | beam.Map(fired_elements)) + # TODO(BEAM-2519): timestamp assignment for elements from a GBK should + # respect the TimestampCombiner. The test below should also verify the + # timestamps of the outputted elements once this is implemented. + + # TODO(BEAM-3377): Reinstate after assert_that in streaming is fixed. + assert_that(records, equal_to([ + ('k', ['a'])])) + + p.run() + # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. + self.assertEqual([('k', ['a'])], result) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/testing/test_utils_test.py b/sdks/python/apache_beam/testing/test_utils_test.py index 877ee397bd39..0018c0ed1541 100644 --- a/sdks/python/apache_beam/testing/test_utils_test.py +++ b/sdks/python/apache_beam/testing/test_utils_test.py @@ -49,7 +49,7 @@ def test_delete_files_fails_with_io_error(self): with self.assertRaises(BeamIOError) as error: utils.delete_files([path]) self.assertTrue( - error.exception.message.startswith('Delete operation failed')) + error.exception.args[0].startswith('Delete operation failed')) self.assertEqual(error.exception.exception_details.keys(), [path]) def test_delete_files_fails_with_invalid_arg(self): diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 8e4188aca673..b6f19c6c03eb 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -36,6 +36,11 @@ from apache_beam.typehints import with_input_types from apache_beam.typehints import with_output_types +try: + long # Python 2 +except NameError: + long = int # Python 3 + __all__ = [ 'Count', 'Mean', @@ -536,30 +541,35 @@ def extract_output(self, accumulator): return accumulator -def curry_combine_fn(fn, args, kwargs): - if not args and not kwargs: - return fn +class _CurriedFn(core.CombineFn): + """Wrapped CombineFn with extra arguments.""" + + def __init__(self, fn, args, kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs - # Create CurriedFn class for the combiner - class CurriedFn(core.CombineFn): - """CombineFn that applies extra arguments.""" + def create_accumulator(self): + return self.fn.create_accumulator(*self.args, **self.kwargs) - def create_accumulator(self): - return fn.create_accumulator(*args, **kwargs) + def add_input(self, accumulator, element): + return self.fn.add_input(accumulator, element, *self.args, **self.kwargs) - def add_input(self, accumulator, element): - return fn.add_input(accumulator, element, *args, **kwargs) + def merge_accumulators(self, accumulators): + return self.fn.merge_accumulators(accumulators, *self.args, **self.kwargs) - def merge_accumulators(self, accumulators): - return fn.merge_accumulators(accumulators, *args, **kwargs) + def extract_output(self, accumulator): + return self.fn.extract_output(accumulator, *self.args, **self.kwargs) - def extract_output(self, accumulator): - return fn.extract_output(accumulator, *args, **kwargs) + def apply(self, elements): + return self.fn.apply(elements, *self.args, **self.kwargs) - def apply(self, elements): - return fn.apply(elements, *args, **kwargs) - return CurriedFn() +def curry_combine_fn(fn, args, kwargs): + if not args and not kwargs: + return fn + else: + return _CurriedFn(fn, args, kwargs) class PhasedCombineFnExecutor(object): diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index 8885d27c84d6..705106e784c5 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -221,7 +221,7 @@ def is_good_sample(actual): with TestPipeline() as pipeline: pcoll = pipeline | 'start' >> Create([1, 1, 2, 2]) - for ix in xrange(9): + for ix in range(9): assert_that( pcoll | 'sample-%d' % ix >> combine.Sample.FixedSizeGlobally(3), is_good_sample, @@ -230,7 +230,7 @@ def is_good_sample(actual): def test_per_key_sample(self): pipeline = TestPipeline() pcoll = pipeline | 'start-perkey' >> Create( - sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in xrange(9)), [])) + sum(([(i, 1), (i, 1), (i, 2), (i, 2)] for i in range(9)), [])) result = pcoll | 'sample' >> combine.Sample.FixedSizePerKey(3) def matcher(): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index b4b8b7f64e0a..2d411ee75331 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -23,8 +23,6 @@ import inspect import types -from google.protobuf import wrappers_pb2 - from apache_beam import coders from apache_beam import pvalue from apache_beam import typehints @@ -32,6 +30,8 @@ from apache_beam.internal import pickler from apache_beam.internal import util from apache_beam.options.pipeline_options import TypeOptions +from apache_beam.portability import common_urns +from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms import ptransform from apache_beam.transforms.display import DisplayDataItem @@ -378,7 +378,7 @@ def is_process_bounded(self): return False # Method is a classmethod return True - urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_DO_FN) + urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_DOFN) def _fn_takes_side_inputs(fn): @@ -455,6 +455,9 @@ def infer_output_type(self, input_type): def _process_argspec_fn(self): return getattr(self._fn, '_argspec_fn', self._fn) + def _inspect_process(self): + return inspect.getargspec(self._process_argspec_fn()) + class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): """A function object used by a Combine transform with custom processing. @@ -582,7 +585,7 @@ def maybe_from_callable(fn): def get_accumulator_coder(self): return coders.registry.get_coder(object) - urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_COMBINE_FN) + urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_COMBINE_FN) class CallableWrapperCombineFn(CombineFn): @@ -855,11 +858,11 @@ def to_runner_api_parameter(self, context): "expected instance of ParDo, but got %s" % self.__class__ picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data()) return ( - urns.PARDO_TRANSFORM, + common_urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload( do_fn=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( - urn=urns.PICKLED_DO_FN_INFO, + urn=python_urns.PICKLED_DOFN_INFO, payload=picked_pardo_fn_data)), # It'd be nice to name these according to their actual # names/positions in the orignal argument list, but such a @@ -871,9 +874,9 @@ def to_runner_api_parameter(self, context): for ix, si in enumerate(self.side_inputs)})) @PTransform.register_urn( - urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload) + common_urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload) def from_runner_api_parameter(pardo_payload, context): - assert pardo_payload.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO + assert pardo_payload.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO fn, args, kwargs, si_tags_and_types, windowing = pickler.loads( pardo_payload.do_fn.spec.payload) if si_tags_and_types: @@ -1228,11 +1231,11 @@ def to_runner_api_parameter(self, context): else: combine_fn = self.fn return ( - urns.COMBINE_PER_KEY_TRANSFORM, + common_urns.COMBINE_PER_KEY_TRANSFORM, _combine_payload(combine_fn, context)) @PTransform.register_urn( - urns.COMBINE_PER_KEY_TRANSFORM, beam_runner_api_pb2.CombinePayload) + common_urns.COMBINE_PER_KEY_TRANSFORM, beam_runner_api_pb2.CombinePayload) def from_runner_api_parameter(combine_payload, context): return CombinePerKey( CombineFn.from_runner_api(combine_payload.combine_fn, context)) @@ -1266,11 +1269,12 @@ def to_runner_api_parameter(self, context): else: combine_fn = self.fn return ( - urns.COMBINE_GROUPED_VALUES_TRANSFORM, + common_urns.COMBINE_GROUPED_VALUES_TRANSFORM, _combine_payload(combine_fn, context)) @PTransform.register_urn( - urns.COMBINE_GROUPED_VALUES_TRANSFORM, beam_runner_api_pb2.CombinePayload) + common_urns.COMBINE_GROUPED_VALUES_TRANSFORM, + beam_runner_api_pb2.CombinePayload) def from_runner_api_parameter(combine_payload, context): return CombineValues( CombineFn.from_runner_api(combine_payload.combine_fn, context)) @@ -1395,9 +1399,9 @@ def expand(self, pcoll): | 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing)) def to_runner_api_parameter(self, unused_context): - return urns.GROUP_BY_KEY_TRANSFORM, None + return common_urns.GROUP_BY_KEY_TRANSFORM, None - @PTransform.register_urn(urns.GROUP_BY_KEY_TRANSFORM, None) + @PTransform.register_urn(common_urns.GROUP_BY_KEY_TRANSFORM, None) def from_runner_api_parameter(unused_payload, unused_context): return GroupByKey() @@ -1414,13 +1418,6 @@ def expand(self, pcoll): self._check_pcollection(pcoll) return pvalue.PCollection(pcoll.pipeline) - def to_runner_api_parameter(self, unused_context): - return urns.GROUP_BY_KEY_ONLY_TRANSFORM, None - - @PTransform.register_urn(urns.GROUP_BY_KEY_ONLY_TRANSFORM, None) - def from_runner_api_parameter(unused_payload, unused_context): - return _GroupByKeyOnly() - @typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]]) @typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]]) @@ -1435,18 +1432,6 @@ def expand(self, pcoll): self._check_pcollection(pcoll) return pvalue.PCollection(pcoll.pipeline) - def to_runner_api_parameter(self, context): - return ( - urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, - wrappers_pb2.BytesValue(value=context.windowing_strategies.get_id( - self.windowing))) - - @PTransform.register_urn( - urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue) - def from_runner_api_parameter(payload, context): - return _GroupAlsoByWindow( - context.windowing_strategies.get_by_id(payload.value)) - class _GroupAlsoByWindowDoFn(DoFn): # TODO(robertwb): Support combiner lifting. @@ -1644,7 +1629,7 @@ def expand(self, pcoll): def to_runner_api_parameter(self, context): return ( - urns.WINDOW_INTO_TRANSFORM, + common_urns.WINDOW_INTO_TRANSFORM, self.windowing.to_runner_api(context)) @staticmethod @@ -1658,7 +1643,7 @@ def from_runner_api_parameter(proto, context): PTransform.register_urn( - urns.WINDOW_INTO_TRANSFORM, + common_urns.WINDOW_INTO_TRANSFORM, # TODO(robertwb): Update WindowIntoPayload to include the full strategy. # (Right now only WindowFn is used, but we need this to reconstitute the # WindowInto transform, and in the future will need it at runtime to @@ -1715,7 +1700,7 @@ def get_windowing(self, inputs): return super(Flatten, self).get_windowing(inputs) def to_runner_api_parameter(self, context): - return urns.FLATTEN_TRANSFORM, None + return common_urns.FLATTEN_TRANSFORM, None @staticmethod def from_runner_api_parameter(unused_parameter, unused_context): @@ -1723,7 +1708,7 @@ def from_runner_api_parameter(unused_parameter, unused_context): PTransform.register_urn( - urns.FLATTEN_TRANSFORM, None, Flatten.from_runner_api_parameter) + common_urns.FLATTEN_TRANSFORM, None, Flatten.from_runner_api_parameter) class Create(PTransform): diff --git a/sdks/python/apache_beam/transforms/cy_combiners.py b/sdks/python/apache_beam/transforms/cy_combiners.py index 84aee212790c..e141e7dff794 100644 --- a/sdks/python/apache_beam/transforms/cy_combiners.py +++ b/sdks/python/apache_beam/transforms/cy_combiners.py @@ -77,6 +77,7 @@ def __init__(self): self.value = 0 def add_input(self, element): + global INT64_MAX, INT64_MIN # pylint: disable=global-variable-not-assigned element = int(element) if not INT64_MIN <= element <= INT64_MAX: raise OverflowError(element) diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 24904956d4d1..c7fc641804dc 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -51,6 +51,7 @@ class and wrapper class that allows lambda functions to be used as from apache_beam import pvalue from apache_beam.internal import pickler from apache_beam.internal import util +from apache_beam.portability import python_urns from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.display import HasDisplayData from apache_beam.typehints import typehints @@ -60,7 +61,6 @@ class and wrapper class that allows lambda functions to be used as from apache_beam.typehints.trivial_inference import instance_to_type from apache_beam.typehints.typehints import validate_composite_type_param from apache_beam.utils import proto_utils -from apache_beam.utils import urns __all__ = [ 'PTransform', @@ -555,7 +555,7 @@ def from_runner_api(cls, proto, context): context) def to_runner_api_parameter(self, context): - return (urns.PICKLED_TRANSFORM, + return (python_urns.PICKLED_TRANSFORM, wrappers_pb2.BytesValue(value=pickler.dumps(self))) @staticmethod @@ -564,7 +564,7 @@ def from_runner_api_parameter(spec_parameter, unused_context): PTransform.register_urn( - urns.PICKLED_TRANSFORM, + python_urns.PICKLED_TRANSFORM, wrappers_pb2.BytesValue, PTransform.from_runner_api_parameter) diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index a44903999172..299bfd8dbe79 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -154,7 +154,7 @@ def test_do_with_do_fn_returning_string_raises_warning(self): expected_error_prefix = ('Returning a str from a ParDo or FlatMap ' 'is discouraged.') - self.assertStartswith(cm.exception.message, expected_error_prefix) + self.assertStartswith(cm.exception.args[0], expected_error_prefix) def test_do_with_do_fn_returning_dict_raises_warning(self): pipeline = TestPipeline() @@ -169,7 +169,7 @@ def test_do_with_do_fn_returning_dict_raises_warning(self): expected_error_prefix = ('Returning a dict from a ParDo or FlatMap ' 'is discouraged.') - self.assertStartswith(cm.exception.message, expected_error_prefix) + self.assertStartswith(cm.exception.args[0], expected_error_prefix) def test_do_with_multiple_outputs_maintains_unique_name(self): pipeline = TestPipeline() @@ -284,7 +284,7 @@ def incorrect_par_do_fn(x): pipeline.run() expected_error_prefix = 'FlatMap and ParDo must return an iterable.' - self.assertStartswith(cm.exception.message, expected_error_prefix) + self.assertStartswith(cm.exception.args[0], expected_error_prefix) def test_do_fn_with_finish(self): class MyDoFn(beam.DoFn): @@ -601,7 +601,7 @@ def test_group_by_key_input_must_be_kv_pairs(self): pipeline.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], 'Input type hint violation at D: expected ' 'Tuple[TypeVariable[K], TypeVariable[V]]') @@ -614,7 +614,7 @@ def test_group_by_key_only_input_must_be_kv_pairs(self): expected_error_prefix = ('Input type hint violation at D: expected ' 'Tuple[TypeVariable[K], TypeVariable[V]]') - self.assertStartswith(cm.exception.message, expected_error_prefix) + self.assertStartswith(cm.exception.args[0], expected_error_prefix) def test_keys_and_values(self): pipeline = TestPipeline() @@ -934,7 +934,7 @@ def process(self, element, prefix): self.assertEqual("Type hint violation for 'Upper': " "requires but got for element", - e.exception.message) + e.exception.args[0]) def test_do_fn_pipeline_runtime_type_check_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -969,7 +969,7 @@ def process(self, element, num): self.assertEqual("Type hint violation for 'Add': " "requires but got for element", - e.exception.message) + e.exception.args[0]) def test_pardo_does_not_type_check_using_type_hint_decorators(self): @with_input_types(a=int) @@ -986,7 +986,7 @@ def int_to_str(a): self.assertEqual("Type hint violation for 'ToStr': " "requires but got for a", - e.exception.message) + e.exception.args[0]) def test_pardo_properly_type_checks_using_type_hint_decorators(self): @with_input_types(a=str) @@ -1018,7 +1018,7 @@ def test_pardo_does_not_type_check_using_type_hint_methods(self): self.assertEqual("Type hint violation for 'Upper': " "requires but got for x", - e.exception.message) + e.exception.args[0]) def test_pardo_properly_type_checks_using_type_hint_methods(self): # Pipeline should be created successfully without an error @@ -1043,7 +1043,7 @@ def test_map_does_not_type_check_using_type_hints_methods(self): self.assertEqual("Type hint violation for 'Upper': " "requires but got for x", - e.exception.message) + e.exception.args[0]) def test_map_properly_type_checks_using_type_hints_methods(self): # No error should be raised if this type-checks properly. @@ -1069,7 +1069,7 @@ def upper(s): self.assertEqual("Type hint violation for 'Upper': " "requires but got for s", - e.exception.message) + e.exception.args[0]) def test_map_properly_type_checks_using_type_hints_decorator(self): @with_input_types(a=bool) @@ -1096,7 +1096,7 @@ def test_filter_does_not_type_check_using_type_hints_method(self): self.assertEqual("Type hint violation for 'Below 3': " "requires but got for x", - e.exception.message) + e.exception.args[0]) def test_filter_type_checks_using_type_hints_method(self): # No error should be raised if this type-checks properly. @@ -1121,7 +1121,7 @@ def more_than_half(a): self.assertEqual("Type hint violation for 'Half': " "requires but got for a", - e.exception.message) + e.exception.args[0]) def test_filter_type_checks_using_type_hints_decorator(self): @with_input_types(b=int) @@ -1170,7 +1170,7 @@ def test_group_by_key_only_does_not_type_check(self): self.assertEqual("Input type hint violation at F: " "expected Tuple[TypeVariable[K], TypeVariable[V]], " "got ", - e.exception.message) + e.exception.args[0]) def test_group_by_does_not_type_check(self): # Create is returning a List[int, str], rather than a KV[int, str] that is @@ -1184,7 +1184,7 @@ def test_group_by_does_not_type_check(self): self.assertEqual("Input type hint violation at T: " "expected Tuple[TypeVariable[K], TypeVariable[V]], " "got Iterable[int]", - e.exception.message) + e.exception.args[0]) def test_pipeline_checking_pardo_insufficient_type_information(self): self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED' @@ -1198,7 +1198,7 @@ def test_pipeline_checking_pardo_insufficient_type_information(self): self.assertEqual('Pipeline type checking is enabled, however no output ' 'type-hint was found for the PTransform Create(Nums)', - e.exception.message) + e.exception.args[0]) def test_pipeline_checking_gbk_insufficient_type_information(self): self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED' @@ -1213,7 +1213,7 @@ def test_pipeline_checking_gbk_insufficient_type_information(self): self.assertEqual('Pipeline type checking is enabled, however no output ' 'type-hint was found for the PTransform ' 'ParDo(ModDup)', - e.exception.message) + e.exception.args[0]) def test_disable_pipeline_type_check(self): self.p._options.view_as(TypeOptions).pipeline_type_check = False @@ -1243,7 +1243,7 @@ def int_to_string(x): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within ParDo(ToStr): " "Type-hint for argument: 'x' violated. " "Expected an instance of , " @@ -1294,7 +1294,7 @@ def is_even_as_key(a): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within ParDo(IsEven): " "Tuple[bool, int] hint type-constraint violated. " "The type of element #0 in the passed tuple is incorrect. " @@ -1334,7 +1334,7 @@ def test_pipeline_runtime_checking_violation_simple_type_input(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within ParDo(ToInt): " "Type-hint for argument: 'x' violated. " "Expected an instance of , " @@ -1353,7 +1353,7 @@ def test_pipeline_runtime_checking_violation_composite_type_input(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within ParDo(Add): " "Type-hint for argument: 'x_y' violated: " "Tuple[int, int] hint type-constraint violated. " @@ -1380,7 +1380,7 @@ def test_pipeline_runtime_checking_violation_simple_type_output(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within " "ParDo(ToInt): " "According to type-hint expected output should be " @@ -1404,7 +1404,7 @@ def test_pipeline_runtime_checking_violation_composite_type_output(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within " "ParDo(Swap): Tuple type constraint violated. " "Valid object instance must be of type 'tuple'. Instead, " @@ -1424,7 +1424,7 @@ def add(a, b): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within ParDo(Add 1): " "Type-hint for argument: 'b' violated. " "Expected an instance of , " @@ -1443,7 +1443,7 @@ def test_pipeline_runtime_checking_violation_with_side_inputs_via_method(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within ParDo(Add 1): " "Type-hint for argument: 'one' violated. " "Expected an instance of , " @@ -1478,7 +1478,7 @@ def bad_combine(a): "All functions for a Combine PTransform must accept a " "single argument compatible with: Iterable[Any]. " "Instead a function with input type: was received.", - e.exception.message) + e.exception.args[0]) def test_combine_pipeline_type_propagation_using_decorators(self): @with_output_types(int) @@ -1532,7 +1532,7 @@ def iter_mul(ints): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within " "Mul/CombinePerKey: " "Type-hint for return type violated. " @@ -1572,7 +1572,7 @@ def test_combine_pipeline_type_check_violation_using_methods(self): self.assertEqual("Input type hint violation at SortJoin: " "expected , got ", - e.exception.message) + e.exception.args[0]) def test_combine_runtime_type_check_violation_using_methods(self): self.p._options.view_as(TypeOptions).pipeline_type_check = False @@ -1586,7 +1586,7 @@ def test_combine_runtime_type_check_violation_using_methods(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within " "ParDo(SortJoin/KeyWithVoid): " "Type-hint for argument: 'v' violated. " @@ -1603,7 +1603,7 @@ def test_combine_insufficient_type_hint_information(self): | 'F' >> beam.Map(lambda x: x + 1)) self.assertStartswith( - e.exception.message, + e.exception.args[0], 'Pipeline type checking is enabled, ' 'however no output type-hint was found for the PTransform ' 'ParDo(' @@ -1628,7 +1628,7 @@ def test_mean_globally_pipeline_checking_violated(self): "Type hint violation for 'CombinePerKey': " "requires Tuple[TypeVariable[K], Union[float, int, long]] " "but got Tuple[None, str] for element", - e.exception.message) + e.exception.args[0]) def test_mean_globally_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1659,7 +1659,7 @@ def test_mean_globally_runtime_checking_violated(self): "the passed Iterable is incorrect: Union[int, float] " "type-constraint violated. Expected an instance of one " "of: ('int', 'float'), received str instead.", - e.exception.message) + e.exception.args[0]) def test_mean_per_key_pipeline_checking_satisfied(self): d = (self.p @@ -1685,7 +1685,7 @@ def test_mean_per_key_pipeline_checking_violated(self): "Type hint violation for 'CombinePerKey(MeanCombineFn)': " "requires Tuple[TypeVariable[K], Union[float, int, long]] " "but got Tuple[str, str] for element", - e.exception.message) + e.exception.args[0]) def test_mean_per_key_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1713,7 +1713,7 @@ def test_mean_per_key_runtime_checking_violated(self): self.p.run() self.assertStartswith( - e.exception.message, + e.exception.args[0], "Runtime type violation detected within " "OddMean/CombinePerKey(MeanCombineFn): " "Type-hint for argument: 'element' violated: " @@ -1762,7 +1762,7 @@ def test_count_perkey_pipeline_type_checking_violated(self): "Type hint violation for 'CombinePerKey(CountCombineFn)': " "requires Tuple[TypeVariable[K], Any] " "but got for element", - e.exception.message) + e.exception.args[0]) def test_count_perkey_runtime_type_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1797,7 +1797,7 @@ def test_count_perelement_pipeline_type_checking_violated(self): self.assertEqual('Pipeline type checking is enabled, however no output ' 'type-hint was found for the PTransform ' 'Create(f)', - e.exception.message) + e.exception.args[0]) def test_count_perelement_runtime_type_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1843,7 +1843,7 @@ def test_per_key_pipeline_checking_violated(self): "Type hint violation for 'CombinePerKey(TopCombineFn)': " "requires Tuple[TypeVariable[K], TypeVariable[T]] " "but got for element", - e.exception.message) + e.exception.args[0]) def test_per_key_pipeline_checking_satisfied(self): d = (self.p @@ -1978,7 +1978,7 @@ def test_to_dict_pipeline_check_violated(self): "requires " "Tuple[TypeVariable[K], Tuple[TypeVariable[K], TypeVariable[V]]] " "but got Tuple[None, int] for element", - e.exception.message) + e.exception.args[0]) def test_to_dict_pipeline_check_satisfied(self): d = (self.p @@ -2015,7 +2015,7 @@ def test_runtime_type_check_python_type_error(self): # Instead the above pipeline should have triggered a regular Python runtime # TypeError. self.assertEqual("object of type 'int' has no len() [while running 'Len']", - e.exception.message) + e.exception.args[0]) self.assertFalse(isinstance(e, typehints.TypeCheckError)) def test_pardo_type_inference(self): @@ -2048,7 +2048,7 @@ def test_inferred_bad_kv_type(self): self.assertEqual('Input type hint violation at GroupByKey: ' 'expected Tuple[TypeVariable[K], TypeVariable[V]], ' 'got Tuple[str, int, float]', - e.exception.message) + e.exception.args[0]) def test_type_inference_command_line_flag_toggle(self): self.p._options.view_as(TypeOptions).pipeline_type_check = False diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 47b3ea480d95..7d03240f9416 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -23,6 +23,7 @@ import collections import copy import itertools +import logging import numbers from abc import ABCMeta from abc import abstractmethod @@ -47,6 +48,7 @@ 'TriggerFn', 'DefaultTrigger', 'AfterWatermark', + 'AfterProcessingTime', 'AfterCount', 'Repeatedly', 'AfterAny', @@ -165,11 +167,16 @@ def on_merge(self, to_be_merged, merge_result, context): pass @abstractmethod - def should_fire(self, watermark, window, context): + def should_fire(self, time_domain, timestamp, window, context): """Whether this trigger should cause the window to fire. Args: - watermark: (a lower bound on) the watermark of the system + time_domain: WATERMARK for event-time timers and REAL_TIME for + processing-time timers. + timestamp: for time_domain WATERMARK, it represents the + watermark: (a lower bound on) the watermark of the system + and for time_domain REAL_TIME, it represents the + trigger: timestamp of the processing-time timer. window: the window whose trigger is being considered context: a context (e.g. a TriggerContext instance) for managing state and setting timers @@ -207,6 +214,7 @@ def from_runner_api(proto, context): 'after_any': AfterAny, 'after_each': AfterEach, 'after_end_of_window': AfterWatermark, + 'after_processing_time': AfterProcessingTime, # after_processing_time, after_synchronized_processing_time # always 'default': DefaultTrigger, @@ -239,7 +247,7 @@ def on_merge(self, to_be_merged, merge_result, context): if window.end != merge_result.end: context.clear_timer('', TimeDomain.WATERMARK) - def should_fire(self, watermark, window, context): + def should_fire(self, time_domain, watermark, window, context): return watermark >= window.end def on_fire(self, watermark, window, context): @@ -260,6 +268,54 @@ def to_runner_api(self, unused_context): default=beam_runner_api_pb2.Trigger.Default()) +class AfterProcessingTime(TriggerFn): + """Fire exactly once after a specified delay from processing time. + + AfterProcessingTime is experimental. No backwards compatibility guarantees. + """ + + def __init__(self, delay=0): + self.delay = delay + + def __repr__(self): + return 'AfterProcessingTime(delay=%d)' % self.delay + + def on_element(self, element, window, context): + context.set_timer( + '', TimeDomain.REAL_TIME, context.get_current_time() + self.delay) + + def on_merge(self, to_be_merged, merge_result, context): + # timers will be kept through merging + pass + + def should_fire(self, time_domain, timestamp, window, context): + if time_domain == TimeDomain.REAL_TIME: + return True + + def on_fire(self, timestamp, window, context): + return True + + def reset(self, window, context): + pass + + @staticmethod + def from_runner_api(proto, context): + return AfterProcessingTime( + delay=( + proto.after_processing_time + .timestamp_transforms[0] + .delay + .delay_millis)) + + def to_runner_api(self, context): + delay_proto = beam_runner_api_pb2.TimestampTransform( + delay=beam_runner_api_pb2.TimestampTransform.Delay( + delay_millis=self.delay)) + return beam_runner_api_pb2.Trigger( + after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime( + timestamp_transforms=[delay_proto])) + + class AfterWatermark(TriggerFn): """Fire exactly once when the watermark passes the end of the window. @@ -309,15 +365,15 @@ def on_merge(self, to_be_merged, merge_result, context): self.early.on_merge( to_be_merged, merge_result, NestedContext(context, 'early')) - def should_fire(self, watermark, window, context): + def should_fire(self, time_domain, watermark, window, context): if self.is_late(context): - return self.late.should_fire( - watermark, window, NestedContext(context, 'late')) + return self.late.should_fire(time_domain, watermark, + window, NestedContext(context, 'late')) elif watermark >= window.end: return True elif self.early: - return self.early.should_fire( - watermark, window, NestedContext(context, 'early')) + return self.early.should_fire(time_domain, watermark, + window, NestedContext(context, 'early')) return False def on_fire(self, watermark, window, context): @@ -396,7 +452,7 @@ def on_merge(self, to_be_merged, merge_result, context): # states automatically merged pass - def should_fire(self, watermark, window, context): + def should_fire(self, time_domain, watermark, window, context): return context.get_state(self.COUNT_TAG) >= self.count def on_fire(self, watermark, window, context): @@ -427,14 +483,14 @@ def __repr__(self): def __eq__(self, other): return type(self) == type(other) and self.underlying == other.underlying - def on_element(self, element, window, context): # get window from context? + def on_element(self, element, window, context): self.underlying.on_element(element, window, context) def on_merge(self, to_be_merged, merge_result, context): self.underlying.on_merge(to_be_merged, merge_result, context) - def should_fire(self, watermark, window, context): - return self.underlying.should_fire(watermark, window, context) + def should_fire(self, time_domain, watermark, window, context): + return self.underlying.should_fire(time_domain, watermark, window, context) def on_fire(self, watermark, window, context): if self.underlying.on_fire(watermark, window, context): @@ -482,16 +538,19 @@ def on_merge(self, to_be_merged, merge_result, context): trigger.on_merge( to_be_merged, merge_result, self._sub_context(context, ix)) - def should_fire(self, watermark, window, context): + def should_fire(self, time_domain, watermark, window, context): + self._time_domain = time_domain return self.combine_op( - trigger.should_fire(watermark, window, self._sub_context(context, ix)) + trigger.should_fire(time_domain, watermark, window, + self._sub_context(context, ix)) for ix, trigger in enumerate(self.triggers)) def on_fire(self, watermark, window, context): finished = [] for ix, trigger in enumerate(self.triggers): nested_context = self._sub_context(context, ix) - if trigger.should_fire(watermark, window, nested_context): + if trigger.should_fire(TimeDomain.WATERMARK, watermark, + window, nested_context): finished.append(trigger.on_fire(watermark, window, nested_context)) return self.combine_op(finished) @@ -575,11 +634,11 @@ def on_merge(self, to_be_merged, merge_result, context): self.triggers[ix].on_merge( to_be_merged, merge_result, self._sub_context(context, ix)) - def should_fire(self, watermark, window, context): + def should_fire(self, time_domain, watermark, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): return self.triggers[ix].should_fire( - watermark, window, self._sub_context(context, ix)) + time_domain, watermark, window, self._sub_context(context, ix)) def on_fire(self, watermark, window, context): ix = context.get_state(self.INDEX_TAG) @@ -633,9 +692,13 @@ def to_runner_api(self, context): class TriggerContext(object): - def __init__(self, outer, window): + def __init__(self, outer, window, clock): self._outer = outer self._window = window + self._clock = clock + + def get_current_time(self): + return self._clock.time() def set_timer(self, name, time_domain, timestamp): self._outer.set_timer(self._window, name, time_domain, timestamp) @@ -709,8 +772,8 @@ def get_state(self, window, tag): def clear_state(self, window, tag): pass - def at(self, window): - return TriggerContext(self, window) + def at(self, window, clock=None): + return TriggerContext(self, window, clock) class UnmergedState(SimpleState): @@ -832,7 +895,8 @@ def __repr__(self): repr(self.raw_state).split('\n')) -def create_trigger_driver(windowing, is_batch=False, phased_combine_fn=None): +def create_trigger_driver(windowing, + is_batch=False, phased_combine_fn=None, clock=None): """Create the TriggerDriver for the given windowing and options.""" # TODO(robertwb): We can do more if we know elements are in timestamp @@ -840,7 +904,7 @@ def create_trigger_driver(windowing, is_batch=False, phased_combine_fn=None): if windowing.is_default() and is_batch: driver = DefaultGlobalBatchTriggerDriver() else: - driver = GeneralTriggerDriver(windowing) + driver = GeneralTriggerDriver(windowing, clock) if phased_combine_fn: # TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using @@ -953,7 +1017,8 @@ class GeneralTriggerDriver(TriggerDriver): ELEMENTS = _ListStateTag('elements') TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn()) - def __init__(self, windowing): + def __init__(self, windowing, clock): + self.clock = clock self.window_fn = windowing.windowfn self.timestamp_combiner_impl = TimestampCombiner.get_impl( windowing.timestamp_combiner, self.window_fn) @@ -1020,14 +1085,15 @@ def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument if output_time is not None: state.add_state(window, self.WATERMARK_HOLD, output_time) - context = state.at(window) + context = state.at(window, self.clock) for value, unused_timestamp in elements: state.add_state(window, self.ELEMENTS, value) self.trigger_fn.on_element(value, window, context) # Maybe fire this window. watermark = MIN_TIMESTAMP - if self.trigger_fn.should_fire(watermark, window, context): + if self.trigger_fn.should_fire(TimeDomain.WATERMARK, watermark, + window, context): finished = self.trigger_fn.on_fire(watermark, window, context) yield self._output(window, finished, state) @@ -1038,10 +1104,12 @@ def process_timer(self, window_id, unused_name, time_domain, timestamp, window = state.get_window(window_id) if state.get_state(window, self.TOMBSTONE): return - if time_domain == TimeDomain.WATERMARK: + + if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): if not self.is_merging or window in state.known_windows(): - context = state.at(window) - if self.trigger_fn.should_fire(timestamp, window, context): + context = state.at(window, self.clock) + if self.trigger_fn.should_fire(time_domain, timestamp, + window, context): finished = self.trigger_fn.on_fire(timestamp, window, context) yield self._output(window, finished, state) else: diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py index 3afabaf0aa71..d66736f6218f 100644 --- a/sdks/python/apache_beam/transforms/trigger_test.py +++ b/sdks/python/apache_beam/transforms/trigger_test.py @@ -26,6 +26,7 @@ import apache_beam as beam from apache_beam.runners import pipeline_context +from apache_beam.runners.direct.clock import TestClock from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -100,7 +101,7 @@ def run_trigger(self, window_fn, trigger_fn, accumulation_mode, expected_panes): actual_panes = collections.defaultdict(list) driver = GeneralTriggerDriver( - Windowing(window_fn, trigger_fn, accumulation_mode)) + Windowing(window_fn, trigger_fn, accumulation_mode), TestClock()) state = InMemoryUnmergedState() for bundle in bundles: @@ -470,7 +471,7 @@ def split_args(s): args = [] start = 0 depth = 0 - for ix in xrange(len(s)): + for ix in range(len(s)): c = s[ix] if c in '({[': depth += 1 diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index ee9d6f971871..c250e8c6d365 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -55,6 +55,8 @@ from google.protobuf import timestamp_pb2 from apache_beam.coders import coders +from apache_beam.portability import common_urns +from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import standard_window_fns_pb2 from apache_beam.transforms import timeutil @@ -172,7 +174,7 @@ def get_transformed_output_time(self, window, input_timestamp): # pylint: disab # By default, just return the input timestamp. return input_timestamp - urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_WINDOW_FN) + urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_WINDOWFN) class BoundedWindow(object): @@ -306,9 +308,9 @@ def __ne__(self, other): return not self == other def to_runner_api_parameter(self, context): - return urns.GLOBAL_WINDOWS_FN, None + return common_urns.GLOBAL_WINDOWS_WINDOWFN, None - @urns.RunnerApiFn.register_urn(urns.GLOBAL_WINDOWS_FN, None) + @urns.RunnerApiFn.register_urn(common_urns.GLOBAL_WINDOWS_WINDOWFN, None) def from_runner_api_parameter(unused_fn_parameter, unused_context): return GlobalWindows() @@ -349,7 +351,7 @@ def __ne__(self, other): return not self == other def to_runner_api_parameter(self, context): - return (urns.FIXED_WINDOWS_FN, + return (common_urns.FIXED_WINDOWS_WINDOWFN, standard_window_fns_pb2.FixedWindowsPayload( size=proto_utils.from_micros( duration_pb2.Duration, self.size.micros), @@ -357,7 +359,8 @@ def to_runner_api_parameter(self, context): timestamp_pb2.Timestamp, self.offset.micros))) @urns.RunnerApiFn.register_urn( - urns.FIXED_WINDOWS_FN, standard_window_fns_pb2.FixedWindowsPayload) + common_urns.FIXED_WINDOWS_WINDOWFN, + standard_window_fns_pb2.FixedWindowsPayload) def from_runner_api_parameter(fn_parameter, unused_context): return FixedWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), @@ -404,7 +407,7 @@ def __eq__(self, other): and self.period == other.period) def to_runner_api_parameter(self, context): - return (urns.SLIDING_WINDOWS_FN, + return (common_urns.SLIDING_WINDOWS_WINDOWFN, standard_window_fns_pb2.SlidingWindowsPayload( size=proto_utils.from_micros( duration_pb2.Duration, self.size.micros), @@ -414,7 +417,7 @@ def to_runner_api_parameter(self, context): duration_pb2.Duration, self.period.micros))) @urns.RunnerApiFn.register_urn( - urns.SLIDING_WINDOWS_FN, + common_urns.SLIDING_WINDOWS_WINDOWFN, standard_window_fns_pb2.SlidingWindowsPayload) def from_runner_api_parameter(fn_parameter, unused_context): return SlidingWindows( @@ -471,13 +474,14 @@ def __eq__(self, other): return self.gap_size == other.gap_size def to_runner_api_parameter(self, context): - return (urns.SESSION_WINDOWS_FN, + return (common_urns.SESSION_WINDOWS_WINDOWFN, standard_window_fns_pb2.SessionsPayload( gap_size=proto_utils.from_micros( duration_pb2.Duration, self.gap_size.micros))) @urns.RunnerApiFn.register_urn( - urns.SESSION_WINDOWS_FN, standard_window_fns_pb2.SessionsPayload) + common_urns.SESSION_WINDOWS_WINDOWFN, + standard_window_fns_pb2.SessionsPayload) def from_runner_api_parameter(fn_parameter, unused_context): return Sessions( gap_size=Duration(micros=fn_parameter.gap_size.ToMicroseconds())) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 89dc6afa34c8..1c7a92a6fa0c 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -567,11 +567,13 @@ def __getattr__(self, attr): return self.__iter__() return getattr(self.internal_gen, attr) - def next(self): + def __next__(self): next_val = next(self.internal_gen) self.interleave_func(next_val) return next_val + next = __next__ + def __iter__(self): while True: x = next(self.internal_gen) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 8a8e07ecb4be..0be931e8fe2b 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -138,7 +138,7 @@ def convert_to_beam_type(typ): if _len_arg(typ) != arity: raise ValueError('expecting type %s to have arity %d, had arity %d ' 'instead' % (str(typ), arity, _len_arg(typ))) - typs = [convert_to_beam_type(_get_arg(typ, i)) for i in xrange(arity)] + typs = [convert_to_beam_type(_get_arg(typ, i)) for i in range(arity)] if arity == 0: # Nullary types (e.g. Any) don't accept empty tuples as arguments. return matched_entry.beam_type @@ -161,6 +161,6 @@ def convert_to_beam_types(args): a dictionary with the same keys, and values which have been converted. """ if isinstance(args, dict): - return {k: convert_to_beam_type(v) for k, v in args.iteritems()} + return {k: convert_to_beam_type(v) for k, v in args.items()} else: return [convert_to_beam_type(v) for v in args] diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 7fae11b63eab..7f552f2e4a4b 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -344,7 +344,7 @@ def load_deref(state, arg): def call_function(state, arg, has_var=False, has_kw=False): # TODO(robertwb): Recognize builtins and dataflow objects # (especially special return values). - pop_count = (arg & 0xF) + (arg & 0xF0) / 8 + 1 + has_var + has_kw + pop_count = (arg & 0xF) + (arg & 0xF0) // 8 + 1 + has_var + has_kw state.stack[-pop_count:] = [Any] diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 28bf8f5ba6f3..f1fb1d7bdb7c 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -33,7 +33,6 @@ from apache_beam.typehints import Any from apache_beam.typehints import typehints from six.moves import builtins -from six.moves import zip class TypeInferenceError(ValueError): diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index cd5b8c2f50b6..1274fb772773 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -110,6 +110,7 @@ def testTupleListComprehension(self): typehints.List[typehints.Union[int, float]], lambda xs: [x for x in xs], [typehints.Tuple[int, float]]) + # TODO(luke-zhu): This test fails in Python 3 self.assertReturnType( typehints.List[typehints.Tuple[str, int]], lambda kvs: [(kvs[0], v) for v in kvs[1]], diff --git a/sdks/python/apache_beam/typehints/typecheck.py b/sdks/python/apache_beam/typehints/typecheck.py index a8887465099f..7c7012c8aea7 100644 --- a/sdks/python/apache_beam/typehints/typecheck.py +++ b/sdks/python/apache_beam/typehints/typecheck.py @@ -25,6 +25,8 @@ import sys import types +import six + from apache_beam import pipeline from apache_beam.pvalue import TaggedOutput from apache_beam.transforms import core @@ -84,14 +86,14 @@ def wrapper(self, method, args, kwargs): except TypeCheckError as e: error_msg = ('Runtime type violation detected within ParDo(%s): ' '%s' % (self.full_label, e)) - raise TypeCheckError, error_msg, sys.exc_info()[2] + six.reraise(TypeCheckError, error_msg, sys.exc_info()[2]) else: return self._check_type(result) def _check_type(self, output): if output is None: return output - elif isinstance(output, (dict, basestring)): + elif isinstance(output, (dict,) + six.string_types): object_type = type(output).__name__ raise TypeCheckError('Returning a %s from a ParDo or FlatMap is ' 'discouraged. Please use list("%s") if you really ' @@ -173,12 +175,12 @@ def _type_check(self, type_constraint, datum, is_input): try: check_constraint(type_constraint, datum) except CompositeTypeHintError as e: - raise TypeCheckError, e.message, sys.exc_info()[2] + six.reraise(TypeCheckError, e.args[0], sys.exc_info()[2]) except SimpleTypeHintError: error_msg = ("According to type-hint expected %s should be of type %s. " "Instead, received '%s', an instance of type %s." % (datum_type, type_constraint, datum, type(datum))) - raise TypeCheckError, error_msg, sys.exc_info()[2] + six.reraise(TypeCheckError, error_msg, sys.exc_info()[2]) class TypeCheckCombineFn(core.CombineFn): diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 2581457e7ea1..598847e023ce 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -60,8 +60,8 @@ def test_non_function(self): [1, 2, 3] | beam.Map(str.upper) def test_loose_bounds(self): - @typehints.with_input_types(typehints.Union[int, float, long]) - @typehints.with_output_types(basestring) + @typehints.with_input_types(typehints.Union[int, float]) + @typehints.with_output_types(str) def format_number(x): return '%g' % x result = [1, 2, 3] | beam.Map(format_number) diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 6e1d8b7f276c..3455672e7a82 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -68,6 +68,8 @@ import sys import types +import six + __all__ = [ 'Any', 'Union', @@ -403,6 +405,8 @@ class AnyTypeConstraint(TypeConstraint): function arguments or return types. All other TypeConstraint's are equivalent to 'Any', and its 'type_check' method is a no-op. """ + def __eq__(self, other): + return type(self) == type(other) def __repr__(self): return 'Any' @@ -413,6 +417,9 @@ def type_check(self, instance): class TypeVariable(AnyTypeConstraint): + def __eq__(self, other): + return type(self) == type(other) and self.name == other.name + def __init__(self, name): self.name = name @@ -802,7 +809,7 @@ def type_check(self, dict_instance): 'type dict. %s is of type %s.' % (dict_instance, dict_instance.__class__.__name__)) - for key, value in dict_instance.iteritems(): + for key, value in dict_instance.items(): try: check_constraint(self.key_type, key) except CompositeTypeHintError as e: @@ -985,6 +992,7 @@ def __getitem__(self, type_param): IteratorTypeConstraint = IteratorHint.IteratorTypeConstraint +@six.add_metaclass(GetitemConstructor) class WindowedTypeConstraint(TypeConstraint): """A type constraint for WindowedValue objects. @@ -993,7 +1001,6 @@ class WindowedTypeConstraint(TypeConstraint): Attributes: inner_type: The type which the element should be an instance of. """ - __metaclass__ = GetitemConstructor def __init__(self, inner_type): self.inner_type = inner_type diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index af575f4aba42..2994adc0aa5f 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -134,7 +134,7 @@ def test_getitem_must_be_valid_type_param_cant_be_object_instance(self): with self.assertRaises(TypeError) as e: typehints.Union[5] self.assertEqual('Cannot create Union without a sequence of types.', - e.exception.message) + e.exception.args[0]) def test_getitem_must_be_valid_type_param(self): t = [2, 3] @@ -142,7 +142,7 @@ def test_getitem_must_be_valid_type_param(self): typehints.Union[t] self.assertEqual('All parameters to a Union hint must be a non-sequence, ' 'a type, or a TypeConstraint. 2 is an instance of int.', - e.exception.message) + e.exception.args[0]) def test_getitem_duplicates_ignored(self): # Types should be de-duplicated. @@ -219,7 +219,7 @@ def test_union_hint_enforcement_not_part_of_union(self): self.assertEqual("Union[float, int] type-constraint violated. Expected an " "instance of one of: ('float', 'int'), received str " "instead.", - e.exception.message) + e.exception.args[0]) class OptionalHintTestCase(TypeHintTestCase): @@ -227,7 +227,7 @@ class OptionalHintTestCase(TypeHintTestCase): def test_getitem_sequence_not_allowed(self): with self.assertRaises(TypeError) as e: typehints.Optional[int, str] - self.assertTrue(e.exception.message.startswith( + self.assertTrue(e.exception.args[0].startswith( 'An Option type-hint only accepts a single type parameter.')) def test_getitem_proxy_to_union(self): @@ -243,21 +243,21 @@ def test_getitem_invalid_ellipsis_type_param(self): with self.assertRaises(TypeError) as e: typehints.Tuple[int, int, ...] - self.assertEqual(error_msg, e.exception.message) + self.assertEqual(error_msg, e.exception.args[0]) with self.assertRaises(TypeError) as e: typehints.Tuple[...] - self.assertEqual(error_msg, e.exception.message) + self.assertEqual(error_msg, e.exception.args[0]) def test_getitem_params_must_be_type_or_constraint(self): expected_error_prefix = 'All parameters to a Tuple hint must be' with self.assertRaises(TypeError) as e: typehints.Tuple[5, [1, 3]] - self.assertTrue(e.exception.message.startswith(expected_error_prefix)) + self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) with self.assertRaises(TypeError) as e: typehints.Tuple[list, dict] - self.assertTrue(e.exception.message.startswith(expected_error_prefix)) + self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) def test_compatibility_arbitrary_length(self): self.assertNotCompatible( @@ -310,7 +310,7 @@ def test_type_check_must_be_tuple(self): for t in invalid_instances: with self.assertRaises(TypeError) as e: hint.type_check(t) - self.assertTrue(e.exception.message.startswith(expected_error_prefix)) + self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) def test_type_check_must_have_same_arity(self): # A 2-tuple of ints. @@ -322,7 +322,7 @@ def test_type_check_must_have_same_arity(self): self.assertEqual('Passed object instance is of the proper type, but ' 'differs in length from the hinted type. Expected a ' 'tuple of length 2, received a tuple of length 3.', - e.exception.message) + e.exception.args[0]) def test_type_check_invalid_simple_types(self): hint = typehints.Tuple[str, bool] @@ -332,7 +332,7 @@ def test_type_check_invalid_simple_types(self): 'type of element #0 in the passed tuple is incorrect.' ' Expected an instance of type str, instead received ' 'an instance of type int.', - e.exception.message) + e.exception.args[0]) def test_type_check_invalid_composite_type(self): hint = typehints.Tuple[DummyTestClass1, DummyTestClass2] @@ -345,7 +345,7 @@ def test_type_check_invalid_composite_type(self): 'passed tuple is incorrect. Expected an instance of type ' 'DummyTestClass1, instead received an instance of type ' 'DummyTestClass2.', - e.exception.message) + e.exception.args[0]) def test_type_check_valid_simple_types(self): hint = typehints.Tuple[float, bool] @@ -382,7 +382,7 @@ def test_type_check_invalid_simple_type_arbitrary_length(self): 'of element #2 in the passed tuple is incorrect. Expected ' 'an instance of type str, instead received an instance of ' 'type int.', - e.exception.message) + e.exception.args[0]) def test_type_check_invalid_composite_type_arbitrary_length(self): hint = typehints.Tuple[typehints.List[int], ...] @@ -396,7 +396,7 @@ def test_type_check_invalid_composite_type_arbitrary_length(self): "List type-constraint violated. Valid object instance " "must be of type 'list'. Instead, an instance of 'str' " "was received.", - e.exception.message) + e.exception.args[0]) class ListHintTestCase(TypeHintTestCase): @@ -439,7 +439,7 @@ def test_enforce_list_type_constraint_invalid_simple_type(self): 'element #0 in the passed list is incorrect. Expected an ' 'instance of type int, instead received an instance of ' 'type str.', - e.exception.message) + e.exception.args[0]) def test_enforce_list_type_constraint_invalid_composite_type(self): hint = typehints.List[typehints.Tuple[int, int]] @@ -453,7 +453,7 @@ def test_enforce_list_type_constraint_invalid_composite_type(self): 'violated. The type of element #0 in the passed tuple' ' is incorrect. Expected an instance of type int, ' 'instead received an instance of type str.', - e.exception.message) + e.exception.args[0]) class KVHintTestCase(TypeHintTestCase): @@ -464,7 +464,7 @@ def test_getitem_param_must_be_tuple(self): self.assertEqual('Parameter to KV type-hint must be a tuple of types: ' 'KV[.., ..].', - e.exception.message) + e.exception.args[0]) def test_getitem_param_must_have_length_2(self): with self.assertRaises(TypeError) as e: @@ -473,7 +473,7 @@ def test_getitem_param_must_have_length_2(self): self.assertEqual("Length of parameters to a KV type-hint must be " "exactly 2. Passed parameters: (, , ), have a length of 3.", - e.exception.message) + e.exception.args[0]) def test_getitem_proxy_to_tuple(self): hint = typehints.KV[int, str] @@ -493,7 +493,7 @@ def test_getitem_param_must_be_tuple(self): self.assertEqual('Parameter to Dict type-hint must be a tuple of ' 'types: Dict[.., ..].', - e.exception.message) + e.exception.args[0]) def test_getitem_param_must_have_length_2(self): with self.assertRaises(TypeError) as e: @@ -502,7 +502,7 @@ def test_getitem_param_must_have_length_2(self): self.assertEqual("Length of parameters to a Dict type-hint must be " "exactly 2. Passed parameters: (, , ), have a length of 3.", - e.exception.message) + e.exception.args[0]) def test_key_type_must_be_valid_composite_param(self): with self.assertRaises(TypeError): @@ -533,7 +533,7 @@ def test_type_checks_not_dict(self): hint.type_check(l) self.assertEqual('Dict type-constraint violated. All passed instances ' 'must be of type dict. [1, 2] is of type list.', - e.exception.message) + e.exception.args[0]) def test_type_check_invalid_key_type(self): hint = typehints.Dict[typehints.Tuple[int, int, int], @@ -548,7 +548,7 @@ def test_type_check_invalid_key_type(self): 'instance is of the proper type, but differs in ' 'length from the hinted type. Expected a tuple of ' 'length 3, received a tuple of length 2.', - e.exception.message) + e.exception.args[0]) def test_type_check_invalid_value_type(self): hint = typehints.Dict[str, typehints.Dict[int, str]] @@ -560,7 +560,7 @@ def test_type_check_invalid_value_type(self): 'Dict[int, str]. Instead: Dict type-constraint ' 'violated. All passed instances must be of type dict.' ' [1, 2, 3] is of type list.', - e.exception.message) + e.exception.args[0]) def test_type_check_valid_simple_type(self): hint = typehints.Dict[int, str] @@ -588,7 +588,7 @@ def test_getitem_invalid_composite_type_param(self): self.assertEqual("Parameter to a Set hint must be a non-sequence, a " "type, or a TypeConstraint. is an " "instance of type.", - e.exception.message) + e.exception.args[0]) def test_compatibility(self): hint1 = typehints.Set[typehints.List[str]] @@ -609,7 +609,7 @@ def test_type_check_must_be_set(self): self.assertEqual("Set type-constraint violated. Valid object instance " "must be of type 'set'. Instead, an instance of 'int'" " was received.", - e.exception.message) + e.exception.args[0]) def test_type_check_invalid_elem_type(self): hint = typehints.Set[float] @@ -635,7 +635,7 @@ def test_getitem_invalid_composite_type_param(self): self.assertEqual('Parameter to an Iterable hint must be a ' 'non-sequence, a type, or a TypeConstraint. 5 is ' 'an instance of int.', - e.exception.message) + e.exception.args[0]) def test_compatibility(self): self.assertCompatible(typehints.Iterable[int], typehints.List[int]) @@ -678,7 +678,7 @@ def test_type_check_must_be_iterable(self): self.assertEqual("Iterable type-constraint violated. Valid object " "instance must be of type 'iterable'. Instead, an " "instance of 'int' was received.", - e.exception.message) + e.exception.args[0]) def test_type_check_violation_invalid_simple_type(self): hint = typehints.Iterable[float] @@ -746,7 +746,7 @@ def all_upper(s): 'hint type-constraint violated. Expected a iterator ' 'of type int. Instead received a iterator of type ' 'str.', - e.exception.message) + e.exception.args[0]) def test_generator_argument_hint_invalid_yield_type(self): def wrong_yield_gen(): @@ -765,7 +765,7 @@ def increment(a): "hint type-constraint violated. Expected a iterator " "of type int. Instead received a iterator of type " "str.", - e.exception.message) + e.exception.args[0]) class TakesDecoratorTestCase(TypeHintTestCase): @@ -781,7 +781,7 @@ def unused_foo(a): self.assertEqual('All type hint arguments must be a non-sequence, a ' 'type, or a TypeConstraint. [1, 2] is an instance of ' 'list.', - e.exception.message) + e.exception.args[0]) with self.assertRaises(TypeError) as e: t = 5 @@ -793,7 +793,7 @@ def unused_foo(a): self.assertEqual('All type hint arguments must be a non-sequence, a type, ' 'or a TypeConstraint. 5 is an instance of int.', - e.exception.message) + e.exception.args[0]) def test_basic_type_assertion(self): @check_type_hints @@ -807,7 +807,7 @@ def foo(a): self.assertEqual("Type-hint for argument: 'a' violated. Expected an " "instance of , instead found an " "instance of .", - e.exception.message) + e.exception.args[0]) def test_composite_type_assertion(self): @check_type_hints @@ -823,7 +823,7 @@ def foo(a): "type-constraint violated. The type of element #0 in " "the passed list is incorrect. Expected an instance of " "type int, instead received an instance of type str.", - e.exception.message) + e.exception.args[0]) def test_valid_simple_type_arguments(self): @with_input_types(a=str) @@ -861,7 +861,7 @@ def sub(a, b): self.assertEqual("Type-hint for argument: 'b' violated. Expected an " "instance of , instead found an instance " "of .", - e.exception.message) + e.exception.args[0]) def test_valid_only_positional_arguments(self): @with_input_types(int, int) @@ -907,7 +907,7 @@ def foo(a): self.assertEqual("Type-hint for return type violated. Expected an " "instance of , instead found an instance " "of .", - e.exception.message) + e.exception.args[0]) def test_type_check_simple_type(self): @with_output_types(str) diff --git a/sdks/python/apache_beam/utils/annotations.py b/sdks/python/apache_beam/utils/annotations.py index 017dd6b81a4d..036b08287df1 100644 --- a/sdks/python/apache_beam/utils/annotations.py +++ b/sdks/python/apache_beam/utils/annotations.py @@ -96,7 +96,7 @@ def inner(*args, **kwargs): message += '. Use %s instead.' % current if current else '.' if extra_message: message += '. ' + extra_message - warnings.warn(message, warning_type) + warnings.warn(message, warning_type, stacklevel=2) return fnc(*args, **kwargs) return inner return _annotate diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index a2c3f6ab1af6..9f9c8cd16296 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -24,12 +24,17 @@ import logging import os import pstats -import StringIO +import sys import tempfile import time import warnings from threading import Timer +if sys.version_info[0] < 3: + import StringIO +else: + from io import StringIO + class Profile(object): """cProfile wrapper context for saving and logging profiler results.""" @@ -66,7 +71,7 @@ def __exit__(self, *args): os.remove(filename) if self.log_results: - s = StringIO.StringIO() + s = StringIO() self.stats = pstats.Stats( self.profile, stream=s).sort_stats(Profile.SORTBY) self.stats.print_stats() diff --git a/sdks/python/apache_beam/utils/retry.py b/sdks/python/apache_beam/utils/retry.py index 927da14678c1..97bd03798509 100644 --- a/sdks/python/apache_beam/utils/retry.py +++ b/sdks/python/apache_beam/utils/retry.py @@ -31,6 +31,8 @@ import time import traceback +import six + from apache_beam.io.filesystem import BeamIOError # Protect against environments where apitools library is not available. @@ -80,7 +82,7 @@ def __init__(self, initial_delay_secs, num_retries, factor=2, fuzz=0.5, def __iter__(self): current_delay_secs = min(self._max_delay_secs, self._initial_delay_secs) - for _ in xrange(self._num_retries): + for _ in range(self._num_retries): fuzz_multiplier = 1 - self._fuzz + random.random() * self._fuzz yield current_delay_secs * fuzz_multiplier current_delay_secs = min( @@ -185,7 +187,7 @@ def wrapper(*args, **kwargs): sleep_interval = next(retry_intervals) except StopIteration: # Re-raise the original exception since we finished the retries. - raise exn, None, exn_traceback # pylint: disable=raising-bad-type + six.reraise(exn, None, exn_traceback) # pylint: disable=raising-bad-type logger( 'Retry with exponential backoff: waiting for %s seconds before ' diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index b3e840ee284e..c437d5a3e7c4 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -21,6 +21,7 @@ """ from __future__ import absolute_import +from __future__ import division import datetime @@ -68,7 +69,7 @@ def __repr__(self): if micros < 0: sign = '-' micros = -micros - int_part = micros / 1000000 + int_part = micros // 1000000 frac_part = micros % 1000000 if frac_part: return 'Timestamp(%s%d.%06d)' % (sign, int_part, frac_part) @@ -86,11 +87,11 @@ def isoformat(self): def __float__(self): # Note that the returned value may have lost precision. - return float(self.micros) / 1000000 + return self.micros / 1000000 def __int__(self): # Note that the returned value may have lost precision. - return self.micros / 1000000 + return self.micros // 1000000 def __cmp__(self, other): # Allow comparisons between Duration and Timestamp values. @@ -160,7 +161,7 @@ def __repr__(self): if micros < 0: sign = '-' micros = -micros - int_part = micros / 1000000 + int_part = micros // 1000000 frac_part = micros % 1000000 if frac_part: return 'Duration(%s%d.%06d)' % (sign, int_part, frac_part) @@ -168,7 +169,7 @@ def __repr__(self): def __float__(self): # Note that the returned value may have lost precision. - return float(self.micros) / 1000000 + return self.micros / 1000000 def __cmp__(self, other): # Allow comparisons between Duration and Timestamp values. @@ -200,7 +201,7 @@ def __rsub__(self, other): def __mul__(self, other): other = Duration.of(other) - return Duration(micros=self.micros * other.micros / 1000000) + return Duration(micros=self.micros * other.micros // 1000000) def __rmul__(self, other): return self * other diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index bd02fe1dfb78..e62fbcd0948c 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -25,48 +25,6 @@ from apache_beam.internal import pickler from apache_beam.utils import proto_utils -PICKLED_WINDOW_FN = "beam:windowfn:pickled_python:v0.1" -GLOBAL_WINDOWS_FN = "beam:windowfn:global_windows:v0.1" -FIXED_WINDOWS_FN = "beam:windowfn:fixed_windows:v0.1" -SLIDING_WINDOWS_FN = "beam:windowfn:sliding_windows:v0.1" -SESSION_WINDOWS_FN = "beam:windowfn:session_windows:v0.1" - -PICKLED_DO_FN = "beam:dofn:pickled_python:v0.1" -PICKLED_DO_FN_INFO = "beam:dofn:pickled_python_info:v0.1" -PICKLED_COMBINE_FN = "beam:combinefn:pickled_python:v0.1" - -PICKLED_TRANSFORM = "beam:ptransform:pickled_python:v0.1" -PARDO_TRANSFORM = "beam:ptransform:pardo:v0.1" -GROUP_BY_KEY_TRANSFORM = "urn:beam:transform:groupbykey:v1" -GROUP_BY_KEY_ONLY_TRANSFORM = "beam:ptransform:group_by_key_only:v0.1" -GROUP_ALSO_BY_WINDOW_TRANSFORM = "beam:ptransform:group_also_by_window:v0.1" -COMBINE_PER_KEY_TRANSFORM = "beam:ptransform:combine_per_key:v0.1" -COMBINE_GROUPED_VALUES_TRANSFORM = "beam:ptransform:combine_grouped_values:v0.1" -PRECOMBINE_TRANSFORM = "beam:ptransform:combine_pre:v0.1" -MERGE_ACCUMULATORS_TRANSFORM = "beam:ptransform:combine_merge_accumulators:v0.1" -EXTRACT_OUTPUTS_TRANSFORM = "beam:ptransform:combine_extract_outputs:v0.1" -FLATTEN_TRANSFORM = "beam:ptransform:flatten:v0.1" -READ_TRANSFORM = "beam:ptransform:read:v0.1" -RESHUFFLE_TRANSFORM = "beam:ptransform:reshuffle:v0.1" -WINDOW_INTO_TRANSFORM = "beam:ptransform:window_into:v0.1" - -PICKLED_SOURCE = "beam:source:pickled_python:v0.1" - -PICKLED_CODER = "beam:coder:pickled_python:v0.1" -BYTES_CODER = "urn:beam:coders:bytes:0.1" -VAR_INT_CODER = "urn:beam:coders:varint:0.1" -INTERVAL_WINDOW_CODER = "urn:beam:coders:interval_window:0.1" -ITERABLE_CODER = "urn:beam:coders:stream:0.1" -KV_CODER = "urn:beam:coders:kv:0.1" -LENGTH_PREFIX_CODER = "urn:beam:coders:length_prefix:0.1" -GLOBAL_WINDOW_CODER = "urn:beam:coders:global_window:0.1" -WINDOWED_VALUE_CODER = "urn:beam:coders:windowed_value:0.1" - -ITERABLE_ACCESS = "urn:beam:sideinput:iterable" -MULTIMAP_ACCESS = "urn:beam:sideinput:multimap" -PICKLED_PYTHON_VIEWFN = "beam:view_fn:pickled_python_data:v0.1" -PICKLED_WINDOW_MAPPING_FN = "beam:window_mapping_fn:pickled_python:v0.1" - class RunnerApiFn(object): """Abstract base class that provides urn registration utilities. diff --git a/sdks/python/container/build.gradle b/sdks/python/container/build.gradle index 5d8c2b8fa7b8..bbf0c709c9e6 100644 --- a/sdks/python/container/build.gradle +++ b/sdks/python/container/build.gradle @@ -32,6 +32,7 @@ dependencies { // TODO(herohde): use "./" prefix to prevent gogradle use base github path, for now. // TODO(herohde): get the pkg subdirectory only, if possible. We spend mins pulling cmd/beamctl deps. build name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir + test name: './github.com/apache/beam/sdks/go', dir: project(':sdks:go').projectDir } } @@ -45,8 +46,13 @@ golang { } docker { - // TODO(herohde): make the name easier to generate for releases. - name System.properties['user.name'] + '-docker-apache.bintray.io/beam/python:latest' + String repositoryRoot + if (rootProject.hasProperty(["docker-repository-root"])) { + repositoryRoot = rootProject["docker-repository-root"] + } else { + repositoryRoot = "${System.properties["user.name"]}-docker-apache.bintray.io/beam" + } + name "${repositoryRoot}/python:latest" files "./build/" } // Ensure that making the docker image builds any required artifacts diff --git a/sdks/python/container/run_validatescontainer.sh b/sdks/python/container/run_validatescontainer.sh new file mode 100755 index 000000000000..ba601b5b4a3a --- /dev/null +++ b/sdks/python/container/run_validatescontainer.sh @@ -0,0 +1,125 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script will be run by Jenkins as a post commit test. In order to run +# locally make the following changes: +# +# LOCAL_PATH -> Path of tox and virtualenv if you have them already installed. +# GCS_LOCATION -> Temporary location to use for service tests. +# PROJECT -> Project name to use for dataflow and docker images. +# +# Execute from the root of the repository: sdks/python/container/run_validatescontainer.sh + +set -e +set -v + +# pip install --user installation location. +LOCAL_PATH=$HOME/.local/bin/ + +# Where to store integration test outputs. +GCS_LOCATION=gs://temp-storage-for-end-to-end-tests + +# Project for the container and integration test +PROJECT=apache-beam-testing + +# Verify in the root of the repository +test -d sdks/python/container + +# Verify docker and gcloud commands exist +command -v docker +command -v gcloud +docker -v +gcloud -v + +# ensure maven version is 3.5 or above +TMPDIR=$(mktemp -d) +MVN=$(which mvn) +mvn_ver=$($MVN -v | head -1 | awk '{print $3}') +if [[ "$mvn_ver" < "3.5" ]] +then + pushd $TMPDIR + curl http://www.apache.org/dist/maven/maven-3/3.5.2/binaries/apache-maven-3.5.2-bin.tar.gz --output maven.tar.gz + tar xf maven.tar.gz + MVN="$(pwd)/apache-maven-3.5.2/bin/mvn" + popd +fi + +# ensure gcloud is version 186 or above +gcloud_ver=$(gcloud -v | head -1 | awk '{print $4}') +if [[ "$gcloud_ver" < "186" ]] +then + pushd $TMPDIR + curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-186.0.0-linux-x86_64.tar.gz --output gcloud.tar.gz + tar xf gcloud.tar.gz + ./google-cloud-sdk/install.sh --quiet + . ./google-cloud-sdk/path.bash.inc + popd + gcloud components update --quiet || echo 'gcloud components update failed' + gcloud -v +fi + +# Build the container +TAG=$(date +%Y%m%d-%H%M%S) +CONTAINER=us.gcr.io/$PROJECT/$USER/python +echo "Using container $CONTAINER" +$MVN clean install -DskipTests -Pbuild-containers --projects sdks/python/container -Ddocker-repository-root=us.gcr.io/$PROJECT/$USER -Ddockerfile.tag=$TAG -amd + +# Verify it exists +docker images | grep $TAG + +# Push the container +gcloud docker -- push $CONTAINER + +# INFRA does not install virtualenv +pip install virtualenv --user + +# Virtualenv for the rest of the script to run setup & e2e test +${LOCAL_PATH}/virtualenv sdks/python/container +. sdks/python/container/bin/activate +cd sdks/python +pip install -e .[gcp,test] + +# Create a tarball +python setup.py sdist +SDK_LOCATION=$(find dist/apache-beam-*.tar.gz) + +# Run ValidatesRunner tests on Google Cloud Dataflow service +echo ">>> RUNNING DATAFLOW RUNNER VALIDATESCONTAINER TEST" +python setup.py nosetests \ + --attr ValidatesContainer \ + --nocapture \ + --processes=1 \ + --process-timeout=900 \ + --test-pipeline-options=" \ + --runner=TestDataflowRunner \ + --project=$PROJECT \ + --worker_harness_container_image=$CONTAINER:$TAG \ + --staging_location=$GCS_LOCATION/staging-validatesrunner-test \ + --temp_location=$GCS_LOCATION/temp-validatesrunner-test \ + --output=$GCS_LOCATION/output \ + --sdk_location=$SDK_LOCATION \ + --num_workers=1" + +# Delete the container locally and remotely +docker rmi $CONTAINER:$TAG || echo "Failed to remove container" +gcloud container images delete $CONTAINER:$TAG || echo "Failed to delete container" + +# Clean up tempdir +rm -rf $TMPDIR + +echo ">>> SUCCESS DATAFLOW RUNNER VALIDATESCONTAINER TEST" diff --git a/sdks/python/generate_pydoc.sh b/sdks/python/generate_pydoc.sh index b4ec96990c9c..1a5f4d3ba1a1 100755 --- a/sdks/python/generate_pydoc.sh +++ b/sdks/python/generate_pydoc.sh @@ -100,7 +100,6 @@ import apache_beam as beam intersphinx_mapping = { 'python': ('https://docs.python.org/2', None), 'hamcrest': ('https://pyhamcrest.readthedocs.io/en/latest/', None), - 'hdfs3': ('https://hdfs3.readthedocs.io/en/latest/', None), } # Since private classes are skipped by sphinx, if there is any cross reference diff --git a/sdks/python/setup.py b/sdks/python/setup.py index d0034f7dee74..a069237e22be 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -23,6 +23,7 @@ import os import pkg_resources import platform +import re import shutil import subprocess import sys @@ -99,6 +100,7 @@ def get_version(): 'crcmod>=1.7,<2.0', 'dill==0.2.6', 'grpcio>=1.0,<2', + 'hdfs>=2.1.0,<3.0.0', 'httplib2>=0.8,<0.10', 'mock>=1.0.1,<3.0.0', 'oauth2client>=2.0.1,<5', @@ -109,7 +111,6 @@ def get_version(): 'six>=1.9,<1.12', 'typing>=3.6.0,<3.7.0', 'futures>=3.1.1,<4.0.0', - 'hdfs3>=0.3.0,<0.4.0', ] REQUIRED_SETUP_PACKAGES = [ @@ -147,6 +148,35 @@ def run(self): return original_cmd +def generate_common_urns(): + src = os.path.join( + os.path.dirname(__file__), + '../../' + 'model/pipeline/src/main/resources/org/apache/beam/model/common_urns.md') + out = os.path.join( + os.path.dirname(__file__), + 'apache_beam/portability/common_urns.py') + src_time = os.path.getmtime(src) if os.path.exists(src) else -1 + out_time = os.path.getmtime(out) if os.path.exists(out) else -1 + if src_time > out_time: + print 'Regenerating common_urns module.' + urns = {} + for m in re.finditer( + r'\b(?:urn:)?beam:(\S+):(\S+):(v\S+)', open(src).read()): + kind, name, version = m.groups() + var_name = name.upper() + '_' + kind.upper() + if var_name in urns: + var_name += '_' + version.upper().replace('.', '_') + urns[var_name] = m.group(0) + open(out, 'w').write( + '# Autogenerated from common_urns.md\n' + + '# pylint: disable=line-too-long\n\n' + + '\n'.join('%s = "%s"' % urn + for urn in sorted(urns.items(), key=lambda kv: kv[1])) + + '\n') +generate_common_urns() + + setuptools.setup( name=PACKAGE_NAME, version=PACKAGE_VERSION, diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index f00ae6147edd..31bcc9cdb572 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -33,6 +33,7 @@ deps = whitelist_externals=find commands = python --version + pip --version # Clean up all previous python generated files. - find apache_beam -type f -name '*.pyc' -delete pip install -e .[test] @@ -56,6 +57,7 @@ whitelist_externals= time commands = python --version + pip --version # Clean up all previous python generated files. - find apache_beam -type f -name '*.pyc' -delete # Clean up all previous cython generated files. @@ -79,8 +81,9 @@ deps = nose==1.3.7 whitelist_externals=find commands = - pip install -e .[test,gcp] python --version + pip --version + pip install -e .[test,gcp] # Clean up all previous python generated files. - find apache_beam -type f -name '*.pyc' -delete python apache_beam/examples/complete/autocomplete_test.py @@ -96,6 +99,8 @@ deps= isort==4.2.15 whitelist_externals=time commands = + python --version + pip --version time pip install -e .[test] time {toxinidir}/run_pylint.sh passenv = TRAVIS* @@ -108,6 +113,8 @@ deps= sphinx_rtd_theme==0.2.4 whitelist_externals=time commands = + python --version + pip --version time pip install -e .[test,gcp,docs] time {toxinidir}/generate_pydoc.sh passenv = TRAVIS* @@ -120,6 +127,7 @@ deps = whitelist_externals=find commands = python --version + pip --version pip install -e .[test,gcp] # Clean up all previous python generated files. - find apache_beam -type f -name '*.pyc' -delete diff --git a/settings.gradle b/settings.gradle index ceaab4a07837..8446c45c6c9b 100644 --- a/settings.gradle +++ b/settings.gradle @@ -20,6 +20,7 @@ include ":examples:java" include ":model:fn-execution" include ":model:job-management" include ":model:pipeline" +include ":release" include ":runners:apex" include ":runners:core-construction-java" include ":runners:core-java" @@ -43,6 +44,7 @@ include ":sdks:java:extensions:google-cloud-platform-core" include ":sdks:java:extensions:jackson" include ":sdks:java:extensions:join-library" include ":sdks:java:extensions:protobuf" +include ":sdks:java:extensions:sketching" include ":sdks:java:extensions:sorter" include ":sdks:java:extensions:sql" include ":sdks:java:fn-execution" @@ -72,7 +74,6 @@ include ":sdks:java:io:redis" include ":sdks:java:io:solr" include ":sdks:java:io:tika" include ":sdks:java:io:xml" -include ":sdks:java:java8tests" include ":sdks:java:maven-archetypes:examples" include ":sdks:java:maven-archetypes:starter" include ":sdks:java:nexmark"