Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,8 @@ public abstract static class TypedWrite<UserT, DestinationT, OutputT>
abstract @Nullable DynamicAvroDestinations<UserT, DestinationT, OutputT>
getDynamicDestinations();

abstract AvroSink.@Nullable DatumWriterFactory<OutputT> getDatumWriterFactory();

/**
* The codec used to encode the blocks in the Avro file. String value drawn from those in
* https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html
Expand Down Expand Up @@ -1305,6 +1307,9 @@ abstract Builder<UserT, DestinationT, OutputT> setMetadata(
abstract Builder<UserT, DestinationT, OutputT> setDynamicDestinations(
DynamicAvroDestinations<UserT, DestinationT, OutputT> dynamicDestinations);

abstract Builder<UserT, DestinationT, OutputT> setDatumWriterFactory(
AvroSink.DatumWriterFactory<OutputT> datumWriterFactory);

abstract TypedWrite<UserT, DestinationT, OutputT> build();
}

Expand Down Expand Up @@ -1498,6 +1503,15 @@ public TypedWrite<UserT, DestinationT, OutputT> withCodec(CodecFactory codec) {
return toBuilder().setCodec(new SerializableAvroCodecFactory(codec)).build();
}

/**
* Specifies a {@link AvroSink.DatumWriterFactory} to use for creating {@link
* org.apache.avro.io.DatumWriter} instances.
*/
public TypedWrite<UserT, DestinationT, OutputT> withDatumWriterFactory(
AvroSink.DatumWriterFactory<OutputT> datumWriterFactory) {
return toBuilder().setDatumWriterFactory(datumWriterFactory).build();
}

/**
* Writes to Avro file(s) with the specified metadata.
*
Expand Down Expand Up @@ -1539,7 +1553,8 @@ DynamicAvroDestinations<UserT, DestinationT, OutputT> resolveDynamicDestinations
getSchema(),
getMetadata(),
getCodec().getCodec(),
getFormatFunction());
getFormatFunction(),
getDatumWriterFactory());
}
return dynamicDestinations;
}
Expand Down Expand Up @@ -1700,6 +1715,11 @@ public Write<T> withCodec(CodecFactory codec) {
return new Write<>(inner.withCodec(codec));
}

/** See {@link TypedWrite#withDatumWriterFactory}. */
public Write<T> withDatumWriterFactory(AvroSink.DatumWriterFactory<T> datumWriterFactory) {
return new Write<>(inner.withDatumWriterFactory(datumWriterFactory));
}

/**
* Specify that output filenames are wanted.
*
Expand Down Expand Up @@ -1742,7 +1762,22 @@ public static <UserT, OutputT> DynamicAvroDestinations<UserT, Void, OutputT> con
Map<String, Object> metadata,
CodecFactory codec,
SerializableFunction<UserT, OutputT> formatFunction) {
return new ConstantAvroDestination<>(filenamePolicy, schema, metadata, codec, formatFunction);
return constantDestinations(filenamePolicy, schema, metadata, codec, formatFunction, null);
}

/**
* Returns a {@link DynamicAvroDestinations} that always returns the same {@link FilenamePolicy},
* schema, metadata, and codec.
*/
public static <UserT, OutputT> DynamicAvroDestinations<UserT, Void, OutputT> constantDestinations(
FilenamePolicy filenamePolicy,
Schema schema,
Map<String, Object> metadata,
CodecFactory codec,
SerializableFunction<UserT, OutputT> formatFunction,
AvroSink.@Nullable DatumWriterFactory<OutputT> datumWriterFactory) {
return new ConstantAvroDestination<>(
filenamePolicy, schema, metadata, codec, formatFunction, datumWriterFactory);
}
/////////////////////////////////////////////////////////////////////////////

Expand Down
27 changes: 21 additions & 6 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io;

import java.io.Serializable;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.Map;
Expand All @@ -32,9 +33,15 @@
import org.checkerframework.checker.nullness.qual.Nullable;

