Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageApiInsertError;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils;
import org.apache.beam.sdk.io.gcp.bigquery.WriteResult;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration;
Expand Down Expand Up @@ -94,7 +95,7 @@ protected SchemaTransform from(

@Override
public String identifier() {
return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v1");
return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v2");
}

@Override
Expand Down Expand Up @@ -125,6 +126,24 @@ public abstract static class BigQueryStorageWriteApiSchemaTransformConfiguration
.put(WriteDisposition.WRITE_APPEND.name(), WriteDisposition.WRITE_APPEND)
.build();

@AutoValue
public abstract static class ErrorHandling {
@SchemaFieldDescription("The name of the output PCollection containing failed writes.")
public abstract String getOutput();

public static Builder builder() {
return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration_ErrorHandling
.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setOutput(String output);

public abstract ErrorHandling build();
}
}

public void validate() {
String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: ";

Expand All @@ -151,6 +170,12 @@ public void validate() {
this.getWriteDisposition(),
WRITE_DISPOSITIONS.keySet());
}

if (this.getErrorHandling() != null) {
checkArgument(
!Strings.isNullOrEmpty(this.getErrorHandling().getOutput()),
invalidConfigMessage + "Output must not be empty if error handling specified.");
}
}

/**
Expand Down Expand Up @@ -198,6 +223,10 @@ public static Builder builder() {
@Nullable
public abstract Boolean getAutoSharding();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
@Nullable
public abstract ErrorHandling getErrorHandling();

/** Builder for {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. */
@AutoValue.Builder
public abstract static class Builder {
Expand All @@ -214,6 +243,8 @@ public abstract static class Builder {

public abstract Builder setAutoSharding(Boolean autoSharding);

public abstract Builder setErrorHandling(ErrorHandling errorHandling);

/** Builds a {@link BigQueryStorageWriteApiSchemaTransformConfiguration} instance. */
public abstract BigQueryStorageWriteApiSchemaTransformProvider
.BigQueryStorageWriteApiSchemaTransformConfiguration
Expand Down Expand Up @@ -244,7 +275,7 @@ public void setBigQueryServices(BigQueryServices testBigQueryServices) {

// A generic counter for PCollection of Row. Will be initialized with the given
// name argument. Performs element-wise counter of the input PCollection.
private static class ElementCounterFn extends DoFn<Row, Row> {
private static class ElementCounterFn<T> extends DoFn<T, T> {

private Counter bqGenericElementCounter;
private Long elementsInBundle = 0L;
Expand All @@ -267,6 +298,18 @@ public void finish(FinishBundleContext c) {
}
}

private static class FailOnError extends DoFn<BigQueryStorageApiInsertError, Void> {
@ProcessElement
public void process(ProcessContext c) {
throw new RuntimeException(c.element().getErrorMessage());
}
}

private static class NoOutputDoFn<T> extends DoFn<T, Row> {
@ProcessElement
public void process(ProcessContext c) {}
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
// Check that the input exists
Expand Down Expand Up @@ -294,53 +337,55 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
WriteResult result =
inputRows
.apply(
"element-count", ParDo.of(new ElementCounterFn("BigQuery-write-element-counter")))
"element-count",
ParDo.of(new ElementCounterFn<Row>("BigQuery-write-element-counter")))
.setRowSchema(inputSchema)
.apply(write);

Schema rowSchema = inputRows.getSchema();
Schema errorSchema =
Schema.of(
Field.of("failed_row", FieldType.row(rowSchema)),
Field.of("error_message", FieldType.STRING));

// Failed rows
PCollection<Row> failedRows =
result
.getFailedStorageApiInserts()
.apply(
"Construct failed rows",
MapElements.into(TypeDescriptors.rows())
.via(
(storageError) ->
BigQueryUtils.toBeamRow(rowSchema, storageError.getRow())))
.setRowSchema(rowSchema);

// Failed rows with error message
PCollection<Row> failedRowsWithErrors =
// Give something that can be followed.
PCollection<Row> postWrite =
result
.getFailedStorageApiInserts()
.apply(
"Construct failed rows and errors",
MapElements.into(TypeDescriptors.rows())
.via(
(storageError) ->
Row.withSchema(errorSchema)
.withFieldValue("error_message", storageError.getErrorMessage())
.withFieldValue(
"failed_row",
BigQueryUtils.toBeamRow(rowSchema, storageError.getRow()))
.build()))
.setRowSchema(errorSchema);

PCollection<Row> failedRowsOutput =
failedRows
.apply("error-count", ParDo.of(new ElementCounterFn("BigQuery-write-error-counter")))
.setRowSchema(rowSchema);

return PCollectionRowTuple.of(FAILED_ROWS_TAG, failedRowsOutput)
.and(FAILED_ROWS_WITH_ERRORS_TAG, failedRowsWithErrors)
.and("errors", failedRowsWithErrors);
.apply("post-write", ParDo.of(new NoOutputDoFn<BigQueryStorageApiInsertError>()))
.setRowSchema(Schema.of());

if (configuration.getErrorHandling() == null) {
result
.getFailedStorageApiInserts()
.apply("Error on failed inserts", ParDo.of(new FailOnError()));
return PCollectionRowTuple.of("post_write", postWrite);
} else {
result
.getFailedStorageApiInserts()
.apply(
"error-count",
ParDo.of(
new ElementCounterFn<BigQueryStorageApiInsertError>(
"BigQuery-write-error-counter")));

// Failed rows with error message
Schema errorSchema =
Schema.of(
Field.of("failed_row", FieldType.row(inputSchema)),
Field.of("error_message", FieldType.STRING));
PCollection<Row> failedRowsWithErrors =
result
.getFailedStorageApiInserts()
.apply(
"Construct failed rows and errors",
MapElements.into(TypeDescriptors.rows())
.via(
(storageError) ->
Row.withSchema(errorSchema)
.withFieldValue("error_message", storageError.getErrorMessage())
.withFieldValue(
"failed_row",
BigQueryUtils.toBeamRow(inputSchema, storageError.getRow()))
.build()))
.setRowSchema(errorSchema);
return PCollectionRowTuple.of("post_write", postWrite)
.and(configuration.getErrorHandling().getOutput(), failedRowsWithErrors);
}
}

BigQueryIO.Write<Row> createStorageWriteApiTransform() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -211,7 +213,13 @@ public void testInputElementCount() throws Exception {
public void testFailedRows() throws Exception {
String tableSpec = "project:dataset.write_with_fail";
BigQueryStorageWriteApiSchemaTransformConfiguration config =
BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build();
BigQueryStorageWriteApiSchemaTransformConfiguration.builder()
.setTable(tableSpec)
.setErrorHandling(
BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder()
.setOutput("FailedRows")
.build())
.build();

String failValue = "fail_me";

Expand All @@ -234,7 +242,15 @@ public void testFailedRows() throws Exception {
fakeDatasetService.setShouldFailRow(shouldFailRow);

PCollectionRowTuple result = runWithConfig(config, totalRows);
PCollection<Row> failedRows = result.get("FailedRows");
PCollection<Row> failedRows =
result
.get("FailedRows")
.apply(
"ExtractFailedRows",
MapElements.into(TypeDescriptors.rows())
.via((rowAndError) -> rowAndError.<Row>getValue("failed_row")))
.setRowSchema(SCHEMA);
;

PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows);
p.run().waitUntilFinish();
Expand All @@ -250,7 +266,13 @@ public void testFailedRows() throws Exception {
public void testErrorCount() throws Exception {
String tableSpec = "project:dataset.error_count";
BigQueryStorageWriteApiSchemaTransformConfiguration config =
BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build();
BigQueryStorageWriteApiSchemaTransformConfiguration.builder()
.setTable(tableSpec)
.setErrorHandling(
BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder()
.setOutput("FailedRows")
.build())
.build();

Function<TableRow, Boolean> shouldFailRow =
(Function<TableRow, Boolean> & Serializable) tr -> tr.get("name").equals("a");
Expand Down
Loading