diff --git a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java index 9f9e64fc4a2e..3cfe08ba4d11 100644 --- a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java +++ b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java @@ -150,7 +150,7 @@ public static void createTableWithStatement(DataSource dataSource, String stmt) public static ArrayList> getTestDataToWrite(long rowsToAdd) { ArrayList> data = new ArrayList<>(); for (int i = 0; i < rowsToAdd; i++) { - KV kv = KV.of(i, "Test"); + KV kv = KV.of(i, TestRow.getNameForSeed(i)); data.add(kv); } return data; diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java index 2ecdde7626ed..288f1467fa30 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java @@ -25,7 +25,6 @@ import com.google.cloud.Timestamp; import java.sql.SQLException; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -51,8 +50,8 @@ import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.FlatMapElements; import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -62,6 +61,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.joda.time.Duration; import org.joda.time.Instant; @@ -96,8 +96,6 @@ @RunWith(JUnit4.class) public class JdbcIOIT { - // the number of rows written to table in normal integration tests (not the performance test). - private static final int EXPECTED_ROW_COUNT = 1000; private static final String NAMESPACE = JdbcIOIT.class.getName(); // the number of rows written to table in the performance test. private static int numberOfRows; @@ -117,6 +115,7 @@ public static void setup() { } org.junit.Assume.assumeNotNull(options); numberOfRows = options.getNumberOfRecords(); + dataSource = DatabaseTestHelper.getPostgresDataSource(options); tableName = DatabaseTestHelper.getTestTableName("IT"); settings = @@ -137,7 +136,7 @@ public void testWriteThenRead() throws SQLException { try { PipelineResult writeResult = runWrite(); PipelineResult.State writeState = writeResult.waitUntilFinish(); - PipelineResult readResult = runRead(); + PipelineResult readResult = runRead(tableName); PipelineResult.State readState = readResult.waitUntilFinish(); gatherAndPublishMetrics(writeResult, readResult); // Fail the test if pipeline failed. @@ -234,7 +233,10 @@ private PipelineResult runWrite() { * verify that their values are correct. Where first/last 500 rows is determined by the fact that * we know all rows have a unique id - we can use the natural ordering of that key. */ - private PipelineResult runRead() { + private PipelineResult runRead(String tableName) { + if (tableName == null) { + tableName = JdbcIOIT.tableName; + } PCollection namesAndIds = pipelineRead .apply( @@ -382,9 +384,15 @@ public void testWriteWithWriteResults() throws Exception { String firstTableName = DatabaseTestHelper.getTestTableName("JDBCIT_WRITE"); DatabaseTestHelper.createTable(dataSource, firstTableName); try { - ArrayList> data = getTestDataToWrite(EXPECTED_ROW_COUNT); - PCollection> dataCollection = pipelineWrite.apply(Create.of(data)); + PCollection> dataCollection = + pipelineWrite + .apply(GenerateSequence.from(0).to(numberOfRows)) + .apply( + FlatMapElements.into( + TypeDescriptors.kvs( + TypeDescriptors.integers(), TypeDescriptors.strings())) + .via(num -> getTestDataToWrite(1))); PCollection resultSetCollection = dataCollection.apply( getJdbcWriteWithReturning(firstTableName) @@ -397,16 +405,12 @@ public void testWriteWithWriteResults() throws Exception { })); resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER); - List expectedResult = new ArrayList<>(); - for (int id = 0; id < EXPECTED_ROW_COUNT; id++) { - expectedResult.add(new JdbcTestHelper.TestDto(id)); - } - - PAssert.that(resultSetCollection).containsInAnyOrder(expectedResult); + PAssert.that(resultSetCollection.apply(Count.globally())) + .containsInAnyOrder(Long.valueOf(numberOfRows)); pipelineWrite.run().waitUntilFinish(); - assertRowCount(dataSource, firstTableName, EXPECTED_ROW_COUNT); + assertRowCount(dataSource, firstTableName, numberOfRows); } finally { DatabaseTestHelper.deleteTable(dataSource, firstTableName); }