Skip to content
This repository was archived by the owner on Nov 11, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import com.google.cloud.dataflow.sdk.coders.AvroCoder;
import com.google.cloud.dataflow.sdk.coders.DefaultCoder;
import com.google.cloud.dataflow.sdk.io.AvroIO.Write.Bound;
import com.google.cloud.dataflow.sdk.runners.DirectPipeline;
import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
import com.google.cloud.dataflow.sdk.testing.TestPipeline;
Expand All @@ -44,6 +45,7 @@
import org.junit.runners.JUnit4;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -197,30 +199,60 @@ public void testAvroIOWriteAndReadSchemaUpgrade() throws Throwable {
}

@SuppressWarnings("deprecation") // using AvroCoder#createDatumReader for tests.
@Test
public void testAvroSinkWrite() throws Exception {
String outputFilePrefix = new File(tmpFolder.getRoot(), "prefix").getAbsolutePath();
String[] expectedElements = new String[] {"first", "second", "third"};

private void runTestWrite(String[] expectedElements, int numShards) throws IOException {
File baseOutputFile = new File(tmpFolder.getRoot(), "prefix");
String outputFilePrefix = baseOutputFile.getAbsolutePath();
TestPipeline p = TestPipeline.create();
p.apply(Create.<String>of(expectedElements))
.apply(AvroIO.Write.to(outputFilePrefix).withSchema(String.class));
Bound<String> write = AvroIO.Write.to(outputFilePrefix).withSchema(String.class);
if (numShards > 1) {
write = write.withNumShards(numShards);
} else {
write = write.withoutSharding();
}
p.apply(Create.<String>of(expectedElements)).apply(write);
p.run();

// Validate that the data written matches the expected elements in the expected order
String expectedName =
IOChannelUtils.constructName(
outputFilePrefix, ShardNameTemplate.INDEX_OF_MAX, "" /* no suffix */, 0, 1);
File outputFile = new File(expectedName);
assertTrue("Expected output file " + expectedName, outputFile.exists());
try (DataFileReader<String> reader =
new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) {
List<String> actualElements = new ArrayList<>();
Iterators.addAll(actualElements, reader);
assertThat(actualElements, containsInAnyOrder(expectedElements));
List<File> expectedFiles = new ArrayList<>();
if (numShards == 1) {
expectedFiles.add(baseOutputFile);
} else {
for (int i = 0; i < numShards; i++) {
expectedFiles.add(
new File(
IOChannelUtils.constructName(
outputFilePrefix,
write.getShardNameTemplate(),
"" /* no suffix */,
i,
numShards)));
}
}

List<String> actualElements = new ArrayList<>();
for (File outputFile : expectedFiles) {
assertTrue("Expected output file " + outputFile.getName(), outputFile.exists());
try (DataFileReader<String> reader =
new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) {
Iterators.addAll(actualElements, reader);
}
}
assertThat(actualElements, containsInAnyOrder(expectedElements));
}

// TODO: for Write only, test withSuffix, withNumShards,
@Test
public void testAvroSinkWrite() throws Exception {
String[] expectedElements = new String[] {"first", "second", "third"};

runTestWrite(expectedElements, 1);
}

@Test
public void testAvroSinkShardedWrite() throws Exception {
String[] expectedElements = new String[] {"first", "second", "third", "fourth", "fifth"};

runTestWrite(expectedElements, 4);
}
// TODO: for Write only, test withSuffix,
// withShardNameTemplate and withoutSharding.
}
56 changes: 41 additions & 15 deletions sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.util.CoderUtils;
import com.google.cloud.dataflow.sdk.util.GcsUtil;
import com.google.cloud.dataflow.sdk.util.IOChannelUtils;
import com.google.cloud.dataflow.sdk.util.TestCredential;
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath;
import com.google.cloud.dataflow.sdk.values.PCollection;
Expand Down Expand Up @@ -198,36 +199,57 @@ public void testReadNamed() throws Exception {
}

<T> void runTestWrite(T[] elems, Coder<T> coder) throws Exception {
File tmpFile = tmpFolder.newFile("file.txt");
String filename = tmpFile.getPath();
runTestWrite(elems, coder, 1);
}

<T> void runTestWrite(T[] elems, Coder<T> coder, int numShards) throws Exception {
String filename = tmpFolder.newFile("file.txt").getPath();

Pipeline p = TestPipeline.create();

PCollection<T> input =
p.apply(Create.of(Arrays.asList(elems)).withCoder(coder));
PCollection<T> input = p.apply(Create.of(Arrays.asList(elems)).withCoder(coder));

TextIO.Write.Bound<T> write;
if (coder.equals(StringUtf8Coder.of())) {
TextIO.Write.Bound<String> writeStrings =
TextIO.Write.to(filename).withoutSharding();
TextIO.Write.Bound<String> writeStrings = TextIO.Write.to(filename);
// T==String
write = (TextIO.Write.Bound<T>) writeStrings;
} else {
write = TextIO.Write.to(filename).withCoder(coder).withoutSharding();
write = TextIO.Write.to(filename).withCoder(coder);
}
if (numShards == 1) {
write = write.withoutSharding();
} else {
write = write.withNumShards(numShards).withShardNameTemplate(ShardNameTemplate.INDEX_OF_MAX);
}

input.apply(write);

p.run();

List<File> expectedFiles = new ArrayList<>();
if (numShards == 1) {
expectedFiles.add(new File(filename));
} else {
for (int i = 0; i < numShards; i++) {
expectedFiles.add(
new File(
tmpFolder.getRoot(),
IOChannelUtils.constructName(
"file.txt", ShardNameTemplate.INDEX_OF_MAX, "", i, numShards)));
}
}

List<String> actual = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(tmpFile))) {
for (;;) {
String line = reader.readLine();
if (line == null) {
break;
for (File tmpFile : expectedFiles) {
try (BufferedReader reader = new BufferedReader(new FileReader(tmpFile))) {
for (;;) {
String line = reader.readLine();
if (line == null) {
break;
}
actual.add(line);
}
actual.add(line);
}
}

Expand All @@ -239,8 +261,7 @@ <T> void runTestWrite(T[] elems, Coder<T> coder) throws Exception {
expected[i] = line;
}

assertThat(actual,
containsInAnyOrder(expected));
assertThat(actual, containsInAnyOrder(expected));
}

@Test
Expand Down Expand Up @@ -284,6 +305,11 @@ public void testWriteNamed() {
}
}

@Test
public void testShardedWrite() throws Exception {
runTestWrite(LINES_ARRAY, StringUtf8Coder.of(), 5);
}

@Test
public void testUnsupportedFilePattern() throws IOException {
File outFolder = tmpFolder.newFolder();
Expand Down