diff --git a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java index 6adbc59db3d6..5c2a19dd5411 100644 --- a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java +++ b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java @@ -405,12 +405,14 @@ public PCollection expand(PBegin input) { .withSplit() .withBeamSchemas(getInferBeamSchema()) .withAvroDataModel(getAvroDataModel()) - .withProjection(getProjectionSchema(), getEncoderSchema())); + .withProjection(getProjectionSchema(), getEncoderSchema()) + .withConfiguration(getConfiguration())); } return inputFiles.apply( readFiles(getSchema()) .withBeamSchemas(getInferBeamSchema()) - .withAvroDataModel(getAvroDataModel())); + .withAvroDataModel(getAvroDataModel()) + .withConfiguration(getConfiguration())); } @Override @@ -428,6 +430,8 @@ public abstract static class Parse extends PTransform> abstract SerializableFunction getParseFn(); + abstract @Nullable Coder getCoder(); + abstract @Nullable SerializableConfiguration getConfiguration(); abstract boolean isSplittable(); @@ -440,6 +444,8 @@ abstract static class Builder { abstract Builder setParseFn(SerializableFunction parseFn); + abstract Builder setCoder(Coder coder); + abstract Builder setConfiguration(SerializableConfiguration configuration); abstract Builder setSplittable(boolean splittable); @@ -455,6 +461,11 @@ public Parse from(String inputFiles) { return from(ValueProvider.StaticValueProvider.of(inputFiles)); } + /** Specify the output coder to use for output of the {@code ParseFn}. */ + public Parse withCoder(Coder coder) { + return (coder == null) ? this : toBuilder().setCoder(coder).build(); + } + /** Specify Hadoop configuration for ParquetReader. */ public Parse withConfiguration(Map configuration) { return toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build(); @@ -474,6 +485,7 @@ public PCollection expand(PBegin input) { .apply( parseFilesGenericRecords(getParseFn()) .toBuilder() + .setCoder(getCoder()) .setSplittable(isSplittable()) .build()); } @@ -486,6 +498,8 @@ public abstract static class ParseFiles abstract SerializableFunction getParseFn(); + abstract @Nullable Coder getCoder(); + abstract @Nullable SerializableConfiguration getConfiguration(); abstract boolean isSplittable(); @@ -496,6 +510,8 @@ public abstract static class ParseFiles abstract static class Builder { abstract Builder setParseFn(SerializableFunction parseFn); + abstract Builder setCoder(Coder coder); + abstract Builder setConfiguration(SerializableConfiguration configuration); abstract Builder setSplittable(boolean split); @@ -503,6 +519,11 @@ abstract static class Builder { abstract ParseFiles build(); } + /** Specify the output coder to use for output of the {@code ParseFn}. */ + public ParseFiles withCoder(Coder coder) { + return (coder == null) ? this : toBuilder().setCoder(coder).build(); + } + /** Specify Hadoop configuration for ParquetReader. */ public ParseFiles withConfiguration(Map configuration) { return toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build(); @@ -537,7 +558,7 @@ private boolean isGenericRecordOutput() { /** * Identifies the {@code Coder} to be used for the output PCollection. * - *

Returns {@link AvroCoder} if expected output is {@link GenericRecord}. + *

throws an exception if expected output is of type {@link GenericRecord}. * * @param coderRegistry the {@link org.apache.beam.sdk.Pipeline}'s CoderRegistry to identify * Coder for expected output type of {@link #getParseFn()} @@ -547,12 +568,17 @@ private Coder inferCoder(CoderRegistry coderRegistry) { throw new IllegalArgumentException("Parse can't be used for reading as GenericRecord."); } + // Use explicitly provided coder + if (getCoder() != null) { + return getCoder(); + } + // If not GenericRecord infer it from ParseFn. try { return coderRegistry.getCoder(TypeDescriptors.outputOf(getParseFn())); } catch (CannotProvideCoderException e) { throw new IllegalArgumentException( - "Unable to infer coder for output of parseFn. Specify it explicitly using withCoder().", + "Unable to infer coder for output of parseFn. Specify it explicitly using .withCoder().", e); } } @@ -618,6 +644,10 @@ public ReadFiles withConfiguration(Map configuration) { return toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build(); } + public ReadFiles withConfiguration(SerializableConfiguration configuration) { + return toBuilder().setConfiguration(configuration).build(); + } + @Experimental(Kind.SCHEMAS) public ReadFiles withBeamSchemas(boolean inferBeamSchema) { return toBuilder().setInferBeamSchema(inferBeamSchema).build(); diff --git a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java index f3406dfa8b57..301d1022b6ab 100644 --- a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java +++ b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java @@ -43,6 +43,8 @@ import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.parquet.ParquetIO.GenericRecordPassthroughFn; import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.utils.AvroUtils; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -50,6 +52,7 @@ import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.junit.Rule; import org.junit.Test; @@ -296,6 +299,34 @@ public void testReadFilesAsJsonForUnknownSchemaFiles() { mainPipeline.run().waitUntilFinish(); } + @Test + public void testReadFilesAsRowForUnknownSchemaFiles() { + List records = generateGenericRecords(1000); + List expectedRows = + records.stream().map(record -> AvroUtils.toBeamRowStrict(record, null)).collect(toList()); + + PCollection writeThenRead = + mainPipeline + .apply(Create.of(records).withCoder(AvroCoder.of(SCHEMA))) + .apply( + FileIO.write() + .via(ParquetIO.sink(SCHEMA)) + .to(temporaryFolder.getRoot().getAbsolutePath())) + .getPerDestinationOutputFilenames() + .apply(Values.create()) + .apply(FileIO.matchAll()) + .apply(FileIO.readMatches()) + .apply( + ParquetIO.parseFilesGenericRecords( + (SerializableFunction) + record -> AvroUtils.toBeamRowStrict(record, null)) + .withCoder(SchemaCoder.of(AvroUtils.toBeamSchema(SCHEMA)))); + + PAssert.that(writeThenRead).containsInAnyOrder(expectedRows); + + mainPipeline.run().waitUntilFinish(); + } + @Test @SuppressWarnings({"nullable", "ConstantConditions"} /* forced check. */) public void testReadFilesUnknownSchemaFilesForGenericRecordThrowException() {