Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,27 @@
import java.io.IOException;
import org.apache.avro.Schema;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.MimeTypes;

class AvroRowWriter<T> extends BigQueryRowWriter<T> {
private final DataFileWriter<GenericRecord> writer;
class AvroRowWriter<AvroT, T> extends BigQueryRowWriter<T> {
private final DataFileWriter<AvroT> writer;
private final Schema schema;
private final SerializableFunction<AvroWriteRequest<T>, GenericRecord> toAvroRecord;
private final SerializableFunction<AvroWriteRequest<T>, AvroT> toAvroRecord;

AvroRowWriter(
String basename,
Schema schema,
SerializableFunction<AvroWriteRequest<T>, GenericRecord> toAvroRecord)
SerializableFunction<AvroWriteRequest<T>, AvroT> toAvroRecord,
SerializableFunction<Schema, DatumWriter<AvroT>> writerFactory)
throws Exception {
super(basename, MimeTypes.BINARY);

this.schema = schema;
this.toAvroRecord = toAvroRecord;
this.writer =
new DataFileWriter<GenericRecord>(new GenericDatumWriter<>())
.create(schema, getOutputStream());
new DataFileWriter<>(writerFactory.apply(schema)).create(schema, getOutputStream());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.annotations.Experimental;
Expand Down Expand Up @@ -294,14 +296,16 @@
* <ul>
* <li>{@link BigQueryIO.Write#withAvroFormatFunction(SerializableFunction)} (recommended) to
* write data using avro records.
* <li>{@link BigQueryIO.Write#withAvroWriter} to write avro data using a user-specified {@link
* DatumWriter} (and format function).
* <li>{@link BigQueryIO.Write#withFormatFunction(SerializableFunction)} to write data as json
* encoded {@link TableRow TableRows}.
* </ul>
*
* If {@link BigQueryIO.Write#withAvroFormatFunction(SerializableFunction)} is used, the table
* schema MUST be specified using one of the {@link Write#withJsonSchema(String)}, {@link
* Write#withJsonSchema(ValueProvider)}, {@link Write#withSchemaFromView(PCollectionView)} methods,
* or {@link Write#to(DynamicDestinations)}.
* If {@link BigQueryIO.Write#withAvroFormatFunction(SerializableFunction)} or {@link
* BigQueryIO.Write#withAvroWriter} is used, the table schema MUST be specified using one of the
* {@link Write#withJsonSchema(String)}, {@link Write#withJsonSchema(ValueProvider)}, {@link
* Write#withSchemaFromView(PCollectionView)} methods, or {@link Write#to(DynamicDestinations)}.
*
* <pre>{@code
* class Quote {
Expand Down Expand Up @@ -488,6 +492,9 @@ public class BigQueryIO {
*/
static final SerializableFunction<TableRow, TableRow> IDENTITY_FORMATTER = input -> input;

static final SerializableFunction<org.apache.avro.Schema, DatumWriter<GenericRecord>>
GENERIC_DATUM_WRITER_FACTORY = schema -> new GenericDatumWriter<>();

private static final SerializableFunction<TableSchema, org.apache.avro.Schema>
DEFAULT_AVRO_SCHEMA_FACTORY =
new SerializableFunction<TableSchema, org.apache.avro.Schema>() {
Expand Down Expand Up @@ -1763,7 +1770,7 @@ public enum Method {
abstract SerializableFunction<T, TableRow> getFormatFunction();

@Nullable
abstract SerializableFunction<AvroWriteRequest<T>, GenericRecord> getAvroFormatFunction();
abstract RowWriterFactory.AvroRowWriterFactory<T, ?, ?> getAvroRowWriterFactory();

@Nullable
abstract SerializableFunction<TableSchema, org.apache.avro.Schema> getAvroSchemaFactory();
Expand Down Expand Up @@ -1851,8 +1858,8 @@ abstract Builder<T> setTableFunction(

abstract Builder<T> setFormatFunction(SerializableFunction<T, TableRow> formatFunction);

abstract Builder<T> setAvroFormatFunction(
SerializableFunction<AvroWriteRequest<T>, GenericRecord> avroFormatFunction);
abstract Builder<T> setAvroRowWriterFactory(
RowWriterFactory.AvroRowWriterFactory<T, ?, ?> avroRowWriterFactory);

abstract Builder<T> setAvroSchemaFactory(
SerializableFunction<TableSchema, org.apache.avro.Schema> avroSchemaFactory);
Expand Down Expand Up @@ -2056,13 +2063,43 @@ public Write<T> withFormatFunction(SerializableFunction<T, TableRow> formatFunct
}

/**
* Formats the user's type into a {@link GenericRecord} to be written to BigQuery.
* Formats the user's type into a {@link GenericRecord} to be written to BigQuery. The
* GenericRecords are written as avro using the standard {@link GenericDatumWriter}.
*
* <p>This is mutually exclusive with {@link #withFormatFunction}, only one may be set.
*/
public Write<T> withAvroFormatFunction(
SerializableFunction<AvroWriteRequest<T>, GenericRecord> avroFormatFunction) {
return toBuilder().setAvroFormatFunction(avroFormatFunction).setOptimizeWrites(true).build();
return withAvroWriter(avroFormatFunction, GENERIC_DATUM_WRITER_FACTORY);
}

/**
* Writes the user's type as avro using the supplied {@link DatumWriter}.
*
* <p>This is mutually exclusive with {@link #withFormatFunction}, only one may be set.
*
* <p>Overwrites {@link #withAvroFormatFunction} if it has been set.
*/
public Write<T> withAvroWriter(
SerializableFunction<org.apache.avro.Schema, DatumWriter<T>> writerFactory) {
return withAvroWriter(AvroWriteRequest::getElement, writerFactory);
}

/**
* Convert's the user's type to an avro record using the supplied avroFormatFunction. Records
* are then written using the supplied writer instances returned from writerFactory.
*
* <p>This is mutually exclusive with {@link #withFormatFunction}, only one may be set.
*
* <p>Overwrites {@link #withAvroFormatFunction} if it has been set.
*/
public <AvroT> Write<T> withAvroWriter(
SerializableFunction<AvroWriteRequest<T>, AvroT> avroFormatFunction,
SerializableFunction<org.apache.avro.Schema, DatumWriter<AvroT>> writerFactory) {
return toBuilder()
.setOptimizeWrites(true)
.setAvroRowWriterFactory(RowWriterFactory.avroRecords(avroFormatFunction, writerFactory))
.build();
}

/**
Expand Down Expand Up @@ -2484,7 +2521,7 @@ public WriteResult expand(PCollection<T> input) {
if (method != Method.FILE_LOADS) {
// we only support writing avro for FILE_LOADS
checkArgument(
getAvroFormatFunction() == null,
getAvroRowWriterFactory() == null,
"Writing avro formatted data is only supported for FILE_LOADS, however "
+ "the method was %s",
method);
Expand Down Expand Up @@ -2546,8 +2583,8 @@ private <DestinationT> WriteResult expandTyped(
PCollection<T> input, DynamicDestinations<T, DestinationT> dynamicDestinations) {
boolean optimizeWrites = getOptimizeWrites();
SerializableFunction<T, TableRow> formatFunction = getFormatFunction();
SerializableFunction<AvroWriteRequest<T>, GenericRecord> avroFormatFunction =
getAvroFormatFunction();
RowWriterFactory.AvroRowWriterFactory<T, ?, DestinationT> avroRowWriterFactory =
(RowWriterFactory.AvroRowWriterFactory<T, ?, DestinationT>) getAvroRowWriterFactory();

boolean hasSchema =
getJsonSchema() != null
Expand All @@ -2559,8 +2596,8 @@ private <DestinationT> WriteResult expandTyped(
optimizeWrites = true;

checkArgument(
avroFormatFunction == null,
"avroFormatFunction is unsupported when using Beam schemas.");
avroRowWriterFactory == null,
"avro avroFormatFunction is unsupported when using Beam schemas.");

if (formatFunction == null) {
// If no format function set, then we will automatically convert the input type to a
Expand Down Expand Up @@ -2593,10 +2630,10 @@ private <DestinationT> WriteResult expandTyped(
Method method = resolveMethod(input);
if (optimizeWrites) {
RowWriterFactory<T, DestinationT> rowWriterFactory;
if (avroFormatFunction != null) {
if (avroRowWriterFactory != null) {
checkArgument(
formatFunction == null,
"Only one of withFormatFunction or withAvroFormatFunction maybe set, not both.");
"Only one of withFormatFunction or withAvroFormatFunction/withAvroWriter maybe set, not both.");

SerializableFunction<TableSchema, org.apache.avro.Schema> avroSchemaFactory =
getAvroSchemaFactory();
Expand All @@ -2607,9 +2644,7 @@ private <DestinationT> WriteResult expandTyped(
+ "is set but no avroSchemaFactory is defined.");
avroSchemaFactory = DEFAULT_AVRO_SCHEMA_FACTORY;
}
rowWriterFactory =
RowWriterFactory.avroRecords(
avroFormatFunction, avroSchemaFactory, dynamicDestinations);
rowWriterFactory = avroRowWriterFactory.prepare(dynamicDestinations, avroSchemaFactory);
} else if (formatFunction != null) {
rowWriterFactory = RowWriterFactory.tableRows(formatFunction);
} else {
Expand All @@ -2634,7 +2669,7 @@ private <DestinationT> WriteResult expandTyped(
rowWriterFactory,
method);
} else {
checkArgument(avroFormatFunction == null);
checkArgument(avroRowWriterFactory == null);
checkArgument(
formatFunction != null,
"A function must be provided to convert the input type into a TableRow or "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import com.google.api.services.bigquery.model.TableSchema;
import java.io.Serializable;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.beam.sdk.transforms.SerializableFunction;

abstract class RowWriterFactory<ElementT, DestinationT> implements Serializable {
Expand Down Expand Up @@ -74,29 +74,38 @@ String getSourceFormat() {
}
}

static <ElementT, DestinationT> RowWriterFactory<ElementT, DestinationT> avroRecords(
SerializableFunction<AvroWriteRequest<ElementT>, GenericRecord> toAvro,
SerializableFunction<TableSchema, Schema> schemaFactory,
DynamicDestinations<?, DestinationT> dynamicDestinations) {
return new AvroRowWriterFactory<>(toAvro, schemaFactory, dynamicDestinations);
static <ElementT, AvroT, DestinationT>
AvroRowWriterFactory<ElementT, AvroT, DestinationT> avroRecords(
SerializableFunction<AvroWriteRequest<ElementT>, AvroT> toAvro,
SerializableFunction<Schema, DatumWriter<AvroT>> writerFactory) {
return new AvroRowWriterFactory<>(toAvro, writerFactory, null, null);
}

private static final class AvroRowWriterFactory<ElementT, DestinationT>
static final class AvroRowWriterFactory<ElementT, AvroT, DestinationT>
extends RowWriterFactory<ElementT, DestinationT> {

private final SerializableFunction<AvroWriteRequest<ElementT>, GenericRecord> toAvro;
private final SerializableFunction<AvroWriteRequest<ElementT>, AvroT> toAvro;
private final SerializableFunction<Schema, DatumWriter<AvroT>> writerFactory;
private final SerializableFunction<TableSchema, Schema> schemaFactory;
private final DynamicDestinations<?, DestinationT> dynamicDestinations;

private AvroRowWriterFactory(
SerializableFunction<AvroWriteRequest<ElementT>, GenericRecord> toAvro,
SerializableFunction<AvroWriteRequest<ElementT>, AvroT> toAvro,
SerializableFunction<Schema, DatumWriter<AvroT>> writerFactory,
SerializableFunction<TableSchema, Schema> schemaFactory,
DynamicDestinations<?, DestinationT> dynamicDestinations) {
this.toAvro = toAvro;
this.writerFactory = writerFactory;
this.schemaFactory = schemaFactory;
this.dynamicDestinations = dynamicDestinations;
}

AvroRowWriterFactory<ElementT, AvroT, DestinationT> prepare(
DynamicDestinations<?, DestinationT> dynamicDestinations,
SerializableFunction<TableSchema, Schema> schemaFactory) {
return new AvroRowWriterFactory<>(toAvro, writerFactory, schemaFactory, dynamicDestinations);
}

@Override
OutputType getOutputType() {
return OutputType.AvroGenericRecord;
Expand All @@ -107,7 +116,7 @@ BigQueryRowWriter<ElementT> createRowWriter(String tempFilePrefix, DestinationT
throws Exception {
TableSchema tableSchema = dynamicDestinations.getSchema(destination);
Schema avroSchema = schemaFactory.apply(tableSchema);
return new AvroRowWriter<>(tempFilePrefix, avroSchema, toAvro);
return new AvroRowWriter<>(tempFilePrefix, avroSchema, toAvro, writerFactory);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.Encoder;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
Expand Down Expand Up @@ -782,6 +785,66 @@ public void testWriteAvro() throws Exception {
.set("instantVal", "2019-02-01 00:00:00 UTC")));
}

@Test
public void testWriteAvroWithCustomWriter() throws Exception {
SerializableFunction<AvroWriteRequest<InputRecord>, GenericRecord> formatFunction =
r -> {
GenericRecord rec = new GenericData.Record(r.getSchema());
InputRecord i = r.getElement();
rec.put("strVal", i.strVal());
rec.put("longVal", i.longVal());
rec.put("doubleVal", i.doubleVal());
rec.put("instantVal", i.instantVal().getMillis() * 1000);
return rec;
};

SerializableFunction<org.apache.avro.Schema, DatumWriter<GenericRecord>> customWriterFactory =
s ->
new GenericDatumWriter<GenericRecord>() {
@Override
protected void writeString(org.apache.avro.Schema schema, Object datum, Encoder out)
throws IOException {
super.writeString(schema, datum.toString() + "_custom", out);
}
};

p.apply(
Create.of(
InputRecord.create("test", 1, 1.0, Instant.parse("2019-01-01T00:00:00Z")),
InputRecord.create("test2", 2, 2.0, Instant.parse("2019-02-01T00:00:00Z")))
.withCoder(INPUT_RECORD_CODER))
.apply(
BigQueryIO.<InputRecord>write()
.to("dataset-id.table-id")
.withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)
.withSchema(
new TableSchema()
.setFields(
ImmutableList.of(
new TableFieldSchema().setName("strVal").setType("STRING"),
new TableFieldSchema().setName("longVal").setType("INTEGER"),
new TableFieldSchema().setName("doubleVal").setType("FLOAT"),
new TableFieldSchema().setName("instantVal").setType("TIMESTAMP"))))
.withTestServices(fakeBqServices)
.withAvroWriter(formatFunction, customWriterFactory)
.withoutValidation());
p.run();

assertThat(
fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"),
containsInAnyOrder(
new TableRow()
.set("strVal", "test_custom")
.set("longVal", "1")
.set("doubleVal", 1.0D)
.set("instantVal", "2019-01-01 00:00:00 UTC"),
new TableRow()
.set("strVal", "test2_custom")
.set("longVal", "2")
.set("doubleVal", 2.0D)
.set("instantVal", "2019-02-01 00:00:00 UTC")));
}

@Test
public void testStreamingWrite() throws Exception {
p.apply(
Expand Down Expand Up @@ -1352,7 +1415,7 @@ public void testWriteValidateFailsBothFormatFunctions() {

thrown.expect(IllegalArgumentException.class);
thrown.expectMessage(
"Only one of withFormatFunction or withAvroFormatFunction maybe set, not both");
"Only one of withFormatFunction or withAvroFormatFunction/withAvroWriter maybe set, not both.");
p.apply(Create.empty(INPUT_RECORD_CODER))
.apply(
BigQueryIO.<InputRecord>write()
Expand Down