/** A {@link FileBasedSink} for Avro files. */
class AvroSink<UserT, DestinationT, OutputT> extends FileBasedSink<UserT, DestinationT, OutputT> {
public class AvroSink<UserT, DestinationT, OutputT>
extends FileBasedSink<UserT, DestinationT, OutputT> {
private final boolean genericRecords;

@FunctionalInterface
public interface DatumWriterFactory<T> extends Serializable {
DatumWriter<T> apply(Schema writer);
}

AvroSink(
ValueProvider<ResourceId> outputPrefix,
DynamicAvroDestinations<UserT, DestinationT, OutputT> dynamicDestinations,
Expand All @@ -57,7 +64,7 @@ public WriteOperation<DestinationT, OutputT> createWriteOperation() {
/** A {@link WriteOperation WriteOperation} for Avro files. */
private static class AvroWriteOperation<DestinationT, OutputT>
extends WriteOperation<DestinationT, OutputT> {
private final DynamicAvroDestinations<?, DestinationT, ?> dynamicDestinations;
private final DynamicAvroDestinations<?, DestinationT, OutputT> dynamicDestinations;
private final boolean genericRecords;

private AvroWriteOperation(AvroSink<?, DestinationT, OutputT> sink, boolean genericRecords) {
Expand All @@ -78,12 +85,12 @@ private static class AvroWriter<DestinationT, OutputT> extends Writer<Destinatio
// Initialized in prepareWrite
private @Nullable DataFileWriter<OutputT> dataFileWriter;

private final DynamicAvroDestinations<?, DestinationT, ?> dynamicDestinations;
private final DynamicAvroDestinations<?, DestinationT, OutputT> dynamicDestinations;
private final boolean genericRecords;

public AvroWriter(
WriteOperation<DestinationT, OutputT> writeOperation,
DynamicAvroDestinations<?, DestinationT, ?> dynamicDestinations,
DynamicAvroDestinations<?, DestinationT, OutputT> dynamicDestinations,
boolean genericRecords) {
super(writeOperation, MimeTypes.BINARY);
this.dynamicDestinations = dynamicDestinations;
Expand All @@ -97,9 +104,17 @@ protected void prepareWrite(WritableByteChannel channel) throws Exception {
CodecFactory codec = dynamicDestinations.getCodec(destination);
Schema schema = dynamicDestinations.getSchema(destination);
Map<String, Object> metadata = dynamicDestinations.getMetadata(destination);
DatumWriter<OutputT> datumWriter;
DatumWriterFactory<OutputT> datumWriterFactory =
dynamicDestinations.getDatumWriterFactory(destination);

if (datumWriterFactory == null) {
datumWriter =
genericRecords ? new GenericDatumWriter<>(schema) : new ReflectDatumWriter<>(schema);
} else {
datumWriter = datumWriterFactory.apply(schema);
}

DatumWriter<OutputT> datumWriter =
genericRecords ? new GenericDatumWriter<>(schema) : new ReflectDatumWriter<>(schema);
dataFileWriter = new DataFileWriter<>(datumWriter).setCodec(codec);
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
Object v = entry.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public Schema apply(String input) {
private final Map<String, Object> metadata;
private final SerializableAvroCodecFactory codec;
private final SerializableFunction<UserT, OutputT> formatFunction;
private final AvroSink.DatumWriterFactory<OutputT> datumWriterFactory;

private class Metadata implements HasDisplayData {
@Override
Expand All @@ -74,11 +75,22 @@ public ConstantAvroDestination(
Map<String, Object> metadata,
CodecFactory codec,
SerializableFunction<UserT, OutputT> formatFunction) {
this(filenamePolicy, schema, metadata, codec, formatFunction, null);
}

public ConstantAvroDestination(
FilenamePolicy filenamePolicy,
Schema schema,
Map<String, Object> metadata,
CodecFactory codec,
SerializableFunction<UserT, OutputT> formatFunction,
AvroSink.@Nullable DatumWriterFactory<OutputT> datumWriterFactory) {
this.filenamePolicy = filenamePolicy;
this.schema = Suppliers.compose(new SchemaFunction(), Suppliers.ofInstance(schema.toString()));
this.metadata = metadata;
this.codec = new SerializableAvroCodecFactory(codec);
this.formatFunction = formatFunction;
this.datumWriterFactory = datumWriterFactory;
}

@Override
Expand Down Expand Up @@ -116,6 +128,11 @@ public CodecFactory getCodec(Void destination) {
return codec.getCodec();
}

@Override
public AvroSink.@Nullable DatumWriterFactory<OutputT> getDatumWriterFactory(Void destination) {
return datumWriterFactory;
}

@Override
public void populateDisplayData(DisplayData.Builder builder) {
filenamePolicy.populateDisplayData(builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
import org.apache.avro.file.CodecFactory;
import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* A specialization of {@link DynamicDestinations} for {@link AvroIO}. In addition to dynamic file
* destinations, this allows specifying other AVRO properties (schema, metadata, codec) per
* destination.
* destinations, this allows specifying other AVRO properties (schema, metadata, codec, datum
* writer) per destination.
*/
public abstract class DynamicAvroDestinations<UserT, DestinationT, OutputT>
extends DynamicDestinations<UserT, DestinationT, OutputT> {
Expand All @@ -42,4 +43,13 @@ public Map<String, Object> getMetadata(DestinationT destination) {
public CodecFactory getCodec(DestinationT destination) {
return AvroIO.TypedWrite.DEFAULT_CODEC;
}

/**
* Return a {@link AvroSink.DatumWriterFactory} for a given destination. If provided, it will be
* used to created {@link org.apache.avro.io.DatumWriter} instances as required.
*/
public AvroSink.@Nullable DatumWriterFactory<OutputT> getDatumWriterFactory(
DestinationT destinationT) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@
import java.util.Objects;
import java.util.Random;
import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileReader;
import org.apache.avro.file.DataFileStream;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.GenericRecordBuilder;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.Encoder;
import org.apache.avro.reflect.ReflectData;
import org.apache.avro.reflect.ReflectDatumReader;
import org.apache.beam.sdk.coders.AvroCoder;
Expand Down Expand Up @@ -1486,6 +1491,73 @@ public void testAvroSinkShardedWrite() throws Exception {

runTestWrite(expectedElements, 4);
}

@Test
@Category(NeedsRunner.class)
public void testAvroSinkWriteWithCustomFactory() throws Exception {
Integer[] expectedElements = new Integer[] {1, 2, 3, 4, 5};

File baseOutputFile = new File(tmpFolder.getRoot(), "prefix");
String outputFilePrefix = baseOutputFile.getAbsolutePath();

Schema recordSchema = SchemaBuilder.record("root").fields().requiredInt("i1").endRecord();

AvroIO.TypedWrite<Integer, Void, Integer> write =
AvroIO.<Integer, Integer>writeCustomType()
.to(outputFilePrefix)
.withSchema(recordSchema)
.withFormatFunction(f -> f)
.withDatumWriterFactory(
f ->
new DatumWriter<Integer>() {
private DatumWriter<GenericRecord> inner = new GenericDatumWriter<>(f);

@Override
public void setSchema(Schema schema) {
inner.setSchema(schema);
}

@Override
public void write(Integer datum, Encoder out) throws IOException {
GenericRecord record =
new GenericRecordBuilder(f).set("i1", datum).build();
inner.write(record, out);
}
})
.withSuffix(".avro");

write = write.withoutSharding();

writePipeline.apply(Create.of(ImmutableList.copyOf(expectedElements))).apply(write);
writePipeline.run();

File expectedFile =
new File(
DefaultFilenamePolicy.constructName(
FileBasedSink.convertToFileResourceIfPossible(outputFilePrefix),
"",
".avro",
1,
1,
null,
null)
.toString());

assertTrue("Expected output file " + expectedFile.getName(), expectedFile.exists());
DataFileReader<GenericRecord> dataFileReader =
new DataFileReader<>(expectedFile, new GenericDatumReader<>(recordSchema));

List<GenericRecord> actualRecords = new ArrayList<>();
Iterators.addAll(actualRecords, dataFileReader);

GenericRecord[] expectedRecords =
Arrays.stream(expectedElements)
.map(i -> new GenericRecordBuilder(recordSchema).set("i1", i).build())
.toArray(GenericRecord[]::new);

assertThat(actualRecords, containsInAnyOrder(expectedRecords));
}

// TODO: for Write only, test withSuffix,
// withShardNameTemplate and withoutSharding.
}
Expand Down