diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java index fda17f6693af..0e69218bc346 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java @@ -1257,6 +1257,8 @@ public abstract static class TypedWrite abstract @Nullable DynamicAvroDestinations getDynamicDestinations(); + abstract AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory(); + /** * The codec used to encode the blocks in the Avro file. String value drawn from those in * https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html @@ -1305,6 +1307,9 @@ abstract Builder setMetadata( abstract Builder setDynamicDestinations( DynamicAvroDestinations dynamicDestinations); + abstract Builder setDatumWriterFactory( + AvroSink.DatumWriterFactory datumWriterFactory); + abstract TypedWrite build(); } @@ -1498,6 +1503,15 @@ public TypedWrite withCodec(CodecFactory codec) { return toBuilder().setCodec(new SerializableAvroCodecFactory(codec)).build(); } + /** + * Specifies a {@link AvroSink.DatumWriterFactory} to use for creating {@link + * org.apache.avro.io.DatumWriter} instances. + */ + public TypedWrite withDatumWriterFactory( + AvroSink.DatumWriterFactory datumWriterFactory) { + return toBuilder().setDatumWriterFactory(datumWriterFactory).build(); + } + /** * Writes to Avro file(s) with the specified metadata. * @@ -1539,7 +1553,8 @@ DynamicAvroDestinations resolveDynamicDestinations getSchema(), getMetadata(), getCodec().getCodec(), - getFormatFunction()); + getFormatFunction(), + getDatumWriterFactory()); } return dynamicDestinations; } @@ -1700,6 +1715,11 @@ public Write withCodec(CodecFactory codec) { return new Write<>(inner.withCodec(codec)); } + /** See {@link TypedWrite#withDatumWriterFactory}. */ + public Write withDatumWriterFactory(AvroSink.DatumWriterFactory datumWriterFactory) { + return new Write<>(inner.withDatumWriterFactory(datumWriterFactory)); + } + /** * Specify that output filenames are wanted. * @@ -1742,7 +1762,22 @@ public static DynamicAvroDestinations con Map metadata, CodecFactory codec, SerializableFunction formatFunction) { - return new ConstantAvroDestination<>(filenamePolicy, schema, metadata, codec, formatFunction); + return constantDestinations(filenamePolicy, schema, metadata, codec, formatFunction, null); + } + + /** + * Returns a {@link DynamicAvroDestinations} that always returns the same {@link FilenamePolicy}, + * schema, metadata, and codec. + */ + public static DynamicAvroDestinations constantDestinations( + FilenamePolicy filenamePolicy, + Schema schema, + Map metadata, + CodecFactory codec, + SerializableFunction formatFunction, + AvroSink.@Nullable DatumWriterFactory datumWriterFactory) { + return new ConstantAvroDestination<>( + filenamePolicy, schema, metadata, codec, formatFunction, datumWriterFactory); } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java index 2ee0181eefc6..fe198b1c155a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io; +import java.io.Serializable; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Map; @@ -32,9 +33,15 @@ import org.checkerframework.checker.nullness.qual.Nullable; /** A {@link FileBasedSink} for Avro files. */ -class AvroSink extends FileBasedSink { +public class AvroSink + extends FileBasedSink { private final boolean genericRecords; + @FunctionalInterface + public interface DatumWriterFactory extends Serializable { + DatumWriter apply(Schema writer); + } + AvroSink( ValueProvider outputPrefix, DynamicAvroDestinations dynamicDestinations, @@ -57,7 +64,7 @@ public WriteOperation createWriteOperation() { /** A {@link WriteOperation WriteOperation} for Avro files. */ private static class AvroWriteOperation extends WriteOperation { - private final DynamicAvroDestinations dynamicDestinations; + private final DynamicAvroDestinations dynamicDestinations; private final boolean genericRecords; private AvroWriteOperation(AvroSink sink, boolean genericRecords) { @@ -78,12 +85,12 @@ private static class AvroWriter extends Writer dataFileWriter; - private final DynamicAvroDestinations dynamicDestinations; + private final DynamicAvroDestinations dynamicDestinations; private final boolean genericRecords; public AvroWriter( WriteOperation writeOperation, - DynamicAvroDestinations dynamicDestinations, + DynamicAvroDestinations dynamicDestinations, boolean genericRecords) { super(writeOperation, MimeTypes.BINARY); this.dynamicDestinations = dynamicDestinations; @@ -97,9 +104,17 @@ protected void prepareWrite(WritableByteChannel channel) throws Exception { CodecFactory codec = dynamicDestinations.getCodec(destination); Schema schema = dynamicDestinations.getSchema(destination); Map metadata = dynamicDestinations.getMetadata(destination); + DatumWriter datumWriter; + DatumWriterFactory datumWriterFactory = + dynamicDestinations.getDatumWriterFactory(destination); + + if (datumWriterFactory == null) { + datumWriter = + genericRecords ? new GenericDatumWriter<>(schema) : new ReflectDatumWriter<>(schema); + } else { + datumWriter = datumWriterFactory.apply(schema); + } - DatumWriter datumWriter = - genericRecords ? new GenericDatumWriter<>(schema) : new ReflectDatumWriter<>(schema); dataFileWriter = new DataFileWriter<>(datumWriter).setCodec(codec); for (Map.Entry entry : metadata.entrySet()) { Object v = entry.getValue(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java index 85a0627b1214..cf76cdd7ebe3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java @@ -48,6 +48,7 @@ public Schema apply(String input) { private final Map metadata; private final SerializableAvroCodecFactory codec; private final SerializableFunction formatFunction; + private final AvroSink.DatumWriterFactory datumWriterFactory; private class Metadata implements HasDisplayData { @Override @@ -74,11 +75,22 @@ public ConstantAvroDestination( Map metadata, CodecFactory codec, SerializableFunction formatFunction) { + this(filenamePolicy, schema, metadata, codec, formatFunction, null); + } + + public ConstantAvroDestination( + FilenamePolicy filenamePolicy, + Schema schema, + Map metadata, + CodecFactory codec, + SerializableFunction formatFunction, + AvroSink.@Nullable DatumWriterFactory datumWriterFactory) { this.filenamePolicy = filenamePolicy; this.schema = Suppliers.compose(new SchemaFunction(), Suppliers.ofInstance(schema.toString())); this.metadata = metadata; this.codec = new SerializableAvroCodecFactory(codec); this.formatFunction = formatFunction; + this.datumWriterFactory = datumWriterFactory; } @Override @@ -116,6 +128,11 @@ public CodecFactory getCodec(Void destination) { return codec.getCodec(); } + @Override + public AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory(Void destination) { + return datumWriterFactory; + } + @Override public void populateDisplayData(DisplayData.Builder builder) { filenamePolicy.populateDisplayData(builder); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java index 023d397f6982..4bb450bffe1c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java @@ -22,11 +22,12 @@ import org.apache.avro.file.CodecFactory; import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; /** * A specialization of {@link DynamicDestinations} for {@link AvroIO}. In addition to dynamic file - * destinations, this allows specifying other AVRO properties (schema, metadata, codec) per - * destination. + * destinations, this allows specifying other AVRO properties (schema, metadata, codec, datum + * writer) per destination. */ public abstract class DynamicAvroDestinations extends DynamicDestinations { @@ -42,4 +43,13 @@ public Map getMetadata(DestinationT destination) { public CodecFactory getCodec(DestinationT destination) { return AvroIO.TypedWrite.DEFAULT_CODEC; } + + /** + * Return a {@link AvroSink.DatumWriterFactory} for a given destination. If provided, it will be + * used to created {@link org.apache.avro.io.DatumWriter} instances as required. + */ + public AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory( + DestinationT destinationT) { + return null; + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java index a00b69ad9f3e..c8b3e56588e5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java @@ -47,12 +47,17 @@ import java.util.Objects; import java.util.Random; import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; import org.apache.avro.file.CodecFactory; import org.apache.avro.file.DataFileReader; import org.apache.avro.file.DataFileStream; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.Encoder; import org.apache.avro.reflect.ReflectData; import org.apache.avro.reflect.ReflectDatumReader; import org.apache.beam.sdk.coders.AvroCoder; @@ -1486,6 +1491,73 @@ public void testAvroSinkShardedWrite() throws Exception { runTestWrite(expectedElements, 4); } + + @Test + @Category(NeedsRunner.class) + public void testAvroSinkWriteWithCustomFactory() throws Exception { + Integer[] expectedElements = new Integer[] {1, 2, 3, 4, 5}; + + File baseOutputFile = new File(tmpFolder.getRoot(), "prefix"); + String outputFilePrefix = baseOutputFile.getAbsolutePath(); + + Schema recordSchema = SchemaBuilder.record("root").fields().requiredInt("i1").endRecord(); + + AvroIO.TypedWrite write = + AvroIO.writeCustomType() + .to(outputFilePrefix) + .withSchema(recordSchema) + .withFormatFunction(f -> f) + .withDatumWriterFactory( + f -> + new DatumWriter() { + private DatumWriter inner = new GenericDatumWriter<>(f); + + @Override + public void setSchema(Schema schema) { + inner.setSchema(schema); + } + + @Override + public void write(Integer datum, Encoder out) throws IOException { + GenericRecord record = + new GenericRecordBuilder(f).set("i1", datum).build(); + inner.write(record, out); + } + }) + .withSuffix(".avro"); + + write = write.withoutSharding(); + + writePipeline.apply(Create.of(ImmutableList.copyOf(expectedElements))).apply(write); + writePipeline.run(); + + File expectedFile = + new File( + DefaultFilenamePolicy.constructName( + FileBasedSink.convertToFileResourceIfPossible(outputFilePrefix), + "", + ".avro", + 1, + 1, + null, + null) + .toString()); + + assertTrue("Expected output file " + expectedFile.getName(), expectedFile.exists()); + DataFileReader dataFileReader = + new DataFileReader<>(expectedFile, new GenericDatumReader<>(recordSchema)); + + List actualRecords = new ArrayList<>(); + Iterators.addAll(actualRecords, dataFileReader); + + GenericRecord[] expectedRecords = + Arrays.stream(expectedElements) + .map(i -> new GenericRecordBuilder(recordSchema).set("i1", i).build()) + .toArray(GenericRecord[]::new); + + assertThat(actualRecords, containsInAnyOrder(expectedRecords)); + } + // TODO: for Write only, test withSuffix, // withShardNameTemplate and withoutSharding. }