diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java index 2258a9136f..7db558dc5a 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java @@ -24,6 +24,7 @@ import com.google.cloud.dataflow.sdk.coders.AvroCoder; import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO.Write.Bound; import com.google.cloud.dataflow.sdk.runners.DirectPipeline; import com.google.cloud.dataflow.sdk.testing.DataflowAssert; import com.google.cloud.dataflow.sdk.testing.TestPipeline; @@ -44,6 +45,7 @@ import org.junit.runners.JUnit4; import java.io.File; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -197,30 +199,60 @@ public void testAvroIOWriteAndReadSchemaUpgrade() throws Throwable { } @SuppressWarnings("deprecation") // using AvroCoder#createDatumReader for tests. - @Test - public void testAvroSinkWrite() throws Exception { - String outputFilePrefix = new File(tmpFolder.getRoot(), "prefix").getAbsolutePath(); - String[] expectedElements = new String[] {"first", "second", "third"}; - + private void runTestWrite(String[] expectedElements, int numShards) throws IOException { + File baseOutputFile = new File(tmpFolder.getRoot(), "prefix"); + String outputFilePrefix = baseOutputFile.getAbsolutePath(); TestPipeline p = TestPipeline.create(); - p.apply(Create.of(expectedElements)) - .apply(AvroIO.Write.to(outputFilePrefix).withSchema(String.class)); + Bound write = AvroIO.Write.to(outputFilePrefix).withSchema(String.class); + if (numShards > 1) { + write = write.withNumShards(numShards); + } else { + write = write.withoutSharding(); + } + p.apply(Create.of(expectedElements)).apply(write); p.run(); // Validate that the data written matches the expected elements in the expected order - String expectedName = - IOChannelUtils.constructName( - outputFilePrefix, ShardNameTemplate.INDEX_OF_MAX, "" /* no suffix */, 0, 1); - File outputFile = new File(expectedName); - assertTrue("Expected output file " + expectedName, outputFile.exists()); - try (DataFileReader reader = - new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) { - List actualElements = new ArrayList<>(); - Iterators.addAll(actualElements, reader); - assertThat(actualElements, containsInAnyOrder(expectedElements)); + List expectedFiles = new ArrayList<>(); + if (numShards == 1) { + expectedFiles.add(baseOutputFile); + } else { + for (int i = 0; i < numShards; i++) { + expectedFiles.add( + new File( + IOChannelUtils.constructName( + outputFilePrefix, + write.getShardNameTemplate(), + "" /* no suffix */, + i, + numShards))); + } + } + + List actualElements = new ArrayList<>(); + for (File outputFile : expectedFiles) { + assertTrue("Expected output file " + outputFile.getName(), outputFile.exists()); + try (DataFileReader reader = + new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) { + Iterators.addAll(actualElements, reader); + } } + assertThat(actualElements, containsInAnyOrder(expectedElements)); } - // TODO: for Write only, test withSuffix, withNumShards, + @Test + public void testAvroSinkWrite() throws Exception { + String[] expectedElements = new String[] {"first", "second", "third"}; + + runTestWrite(expectedElements, 1); + } + + @Test + public void testAvroSinkShardedWrite() throws Exception { + String[] expectedElements = new String[] {"first", "second", "third", "fourth", "fifth"}; + + runTestWrite(expectedElements, 4); + } + // TODO: for Write only, test withSuffix, // withShardNameTemplate and withoutSharding. } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java index 0a8e381108..40d11d9866 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java @@ -42,6 +42,7 @@ import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.util.CoderUtils; import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; import com.google.cloud.dataflow.sdk.util.TestCredential; import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; import com.google.cloud.dataflow.sdk.values.PCollection; @@ -198,36 +199,57 @@ public void testReadNamed() throws Exception { } void runTestWrite(T[] elems, Coder coder) throws Exception { - File tmpFile = tmpFolder.newFile("file.txt"); - String filename = tmpFile.getPath(); + runTestWrite(elems, coder, 1); + } + + void runTestWrite(T[] elems, Coder coder, int numShards) throws Exception { + String filename = tmpFolder.newFile("file.txt").getPath(); Pipeline p = TestPipeline.create(); - PCollection input = - p.apply(Create.of(Arrays.asList(elems)).withCoder(coder)); + PCollection input = p.apply(Create.of(Arrays.asList(elems)).withCoder(coder)); TextIO.Write.Bound write; if (coder.equals(StringUtf8Coder.of())) { - TextIO.Write.Bound writeStrings = - TextIO.Write.to(filename).withoutSharding(); + TextIO.Write.Bound writeStrings = TextIO.Write.to(filename); // T==String write = (TextIO.Write.Bound) writeStrings; } else { - write = TextIO.Write.to(filename).withCoder(coder).withoutSharding(); + write = TextIO.Write.to(filename).withCoder(coder); + } + if (numShards == 1) { + write = write.withoutSharding(); + } else { + write = write.withNumShards(numShards).withShardNameTemplate(ShardNameTemplate.INDEX_OF_MAX); } input.apply(write); p.run(); + List expectedFiles = new ArrayList<>(); + if (numShards == 1) { + expectedFiles.add(new File(filename)); + } else { + for (int i = 0; i < numShards; i++) { + expectedFiles.add( + new File( + tmpFolder.getRoot(), + IOChannelUtils.constructName( + "file.txt", ShardNameTemplate.INDEX_OF_MAX, "", i, numShards))); + } + } + List actual = new ArrayList<>(); - try (BufferedReader reader = new BufferedReader(new FileReader(tmpFile))) { - for (;;) { - String line = reader.readLine(); - if (line == null) { - break; + for (File tmpFile : expectedFiles) { + try (BufferedReader reader = new BufferedReader(new FileReader(tmpFile))) { + for (;;) { + String line = reader.readLine(); + if (line == null) { + break; + } + actual.add(line); } - actual.add(line); } } @@ -239,8 +261,7 @@ void runTestWrite(T[] elems, Coder coder) throws Exception { expected[i] = line; } - assertThat(actual, - containsInAnyOrder(expected)); + assertThat(actual, containsInAnyOrder(expected)); } @Test @@ -284,6 +305,11 @@ public void testWriteNamed() { } } + @Test + public void testShardedWrite() throws Exception { + runTestWrite(LINES_ARRAY, StringUtf8Coder.of(), 5); + } + @Test public void testUnsupportedFilePattern() throws IOException { File outFolder = tmpFolder.newFolder();