Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,14 @@ public PCollection<GenericRecord> expand(PBegin input) {
.withSplit()
.withBeamSchemas(getInferBeamSchema())
.withAvroDataModel(getAvroDataModel())
.withProjection(getProjectionSchema(), getEncoderSchema()));
.withProjection(getProjectionSchema(), getEncoderSchema())
.withConfiguration(getConfiguration()));
}
return inputFiles.apply(
readFiles(getSchema())
.withBeamSchemas(getInferBeamSchema())
.withAvroDataModel(getAvroDataModel()));
.withAvroDataModel(getAvroDataModel())
.withConfiguration(getConfiguration()));
}

@Override
Expand All @@ -428,6 +430,8 @@ public abstract static class Parse<T> extends PTransform<PBegin, PCollection<T>>

abstract SerializableFunction<GenericRecord, T> getParseFn();

abstract @Nullable Coder<T> getCoder();

abstract @Nullable SerializableConfiguration getConfiguration();

abstract boolean isSplittable();
Expand All @@ -440,6 +444,8 @@ abstract static class Builder<T> {

abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T> parseFn);

abstract Builder<T> setCoder(Coder<T> coder);

abstract Builder<T> setConfiguration(SerializableConfiguration configuration);

abstract Builder<T> setSplittable(boolean splittable);
Expand All @@ -455,6 +461,11 @@ public Parse<T> from(String inputFiles) {
return from(ValueProvider.StaticValueProvider.of(inputFiles));
}

/** Specify the output coder to use for output of the {@code ParseFn}. */
public Parse<T> withCoder(Coder<T> coder) {
return (coder == null) ? this : toBuilder().setCoder(coder).build();
}

/** Specify Hadoop configuration for ParquetReader. */
public Parse<T> withConfiguration(Map<String, String> configuration) {
return toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
Expand All @@ -474,6 +485,7 @@ public PCollection<T> expand(PBegin input) {
.apply(
parseFilesGenericRecords(getParseFn())
.toBuilder()
.setCoder(getCoder())
.setSplittable(isSplittable())
.build());
}
Expand All @@ -486,6 +498,8 @@ public abstract static class ParseFiles<T>

abstract SerializableFunction<GenericRecord, T> getParseFn();

abstract @Nullable Coder<T> getCoder();

abstract @Nullable SerializableConfiguration getConfiguration();

abstract boolean isSplittable();
Expand All @@ -496,13 +510,20 @@ public abstract static class ParseFiles<T>
abstract static class Builder<T> {
abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T> parseFn);

abstract Builder<T> setCoder(Coder<T> coder);

abstract Builder<T> setConfiguration(SerializableConfiguration configuration);

abstract Builder<T> setSplittable(boolean split);

abstract ParseFiles<T> build();
}

/** Specify the output coder to use for output of the {@code ParseFn}. */
public ParseFiles<T> withCoder(Coder<T> coder) {
return (coder == null) ? this : toBuilder().setCoder(coder).build();
}

/** Specify Hadoop configuration for ParquetReader. */
public ParseFiles<T> withConfiguration(Map<String, String> configuration) {
return toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
Expand Down Expand Up @@ -537,7 +558,7 @@ private boolean isGenericRecordOutput() {
/**
* Identifies the {@code Coder} to be used for the output PCollection.
*
* <p>Returns {@link AvroCoder} if expected output is {@link GenericRecord}.
* <p>throws an exception if expected output is of type {@link GenericRecord}.
*
* @param coderRegistry the {@link org.apache.beam.sdk.Pipeline}'s CoderRegistry to identify
* Coder for expected output type of {@link #getParseFn()}
Expand All @@ -547,12 +568,17 @@ private Coder<T> inferCoder(CoderRegistry coderRegistry) {
throw new IllegalArgumentException("Parse can't be used for reading as GenericRecord.");
}

// Use explicitly provided coder
if (getCoder() != null) {
return getCoder();
}

// If not GenericRecord infer it from ParseFn.
try {
return coderRegistry.getCoder(TypeDescriptors.outputOf(getParseFn()));
} catch (CannotProvideCoderException e) {
throw new IllegalArgumentException(
"Unable to infer coder for output of parseFn. Specify it explicitly using withCoder().",
"Unable to infer coder for output of parseFn. Specify it explicitly using .withCoder().",
e);
}
}
Expand Down Expand Up @@ -618,6 +644,10 @@ public ReadFiles withConfiguration(Map<String, String> configuration) {
return toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
}

public ReadFiles withConfiguration(SerializableConfiguration configuration) {
return toBuilder().setConfiguration(configuration).build();
}

@Experimental(Kind.SCHEMAS)
public ReadFiles withBeamSchemas(boolean inferBeamSchema) {
return toBuilder().setInferBeamSchema(inferBeamSchema).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.parquet.ParquetIO.GenericRecordPassthroughFn;
import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.utils.AvroUtils;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.parquet.hadoop.metadata.BlockMetaData;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -296,6 +299,34 @@ public void testReadFilesAsJsonForUnknownSchemaFiles() {
mainPipeline.run().waitUntilFinish();
}

@Test
public void testReadFilesAsRowForUnknownSchemaFiles() {
List<GenericRecord> records = generateGenericRecords(1000);
List<Row> expectedRows =
records.stream().map(record -> AvroUtils.toBeamRowStrict(record, null)).collect(toList());

PCollection<Row> writeThenRead =
mainPipeline
.apply(Create.of(records).withCoder(AvroCoder.of(SCHEMA)))
.apply(
FileIO.<GenericRecord>write()
.via(ParquetIO.sink(SCHEMA))
.to(temporaryFolder.getRoot().getAbsolutePath()))
.getPerDestinationOutputFilenames()
.apply(Values.create())
.apply(FileIO.matchAll())
.apply(FileIO.readMatches())
.apply(
ParquetIO.parseFilesGenericRecords(
(SerializableFunction<GenericRecord, Row>)
record -> AvroUtils.toBeamRowStrict(record, null))
.withCoder(SchemaCoder.of(AvroUtils.toBeamSchema(SCHEMA))));

PAssert.that(writeThenRead).containsInAnyOrder(expectedRows);

mainPipeline.run().waitUntilFinish();
}

@Test
@SuppressWarnings({"nullable", "ConstantConditions"} /* forced check. */)
public void testReadFilesUnknownSchemaFilesForGenericRecordThrowException() {
Expand Down