diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 6ee98eb0ddfa..e44617930119 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageApiInsertError; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; @@ -94,7 +95,7 @@ protected SchemaTransform from( @Override public String identifier() { - return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v1"); + return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v2"); } @Override @@ -125,6 +126,24 @@ public abstract static class BigQueryStorageWriteApiSchemaTransformConfiguration .put(WriteDisposition.WRITE_APPEND.name(), WriteDisposition.WRITE_APPEND) .build(); + @AutoValue + public abstract static class ErrorHandling { + @SchemaFieldDescription("The name of the output PCollection containing failed writes.") + public abstract String getOutput(); + + public static Builder builder() { + return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration_ErrorHandling + .Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setOutput(String output); + + public abstract ErrorHandling build(); + } + } + public void validate() { String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; @@ -151,6 +170,12 @@ public void validate() { this.getWriteDisposition(), WRITE_DISPOSITIONS.keySet()); } + + if (this.getErrorHandling() != null) { + checkArgument( + !Strings.isNullOrEmpty(this.getErrorHandling().getOutput()), + invalidConfigMessage + "Output must not be empty if error handling specified."); + } } /** @@ -198,6 +223,10 @@ public static Builder builder() { @Nullable public abstract Boolean getAutoSharding(); + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") + @Nullable + public abstract ErrorHandling getErrorHandling(); + /** Builder for {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. */ @AutoValue.Builder public abstract static class Builder { @@ -214,6 +243,8 @@ public abstract static class Builder { public abstract Builder setAutoSharding(Boolean autoSharding); + public abstract Builder setErrorHandling(ErrorHandling errorHandling); + /** Builds a {@link BigQueryStorageWriteApiSchemaTransformConfiguration} instance. */ public abstract BigQueryStorageWriteApiSchemaTransformProvider .BigQueryStorageWriteApiSchemaTransformConfiguration @@ -244,7 +275,7 @@ public void setBigQueryServices(BigQueryServices testBigQueryServices) { // A generic counter for PCollection of Row. Will be initialized with the given // name argument. Performs element-wise counter of the input PCollection. - private static class ElementCounterFn extends DoFn { + private static class ElementCounterFn extends DoFn { private Counter bqGenericElementCounter; private Long elementsInBundle = 0L; @@ -267,6 +298,18 @@ public void finish(FinishBundleContext c) { } } + private static class FailOnError extends DoFn { + @ProcessElement + public void process(ProcessContext c) { + throw new RuntimeException(c.element().getErrorMessage()); + } + } + + private static class NoOutputDoFn extends DoFn { + @ProcessElement + public void process(ProcessContext c) {} + } + @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { // Check that the input exists @@ -294,53 +337,55 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { WriteResult result = inputRows .apply( - "element-count", ParDo.of(new ElementCounterFn("BigQuery-write-element-counter"))) + "element-count", + ParDo.of(new ElementCounterFn("BigQuery-write-element-counter"))) .setRowSchema(inputSchema) .apply(write); - Schema rowSchema = inputRows.getSchema(); - Schema errorSchema = - Schema.of( - Field.of("failed_row", FieldType.row(rowSchema)), - Field.of("error_message", FieldType.STRING)); - - // Failed rows - PCollection failedRows = - result - .getFailedStorageApiInserts() - .apply( - "Construct failed rows", - MapElements.into(TypeDescriptors.rows()) - .via( - (storageError) -> - BigQueryUtils.toBeamRow(rowSchema, storageError.getRow()))) - .setRowSchema(rowSchema); - - // Failed rows with error message - PCollection failedRowsWithErrors = + // Give something that can be followed. + PCollection postWrite = result .getFailedStorageApiInserts() - .apply( - "Construct failed rows and errors", - MapElements.into(TypeDescriptors.rows()) - .via( - (storageError) -> - Row.withSchema(errorSchema) - .withFieldValue("error_message", storageError.getErrorMessage()) - .withFieldValue( - "failed_row", - BigQueryUtils.toBeamRow(rowSchema, storageError.getRow())) - .build())) - .setRowSchema(errorSchema); - - PCollection failedRowsOutput = - failedRows - .apply("error-count", ParDo.of(new ElementCounterFn("BigQuery-write-error-counter"))) - .setRowSchema(rowSchema); - - return PCollectionRowTuple.of(FAILED_ROWS_TAG, failedRowsOutput) - .and(FAILED_ROWS_WITH_ERRORS_TAG, failedRowsWithErrors) - .and("errors", failedRowsWithErrors); + .apply("post-write", ParDo.of(new NoOutputDoFn())) + .setRowSchema(Schema.of()); + + if (configuration.getErrorHandling() == null) { + result + .getFailedStorageApiInserts() + .apply("Error on failed inserts", ParDo.of(new FailOnError())); + return PCollectionRowTuple.of("post_write", postWrite); + } else { + result + .getFailedStorageApiInserts() + .apply( + "error-count", + ParDo.of( + new ElementCounterFn( + "BigQuery-write-error-counter"))); + + // Failed rows with error message + Schema errorSchema = + Schema.of( + Field.of("failed_row", FieldType.row(inputSchema)), + Field.of("error_message", FieldType.STRING)); + PCollection failedRowsWithErrors = + result + .getFailedStorageApiInserts() + .apply( + "Construct failed rows and errors", + MapElements.into(TypeDescriptors.rows()) + .via( + (storageError) -> + Row.withSchema(errorSchema) + .withFieldValue("error_message", storageError.getErrorMessage()) + .withFieldValue( + "failed_row", + BigQueryUtils.toBeamRow(inputSchema, storageError.getRow())) + .build())) + .setRowSchema(errorSchema); + return PCollectionRowTuple.of("post_write", postWrite) + .and(configuration.getErrorHandling().getOutput(), failedRowsWithErrors); + } } BigQueryIO.Write createStorageWriteApiTransform() { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index df085bcedec3..54c636bde5fe 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -48,9 +48,11 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptors; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -211,7 +213,13 @@ public void testInputElementCount() throws Exception { public void testFailedRows() throws Exception { String tableSpec = "project:dataset.write_with_fail"; BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + .setTable(tableSpec) + .setErrorHandling( + BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() + .setOutput("FailedRows") + .build()) + .build(); String failValue = "fail_me"; @@ -234,7 +242,15 @@ public void testFailedRows() throws Exception { fakeDatasetService.setShouldFailRow(shouldFailRow); PCollectionRowTuple result = runWithConfig(config, totalRows); - PCollection failedRows = result.get("FailedRows"); + PCollection failedRows = + result + .get("FailedRows") + .apply( + "ExtractFailedRows", + MapElements.into(TypeDescriptors.rows()) + .via((rowAndError) -> rowAndError.getValue("failed_row"))) + .setRowSchema(SCHEMA); + ; PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows); p.run().waitUntilFinish(); @@ -250,7 +266,13 @@ public void testFailedRows() throws Exception { public void testErrorCount() throws Exception { String tableSpec = "project:dataset.error_count"; BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + .setTable(tableSpec) + .setErrorHandling( + BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() + .setOutput("FailedRows") + .build()) + .build(); Function shouldFailRow = (Function & Serializable) tr -> tr.get("name").equals("a"); diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index afdd94740e9f..e092ad069ad0 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -417,6 +417,7 @@ def chain_after(result): from apache_beam.transforms.util import ReshufflePerKey from apache_beam.transforms.window import GlobalWindows from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.utils import retry from apache_beam.utils.annotations import deprecated @@ -2148,6 +2149,7 @@ def expand(self, pcoll): failed_rows=outputs[BigQueryWriteFn.FAILED_ROWS], failed_rows_with_errors=outputs[ BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS]) + elif method_to_use == WriteToBigQuery.Method.FILE_LOADS: if self._temp_file_format == bigquery_tools.FileFormat.AVRO: if self.schema == SCHEMA_AUTODETECT: @@ -2212,33 +2214,45 @@ def find_in_nested_dict(schema): BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS], destination_copy_jobid_pairs=output[ BigQueryBatchFileLoads.DESTINATION_COPY_JOBID_PAIRS]) - else: - # Storage Write API + + elif method_to_use == WriteToBigQuery.Method.STORAGE_WRITE_API: if self.schema is None: - raise AttributeError( - "A schema is required in order to prepare rows" - "for writing with STORAGE_WRITE_API.") - if callable(self.schema): + try: + schema = schema_from_element_type(pcoll.element_type) + is_rows = True + except TypeError as exn: + raise ValueError( + "A schema is required in order to prepare rows" + "for writing with STORAGE_WRITE_API.") from exn + elif callable(self.schema): raise NotImplementedError( "Writing to dynamic destinations is not" "supported for this write method.") elif isinstance(self.schema, vp.ValueProvider): schema = self.schema.get() + is_rows = False else: schema = self.schema + is_rows = False table = bigquery_tools.get_hashable_destination(self.table_reference) # None type is not supported triggering_frequency = self.triggering_frequency or 0 # SchemaTransform expects Beam Rows, so map to Rows first + if is_rows: + input_beam_rows = pcoll + else: + input_beam_rows = ( + pcoll + | "Convert dict to Beam Row" >> beam.Map( + lambda row: bigquery_tools.beam_row_from_dict(row, schema) + ).with_output_types( + RowTypeConstraint.from_fields( + bigquery_tools.get_beam_typehints_from_tableschema(schema))) + ) output_beam_rows = ( - pcoll - | "Convert dict to Beam Row" >> - beam.Map(lambda row: bigquery_tools.beam_row_from_dict(row, schema)). - with_output_types( - RowTypeConstraint.from_fields( - bigquery_tools.get_beam_typehints_from_tableschema(schema))) - | "StorageWriteToBigQuery" >> StorageWriteToBigQuery( + input_beam_rows + | StorageWriteToBigQuery( table=table, create_disposition=self.create_disposition, write_disposition=self.write_disposition, @@ -2247,23 +2261,31 @@ def find_in_nested_dict(schema): with_auto_sharding=self.with_auto_sharding, expansion_service=self.expansion_service)) - # return back from Beam Rows to Python dict elements - failed_rows = ( - output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS] - | beam.Map(lambda row: row.as_dict())) - failed_rows_with_errors = ( - output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS] - | beam.Map( - lambda row: { - "error_message": row.error_message, - "failed_row": row.failed_row.as_dict() - })) + if is_rows: + failed_rows = output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS] + failed_rows_with_errors = output_beam_rows[ + StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS] + else: + # return back from Beam Rows to Python dict elements + failed_rows = ( + output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS] + | beam.Map(lambda row: row.as_dict())) + failed_rows_with_errors = ( + output_beam_rows[StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS] + | beam.Map( + lambda row: { + "error_message": row.error_message, + "failed_row": row.failed_row.as_dict() + })) return WriteResult( method=WriteToBigQuery.Method.STORAGE_WRITE_API, failed_rows=failed_rows, failed_rows_with_errors=failed_rows_with_errors) + else: + raise ValueError(f"Unsupported method {method_to_use}") + def display_data(self): res = {} if self.table_reference is not None and isinstance(self.table_reference, @@ -2487,7 +2509,7 @@ class StorageWriteToBigQuery(PTransform): Experimental; no backwards compatibility guarantees. """ - URN = "beam:schematransform:org.apache.beam:bigquery_storage_write:v1" + URN = "beam:schematransform:org.apache.beam:bigquery_storage_write:v2" FAILED_ROWS = "FailedRows" FAILED_ROWS_WITH_ERRORS = "FailedRowsWithErrors" @@ -2552,11 +2574,17 @@ def expand(self, input): triggeringFrequencySeconds=self._triggering_frequency, useAtLeastOnceSemantics=self._use_at_least_once, writeDisposition=self._write_disposition, - ) + errorHandling={ + 'output': StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS + }) input_tag = self.schematransform_config.inputs[0] - return {input_tag: input} | external_storage_write + result = {input_tag: input} | external_storage_write + result[StorageWriteToBigQuery.FAILED_ROWS] = result[ + StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS] | beam.Map( + lambda row_and_error: row_and_error[0]) + return result class ReadFromBigQuery(PTransform): @@ -2791,14 +2819,14 @@ def _expand_direct_read(self, pcoll): else: project_id = pcoll.pipeline.options.view_as(GoogleCloudOptions).project + pipeline_details = {} + if temp_table_ref is not None: + pipeline_details['temp_table_ref'] = temp_table_ref + elif project_id is not None: + pipeline_details['project_id'] = project_id + pipeline_details['bigquery_dataset_labels'] = self.bigquery_dataset_labels + def _get_pipeline_details(unused_elm): - pipeline_details = {} - if temp_table_ref is not None: - pipeline_details['temp_table_ref'] = temp_table_ref - elif project_id is not None: - pipeline_details['project_id'] = project_id - pipeline_details[ - 'bigquery_dataset_labels'] = self.bigquery_dataset_labels return pipeline_details project_to_cleanup_pcoll = beam.pvalue.AsList( diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index e0944d81d1b2..18b422ed27d4 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -1330,7 +1330,10 @@ def transform_to_runner_api( # Iterate over inputs and outputs by sorted key order, so that ids are # consistently generated for multiple runs of the same pipeline. - transform_spec = transform_to_runner_api(self.transform, context) + try: + transform_spec = transform_to_runner_api(self.transform, context) + except Exception as exn: + raise RuntimeError(f'Unable to translate {self.full_label}') from exn environment_id = self.environment_id transform_urn = transform_spec.urn if transform_spec else None if (not environment_id and diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 4ddf0b3e64a3..4b8e708bfc5c 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -194,6 +194,56 @@ def build(self): return payload +class ExplicitSchemaTransformPayloadBuilder(PayloadBuilder): + def __init__(self, identifier, schema_proto, **kwargs): + self._identifier = identifier + self._schema_proto = schema_proto + self._kwargs = kwargs + + def build(self): + def dict_to_row_recursive(field_type, py_value): + if py_value is None: + return None + type_info = field_type.WhichOneof('type_info') + if type_info == 'row_type': + return dict_to_row(field_type.row_type.schema, py_value) + elif type_info == 'array_type': + return [ + dict_to_row_recursive(field_type.array_type.element_type, value) + for value in py_value + ] + elif type_info == 'map_type': + return { + key: dict_to_row_recursive(field_type.map_type.value_type, value) + for key, + value in py_value.items() + } + else: + return py_value + + def dict_to_row(schema_proto, py_value): + row_type = named_tuple_from_schema(schema_proto) + if isinstance(py_value, dict): + extra = set(py_value.keys()) - set(row_type._fields) + if extra: + raise ValueError( + f"Unknown fields: {extra}. Valid fields: {row_type._fields}") + return row_type( + *[ + dict_to_row_recursive( + field.type, py_value.get(field.name, None)) + for field in schema_proto.fields + ]) + else: + return row_type(py_value) + + return external_transforms_pb2.SchemaTransformPayload( + identifier=self._identifier, + configuration_schema=self._schema_proto, + configuration_row=RowCoder(self._schema_proto).encode( + dict_to_row(self._schema_proto, self._kwargs))) + + class JavaClassLookupPayloadBuilder(PayloadBuilder): """ Builds a payload for directly instantiating a Java transform using a @@ -351,37 +401,16 @@ def __init__( _kwargs = kwargs if rearrange_based_on_discovery: - _kwargs = self._rearrange_kwargs(identifier) - - self._payload_builder = SchemaTransformPayloadBuilder(identifier, **_kwargs) + config = SchemaAwareExternalTransform.discover_config( + self._expansion_service, identifier) + self._payload_builder = ExplicitSchemaTransformPayloadBuilder( + identifier, + named_tuple_to_schema(config.configuration_schema), + **_kwargs) - def _rearrange_kwargs(self, identifier): - # discover and fetch the external SchemaTransform configuration then - # use it to build an appropriate payload - schematransform_config = SchemaAwareExternalTransform.discover_config( - self._expansion_service, identifier) - - external_config_fields = schematransform_config.configuration_schema._fields - ordered_kwargs = OrderedDict() - missing_fields = [] - - for field in external_config_fields: - if field not in self._kwargs: - missing_fields.append(field) - else: - ordered_kwargs[field] = self._kwargs[field] - - extra_fields = list(set(self._kwargs.keys()) - set(external_config_fields)) - if missing_fields: - raise ValueError( - 'Input parameters are missing the following SchemaTransform config ' - 'fields: %s' % missing_fields) - elif extra_fields: - raise ValueError( - 'Input parameters include the following extra fields that are not ' - 'found in the SchemaTransform config schema: %s' % extra_fields) - - return ordered_kwargs + else: + self._payload_builder = SchemaTransformPayloadBuilder( + identifier, **_kwargs) def expand(self, pcolls): # Expand the transform using the expansion service. @@ -390,14 +419,18 @@ def expand(self, pcolls): self._payload_builder, self._expansion_service) - @staticmethod - def discover(expansion_service): + @classmethod + @functools.lru_cache + def discover(cls, expansion_service, ignore_errors=False): """Discover all SchemaTransforms available to the given expansion service. :return: a list of SchemaTransformsConfigs that represent the discovered SchemaTransforms. """ + return list(cls.discover_iter(expansion_service, ignore_errors)) + @staticmethod + def discover_iter(expansion_service, ignore_errors=True): with ExternalTransform.service(expansion_service) as service: discover_response = service.DiscoverSchemaTransform( beam_expansion_api_pb2.DiscoverSchemaTransformRequest()) @@ -406,8 +439,12 @@ def discover(expansion_service): proto_config = discover_response.schema_transform_configs[identifier] try: schema = named_tuple_from_schema(proto_config.config_schema) - except ValueError: - continue + except Exception as exn: + if ignore_errors: + logging.info("Bad schema for %s: %s", identifier, str(exn)[:250]) + continue + else: + raise yield SchemaTransformsConfig( identifier=identifier, @@ -427,7 +464,8 @@ def discover_config(expansion_service, name): are discovered """ - schematransforms = SchemaAwareExternalTransform.discover(expansion_service) + schematransforms = SchemaAwareExternalTransform.discover( + expansion_service, ignore_errors=True) matched = [] for st in schematransforms: diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index fd5c81de6596..b0e39b7d9a5b 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -528,15 +528,19 @@ def test_rearrange_kwargs_based_on_discovery(self, mock_service): kwargs = {"int_field": 0, "str_field": "str"} transform = beam.SchemaAwareExternalTransform( - identifier=identifier, expansion_service=expansion_service, **kwargs) - ordered_kwargs = transform._rearrange_kwargs(identifier) + identifier=identifier, + expansion_service=expansion_service, + rearrange_based_on_discovery=True, + **kwargs) + payload = transform._payload_builder.build() + ordered_fields = [f.name for f in payload.configuration_schema.fields] schematransform_config = beam.SchemaAwareExternalTransform.discover_config( expansion_service, identifier) external_config_fields = schematransform_config.configuration_schema._fields self.assertNotEqual(tuple(kwargs.keys()), external_config_fields) - self.assertEqual(tuple(ordered_kwargs.keys()), external_config_fields) + self.assertEqual(tuple(ordered_fields), external_config_fields) class JavaClassLookupPayloadBuilderTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/typehints/schema_registry.py b/sdks/python/apache_beam/typehints/schema_registry.py index 9ec7b1b65ccf..a73e97f43f70 100644 --- a/sdks/python/apache_beam/typehints/schema_registry.py +++ b/sdks/python/apache_beam/typehints/schema_registry.py @@ -40,13 +40,18 @@ def generate_new_id(self): "schemas.") def add(self, typing, schema): - self.by_id[schema.id] = (typing, schema) + if not schema.id: + self.by_id[schema.id] = (typing, schema) def get_typing_by_id(self, unique_id): + if not unique_id: + return None result = self.by_id.get(unique_id, None) return result[0] if result is not None else None def get_schema_by_id(self, unique_id): + if not unique_id: + return None result = self.by_id.get(unique_id, None) return result[1] if result is not None else None diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index 730d0f0ad0d1..eb0695f337b4 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -20,8 +20,13 @@ import yaml import apache_beam as beam +from apache_beam.typehints.schemas import LogicalType +from apache_beam.typehints.schemas import MillisInstant from apache_beam.yaml import yaml_transform +# Workaround for https://github.com/apache/beam/issues/28151. +LogicalType.register_logical_type(MillisInstant) + def _configure_parser(argv): parser = argparse.ArgumentParser() @@ -57,11 +62,14 @@ def run(argv=None): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pipeline_args, pickle_library='cloudpickle', - **pipeline_spec.get('options', {}))) as p: + **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( + 'options', {})))) as p: print("Building pipeline...") yaml_transform.expand_pipeline(p, pipeline_spec) print("Running pipeline...") if __name__ == '__main__': + import logging + logging.getLogger().setLevel(logging.INFO) run() diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml new file mode 100644 index 000000000000..ea21446c2fb9 --- /dev/null +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file enumerates the various IOs that are available by default as +# top-level transforms in Beam's YAML. +# +# Note that there may be redundant implementations. In these cases the specs +# should be kept in sync. +# TODO(yaml): See if this can be enforced programmatically. + +- type: renaming + transforms: + 'ReadFromBigQuery': 'ReadFromBigQuery' + 'WriteToBigQuery': 'WriteToBigQuery' + config: + mappings: + 'ReadFromBigQuery': + query: 'query' + table: 'tableSpec' + fields: 'selectedFields' + row_restriction: 'rowRestriction' + 'WriteToBigQuery': + table: 'table' + create_disposition: 'createDisposition' + write_disposition: 'writeDisposition' + error_handling: 'errorHandling' + underlying_provider: + type: beamJar + transforms: + 'ReadFromBigQuery': 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1' + 'WriteToBigQuery': 'beam:schematransform:org.apache.beam:bigquery_storage_write:v2' + config: + gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' + +- type: python + transforms: + 'ReadFromBigQuery': 'apache_beam.yaml.yaml_io.read_from_bigquery' + # Disable until https://github.com/apache/beam/issues/28162 is resolved. + # 'WriteToBigQuery': 'apache_beam.yaml.yaml_io.write_to_bigquery' diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py new file mode 100644 index 000000000000..646d5e1fbff1 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This module contains the Python implementations for the builtin IOs. + +They are referenced from standard_io.py. + +Note that in the case that they overlap with other (likely Java) +implementations of the same transforms, the configs must be kept in sync. +""" + +import os + +import yaml + +import apache_beam as beam +from apache_beam.io import ReadFromBigQuery +from apache_beam.io import WriteToBigQuery +from apache_beam.io.gcp.bigquery import BigQueryDisposition +from apache_beam.yaml import yaml_provider + + +def read_from_bigquery( + query=None, table=None, row_restriction=None, fields=None): + if query is None: + assert table is not None + else: + assert table is None and row_restriction is None and fields is None + return ReadFromBigQuery( + query=query, + table=table, + row_restriction=row_restriction, + selected_fields=fields, + method='DIRECT_READ', + output_type='BEAM_ROW') + + +def write_to_bigquery( + table, + *, + create_disposition=BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=BigQueryDisposition.WRITE_APPEND, + error_handling=None): + class WriteToBigQueryHandlingErrors(beam.PTransform): + def default_label(self): + return 'WriteToBigQuery' + + def expand(self, pcoll): + write_result = pcoll | WriteToBigQuery( + table, + method=WriteToBigQuery.Method.STORAGE_WRITE_API + if error_handling else None, + create_disposition=create_disposition, + write_disposition=write_disposition, + temp_file_format='AVRO') + if error_handling and 'output' in error_handling: + # TODO: Support error rates. + return { + 'post_write': write_result.failed_rows_with_errors + | beam.FlatMap(lambda x: None), + error_handling['output']: write_result.failed_rows_with_errors + } + else: + if write_result._method == WriteToBigQuery.Method.FILE_LOADS: + # Never returns errors, just fails. + return { + 'post_write': write_result.destination_load_jobid_pairs + | beam.FlatMap(lambda x: None) + } + else: + + # This should likely be pushed into the BQ read itself to avoid + # the possibility of silently ignoring errors. + def raise_exception(failed_row_with_error): + raise RuntimeError(failed_row_with_error.error_message) + + _ = write_result.failed_rows_with_errors | beam.Map(raise_exception) + return { + 'post_write': write_result.failed_rows_with_errors + | beam.FlatMap(lambda x: None) + } + + return WriteToBigQueryHandlingErrors() + + +def io_providers(): + with open(os.path.join(os.path.dirname(__file__), 'standard_io.yaml')) as fin: + explicit_ios = yaml_provider.parse_providers( + yaml.load(fin, Loader=yaml.SafeLoader)) + + # TOOD(yaml): We should make all top-level IOs explicit. + # This will be a chance to clean up the APIs and align them with their + # Java implementations. + # PythonTransform can be used to get the "raw" transforms for any others. + implicit_ios = yaml_provider.InlineProvider({ + key: getattr(beam.io, key) + for key in dir(beam.io) + if (key.startswith('ReadFrom') or key.startswith('WriteTo')) and + key not in explicit_ios + }) + + return yaml_provider.merge_providers(explicit_ios, implicit_ios) diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index aa5aa7183318..f7a237bc045d 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -72,6 +72,12 @@ def create_transform( """ raise NotImplementedError(type(self)) + def underlying_provider(self): + """If this provider is simply a proxy to another provider, return the + provider that should actually be used for affinity checking. + """ + return self + def affinity(self, other: "Provider"): """Returns a value approximating how good it would be for this provider to be used immediately following a transform from the other provider @@ -81,7 +87,9 @@ def affinity(self, other: "Provider"): # E.g. we could look at the the expected environments themselves. # Possibly, we could provide multiple expansions and have the runner itself # choose the actual implementation based on fusion (and other) criteria. - return self._affinity(other) + other._affinity(self) + return ( + self.underlying_provider()._affinity(other) + + other.underlying_provider()._affinity(self)) def _affinity(self, other: "Provider"): if self is other or self == other: @@ -122,16 +130,18 @@ def create_transform(self, type, args, yaml_create_transform): self._service = self._service() if self._schema_transforms is None: try: - self._schema_transforms = [ - config.identifier + self._schema_transforms = { + config.identifier: config for config in external.SchemaAwareExternalTransform.discover( - self._service) - ] + self._service, ignore_errors=True) + } except Exception: - self._schema_transforms = [] + # It's possible this service doesn't vend schema transforms. + self._schema_transforms = {} urn = self._urns[type] if urn in self._schema_transforms: - return external.SchemaAwareExternalTransform(urn, self._service, **args) + return external.SchemaAwareExternalTransform( + urn, self._service, rearrange_based_on_discovery=True, **args) else: return type >> self.create_external_transform(urn, args) @@ -247,6 +257,19 @@ def available(self): capture_output=True).returncode == 0 +@ExternalProvider.register_provider_type('python') +def python(urns, packages=()): + if packages: + return ExternalPythonProvider(urns, packages) + else: + return InlineProvider({ + name: + python_callable.PythonCallableWithSource.load_from_fully_qualified_name( + constructor) + for (name, constructor) in urns.items() + }) + + @ExternalProvider.register_provider_type('pythonPackage') class ExternalPythonProvider(ExternalProvider): def __init__(self, urns, packages): @@ -429,12 +452,6 @@ def _parse_window_spec(spec): # TODO: Triggering, etc. return beam.WindowInto(window_fn) - ios = { - key: getattr(apache_beam.io, key) - for key in dir(apache_beam.io) - if key.startswith('ReadFrom') or key.startswith('WriteTo') - } - return InlineProvider( dict({ 'Create': lambda elements, @@ -459,8 +476,7 @@ def _parse_window_spec(spec): 'Flatten': Flatten, 'WindowInto': WindowInto, 'GroupByKey': beam.GroupByKey, - }, - **ios)) + })) class PypiExpansionService: @@ -512,6 +528,50 @@ def __exit__(self, *args): self._service = None +@ExternalProvider.register_provider_type('renaming') +class RenamingProvider(Provider): + def __init__(self, transforms, mappings, underlying_provider): + if isinstance(underlying_provider, dict): + underlying_provider = ExternalProvider.provider_from_spec( + underlying_provider) + self._transforms = transforms + self._underlying_provider = underlying_provider + for transform in transforms.keys(): + if transform not in mappings: + raise ValueError(f'Missing transform {transform} in mappings.') + self._mappings = mappings + + def available(self) -> bool: + return self._underlying_provider.available() + + def provided_transforms(self) -> Iterable[str]: + return self._transforms.keys() + + def create_transform( + self, + typ: str, + args: Mapping[str, Any], + yaml_create_transform: Callable[ + [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform] + ) -> beam.PTransform: + """Creates a PTransform instance for the given transform type and arguments. + """ + mappings = self._mappings[typ] + remapped_args = { + mappings.get(key, key): value + for key, value in args.items() + } + return self._underlying_provider.create_transform( + self._transforms[typ], remapped_args, yaml_create_transform) + + def _affinity(self, other): + raise NotImplementedError( + 'Should not be calling _affinity directly on this provider.') + + def underlying_provider(self): + return self._underlying_provider.underlying_provider() + + def parse_providers(provider_specs): providers = collections.defaultdict(list) for provider_spec in provider_specs: @@ -539,10 +599,12 @@ def merge_providers(*provider_sets): def standard_providers(): from apache_beam.yaml.yaml_mapping import create_mapping_provider + from apache_beam.yaml.yaml_io import io_providers with open(os.path.join(os.path.dirname(__file__), 'standard_providers.yaml')) as fin: standard_providers = yaml.load(fin, Loader=SafeLoader) return merge_providers( create_builtin_provider(), create_mapping_provider(), + io_providers(), parse_providers(standard_providers)) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 70cbf0b7cee3..dcc4a3ee7b4f 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -293,8 +293,11 @@ def unique_name(self, spec, ptransform, strictness=0): if 'name' in spec: name = spec['name'] strictness += 1 - else: + elif 'ExternalTransform' not in ptransform.label: + # The label may have interesting information. name = ptransform.label + else: + name = spec['type'] if name in self._seen_names: if strictness >= 2: raise ValueError(f'Duplicate name at {identify_object(spec)}: {name}')