Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ public PCollection<T> expand(PBegin input) {
getFilepattern(),
getMatchConfiguration().getEmptyMatchTreatment(),
getRecordClass(),
getSchema())));
getSchema(),
null)));
return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read;
}

Expand Down Expand Up @@ -734,9 +735,14 @@ private static <T> AvroSource<T> createSource(
ValueProvider<String> filepattern,
EmptyMatchTreatment emptyMatchTreatment,
Class<T> recordClass,
Schema schema) {
Schema schema,
@Nullable AvroSource.DatumReaderFactory<T> readerFactory) {
AvroSource<?> source =
AvroSource.from(filepattern).withEmptyMatchTreatment(emptyMatchTreatment);

if (readerFactory != null) {
source = source.withDatumReaderFactory(readerFactory);
}
return recordClass == GenericRecord.class
? (AvroSource<T>) source.withSchema(schema)
: source.withSchema(recordClass);
Expand All @@ -759,6 +765,9 @@ public abstract static class ReadFiles<T>

abstract boolean getInferBeamSchema();

@Nullable
abstract AvroSource.DatumReaderFactory<T> getDatumReaderFactory();

abstract Builder<T> toBuilder();

@AutoValue.Builder
Expand All @@ -771,6 +780,8 @@ abstract static class Builder<T> {

abstract Builder<T> setInferBeamSchema(boolean infer);

abstract Builder<T> setDatumReaderFactory(AvroSource.DatumReaderFactory<T> factory);

abstract ReadFiles<T> build();
}

Expand All @@ -788,6 +799,10 @@ public ReadFiles<T> withBeamSchemas(boolean withBeamSchemas) {
return toBuilder().setInferBeamSchema(withBeamSchemas).build();
}

public ReadFiles<T> withDatumReaderFactory(AvroSource.DatumReaderFactory<T> factory) {
return toBuilder().setDatumReaderFactory(factory).build();
}

@Override
public PCollection<T> expand(PCollection<FileIO.ReadableFile> input) {
checkNotNull(getSchema(), "schema");
Expand All @@ -796,7 +811,8 @@ public PCollection<T> expand(PCollection<FileIO.ReadableFile> input) {
"Read all via FileBasedSource",
new ReadAllViaFileBasedSource<>(
getDesiredBundleSizeBytes(),
new CreateSourceFn<>(getRecordClass(), getSchema().toString()),
new CreateSourceFn<>(
getRecordClass(), getSchema().toString(), getDatumReaderFactory()),
AvroCoder.of(getRecordClass(), getSchema())));
return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read;
}
Expand Down Expand Up @@ -913,12 +929,15 @@ private static class CreateSourceFn<T>
implements SerializableFunction<String, FileBasedSource<T>> {
private final Class<T> recordClass;
private final Supplier<Schema> schemaSupplier;
private final AvroSource.DatumReaderFactory<T> readerFactory;

CreateSourceFn(Class<T> recordClass, String jsonSchema) {
CreateSourceFn(
Class<T> recordClass, String jsonSchema, AvroSource.DatumReaderFactory<T> readerFactory) {
this.recordClass = recordClass;
this.schemaSupplier =
Suppliers.memoize(
Suppliers.compose(new JsonToSchema(), Suppliers.ofInstance(jsonSchema)));
this.readerFactory = readerFactory;
}

@Override
Expand All @@ -927,7 +946,8 @@ public FileBasedSource<T> apply(String input) {
StaticValueProvider.of(input),
EmptyMatchTreatment.DISALLOW,
recordClass,
schemaSupplier.get());
schemaSupplier.get(),
readerFactory);
}

private static class JsonToSchema implements Function<String, Schema>, Serializable {
Expand Down
86 changes: 67 additions & 19 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ public class AvroSource<T> extends BlockBasedSource<T> {
// The default sync interval is 64k.
private static final long DEFAULT_MIN_BUNDLE_SIZE = 2L * DataFileConstants.DEFAULT_SYNC_INTERVAL;

@FunctionalInterface
public interface DatumReaderFactory<T> extends Serializable {
DatumReader<T> apply(Schema writer, Schema reader);
}

private static final DatumReaderFactory<?> GENERIC_DATUM_READER_FACTORY = GenericDatumReader::new;

private static final DatumReaderFactory<?> REFLECT_DATUM_READER_FACTORY = ReflectDatumReader::new;

// Use cases of AvroSource are:
// 1) AvroSource<GenericRecord> Reading GenericRecord records with a specified schema.
// 2) AvroSource<Foo> Reading records of a generated Avro class Foo.
Expand All @@ -147,6 +156,7 @@ public class AvroSource<T> extends BlockBasedSource<T> {
// readerSchemaString | non-null | non-null | null |
// parseFn | null | null | non-null |
// outputCoder | null | null | non-null |
// readerFactory | either | either | either |
private static class Mode<T> implements Serializable {
private final Class<?> type;

Expand All @@ -157,15 +167,19 @@ private static class Mode<T> implements Serializable {

@Nullable private final Coder<T> outputCoder;

@Nullable private final DatumReaderFactory<?> readerFactory;

private Mode(
Class<?> type,
@Nullable String readerSchemaString,
@Nullable SerializableFunction<GenericRecord, T> parseFn,
@Nullable Coder<T> outputCoder) {
@Nullable Coder<T> outputCoder,
@Nullable DatumReaderFactory<?> readerFactory) {
this.type = type;
this.readerSchemaString = internSchemaString(readerSchemaString);
this.parseFn = parseFn;
this.outputCoder = outputCoder;
this.readerFactory = readerFactory;
}

private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException {
Expand All @@ -188,19 +202,38 @@ private void validate() {
"schema must be specified using withSchema() when not using a parse fn");
}
}

private Mode<T> withReaderFactory(DatumReaderFactory<?> factory) {
return new Mode<>(type, readerSchemaString, parseFn, outputCoder, factory);
}

private DatumReader<?> createReader(Schema writerSchema, Schema readerSchema) {
DatumReaderFactory<?> factory = this.readerFactory;
if (factory == null) {
factory =
(type == GenericRecord.class)
? GENERIC_DATUM_READER_FACTORY
: REFLECT_DATUM_READER_FACTORY;
}
return factory.apply(writerSchema, readerSchema);
}
}

private static Mode<GenericRecord> readGenericRecordsWithSchema(String schema) {
return new Mode<>(GenericRecord.class, schema, null, null);
private static Mode<GenericRecord> readGenericRecordsWithSchema(
String schema, @Nullable DatumReaderFactory<?> factory) {
return new Mode<>(GenericRecord.class, schema, null, null, factory);
}

private static <T> Mode<T> readGeneratedClasses(Class<T> clazz) {
return new Mode<>(clazz, ReflectData.get().getSchema(clazz).toString(), null, null);
private static <T> Mode<T> readGeneratedClasses(
Class<T> clazz, @Nullable DatumReaderFactory<?> factory) {
return new Mode<>(clazz, ReflectData.get().getSchema(clazz).toString(), null, null, factory);
}

private static <T> Mode<T> parseGenericRecords(
SerializableFunction<GenericRecord, T> parseFn, Coder<T> outputCoder) {
return new Mode<>(GenericRecord.class, null, parseFn, outputCoder);
SerializableFunction<GenericRecord, T> parseFn,
Coder<T> outputCoder,
@Nullable DatumReaderFactory<?> factory) {
return new Mode<>(GenericRecord.class, null, parseFn, outputCoder, factory);
}

private final Mode<T> mode;
Expand All @@ -214,7 +247,7 @@ public static AvroSource<GenericRecord> from(ValueProvider<String> fileNameOrPat
fileNameOrPattern,
EmptyMatchTreatment.DISALLOW,
DEFAULT_MIN_BUNDLE_SIZE,
readGenericRecordsWithSchema(null /* will need to be specified in withSchema */));
readGenericRecordsWithSchema(null /* will need to be specified in withSchema */, null));
}

public static AvroSource<GenericRecord> from(Metadata metadata) {
Expand All @@ -223,7 +256,7 @@ public static AvroSource<GenericRecord> from(Metadata metadata) {
DEFAULT_MIN_BUNDLE_SIZE,
0,
metadata.sizeBytes(),
readGenericRecordsWithSchema(null /* will need to be specified in withSchema */));
readGenericRecordsWithSchema(null /* will need to be specified in withSchema */, null));
}

/** Like {@link #from(ValueProvider)}. */
Expand All @@ -243,7 +276,7 @@ public AvroSource<GenericRecord> withSchema(String schema) {
getFileOrPatternSpecProvider(),
getEmptyMatchTreatment(),
getMinBundleSize(),
readGenericRecordsWithSchema(schema));
readGenericRecordsWithSchema(schema, mode.readerFactory));
}

/** Like {@link #withSchema(String)}. */
Expand All @@ -261,13 +294,13 @@ public <X> AvroSource<X> withSchema(Class<X> clazz) {
getMinBundleSize(),
getStartOffset(),
getEndOffset(),
readGeneratedClasses(clazz));
readGeneratedClasses(clazz, mode.readerFactory));
}
return new AvroSource<>(
getFileOrPatternSpecProvider(),
getEmptyMatchTreatment(),
getMinBundleSize(),
readGeneratedClasses(clazz));
readGeneratedClasses(clazz, mode.readerFactory));
}

