diff --git a/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy new file mode 100644 index 000000000000..1b11c841146a --- /dev/null +++ b/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import CommonJobProperties as commonJobProperties +import PostcommitJobBuilder + + +import static PythonTestProperties.CROSS_LANGUAGE_VALIDATES_RUNNER_PYTHON_VERSIONS + + +// This job runs end-to-end cross language GCP IO tests with DataflowRunner. +// Collects tests with the @pytest.mark.uses_gcp_java_expansion_service decorator +PostcommitJobBuilder.postCommitJob('beam_PostCommit_Python_Xlang_Gcp_Dataflow', + 'Run Python_Xlang_Gcp_Dataflow PostCommit', 'Python_Xlang_Gcp_Dataflow (\"Run Python_Xlang_Gcp_Dataflow PostCommit\")', this) { + description('Runs end-to-end cross language GCP IO tests on the Dataflow runner.') + + + // Set common parameters. + commonJobProperties.setTopLevelMainJobProperties(delegate) + + + // Publish all test results to Jenkins + publishers { + archiveJunit('**/pytest*.xml') + } + + + // Gradle goals for this job. + steps { + CROSS_LANGUAGE_VALIDATES_RUNNER_PYTHON_VERSIONS.each { pythonVersion -> + shell("echo \"Running cross language GCP IO tests with Python ${pythonVersion} on DataflowRunner.\"") + gradle { + rootBuildScriptDir(commonJobProperties.checkoutDir) + tasks(":sdks:python:test-suites:dataflow:py${pythonVersion.replace('.', '')}:gcpCrossLanguagePythonUsingJava") + commonJobProperties.setGradleSwitches(delegate) + } + } + } + } \ No newline at end of file diff --git a/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Direct.groovy b/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Direct.groovy new file mode 100644 index 000000000000..229f60161ddb --- /dev/null +++ b/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Direct.groovy @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import CommonJobProperties as commonJobProperties +import PostcommitJobBuilder + + +import static PythonTestProperties.CROSS_LANGUAGE_VALIDATES_RUNNER_PYTHON_VERSIONS + + +// This job runs end-to-end cross language GCP IO tests with DirectRunner. +// Collects tests with the @pytest.mark.uses_gcp_java_expansion_service decorator +PostcommitJobBuilder.postCommitJob('beam_PostCommit_Python_Xlang_Gcp_Direct', + 'Run Python_Xlang_Gcp_Direct PostCommit', 'Python_Xlang_Gcp_Direct (\"Run Python_Xlang_Gcp_Direct PostCommit\")', this) { + description('Runs end-to-end cross language GCP IO tests on the Direct runner.') + + + // Set common parameters. + commonJobProperties.setTopLevelMainJobProperties(delegate) + + + // Publish all test results to Jenkins + publishers { + archiveJunit('**/pytest*.xml') + } + + + // Gradle goals for this job. + steps { + CROSS_LANGUAGE_VALIDATES_RUNNER_PYTHON_VERSIONS.each { pythonVersion -> + shell("echo \"Running cross language GCP IO tests with Python ${pythonVersion} on DirectRunner.\"") + gradle { + rootBuildScriptDir(commonJobProperties.checkoutDir) + tasks(":sdks:python:test-suites:direct:py${pythonVersion.replace('.', '')}:gcpCrossLanguagePythonUsingJava") + commonJobProperties.setGradleSwitches(delegate) + } + } + } + } \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md index aa5421360e5f..b094c4ee5f6f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -60,6 +60,7 @@ ## I/Os * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* BigQuery Storage Write API is now available in Python SDK via cross-language ([#21961](https://github.com/apache/beam/issues/21961)). ## New Features / Improvements diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 53fd1807dec4..6663207fba7d 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -318,6 +318,48 @@ class BeamModulePlugin implements Plugin { } } + // A class defining the common properties in a given suite of cross-language tests + // Properties are shared across runners and are used when creating a CrossLanguageUsingJavaExpansionConfiguration object + static class CrossLanguageTaskCommon { + // Used as the task name for cross-language + String name + // The expansion service's project path (required) + String expansionProjectPath + // Collect Python pipeline tests with this marker + String collectMarker + // Job server startup task. + TaskProvider startJobServer + // Job server cleanup task. + TaskProvider cleanupJobServer + } + + // A class defining the configuration for CrossLanguageUsingJavaExpansion. + static class CrossLanguageUsingJavaExpansionConfiguration { + // Task name for cross-language tests using Java expansion. + String name = 'crossLanguageUsingJavaExpansion' + // Python pipeline options to use. + List pythonPipelineOptions = [ + "--runner=PortableRunner", + "--job_endpoint=localhost:8099", + "--environment_cache_millis=10000", + "--experiments=beam_fn_api", + ] + // Additional pytest options + List pytestOptions = [] + // Job server startup task. + TaskProvider startJobServer + // Job server cleanup task. + TaskProvider cleanupJobServer + // Number of parallel test runs. + Integer numParallelTests = 1 + // Whether the pipeline needs --sdk_location option + boolean needsSdkLocation = false + // Project path for the expansion service to start up + String expansionProjectPath + // Collect Python pipeline tests with this marker + String collectMarker + } + // A class defining the configuration for CrossLanguageValidatesRunner. static class CrossLanguageValidatesRunnerConfiguration { // Task name for cross-language validate runner case. @@ -2375,6 +2417,108 @@ class BeamModulePlugin implements Plugin { } } + /** ***********************************************************************************************/ + // Method to create the createCrossLanguageUsingJavaExpansionTask. + // The method takes CrossLanguageUsingJavaExpansionConfiguration as parameter. + // This method creates a task that runs Python SDK pipeline tests that use Java transforms via an input expansion service + project.ext.createCrossLanguageUsingJavaExpansionTask = { + // This task won't work if the python build file doesn't exist. + if (!project.project(":sdks:python").buildFile.exists()) { + System.err.println 'Python build file not found. Skipping createCrossLanguageUsingJavaExpansionTask.' + return + } + def config = it ? it as CrossLanguageUsingJavaExpansionConfiguration : new CrossLanguageUsingJavaExpansionConfiguration() + + project.evaluationDependsOn(":sdks:python") + project.evaluationDependsOn(config.expansionProjectPath) + project.evaluationDependsOn(":runners:core-construction-java") + project.evaluationDependsOn(":sdks:java:extensions:python") + + // Setting up args to launch the expansion service + def envDir = project.project(":sdks:python").envdir + def pythonDir = project.project(":sdks:python").projectDir + def javaExpansionPort = -1 // will be populated in setupTask + def expansionJar = project.project(config.expansionProjectPath).shadowJar.archivePath + def javaClassLookupAllowlistFile = project.project(config.expansionProjectPath).projectDir.getPath() + def expansionServiceOpts = [ + "group_id": project.name, + "java_expansion_service_jar": expansionJar, + "java_expansion_service_allowlist_file": javaClassLookupAllowlistFile, + ] + def javaContainerSuffix + if (JavaVersion.current() == JavaVersion.VERSION_1_8) { + javaContainerSuffix = 'java8' + } else if (JavaVersion.current() == JavaVersion.VERSION_11) { + javaContainerSuffix = 'java11' + } else if (JavaVersion.current() == JavaVersion.VERSION_17) { + javaContainerSuffix = 'java17' + } else { + String exceptionMessage = "Your Java version is unsupported. You need Java version of 8 or 11 or 17 to get started, but your Java version is: " + JavaVersion.current(); + throw new GradleException(exceptionMessage) + } + + // 1. Builds the chosen expansion service jar and launches it + def setupTask = project.tasks.register(config.name+"Setup") { + dependsOn ':sdks:java:container:' + javaContainerSuffix + ':docker' + dependsOn project.project(config.expansionProjectPath).shadowJar.getPath() + dependsOn ":sdks:python:installGcpTest" + doLast { + project.exec { + // Prepare a port to use for the expansion service + javaExpansionPort = getRandomPort() + expansionServiceOpts.put("java_port", javaExpansionPort) + // setup test env + def serviceArgs = project.project(':sdks:python').mapToArgString(expansionServiceOpts) + executable 'sh' + args '-c', "$pythonDir/scripts/run_expansion_services.sh stop --group_id ${project.name} && $pythonDir/scripts/run_expansion_services.sh start $serviceArgs" + } + } + } + + // 2. Sets up, collects, and runs Python pipeline tests + def sdkLocationOpt = [] + if (config.needsSdkLocation) { + setupTask.configure {dependsOn ':sdks:python:sdist'} + sdkLocationOpt = [ + "--sdk_location=${pythonDir}/build/apache-beam.tar.gz" + ] + } + def beamPythonTestPipelineOptions = [ + "pipeline_opts": config.pythonPipelineOptions + sdkLocationOpt, + "test_opts": config.pytestOptions, + "suite": config.name, + "collect": config.collectMarker, + ] + def cmdArgs = project.project(':sdks:python').mapToArgString(beamPythonTestPipelineOptions) + def pythonTask = project.tasks.register(config.name+"PythonUsingJava") { + group = "Verification" + description = "Runs Python SDK pipeline tests that use a Java expansion service" + dependsOn setupTask + dependsOn config.startJobServer + doLast { + project.exec { + environment "EXPANSION_JAR", expansionJar + environment "EXPANSION_PORT", javaExpansionPort + executable 'sh' + args '-c', ". $envDir/bin/activate && cd $pythonDir && ./scripts/run_integration_test.sh $cmdArgs" + } + } + } + + // 3. Shuts down the expansion service + def cleanupTask = project.tasks.register(config.name+'Cleanup', Exec) { + // teardown test env + executable 'sh' + args '-c', "$pythonDir/scripts/run_expansion_services.sh stop --group_id ${project.name}" + } + + setupTask.configure {finalizedBy cleanupTask} + config.startJobServer.configure {finalizedBy config.cleanupJobServer} + + cleanupTask.configure{mustRunAfter pythonTask} + config.cleanupJobServer.configure{mustRunAfter pythonTask} + } + /** ***********************************************************************************************/ // Method to create the crossLanguageValidatesRunnerTask. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 12d3dfca062b..0056aa36c83f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -332,14 +332,16 @@ public static Schema of(Field... fields) { /** Returns an identical Schema with sorted fields. */ public Schema sorted() { // Create a new schema and copy over the appropriate Schema object attributes: - // {fields, uuid, encodingPositions, options} + // {fields, uuid, options} + // Note: encoding positions are not copied over because generally they should align with the + // ordering of field indices. Otherwise, problems may occur when encoding/decoding Rows of + // this schema. Schema sortedSchema = this.fields.stream() .sorted(Comparator.comparing(Field::getName)) .collect(Schema.toSchema()) .withOptions(getOptions()); sortedSchema.setUUID(getUUID()); - sortedSchema.setEncodingPositions(getEncodingPositions()); return sortedSchema; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTest.java index 47746b599259..9797556618ea 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTest.java @@ -221,7 +221,6 @@ public void testSorted() { .addStringField("d") .build() .withOptions(testOptions); - sortedSchema.setEncodingPositions(unorderedSchema.getEncodingPositions()); assertEquals(true, unorderedSchema.equivalent(unorderedSchemaAfterSorting)); assertEquals( diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java index 5b9b50b248a7..c39e2fe04464 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java @@ -70,10 +70,10 @@ public class ExpansionServiceSchemaTransformProviderTest { private static final Schema TEST_SCHEMATRANSFORM_CONFIG_SCHEMA = Schema.of( - Field.of("str1", FieldType.STRING), - Field.of("str2", FieldType.STRING), Field.of("int1", FieldType.INT32), - Field.of("int2", FieldType.INT32)); + Field.of("int2", FieldType.INT32), + Field.of("str1", FieldType.STRING), + Field.of("str2", FieldType.STRING)); private ExpansionService expansionService = new ExpansionService(); @@ -381,10 +381,10 @@ public void testSchemaTransformExpansion() { .values()); Row configRow = Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA) - .withFieldValue("str1", "aaa") - .withFieldValue("str2", "bbb") .withFieldValue("int1", 111) .withFieldValue("int2", 222) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") .build(); ByteStringOutputStream outputStream = new ByteStringOutputStream(); @@ -440,10 +440,10 @@ public void testSchemaTransformExpansionMultiInputMultiOutput() { Row configRow = Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA) - .withFieldValue("str1", "aaa") - .withFieldValue("str2", "bbb") .withFieldValue("int1", 111) .withFieldValue("int2", 222) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") .build(); ByteStringOutputStream outputStream = new ByteStringOutputStream(); diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle index c55b50ef4a63..1288d91964e1 100644 --- a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle +++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle @@ -15,7 +15,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - apply plugin: 'org.apache.beam.module' apply plugin: 'application' mainClassName = "org.apache.beam.sdk.expansion.service.ExpansionService" diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java index 6028d8b9016e..21f368c78343 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -118,7 +118,7 @@ public class BeamRowToStorageApiProto { .put( SqlTypes.DATETIME.getIdentifier(), (logicalType, value) -> - CivilTimeEncoder.encodePacked64DatetimeSeconds((LocalDateTime) value)) + CivilTimeEncoder.encodePacked64DatetimeMicros((LocalDateTime) value)) .put( SqlTypes.TIMESTAMP.getIdentifier(), (logicalType, value) -> (ChronoUnit.MICROS.between(Instant.EPOCH, (Instant) value))) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java index f0caa958df94..ce5043edc6f9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java @@ -512,7 +512,7 @@ public static TableRow convertGenericRecordToTableRow( return BigQueryAvroUtils.convertGenericRecordToTableRow(record, tableSchema); } - /** Convert a BigQuery TableRow to a Beam Row. */ + /** Convert a Beam Row to a BigQuery TableRow. */ public static TableRow toTableRow(Row row) { TableRow output = new TableRow(); for (int i = 0; i < row.getFieldCount(); i++) { @@ -686,7 +686,14 @@ public static Row toBeamRow(Schema rowSchema, TableSchema bqSchema, TableRow jso if (JSON_VALUE_PARSERS.containsKey(fieldType.getTypeName())) { return JSON_VALUE_PARSERS.get(fieldType.getTypeName()).apply(jsonBQString); } else if (fieldType.isLogicalType(SqlTypes.DATETIME.getIdentifier())) { - return LocalDateTime.parse(jsonBQString, BIGQUERY_DATETIME_FORMATTER); + try { + // Handle if datetime value is in micros ie. 123456789 + Long value = Long.parseLong(jsonBQString); + return CivilTimeEncoder.decodePacked64DatetimeMicrosAsJavaTime(value); + } catch (NumberFormatException e) { + // Handle as a String, ie. "2023-02-16 12:00:00" + return LocalDateTime.parse(jsonBQString, BIGQUERY_DATETIME_FORMATTER); + } } else if (fieldType.isLogicalType(SqlTypes.DATE.getIdentifier())) { return LocalDate.parse(jsonBQString); } else if (fieldType.isLogicalType(SqlTypes.TIME.getIdentifier())) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 8cc1bddc10c9..927f9a178c1c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -24,6 +24,7 @@ import com.google.api.services.bigquery.model.TableReference; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -61,7 +62,7 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -90,7 +91,8 @@ public class BigQueryStorageWriteApiSchemaTransformProvider private static final Duration DEFAULT_TRIGGERING_FREQUENCY = Duration.standardSeconds(DEFAULT_TRIGGER_FREQUENCY_SECS); private static final String INPUT_ROWS_TAG = "input"; - private static final String OUTPUT_ERRORS_TAG = "errors"; + private static final String FAILED_ROWS_TAG = "FailedRows"; + private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors"; @Override protected Class configurationClass() { @@ -115,7 +117,7 @@ public List inputCollectionNames() { @Override public List outputCollectionNames() { - return Collections.singletonList(OUTPUT_ERRORS_TAG); + return Arrays.asList(FAILED_ROWS_TAG, FAILED_ROWS_WITH_ERRORS_TAG); } /** Configuration for writing to BigQuery with Storage Write API. */ @@ -147,17 +149,19 @@ public void validate() { // validate create and write dispositions if (!Strings.isNullOrEmpty(this.getCreateDisposition())) { - checkArgument( - CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()) != null, + checkNotNull( + CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()), invalidConfigMessage - + "Invalid create disposition was specified. Available dispositions are: ", + + "Invalid create disposition (%s) was specified. Available dispositions are: %s", + this.getCreateDisposition(), CREATE_DISPOSITIONS.keySet()); } if (!Strings.isNullOrEmpty(this.getWriteDisposition())) { checkNotNull( WRITE_DISPOSITIONS.get(this.getWriteDisposition().toUpperCase()), invalidConfigMessage - + "Invalid write disposition was specified. Available dispositions are: ", + + "Invalid write disposition (%s) was specified. Available dispositions are: %s", + this.getWriteDisposition(), WRITE_DISPOSITIONS.keySet()); } } @@ -329,32 +333,48 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { .setRowSchema(inputSchema) .apply(write); + Schema rowSchema = inputRows.getSchema(); Schema errorSchema = Schema.of( - Field.of("failed_row", FieldType.STRING), + Field.of("failed_row", FieldType.row(rowSchema)), Field.of("error_message", FieldType.STRING)); - // Errors consisting of failed rows along with their error message - PCollection errorRows = + // Failed rows + PCollection failedRows = + result + .getFailedStorageApiInserts() + .apply( + "Construct failed rows", + MapElements.into(TypeDescriptors.rows()) + .via( + (storageError) -> + BigQueryUtils.toBeamRow(rowSchema, storageError.getRow()))) + .setRowSchema(rowSchema); + + // Failed rows with error message + PCollection failedRowsWithErrors = result .getFailedStorageApiInserts() .apply( - "Extract Errors", - MapElements.into(TypeDescriptor.of(Row.class)) + "Construct failed rows and errors", + MapElements.into(TypeDescriptors.rows()) .via( (storageError) -> Row.withSchema(errorSchema) .withFieldValue("error_message", storageError.getErrorMessage()) - .withFieldValue("failed_row", storageError.getRow().toString()) + .withFieldValue( + "failed_row", + BigQueryUtils.toBeamRow(rowSchema, storageError.getRow())) .build())) .setRowSchema(errorSchema); - PCollection errorOutput = - errorRows + PCollection failedRowsOutput = + failedRows .apply("error-count", ParDo.of(new ElementCounterFn("BigQuery-write-error-counter"))) - .setRowSchema(errorSchema); + .setRowSchema(rowSchema); - return PCollectionRowTuple.of(OUTPUT_ERRORS_TAG, errorOutput); + return PCollectionRowTuple.of(FAILED_ROWS_TAG, failedRowsOutput) + .and(FAILED_ROWS_WITH_ERRORS_TAG, failedRowsWithErrors); } BigQueryIO.Write createStorageWriteApiTransform() { @@ -375,13 +395,13 @@ BigQueryIO.Write createStorageWriteApiTransform() { if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = BigQueryStorageWriteApiSchemaTransformConfiguration.CREATE_DISPOSITIONS.get( - configuration.getCreateDisposition()); + configuration.getCreateDisposition().toUpperCase()); write = write.withCreateDisposition(createDisposition); } if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { WriteDisposition writeDisposition = BigQueryStorageWriteApiSchemaTransformConfiguration.WRITE_DISPOSITIONS.get( - configuration.getWriteDisposition()); + configuration.getWriteDisposition().toUpperCase()); write = write.withWriteDisposition(writeDisposition); } @@ -407,7 +427,7 @@ private void validateSchema( table = BigQueryHelpers.getTable(options, tableRef); } if (table == null) { - LOG.info("Table not found and skipped schema validation: " + tableRef.getTableId()); + LOG.info("Table [{}] not found, skipping schema validation.", tableRef.getTableId()); return; } Schema outputSchema = BigQueryUtils.fromTableSchema(table.getSchema()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java index ca82dc9dae6b..c8b8a3cb6cb1 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -258,7 +258,7 @@ public class BeamRowToStorageApiProtoTest { BASE_ROW.getLogicalTypeValue("sqlTimeValue", LocalTime.class))) .put( "sqldatetimevalue", - CivilTimeEncoder.encodePacked64DatetimeSeconds( + CivilTimeEncoder.encodePacked64DatetimeMicros( BASE_ROW.getLogicalTypeValue("sqlDatetimeValue", LocalDateTime.class))) .put( "sqltimestampvalue", diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index 133097e16842..fef2bb168c8f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -20,12 +20,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import com.google.api.services.bigquery.model.Table; import com.google.api.services.bigquery.model.TableReference; import com.google.api.services.bigquery.model.TableRow; import java.io.Serializable; import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -49,13 +51,12 @@ import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; +import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -78,6 +79,24 @@ public class BigQueryStorageWriteApiSchemaTransformProviderTest { Field.of("number", FieldType.INT64), Field.of("dt", FieldType.logicalType(SqlTypes.DATETIME))); + private static final List ROWS = + Arrays.asList( + Row.withSchema(SCHEMA) + .withFieldValue("name", "a") + .withFieldValue("number", 1L) + .withFieldValue("dt", LocalDateTime.parse("2000-01-01T00:00:00")) + .build(), + Row.withSchema(SCHEMA) + .withFieldValue("name", "b") + .withFieldValue("number", 2L) + .withFieldValue("dt", LocalDateTime.parse("2000-01-02T00:00:00.123")) + .build(), + Row.withSchema(SCHEMA) + .withFieldValue("name", "c") + .withFieldValue("number", 3L) + .withFieldValue("dt", LocalDateTime.parse("2000-01-03T00:00:00.123456")) + .build()); + private static final Schema SCHEMA_WRONG = Schema.of( Field.of("name_wrong", FieldType.STRING), @@ -113,6 +132,11 @@ public void testInvalidConfig() { public PCollectionRowTuple runWithConfig( BigQueryStorageWriteApiSchemaTransformConfiguration config) { + return runWithConfig(config, ROWS); + } + + public PCollectionRowTuple runWithConfig( + BigQueryStorageWriteApiSchemaTransformConfiguration config, List inputRows) { BigQueryStorageWriteApiSchemaTransformProvider provider = new BigQueryStorageWriteApiSchemaTransformProvider(); @@ -120,28 +144,10 @@ public PCollectionRowTuple runWithConfig( (BigQueryStorageWriteApiPCollectionRowTupleTransform) provider.from(config).buildTransform(); - List testRows = - Arrays.asList( - Row.withSchema(SCHEMA) - .withFieldValue("name", "a") - .withFieldValue("number", 1L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-01T00:00:00")) - .build(), - Row.withSchema(SCHEMA) - .withFieldValue("name", "b") - .withFieldValue("number", 2L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-02T00:00:00")) - .build(), - Row.withSchema(SCHEMA) - .withFieldValue("name", "c") - .withFieldValue("number", 3L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-03T00:00:00")) - .build()); - writeRowTupleTransform.setBigQueryServices(fakeBigQueryServices); String tag = provider.inputCollectionNames().get(0); - PCollection rows = p.apply(Create.of(testRows).withRowSchema(SCHEMA)); + PCollection rows = p.apply(Create.of(inputRows).withRowSchema(SCHEMA)); PCollectionRowTuple input = PCollectionRowTuple.of(tag, rows); PCollectionRowTuple result = input.apply(writeRowTupleTransform); @@ -149,17 +155,38 @@ public PCollectionRowTuple runWithConfig( return result; } + public Boolean rowsEquals(List expectedRows, List actualRows) { + if (expectedRows.size() != actualRows.size()) { + return false; + } + for (int i = 0; i < expectedRows.size(); i++) { + // Actual rows may come back out of order. For each TableRow, find its "number" column value + // and match it to the index of the expected row. + TableRow actualRow = actualRows.get(i); + Row expectedRow = expectedRows.get(Integer.parseInt(actualRow.get("number").toString()) - 1); + + if (!expectedRow.getValue("name").equals(actualRow.get("name")) + || !expectedRow + .getValue("number") + .equals(Long.parseLong(actualRow.get("number").toString()))) { + return false; + } + } + return true; + } + @Test public void testSimpleWrite() throws Exception { String tableSpec = "project:dataset.simple_write"; BigQueryStorageWriteApiSchemaTransformConfiguration config = BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); - runWithConfig(config); + runWithConfig(config, ROWS); p.run().waitUntilFinish(); assertNotNull(fakeDatasetService.getTable(BigQueryHelpers.parseTableSpec(tableSpec))); - assertEquals(3, fakeDatasetService.getAllRows("project", "dataset", "simple_write").size()); + assertTrue( + rowsEquals(ROWS, fakeDatasetService.getAllRows("project", "dataset", "simple_write"))); } @Test @@ -249,63 +276,43 @@ public void testInputElementCount() throws Exception { } } - public PCollectionRowTuple runWithError( - BigQueryStorageWriteApiSchemaTransformConfiguration config) { - BigQueryStorageWriteApiSchemaTransformProvider provider = - new BigQueryStorageWriteApiSchemaTransformProvider(); + @Test + public void testFailedRows() throws Exception { + String tableSpec = "project:dataset.write_with_fail"; + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); - BigQueryStorageWriteApiPCollectionRowTupleTransform writeRowTupleTransform = - (BigQueryStorageWriteApiPCollectionRowTupleTransform) - provider.from(config).buildTransform(); + String failValue = "fail_me"; + + List expectedSuccessfulRows = new ArrayList<>(ROWS); + List expectedFailedRows = new ArrayList<>(); + for (long l = 1L; l <= 3L; l++) { + expectedFailedRows.add( + Row.withSchema(SCHEMA) + .withFieldValue("name", failValue) + .withFieldValue("number", l) + .withFieldValue("dt", LocalDateTime.parse("2020-01-01T00:00:00.09")) + .build()); + } + + List totalRows = new ArrayList<>(expectedSuccessfulRows); + totalRows.addAll(expectedFailedRows); Function shouldFailRow = - (Function & Serializable) tr -> tr.get("name").equals("a"); + (Function & Serializable) tr -> tr.get("name").equals(failValue); fakeDatasetService.setShouldFailRow(shouldFailRow); - TableRow row1 = - new TableRow() - .set("name", "a") - .set("number", 1L) - .set("dt", LocalDateTime.parse("2000-01-01T00:00:00")); - TableRow row2 = - new TableRow() - .set("name", "b") - .set("number", 2L) - .set("dt", LocalDateTime.parse("2000-01-02T00:00:00")); - TableRow row3 = - new TableRow() - .set("name", "c") - .set("number", 3L) - .set("dt", LocalDateTime.parse("2000-01-03T00:00:00")); + PCollectionRowTuple result = runWithConfig(config, totalRows); + PCollection failedRows = result.get("FailedRows"); - writeRowTupleTransform.setBigQueryServices(fakeBigQueryServices); - String tag = provider.inputCollectionNames().get(0); - - PCollection rows = - p.apply(Create.of(row1, row2, row3)) - .apply( - MapElements.into(TypeDescriptor.of(Row.class)) - .via((tableRow) -> BigQueryUtils.toBeamRow(SCHEMA, tableRow))) - .setRowSchema(SCHEMA); - - PCollectionRowTuple input = PCollectionRowTuple.of(tag, rows); - PCollectionRowTuple result = input.apply(writeRowTupleTransform); - - return result; - } - - @Test - public void testSimpleWriteWithFailure() throws Exception { - String tableSpec = "project:dataset.simple_write_with_failure"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); - - runWithError(config); + PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows); p.run().waitUntilFinish(); assertNotNull(fakeDatasetService.getTable(BigQueryHelpers.parseTableSpec(tableSpec))); - assertEquals( - 2, fakeDatasetService.getAllRows("project", "dataset", "simple_write_with_failure").size()); + assertTrue( + rowsEquals( + expectedSuccessfulRows, + fakeDatasetService.getAllRows("project", "dataset", "write_with_fail"))); } @Test @@ -314,7 +321,11 @@ public void testErrorCount() throws Exception { BigQueryStorageWriteApiSchemaTransformConfiguration config = BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); - runWithError(config); + Function shouldFailRow = + (Function & Serializable) tr -> tr.get("name").equals("a"); + fakeDatasetService.setShouldFailRow(shouldFailRow); + + runWithConfig(config); PipelineResult result = p.run(); MetricResults metrics = result.metrics(); diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 5fa10d7a6883..bbeedba0021e 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -302,6 +302,18 @@ def chain_after(result): result.destination_copy_jobid_pairs <--> result['destination_copy_jobid_pairs'] ``` +Writing with Storage Write API using Cross Language +--------------------------------------------------- +This sink is able to write with BigQuery's Storage Write API. To do so, specify +the method `WriteToBigQuery.Method.STORAGE_WRITE_API`. This will use the +StorageWriteToBigQuery() transform to discover and use the Java implementation. +Using this transform directly will require the use of beam.Row() elements. + +Similar to streaming inserts, it returns two dead-letter queue PCollections: +one containing just the failed rows and the other containing failed rows and +errors. They can be accessed with `failed_rows` and `failed_rows_with_errors`, +respectively. See the examples above for how to do this. + *** Short introduction to BigQuery concepts *** Tables have rows (TableRow) and each row has cells (TableCell). @@ -397,10 +409,13 @@ def chain_after(result): from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform from apache_beam.transforms.display import DisplayDataItem +from apache_beam.transforms.external import BeamJarExpansionService +from apache_beam.transforms.external import SchemaAwareExternalTransform from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX from apache_beam.transforms.sideinputs import get_sideinput_index from apache_beam.transforms.util import ReshufflePerKey from apache_beam.transforms.window import GlobalWindows +from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.utils import retry from apache_beam.utils.annotations import deprecated from apache_beam.utils.annotations import experimental @@ -432,6 +447,7 @@ def chain_after(result): 'BigQueryQueryPriority', 'WriteToBigQuery', 'WriteResult', + 'StorageWriteToBigQuery', 'ReadFromBigQuery', 'ReadFromBigQueryRequest', 'ReadAllFromBigQuery', @@ -1756,6 +1772,7 @@ class Method(object): DEFAULT = 'DEFAULT' STREAMING_INSERTS = 'STREAMING_INSERTS' FILE_LOADS = 'FILE_LOADS' + STORAGE_WRITE_API = 'STORAGE_WRITE_API' def __init__( self, @@ -1777,6 +1794,7 @@ def __init__( table_side_inputs=None, schema_side_inputs=None, triggering_frequency=None, + use_at_least_once=False, validate=True, temp_file_format=None, ignore_insert_ids=False, @@ -1785,7 +1803,8 @@ def __init__( with_auto_sharding=False, ignore_unknown_columns=False, load_job_project_id=None, - num_streaming_keys=DEFAULT_SHARDS_PER_DESTINATION): + num_streaming_keys=DEFAULT_SHARDS_PER_DESTINATION, + expansion_service=None): """Initialize a WriteToBigQuery transform. Args: @@ -1853,8 +1872,9 @@ def __init__( temp_location, but for pipelines whose temp_location is not appropriate for BQ File Loads, users should pass a specific one. method: The method to use to write to BigQuery. It may be - STREAMING_INSERTS, FILE_LOADS, or DEFAULT. An introduction on loading - data to BigQuery: https://cloud.google.com/bigquery/docs/loading-data. + STREAMING_INSERTS, FILE_LOADS, STORAGE_WRITE_API or DEFAULT. An + introduction on loading data to BigQuery: + https://cloud.google.com/bigquery/docs/loading-data. DEFAULT will use STREAMING_INSERTS on Streaming pipelines and FILE_LOADS on Batch pipelines. Note: FILE_LOADS currently does not support BigQuery's JSON data type: @@ -1903,6 +1923,13 @@ def __init__( triggering_frequency seconds when data is waiting. The batch can be sent earlier if it reaches the maximum batch size set by batch_size. Default value is 0.2 seconds. + + When method is STORAGE_WRITE_API: + A stream of rows will be committed every triggering_frequency seconds. + By default, this will be 5 seconds to ensure exactly-once semantics. + use_at_least_once: Intended only for STORAGE_WRITE_API. When True, will + use at-least-once semantics. This is cheaper and provides lower + latency, but will potentially duplicate records. validate: Indicates whether to perform validation checks on inputs. This parameter is primarily used for testing. temp_file_format: The format to use for file loads into BigQuery. The @@ -1920,8 +1947,8 @@ def __init__( https://cloud.google.com/bigquery/streaming-data-into-bigquery#disabling_best_effort_de-duplication with_auto_sharding: Experimental. If true, enables using a dynamically determined number of shards to write to BigQuery. This can be used for - both FILE_LOADS and STREAMING_INSERTS. Only applicable to unbounded - input. + all of FILE_LOADS, STREAMING_INSERTS, and STORAGE_WRITE_API. Only + applicable to unbounded input. ignore_unknown_columns: Accept rows that contain values that do not match the schema. The unknown values are ignored. Default is False, which treats unknown values as errors. This option is only valid for @@ -1932,6 +1959,9 @@ def __init__( used. num_streaming_keys: The number of shards per destination when writing via streaming inserts. + expansion_service: The address (host:port) of the expansion service. + If no expansion service is provided, will attempt to run the default + GCP expansion service. Used for STORAGE_WRITE_API method. """ self._table = table self._dataset = dataset @@ -1956,6 +1986,8 @@ def __init__( self.max_files_per_bundle = max_files_per_bundle self.method = method or WriteToBigQuery.Method.DEFAULT self.triggering_frequency = triggering_frequency + self.use_at_least_once = use_at_least_once + self.expansion_service = expansion_service self.with_auto_sharding = with_auto_sharding self.insert_retry_strategy = insert_retry_strategy self._validate = validate @@ -2038,7 +2070,7 @@ def expand(self, pcoll): failed_rows=outputs[BigQueryWriteFn.FAILED_ROWS], failed_rows_with_errors=outputs[ BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS]) - else: + elif method_to_use == WriteToBigQuery.Method.FILE_LOADS: if self._temp_file_format == bigquery_tools.FileFormat.AVRO: if self.schema == SCHEMA_AUTODETECT: raise ValueError( @@ -2102,6 +2134,57 @@ def find_in_nested_dict(schema): BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS], destination_copy_jobid_pairs=output[ BigQueryBatchFileLoads.DESTINATION_COPY_JOBID_PAIRS]) + else: + # Storage Write API + if self.schema is None: + raise AttributeError( + "A schema is required in order to prepare rows" + "for writing with STORAGE_WRITE_API.") + if callable(self.schema): + raise NotImplementedError( + "Writing to dynamic destinations is not" + "supported for this write method.") + elif isinstance(self.schema, vp.ValueProvider): + schema = self.schema.get() + else: + schema = self.schema + + table = bigquery_tools.get_hashable_destination(self.table_reference) + # None type is not supported + triggering_frequency = self.triggering_frequency or 0 + # SchemaTransform expects Beam Rows, so map to Rows first + output_beam_rows = ( + pcoll + | + beam.Map(lambda row: bigquery_tools.beam_row_from_dict(row, schema)). + with_output_types( + RowTypeConstraint.from_fields( + bigquery_tools.get_beam_typehints_from_tableschema(schema))) + | StorageWriteToBigQuery( + table=table, + create_disposition=self.create_disposition, + write_disposition=self.write_disposition, + triggering_frequency=triggering_frequency, + use_at_least_once=self.use_at_least_once, + with_auto_sharding=self.with_auto_sharding, + expansion_service=self.expansion_service)) + + # return back from Beam Rows to Python dict elements + failed_rows = ( + output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS] + | beam.Map(lambda row: row.as_dict())) + failed_rows_with_errors = ( + output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS] + | beam.Map( + lambda row: { + "error_message": row.error_message, + "failed_row": row.failed_row.as_dict() + })) + + return WriteResult( + method=WriteToBigQuery.Method.STORAGE_WRITE_API, + failed_rows=failed_rows, + failed_rows_with_errors=failed_rows_with_errors) def display_data(self): res = {} @@ -2221,12 +2304,12 @@ def __init__( destination_copy_jobid_pairs, } - def validate(self, method, attribute): - if self._method != method: + def validate(self, valid_methods, attribute): + if self._method not in valid_methods: raise AttributeError( f'Cannot get {attribute} because it is not produced ' f'by the {self._method} write method. Note: only ' - f'{method} produces this attribute.') + f'{valid_methods} produces this attribute.') @property def destination_load_jobid_pairs( @@ -2238,7 +2321,8 @@ def destination_load_jobid_pairs( Raises: AttributeError: if accessed with a write method besides ``FILE_LOADS``.""" - self.validate(WriteToBigQuery.Method.FILE_LOADS, 'DESTINATION_JOBID_PAIRS') + self.validate([WriteToBigQuery.Method.FILE_LOADS], + 'DESTINATION_JOBID_PAIRS') return self._destination_load_jobid_pairs @@ -2251,7 +2335,7 @@ def destination_file_pairs(self) -> PCollection[Tuple[str, Tuple[str, int]]]: Raises: AttributeError: if accessed with a write method besides ``FILE_LOADS``.""" - self.validate(WriteToBigQuery.Method.FILE_LOADS, 'DESTINATION_FILE_PAIRS') + self.validate([WriteToBigQuery.Method.FILE_LOADS], 'DESTINATION_FILE_PAIRS') return self._destination_file_pairs @@ -2265,26 +2349,30 @@ def destination_copy_jobid_pairs( Raises: AttributeError: if accessed with a write method besides ``FILE_LOADS``.""" - self.validate( - WriteToBigQuery.Method.FILE_LOADS, 'DESTINATION_COPY_JOBID_PAIRS') + self.validate([WriteToBigQuery.Method.FILE_LOADS], + 'DESTINATION_COPY_JOBID_PAIRS') return self._destination_copy_jobid_pairs @property def failed_rows(self) -> PCollection[Tuple[str, dict]]: - """A ``STREAMING_INSERTS`` method attribute + """A ``[STREAMING_INSERTS, STORAGE_WRITE_API]`` method attribute Returns: A PCollection of rows that failed when inserting to BigQuery. Raises: AttributeError: if accessed with a write method - besides ``STREAMING_INSERTS``.""" - self.validate(WriteToBigQuery.Method.STREAMING_INSERTS, 'FAILED_ROWS') + besides ``[STREAMING_INSERTS, STORAGE_WRITE_API]``.""" + self.validate([ + WriteToBigQuery.Method.STREAMING_INSERTS, + WriteToBigQuery.Method.STORAGE_WRITE_API + ], + 'FAILED_ROWS') return self._failed_rows @property def failed_rows_with_errors(self) -> PCollection[Tuple[str, dict, list]]: - """A ``STREAMING_INSERTS`` method attribute + """A ``[STREAMING_INSERTS, STORAGE_WRITE_API]`` method attribute Returns: A PCollection of rows that failed when inserting to BigQuery, @@ -2292,9 +2380,12 @@ def failed_rows_with_errors(self) -> PCollection[Tuple[str, dict, list]]: Raises: AttributeError: if accessed with a write method - besides ``STREAMING_INSERTS``.""" - self.validate( - WriteToBigQuery.Method.STREAMING_INSERTS, 'FAILED_ROWS_WITH_ERRORS') + besides ``[STREAMING_INSERTS, STORAGE_WRITE_API]``.""" + self.validate([ + WriteToBigQuery.Method.STREAMING_INSERTS, + WriteToBigQuery.Method.STORAGE_WRITE_API + ], + 'FAILED_ROWS_WITH_ERRORS') return self._failed_rows_with_errors @@ -2307,6 +2398,88 @@ def __getitem__(self, key): return self.attributes[key].__get__(self, WriteResult) +def _default_io_expansion_service(append_args=None): + return BeamJarExpansionService( + 'sdks:java:io:google-cloud-platform:expansion-service:build', + append_args=append_args) + + +class StorageWriteToBigQuery(PTransform): + """Writes data to BigQuery using Storage API. + + Experimental; no backwards compatibility guarantees. + """ + URN = "beam:schematransform:org.apache.beam:bigquery_storage_write:v1" + FAILED_ROWS = "FailedRows" + FAILED_ROWS_WITH_ERRORS = "FailedRowsWithErrors" + + def __init__( + self, + table, + create_disposition=BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=BigQueryDisposition.WRITE_APPEND, + triggering_frequency=0, + use_at_least_once=False, + with_auto_sharding=False, + expansion_service=None): + """Initialize a StorageWriteToBigQuery transform. + + :param table: + Fully-qualified table ID specified as ``'PROJECT:DATASET.TABLE'``. + :param create_disposition: + String specifying the strategy to take when the table doesn't + exist. Possible values are: + * ``'CREATE_IF_NEEDED'``: create if does not exist. + * ``'CREATE_NEVER'``: fail the write if does not exist. + :param write_disposition: + String specifying the strategy to take when the table already + contains data. Possible values are: + * ``'WRITE_TRUNCATE'``: delete existing rows. + * ``'WRITE_APPEND'``: add to existing rows. + * ``'WRITE_EMPTY'``: fail the write if table not empty. + :param triggering_frequency: + The time in seconds between write commits. Should only be specified + for streaming pipelines. Defaults to 5 seconds. + :param use_at_least_once: + Use at-least-once semantics. Is cheaper and provides lower latency, + but will potentially duplicate records. + :param with_auto_sharding: + Experimental. If true, enables using a dynamically determined number of + shards to write to BigQuery. Only applicable to unbounded input. + :param expansion_service: + The address (host:port) of the expansion service. If no expansion + service is provided, will attempt to run the default GCP expansion + service. + """ + super().__init__() + self._table = table + self._create_disposition = create_disposition + self._write_disposition = write_disposition + self._triggering_frequency = triggering_frequency + self._use_at_least_once = use_at_least_once + self._with_auto_sharding = with_auto_sharding + self._expansion_service = ( + expansion_service or _default_io_expansion_service()) + self.schematransform_config = SchemaAwareExternalTransform.discover_config( + self._expansion_service, self.URN) + + def expand(self, input): + external_storage_write = SchemaAwareExternalTransform( + identifier=self.schematransform_config.identifier, + expansion_service=self._expansion_service, + autoSharding=self._with_auto_sharding, + createDisposition=self._create_disposition, + table=self._table, + triggeringFrequencySeconds=self._triggering_frequency, + useAtLeastOnceSemantics=self._use_at_least_once, + writeDisposition=self._write_disposition, + ) + + input_tag = self.schematransform_config.inputs[0] + + return {input_tag: input} | external_storage_write + + class ReadFromBigQuery(PTransform): """Read data from BigQuery. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index edffd86e5c17..552c22f1f770 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -36,11 +36,14 @@ import time import uuid from json.decoder import JSONDecodeError +from typing import Optional +from typing import Sequence from typing import Tuple from typing import TypeVar from typing import Union import fastavro +import numpy as np import regex import apache_beam @@ -58,6 +61,7 @@ from apache_beam.options import value_provider from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.transforms import DoFn +from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.typehints import Any from apache_beam.utils import retry from apache_beam.utils.histogram import LinearBucket @@ -103,6 +107,21 @@ _DATASET_PATTERN = r'\w{1,1024}' _TABLE_PATTERN = r'[\p{L}\p{M}\p{N}\p{Pc}\p{Pd}\p{Zs}$]{1,1024}' +# TODO(https://github.com/apache/beam/issues/25946): Add support for +# more Beam portable schema types as Python types +BIGQUERY_TYPE_TO_PYTHON_TYPE = { + "STRING": str, + "BOOL": bool, + "BOOLEAN": bool, + "BYTES": bytes, + "INT64": np.int64, + "INTEGER": np.int64, + "FLOAT64": np.float64, + "FLOAT": np.float64, + "NUMERIC": decimal.Decimal, + "TIMESTAMP": apache_beam.utils.timestamp.Timestamp, +} + class FileFormat(object): CSV = 'CSV' @@ -1514,6 +1533,42 @@ def process(self, element, *side_inputs): yield (self.destination(element, *side_inputs), element) +def beam_row_from_dict(row: dict, schema): + """Converts a dictionary row to a Beam Row. + Nested records and lists are supported. + + Args: + row (dict): + The row to convert. + schema (str, dict, ~apache_beam.io.gcp.internal.clients.bigquery.\ +bigquery_v2_messages.TableSchema): + The table schema. Will be used to help convert the row. + + Returns: + ~apache_beam.pvalue.Row: The converted row. + """ + if not isinstance(schema, (bigquery.TableSchema, bigquery.TableFieldSchema)): + schema = get_bq_tableschema(schema) + schema_fields = {field.name: field for field in schema.fields} + beam_row = {} + for col_name, value in row.items(): + # get this column's schema field and handle struct types + field = schema_fields[col_name] + if field.type.upper() in ["RECORD", "STRUCT"]: + # if this is a list of records, we create a list of Beam Rows + if field.mode.upper() == "REPEATED": + list_of_beam_rows = [] + for record in value: + list_of_beam_rows.append(beam_row_from_dict(record, field)) + beam_row[col_name] = list_of_beam_rows + # otherwise, create a Beam Row from this record + else: + beam_row[col_name] = beam_row_from_dict(value, field) + else: + beam_row[col_name] = value + return apache_beam.pvalue.Row(**beam_row) + + def get_table_schema_from_string(schema): """Transform the string table schema into a :class:`~apache_beam.io.gcp.internal.clients.bigquery.\ @@ -1591,6 +1646,32 @@ def get_dict_table_schema(schema): raise TypeError('Unexpected schema argument: %s.' % schema) +def get_bq_tableschema(schema): + """Convert the table schema to a TableSchema object. + + Args: + schema (str, dict, ~apache_beam.io.gcp.internal.clients.bigquery.\ +bigquery_v2_messages.TableSchema): + The schema to be used if the BigQuery table to write has to be created. + This can either be a dict or string or in the TableSchema format. + + Returns: + ~apache_beam.io.gcp.internal.clients.bigquery.\ +bigquery_v2_messages.TableSchema: The schema as a TableSchema object. + """ + if (isinstance(schema, + (bigquery.TableSchema, value_provider.ValueProvider)) or + callable(schema) or schema is None): + return schema + elif isinstance(schema, str): + return get_table_schema_from_string(schema) + elif isinstance(schema, dict): + schema_string = json.dumps(schema) + return parse_table_schema_from_json(schema_string) + else: + raise TypeError('Unexpected schema argument: %s.' % schema) + + def get_avro_schema_from_table_schema(schema): """Transform the table schema into an Avro schema. @@ -1608,6 +1689,44 @@ def get_avro_schema_from_table_schema(schema): "root", dict_table_schema) +def get_beam_typehints_from_tableschema(schema): + """Extracts Beam Python type hints from the schema. + + Args: + schema (~apache_beam.io.gcp.internal.clients.bigquery.\ +bigquery_v2_messages.TableSchema): + The TableSchema to extract type hints from. + + Returns: + List[Tuple[str, Any]]: A list of type hints that describe the input schema. + Nested and repeated fields are supported. + """ + if not isinstance(schema, (bigquery.TableSchema, bigquery.TableFieldSchema)): + schema = get_bq_tableschema(schema) + typehints = [] + for field in schema.fields: + name, field_type, mode = field.name, field.type.upper(), field.mode.upper() + + if field_type in ["STRUCT", "RECORD"]: + # Structs can be represented as Beam Rows. + typehint = RowTypeConstraint.from_fields( + get_beam_typehints_from_tableschema(field)) + elif field_type in BIGQUERY_TYPE_TO_PYTHON_TYPE: + typehint = BIGQUERY_TYPE_TO_PYTHON_TYPE[field_type] + else: + raise ValueError( + f"Converting BigQuery type [{field_type}] to " + "Python Beam type is not supported.") + + if mode == "REPEATED": + typehint = Sequence[typehint] + elif mode != "REQUIRED": + typehint = Optional[typehint] + + typehints.append((name, typehint)) + return typehints + + class BigQueryJobTypes: EXPORT = 'EXPORT' COPY = 'COPY' diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index cf533265d7be..a3e39d8e18d1 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -25,9 +25,12 @@ import math import re import unittest +from typing import Optional +from typing import Sequence import fastavro import mock +import numpy as np import pytz from parameterized import parameterized @@ -38,14 +41,18 @@ from apache_beam.io.gcp.bigquery_tools import BigQueryJobTypes from apache_beam.io.gcp.bigquery_tools import JsonRowWriter from apache_beam.io.gcp.bigquery_tools import RowAsDictJsonCoder +from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict from apache_beam.io.gcp.bigquery_tools import check_schema_equal from apache_beam.io.gcp.bigquery_tools import generate_bq_job_name +from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema from apache_beam.io.gcp.bigquery_tools import parse_table_reference from apache_beam.io.gcp.bigquery_tools import parse_table_schema_from_json from apache_beam.io.gcp.internal.clients import bigquery from apache_beam.metrics import monitoring_infos from apache_beam.metrics.execution import MetricsEnvironment from apache_beam.options.value_provider import StaticValueProvider +from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.utils.timestamp import Timestamp # Protect against environments where bigquery library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -784,6 +791,237 @@ def test_descriptions(self): check_schema_equal(schema1, schema2, ignore_descriptions=True)) +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') +class TestBeamRowFromDict(unittest.TestCase): + DICT_ROW = { + "str": "a", + "bool": True, + "bytes": b'a', + "int": 1, + "float": 0.1, + "numeric": decimal.Decimal("1.11"), + "timestamp": Timestamp(1000, 100) + } + + def get_schema_fields_with_mode(self, mode): + return [{ + "name": "str", "type": "STRING", "mode": mode + }, { + "name": "bool", "type": "boolean", "mode": mode + }, { + "name": "bytes", "type": "BYTES", "mode": mode + }, { + "name": "int", "type": "INTEGER", "mode": mode + }, { + "name": "float", "type": "Float", "mode": mode + }, { + "name": "numeric", "type": "NUMERIC", "mode": mode + }, { + "name": "timestamp", "type": "TIMESTAMP", "mode": mode + }] + + def test_dict_to_beam_row_all_types_required(self): + schema = {"fields": self.get_schema_fields_with_mode("REQUIRED")} + expected_beam_row = beam.Row( + str="a", + bool=True, + bytes=b'a', + int=1, + float=0.1, + numeric=decimal.Decimal("1.11"), + timestamp=Timestamp(1000, 100)) + + self.assertEqual( + expected_beam_row, beam_row_from_dict(self.DICT_ROW, schema)) + + def test_dict_to_beam_row_all_types_repeated(self): + schema = {"fields": self.get_schema_fields_with_mode("REPEATED")} + dict_row = { + "str": ["a", "b"], + "bool": [True, False], + "bytes": [b'a', b'b'], + "int": [1, 2], + "float": [0.1, 0.2], + "numeric": [decimal.Decimal("1.11"), decimal.Decimal("2.22")], + "timestamp": [Timestamp(1000, 100), Timestamp(2000, 200)] + } + + expected_beam_row = beam.Row( + str=["a", "b"], + bool=[True, False], + bytes=[b'a', b'b'], + int=[1, 2], + float=[0.1, 0.2], + numeric=[decimal.Decimal("1.11"), decimal.Decimal("2.22")], + timestamp=[Timestamp(1000, 100), Timestamp(2000, 200)]) + + self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) + + def test_dict_to_beam_row_all_types_nullable(self): + schema = {"fields": self.get_schema_fields_with_mode("nullable")} + dict_row = {k: None for k in self.DICT_ROW} + + expected_beam_row = beam.Row( + str=None, + bool=None, + bytes=None, + int=None, + float=None, + numeric=None, + timestamp=None) + + self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) + + def test_dict_to_beam_row_nested_record(self): + schema_fields_with_nested = [{ + "name": "nested_record", + "type": "record", + "fields": self.get_schema_fields_with_mode("required") + }] + schema_fields_with_nested.extend( + self.get_schema_fields_with_mode("required")) + schema = {"fields": schema_fields_with_nested} + + dict_row = { + "nested_record": self.DICT_ROW, + "str": "a", + "bool": True, + "bytes": b'a', + "int": 1, + "float": 0.1, + "numeric": decimal.Decimal("1.11"), + "timestamp": Timestamp(1000, 100) + } + expected_beam_row = beam.Row( + nested_record=beam.Row( + str="a", + bool=True, + bytes=b'a', + int=1, + float=0.1, + numeric=decimal.Decimal("1.11"), + timestamp=Timestamp(1000, 100)), + str="a", + bool=True, + bytes=b'a', + int=1, + float=0.1, + numeric=decimal.Decimal("1.11"), + timestamp=Timestamp(1000, 100)) + + self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) + + def test_dict_to_beam_row_repeated_nested_record(self): + schema_fields_with_repeated_nested_record = [{ + "name": "nested_repeated_record", + "type": "record", + "mode": "repeated", + "fields": self.get_schema_fields_with_mode("required") + }] + schema = {"fields": schema_fields_with_repeated_nested_record} + + dict_row = { + "nested_repeated_record": [self.DICT_ROW, self.DICT_ROW, self.DICT_ROW], + } + + beam_row = beam.Row( + str="a", + bool=True, + bytes=b'a', + int=1, + float=0.1, + numeric=decimal.Decimal("1.11"), + timestamp=Timestamp(1000, 100)) + expected_beam_row = beam.Row( + nested_repeated_record=[beam_row, beam_row, beam_row]) + + self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) + + +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') +class TestBeamTypehintFromSchema(unittest.TestCase): + EXPECTED_TYPEHINTS = [("str", str), ("bool", bool), ("bytes", bytes), + ("int", np.int64), ("float", np.float64), + ("numeric", decimal.Decimal), ("timestamp", Timestamp)] + + def get_schema_fields_with_mode(self, mode): + return [{ + "name": "str", "type": "STRING", "mode": mode + }, { + "name": "bool", "type": "boolean", "mode": mode + }, { + "name": "bytes", "type": "BYTES", "mode": mode + }, { + "name": "int", "type": "INTEGER", "mode": mode + }, { + "name": "float", "type": "Float", "mode": mode + }, { + "name": "numeric", "type": "NUMERIC", "mode": mode + }, { + "name": "timestamp", "type": "TIMESTAMP", "mode": mode + }] + + def test_typehints_from_required_schema(self): + schema = {"fields": self.get_schema_fields_with_mode("required")} + typehints = get_beam_typehints_from_tableschema(schema) + + self.assertEqual(typehints, self.EXPECTED_TYPEHINTS) + + def test_typehints_from_repeated_schema(self): + schema = {"fields": self.get_schema_fields_with_mode("repeated")} + typehints = get_beam_typehints_from_tableschema(schema) + + expected_repeated_typehints = [ + (name, Sequence[type]) for name, type in self.EXPECTED_TYPEHINTS + ] + + self.assertEqual(typehints, expected_repeated_typehints) + + def test_typehints_from_nullable_schema(self): + schema = {"fields": self.get_schema_fields_with_mode("nullable")} + typehints = get_beam_typehints_from_tableschema(schema) + + expected_nullable_typehints = [ + (name, Optional[type]) for name, type in self.EXPECTED_TYPEHINTS + ] + + self.assertEqual(typehints, expected_nullable_typehints) + + def test_typehints_from_schema_with_struct(self): + schema = { + "fields": [{ + "name": "record", + "type": "record", + "mode": "required", + "fields": self.get_schema_fields_with_mode("required") + }] + } + typehints = get_beam_typehints_from_tableschema(schema) + + expected_typehints = [ + ("record", RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS)) + ] + + self.assertEqual(typehints, expected_typehints) + + def test_typehints_from_schema_with_repeated_struct(self): + schema = { + "fields": [{ + "name": "record", + "type": "record", + "mode": "repeated", + "fields": self.get_schema_fields_with_mode("required") + }] + } + typehints = get_beam_typehints_from_tableschema(schema) + + expected_typehints = [( + "record", + Sequence[RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS)])] + + self.assertEqual(typehints, expected_typehints) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py index a307e06ac5b8..3effe945355d 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py @@ -23,6 +23,7 @@ import base64 import datetime import logging +import os import secrets import time import unittest @@ -32,6 +33,7 @@ import mock import pytest import pytz +from hamcrest.core import assert_that as hamcrest_assert from parameterized import param from parameterized import parameterized @@ -44,6 +46,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.utils.timestamp import Timestamp # Protect against environments where bigquery library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -543,6 +546,193 @@ def test_big_query_write_temp_table_append_schema_update(self, file_format): temp_file_format=file_format)) +class BigQueryXlangStorageWriteIT(unittest.TestCase): + BIGQUERY_DATASET = 'python_xlang_storage_write' + + ELEMENTS = [ + # (int, float, numeric, string, bool, bytes, timestamp) + { + "int": 1, + "float": 0.1, + "numeric": Decimal("1.11"), + "str": "a", + "bool": True, + "bytes": b'a', + "timestamp": Timestamp(1000, 100) + }, + { + "int": 2, + "float": 0.2, + "numeric": Decimal("2.22"), + "str": "b", + "bool": False, + "bytes": b'b', + "timestamp": Timestamp(2000, 200) + }, + { + "int": 3, + "float": 0.3, + "numeric": Decimal("3.33"), + "str": "c", + "bool": True, + "bytes": b'd', + "timestamp": Timestamp(3000, 300) + }, + { + "int": 4, + "float": 0.4, + "numeric": Decimal("4.44"), + "str": "d", + "bool": False, + "bytes": b'd', + "timestamp": Timestamp(4000, 400) + } + ] + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.args = self.test_pipeline.get_full_options_as_args() + self.project = self.test_pipeline.get_option('project') + + self.bigquery_client = BigQueryWrapper() + self.dataset_id = '%s%s%s' % ( + self.BIGQUERY_DATASET, str(int(time.time())), secrets.token_hex(3)) + self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", self.dataset_id, self.project) + if not os.environ.get('EXPANSION_PORT'): + raise ValueError("NO EXPANSION PORT") + else: + _LOGGER.info("expansion port: %s", os.environ.get('EXPANSION_PORT')) + self.expansion_service = ('localhost:%s' % os.environ.get('EXPANSION_PORT')) + + def tearDown(self): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=self.project, datasetId=self.dataset_id, deleteContents=True) + try: + _LOGGER.info( + "Deleting dataset %s in project %s", self.dataset_id, self.project) + self.bigquery_client.client.datasets.Delete(request) + except HttpError: + _LOGGER.debug( + 'Failed to clean up dataset %s in project %s', + self.dataset_id, + self.project) + + def parse_expected_data(self, expected_elements): + data = [] + for row in expected_elements: + values = list(row.values()) + for i, val in enumerate(values): + if isinstance(val, Timestamp): + # BigQuery matcher query returns a datetime.datetime object + values[i] = val.to_utc_datetime().replace( + tzinfo=datetime.timezone.utc) + data.append(tuple(values)) + + return data + + def storage_write_test(self, table_name, items, schema): + table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table_name) + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM %s" % '{}.{}'.format(self.dataset_id, table_name), + data=self.parse_expected_data(items)) + + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create(items) + | beam.io.WriteToBigQuery( + table=table_id, + method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API, + schema=schema, + expansion_service=self.expansion_service)) + hamcrest_assert(p, bq_matcher) + + @pytest.mark.uses_gcp_java_expansion_service + def test_storage_write_all_types(self): + table_name = "python_storage_write_all_types" + schema = ( + "int:INTEGER,float:FLOAT,numeric:NUMERIC,str:STRING," + "bool:BOOLEAN,bytes:BYTES,timestamp:TIMESTAMP") + self.storage_write_test(table_name, self.ELEMENTS, schema) + + @pytest.mark.uses_gcp_java_expansion_service + def test_storage_write_nested_records_and_lists(self): + table_name = "python_storage_write_nested_records_and_lists" + schema = { + "fields": [{ + "name": "repeated_int", "type": "INTEGER", "mode": "REPEATED" + }, + { + "name": "struct", + "type": "STRUCT", + "fields": [{ + "name": "nested_int", "type": "INTEGER" + }, { + "name": "nested_str", "type": "STRING" + }] + }, + { + "name": "repeated_struct", + "type": "STRUCT", + "mode": "REPEATED", + "fields": [{ + "name": "nested_numeric", "type": "NUMERIC" + }, { + "name": "nested_bytes", "type": "BYTES" + }] + }] + } + items = [{ + "repeated_int": [1, 2, 3], + "struct": { + "nested_int": 1, "nested_str": "a" + }, + "repeated_struct": [{ + "nested_numeric": Decimal("1.23"), "nested_bytes": b'a' + }, + { + "nested_numeric": Decimal("3.21"), + "nested_bytes": b'aa' + }] + }] + + self.storage_write_test(table_name, items, schema) + + @pytest.mark.uses_gcp_java_expansion_service + def test_storage_write_beam_rows(self): + table_id = '{}:{}.python_xlang_storage_write_beam_rows'.format( + self.project, self.dataset_id) + + row_elements = [ + beam.Row( + my_int=e['int'], + my_float=e['float'], + my_numeric=e['numeric'], + my_string=e['str'], + my_bool=e['bool'], + my_bytes=e['bytes'], + my_timestamp=e['timestamp']) for e in self.ELEMENTS + ] + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM %s" % + '{}.python_xlang_storage_write_beam_rows'.format(self.dataset_id), + data=self.parse_expected_data(self.ELEMENTS)) + + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create(row_elements) + | beam.io.StorageWriteToBigQuery( + table=table_id, expansion_service=self.expansion_service)) + hamcrest_assert(p, bq_matcher) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 3ab0b87b09ed..7884584306c5 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -367,6 +367,36 @@ def discover(expansion_service): inputs=proto_config.input_pcollection_names, outputs=proto_config.output_pcollection_names) + @staticmethod + def discover_config(expansion_service, name): + """Discover one SchemaTransform by name in the given expansion service. + + :return: one SchemaTransformConfig that represents the discovered + SchemaTransform + + :raises: + ValueError: if more than one SchemaTransform is discovered, or if none + are discovered + """ + + schematransforms = SchemaAwareExternalTransform.discover(expansion_service) + matched = [] + + for st in schematransforms: + if name in st.identifier: + matched.append(st) + + if not matched: + raise ValueError( + "Did not discover any SchemaTransforms resembling the name '%s'" % + name) + elif len(matched) > 1: + raise ValueError( + "Found multiple SchemaTransforms with the name '%s':\n%s\n" % + (name, [st.identifier for st in matched])) + + return matched[0] + class JavaExternalTransform(ptransform.PTransform): """A proxy for Java-implemented external transforms. diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index f38876367c39..650e292bfb96 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -32,7 +32,9 @@ from apache_beam import Pipeline from apache_beam.coders import RowCoder from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.portability.api import beam_expansion_api_pb2 from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.portability.api import schema_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.portability import expansion_service from apache_beam.runners.portability.expansion_service_test import FibTransform @@ -475,6 +477,47 @@ def test_build_payload(self): self.assertEqual(456, schema_transform_config.object_field.int_sub_field) +class SchemaAwareExternalTransformTest(unittest.TestCase): + class MockDiscoveryService: + # define context manager enter and exit functions + def __enter__(self): + return self + + def __exit__(self, unusued1, unused2, unused3): + pass + + def DiscoverSchemaTransform(self, unused_request=None): + test_config = beam_expansion_api_pb2.SchemaTransformConfig( + config_schema=schema_pb2.Schema( + fields=[ + schema_pb2.Field( + name="test_field", + type=schema_pb2.FieldType(atomic_type="STRING")) + ], + id="test-id"), + input_pcollection_names=["input"], + output_pcollection_names=["output"]) + return beam_expansion_api_pb2.DiscoverSchemaTransformResponse( + schema_transform_configs={"test_schematransform": test_config}) + + @mock.patch("apache_beam.transforms.external.ExternalTransform.service") + def test_discover_one_config(self, mock_service): + _mock = self.MockDiscoveryService() + mock_service.return_value = _mock + config = beam.SchemaAwareExternalTransform.discover_config( + "test_service", name="test_schematransform") + self.assertEqual(config.outputs[0], "output") + self.assertEqual(config.inputs[0], "input") + self.assertEqual(config.identifier, "test_schematransform") + + @mock.patch("apache_beam.transforms.external.ExternalTransform.service") + def test_discover_one_config_fails_with_no_configs_found(self, mock_service): + mock_service.return_value = self.MockDiscoveryService() + with self.assertRaises(ValueError): + beam.SchemaAwareExternalTransform.discover_config( + "test_service", name="non_existent") + + class JavaClassLookupPayloadBuilderTest(unittest.TestCase): def _verify_row(self, schema, row_payload, expected_values): row = RowCoder(schema).decode(row_payload) diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini index 214aa88ec662..6e93c5f96e7f 100644 --- a/sdks/python/pytest.ini +++ b/sdks/python/pytest.ini @@ -27,6 +27,7 @@ python_functions = python_files = test_*.py *_test.py *_test_py3*.py *_test_it.py markers = + uses_gcp_java_expansion_service: collect Cross Language GCP Java transforms test runs uses_java_expansion_service: collect Cross Language Java transforms test runs uses_python_expansion_service: collect Cross Language Python transforms test runs xlang_sql_expansion_service: collect for Cross Language with SQL expansion service test runs diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 20e7e345c320..900715ccc0e9 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -17,6 +17,7 @@ */ evaluationDependsOn(':runners:google-cloud-dataflow-java:worker') +evaluationDependsOn(':sdks:python:test-suites:xlang') enablePythonPerformanceTest() String pythonVersionSuffix = project.ext.pythonVersion @@ -407,3 +408,26 @@ project.tasks.register("inferencePostCommitIT") { 'tensorRTtests', ] } + + +// Create cross-language tasks for running tests against Java expansion service(s) +def dataflowProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' +def dataflowRegion = project.findProperty('dataflowRegion') ?: 'us-central1' + +project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> + createCrossLanguageUsingJavaExpansionTask( + name: taskMetadata.name, + expansionProjectPath: taskMetadata.expansionProjectPath, + collectMarker: taskMetadata.collectMarker, + startJobServer: taskMetadata.startJobServer, + cleanupJobServer: taskMetadata.cleanupJobServer, + needsSdkLocation: true, + pythonPipelineOptions: [ + "--runner=TestDataflowRunner", + "--project=${dataflowProject}", + "--region=${dataflowRegion}", + "--sdk_harness_container_image_overrides=.*java.*,gcr.io/apache-beam-testing/beam-sdk/beam_java8_sdk:latest", + ], + pytestOptions: basicPytestOpts + ) +} \ No newline at end of file diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index 1351d5b1fc5b..e23deda8fa5c 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +evaluationDependsOn(':sdks:python:test-suites:xlang') def pythonVersionSuffix = project.ext.pythonVersion.replace('.', '') def pythonContainerVersion = project.ext.pythonVersion @@ -348,3 +349,27 @@ project.tasks.register("inferencePostCommitIT") { // 'tfxInferenceTest', ] } + +// Create cross-language tasks for running tests against Java expansion service(s) +def gcpProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' + +project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> + createCrossLanguageUsingJavaExpansionTask( + name: taskMetadata.name, + expansionProjectPath: taskMetadata.expansionProjectPath, + collectMarker: taskMetadata.collectMarker, + startJobServer: taskMetadata.startJobServer, + cleanupJobServer: taskMetadata.cleanupJobServer, + numParallelTests: 1, + pythonPipelineOptions: [ + "--runner=TestDirectRunner", + "--project=${gcpProject}", + ], + pytestOptions: [ + "--capture=no", // print stdout instantly + "--timeout=4500", // timeout of whole command execution + "--color=yes", // console color + "--log-cli-level=INFO" //log level info + ] + ) +} diff --git a/sdks/python/test-suites/xlang/build.gradle b/sdks/python/test-suites/xlang/build.gradle new file mode 100644 index 000000000000..ea407ac6f3fb --- /dev/null +++ b/sdks/python/test-suites/xlang/build.gradle @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// This is a base file to set up cross language tests for different runners +import org.apache.beam.gradle.BeamModulePlugin +import static org.apache.beam.gradle.BeamModulePlugin.CrossLanguageTaskCommon +project.evaluationDependsOn(":sdks:python") + +// Set up cross language tests +def envDir = project.project(":sdks:python").envdir +def jobPort = BeamModulePlugin.getRandomPort() +def tmpDir = System.getenv("TMPDIR") ?: System.getenv("WORKSPACE") ?: "/tmp" +def pidFile = "${tmpDir}/local_job_service_main-${jobPort}.pid" + +def setupTask = project.tasks.register("fnApiJobServerSetup", Exec) { + dependsOn ':sdks:python:installGcpTest' + + executable 'sh' + args '-c', ". ${envDir}/bin/activate && python -m apache_beam.runners.portability.local_job_service_main --job_port ${jobPort} --pid_file ${pidFile} --background --stdout_file ${tmpDir}/beam-fnapi-job-server.log" +} + +def cleanupTask = project.tasks.register("fnApiJobServerCleanup", Exec) { + executable 'sh' + args '-c', ". ${envDir}/bin/activate && python -m apache_beam.runners.portability.local_job_service_main --pid_file ${pidFile} --stop" +} + +// List of objects representing task metadata to create cross-language tasks from. +// Each object contains the minimum relevant metadata. +def xlangTasks = [] + +// ******** Java GCP expansion service ******** +// Note: this only runs cross-language tests that use the Java GCP expansion service +// To run tests that use another expansion service, create a new CrossLanguageTaskCommon with the +// relevant fields as done here, then add it to `xlangTasks`. +def gcpExpansionProject = project.project(':sdks:java:io:google-cloud-platform:expansion-service') +// Properties that are common across runners. +// Used to launch the expansion service, collect the right tests, and cleanup afterwards +def gcpXlangCommon = new CrossLanguageTaskCommon().tap { + name = "gcpCrossLanguage" + expansionProjectPath = gcpExpansionProject.getPath() + collectMarker = "uses_gcp_java_expansion_service" + startJobServer = setupTask + cleanupJobServer = cleanupTask +} +xlangTasks.add(gcpXlangCommon) + + +ext.xlangTasks = xlangTasks \ No newline at end of file diff --git a/settings.gradle.kts b/settings.gradle.kts index b2f8c89680d1..4800313f26a5 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -259,6 +259,7 @@ include(":sdks:python:test-suites:tox:py37") include(":sdks:python:test-suites:tox:py38") include(":sdks:python:test-suites:tox:py39") include(":sdks:python:test-suites:tox:py310") +include(":sdks:python:test-suites:xlang") include(":sdks:typescript") include(":sdks:typescript:container") include(":vendor:bytebuddy-1_12_8")