From d2ca19961dc3b4182dc18e3bc160342bdeff2a5a Mon Sep 17 00:00:00 2001 From: steve Date: Fri, 1 May 2020 11:00:35 -0400 Subject: [PATCH] Allow users of AvroIO to specify a custom DatumReader implementation --- .../java/org/apache/beam/sdk/io/AvroIO.java | 30 +++++-- .../org/apache/beam/sdk/io/AvroSource.java | 86 +++++++++++++++---- .../apache/beam/sdk/io/AvroSourceTest.java | 45 ++++++++++ 3 files changed, 137 insertions(+), 24 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java index da324e044f78..f98f1feb1ee6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java @@ -697,7 +697,8 @@ public PCollection expand(PBegin input) { getFilepattern(), getMatchConfiguration().getEmptyMatchTreatment(), getRecordClass(), - getSchema()))); + getSchema(), + null))); return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read; } @@ -734,9 +735,14 @@ private static AvroSource createSource( ValueProvider filepattern, EmptyMatchTreatment emptyMatchTreatment, Class recordClass, - Schema schema) { + Schema schema, + @Nullable AvroSource.DatumReaderFactory readerFactory) { AvroSource source = AvroSource.from(filepattern).withEmptyMatchTreatment(emptyMatchTreatment); + + if (readerFactory != null) { + source = source.withDatumReaderFactory(readerFactory); + } return recordClass == GenericRecord.class ? (AvroSource) source.withSchema(schema) : source.withSchema(recordClass); @@ -759,6 +765,9 @@ public abstract static class ReadFiles abstract boolean getInferBeamSchema(); + @Nullable + abstract AvroSource.DatumReaderFactory getDatumReaderFactory(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -771,6 +780,8 @@ abstract static class Builder { abstract Builder setInferBeamSchema(boolean infer); + abstract Builder setDatumReaderFactory(AvroSource.DatumReaderFactory factory); + abstract ReadFiles build(); } @@ -788,6 +799,10 @@ public ReadFiles withBeamSchemas(boolean withBeamSchemas) { return toBuilder().setInferBeamSchema(withBeamSchemas).build(); } + public ReadFiles withDatumReaderFactory(AvroSource.DatumReaderFactory factory) { + return toBuilder().setDatumReaderFactory(factory).build(); + } + @Override public PCollection expand(PCollection input) { checkNotNull(getSchema(), "schema"); @@ -796,7 +811,8 @@ public PCollection expand(PCollection input) { "Read all via FileBasedSource", new ReadAllViaFileBasedSource<>( getDesiredBundleSizeBytes(), - new CreateSourceFn<>(getRecordClass(), getSchema().toString()), + new CreateSourceFn<>( + getRecordClass(), getSchema().toString(), getDatumReaderFactory()), AvroCoder.of(getRecordClass(), getSchema()))); return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read; } @@ -913,12 +929,15 @@ private static class CreateSourceFn implements SerializableFunction> { private final Class recordClass; private final Supplier schemaSupplier; + private final AvroSource.DatumReaderFactory readerFactory; - CreateSourceFn(Class recordClass, String jsonSchema) { + CreateSourceFn( + Class recordClass, String jsonSchema, AvroSource.DatumReaderFactory readerFactory) { this.recordClass = recordClass; this.schemaSupplier = Suppliers.memoize( Suppliers.compose(new JsonToSchema(), Suppliers.ofInstance(jsonSchema))); + this.readerFactory = readerFactory; } @Override @@ -927,7 +946,8 @@ public FileBasedSource apply(String input) { StaticValueProvider.of(input), EmptyMatchTreatment.DISALLOW, recordClass, - schemaSupplier.get()); + schemaSupplier.get(), + readerFactory); } private static class JsonToSchema implements Function, Serializable { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java index 3ec0567af91e..2243d6573ade 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java @@ -137,6 +137,15 @@ public class AvroSource extends BlockBasedSource { // The default sync interval is 64k. private static final long DEFAULT_MIN_BUNDLE_SIZE = 2L * DataFileConstants.DEFAULT_SYNC_INTERVAL; + @FunctionalInterface + public interface DatumReaderFactory extends Serializable { + DatumReader apply(Schema writer, Schema reader); + } + + private static final DatumReaderFactory GENERIC_DATUM_READER_FACTORY = GenericDatumReader::new; + + private static final DatumReaderFactory REFLECT_DATUM_READER_FACTORY = ReflectDatumReader::new; + // Use cases of AvroSource are: // 1) AvroSource Reading GenericRecord records with a specified schema. // 2) AvroSource Reading records of a generated Avro class Foo. @@ -147,6 +156,7 @@ public class AvroSource extends BlockBasedSource { // readerSchemaString | non-null | non-null | null | // parseFn | null | null | non-null | // outputCoder | null | null | non-null | + // readerFactory | either | either | either | private static class Mode implements Serializable { private final Class type; @@ -157,15 +167,19 @@ private static class Mode implements Serializable { @Nullable private final Coder outputCoder; + @Nullable private final DatumReaderFactory readerFactory; + private Mode( Class type, @Nullable String readerSchemaString, @Nullable SerializableFunction parseFn, - @Nullable Coder outputCoder) { + @Nullable Coder outputCoder, + @Nullable DatumReaderFactory readerFactory) { this.type = type; this.readerSchemaString = internSchemaString(readerSchemaString); this.parseFn = parseFn; this.outputCoder = outputCoder; + this.readerFactory = readerFactory; } private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException { @@ -188,19 +202,38 @@ private void validate() { "schema must be specified using withSchema() when not using a parse fn"); } } + + private Mode withReaderFactory(DatumReaderFactory factory) { + return new Mode<>(type, readerSchemaString, parseFn, outputCoder, factory); + } + + private DatumReader createReader(Schema writerSchema, Schema readerSchema) { + DatumReaderFactory factory = this.readerFactory; + if (factory == null) { + factory = + (type == GenericRecord.class) + ? GENERIC_DATUM_READER_FACTORY + : REFLECT_DATUM_READER_FACTORY; + } + return factory.apply(writerSchema, readerSchema); + } } - private static Mode readGenericRecordsWithSchema(String schema) { - return new Mode<>(GenericRecord.class, schema, null, null); + private static Mode readGenericRecordsWithSchema( + String schema, @Nullable DatumReaderFactory factory) { + return new Mode<>(GenericRecord.class, schema, null, null, factory); } - private static Mode readGeneratedClasses(Class clazz) { - return new Mode<>(clazz, ReflectData.get().getSchema(clazz).toString(), null, null); + private static Mode readGeneratedClasses( + Class clazz, @Nullable DatumReaderFactory factory) { + return new Mode<>(clazz, ReflectData.get().getSchema(clazz).toString(), null, null, factory); } private static Mode parseGenericRecords( - SerializableFunction parseFn, Coder outputCoder) { - return new Mode<>(GenericRecord.class, null, parseFn, outputCoder); + SerializableFunction parseFn, + Coder outputCoder, + @Nullable DatumReaderFactory factory) { + return new Mode<>(GenericRecord.class, null, parseFn, outputCoder, factory); } private final Mode mode; @@ -214,7 +247,7 @@ public static AvroSource from(ValueProvider fileNameOrPat fileNameOrPattern, EmptyMatchTreatment.DISALLOW, DEFAULT_MIN_BUNDLE_SIZE, - readGenericRecordsWithSchema(null /* will need to be specified in withSchema */)); + readGenericRecordsWithSchema(null /* will need to be specified in withSchema */, null)); } public static AvroSource from(Metadata metadata) { @@ -223,7 +256,7 @@ public static AvroSource from(Metadata metadata) { DEFAULT_MIN_BUNDLE_SIZE, 0, metadata.sizeBytes(), - readGenericRecordsWithSchema(null /* will need to be specified in withSchema */)); + readGenericRecordsWithSchema(null /* will need to be specified in withSchema */, null)); } /** Like {@link #from(ValueProvider)}. */ @@ -243,7 +276,7 @@ public AvroSource withSchema(String schema) { getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), - readGenericRecordsWithSchema(schema)); + readGenericRecordsWithSchema(schema, mode.readerFactory)); } /** Like {@link #withSchema(String)}. */ @@ -261,13 +294,13 @@ public AvroSource withSchema(Class clazz) { getMinBundleSize(), getStartOffset(), getEndOffset(), - readGeneratedClasses(clazz)); + readGeneratedClasses(clazz, mode.readerFactory)); } return new AvroSource<>( getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), - readGeneratedClasses(clazz)); + readGeneratedClasses(clazz, mode.readerFactory)); } /** @@ -284,13 +317,13 @@ public AvroSource withParseFn( getMinBundleSize(), getStartOffset(), getEndOffset(), - parseGenericRecords(parseFn, coder)); + parseGenericRecords(parseFn, coder, mode.readerFactory)); } return new AvroSource<>( getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), - parseGenericRecords(parseFn, coder)); + parseGenericRecords(parseFn, coder, mode.readerFactory)); } /** @@ -306,6 +339,16 @@ public AvroSource withMinBundleSize(long minBundleSize) { getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), minBundleSize, mode); } + public AvroSource withDatumReaderFactory(DatumReaderFactory factory) { + Mode newMode = mode.withReaderFactory(factory); + if (getMode() == SINGLE_FILE_OR_SUBRANGE) { + return new AvroSource<>( + getSingleFileMetadata(), getMinBundleSize(), getStartOffset(), getEndOffset(), newMode); + } + return new AvroSource<>( + getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), newMode); + } + /** Constructor for FILEPATTERN mode. */ private AvroSource( ValueProvider fileNameOrPattern, @@ -576,11 +619,16 @@ private static InputStream decodeAsInputStream(byte[] data, String codec) throws Schema readerSchema = internOrParseSchemaString( MoreObjects.firstNonNull(mode.readerSchemaString, writerSchemaString)); - this.reader = - (mode.type == GenericRecord.class) - ? new GenericDatumReader(writerSchema, readerSchema) - : new ReflectDatumReader(writerSchema, readerSchema); - this.decoder = DecoderFactory.get().binaryDecoder(decodeAsInputStream(data, codec), null); + + this.reader = mode.createReader(writerSchema, readerSchema); + + if (codec.equals(DataFileConstants.NULL_CODEC)) { + // Avro can read from a byte[] using a more efficient implementation. If the input is not + // compressed, pass the data in directly. + this.decoder = DecoderFactory.get().binaryDecoder(data, null); + } else { + this.decoder = DecoderFactory.get().binaryDecoder(decodeAsInputStream(data, codec), null); + } } @Override diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java index 398dc65d4924..8503870d9328 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java @@ -36,13 +36,16 @@ import java.util.NoSuchElementException; import java.util.Objects; import java.util.Random; +import java.util.stream.Collectors; import org.apache.avro.Schema; import org.apache.avro.file.CodecFactory; import org.apache.avro.file.DataFileConstants; import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.Decoder; import org.apache.avro.reflect.AvroDefault; import org.apache.avro.reflect.Nullable; import org.apache.avro.reflect.ReflectData; @@ -535,6 +538,48 @@ public void testParseFn() throws Exception { assertThat(actual, containsInAnyOrder(expected.toArray())); } + @Test + public void testDatumReaderFactoryWithGenericRecord() throws Exception { + List inputBirds = createRandomRecords(100); + + String filename = + generateTestFile( + "tmp.avro", + inputBirds, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + + AvroSource.DatumReaderFactory factory = + (writer, reader) -> + new GenericDatumReader(writer, reader) { + @Override + protected Object readString(Object old, Decoder in) throws IOException { + return super.readString(old, in) + "_custom"; + } + }; + + AvroSource source = + AvroSource.from(filename) + .withParseFn( + input -> + new Bird( + (long) input.get("number"), + input.get("species").toString(), + input.get("quality").toString(), + (long) input.get("quantity")), + AvroCoder.of(Bird.class)) + .withDatumReaderFactory(factory); + List actual = SourceTestUtils.readFromSource(source, null); + List expected = + inputBirds.stream() + .map(b -> new Bird(b.number, b.species + "_custom", b.quality + "_custom", b.quantity)) + .collect(Collectors.toList()); + + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + private void assertEqualsWithGeneric(List expected, List actual) { assertEquals(expected.size(), actual.size()); for (int i = 0; i < expected.size(); i++) {