/**
Expand All @@ -284,13 +317,13 @@ public <X> AvroSource<X> withParseFn(
getMinBundleSize(),
getStartOffset(),
getEndOffset(),
parseGenericRecords(parseFn, coder));
parseGenericRecords(parseFn, coder, mode.readerFactory));
}
return new AvroSource<>(
getFileOrPatternSpecProvider(),
getEmptyMatchTreatment(),
getMinBundleSize(),
parseGenericRecords(parseFn, coder));
parseGenericRecords(parseFn, coder, mode.readerFactory));
}

/**
Expand All @@ -306,6 +339,16 @@ public AvroSource<T> withMinBundleSize(long minBundleSize) {
getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), minBundleSize, mode);
}

public AvroSource<T> withDatumReaderFactory(DatumReaderFactory<?> factory) {
Mode<T> newMode = mode.withReaderFactory(factory);
if (getMode() == SINGLE_FILE_OR_SUBRANGE) {
return new AvroSource<>(
getSingleFileMetadata(), getMinBundleSize(), getStartOffset(), getEndOffset(), newMode);
}
return new AvroSource<>(
getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), newMode);
}

/** Constructor for FILEPATTERN mode. */
private AvroSource(
ValueProvider<String> fileNameOrPattern,
Expand Down Expand Up @@ -576,11 +619,16 @@ private static InputStream decodeAsInputStream(byte[] data, String codec) throws
Schema readerSchema =
internOrParseSchemaString(
MoreObjects.firstNonNull(mode.readerSchemaString, writerSchemaString));
this.reader =
(mode.type == GenericRecord.class)
? new GenericDatumReader<T>(writerSchema, readerSchema)
: new ReflectDatumReader<T>(writerSchema, readerSchema);
this.decoder = DecoderFactory.get().binaryDecoder(decodeAsInputStream(data, codec), null);

this.reader = mode.createReader(writerSchema, readerSchema);

if (codec.equals(DataFileConstants.NULL_CODEC)) {
// Avro can read from a byte[] using a more efficient implementation. If the input is not
// compressed, pass the data in directly.
this.decoder = DecoderFactory.get().binaryDecoder(data, null);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, avoids the input stream wrapper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

} else {
this.decoder = DecoderFactory.get().binaryDecoder(decodeAsInputStream(data, codec), null);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Random;
import java.util.stream.Collectors;
import org.apache.avro.Schema;
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileConstants;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.Decoder;
import org.apache.avro.reflect.AvroDefault;
import org.apache.avro.reflect.Nullable;
import org.apache.avro.reflect.ReflectData;
Expand Down Expand Up @@ -535,6 +538,48 @@ public void testParseFn() throws Exception {
assertThat(actual, containsInAnyOrder(expected.toArray()));
}

@Test
public void testDatumReaderFactoryWithGenericRecord() throws Exception {
List<Bird> inputBirds = createRandomRecords(100);

String filename =
generateTestFile(
"tmp.avro",
inputBirds,
SyncBehavior.SYNC_DEFAULT,
0,
AvroCoder.of(Bird.class),
DataFileConstants.NULL_CODEC);

AvroSource.DatumReaderFactory<GenericRecord> factory =
(writer, reader) ->
new GenericDatumReader<GenericRecord>(writer, reader) {
@Override
protected Object readString(Object old, Decoder in) throws IOException {
return super.readString(old, in) + "_custom";
}
};

AvroSource<Bird> source =
AvroSource.from(filename)
.withParseFn(
input ->
new Bird(
(long) input.get("number"),
input.get("species").toString(),
input.get("quality").toString(),
(long) input.get("quantity")),
AvroCoder.of(Bird.class))
.withDatumReaderFactory(factory);
List<Bird> actual = SourceTestUtils.readFromSource(source, null);
List<Bird> expected =
inputBirds.stream()
.map(b -> new Bird(b.number, b.species + "_custom", b.quality + "_custom", b.quantity))
.collect(Collectors.toList());

assertThat(actual, containsInAnyOrder(expected.toArray()));
}

private void assertEqualsWithGeneric(List<Bird> expected, List<GenericRecord> actual) {
assertEquals(expected.size(), actual.size());
for (int i = 0; i < expected.size(); i++) {
Expand Down