From fdf1e5aa4d742d0924436995637a532b564ec31c Mon Sep 17 00:00:00 2001 From: Pablo Estrada Date: Thu, 20 Jan 2022 14:27:55 -0800 Subject: [PATCH 1/3] Revert "Revert "Merge pull request #15863 from [BEAM-13184] Autosharding for JdbcIO.write* transforms"" This reverts commit 421bc8068fc561a358cfbf6c9842408672872120. --- .../org/apache/beam/sdk/io/jdbc/JdbcIO.java | 107 +++++++++++++++--- .../org/apache/beam/sdk/io/jdbc/JdbcIOIT.java | 39 +++++++ .../apache/beam/sdk/io/jdbc/JdbcIOTest.java | 30 +++++ 3 files changed, 160 insertions(+), 16 deletions(-) diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 0f6b9c39b12d..5755754b3580 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -66,15 +66,19 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SerializableFunctions; +import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.Wait; +import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.HasDisplayData; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.BackOffUtils; import org.apache.beam.sdk.util.FluentBackoff; @@ -82,6 +86,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.Row; @@ -96,6 +101,7 @@ import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -1318,6 +1324,11 @@ public static class Write extends PTransform, PDone> { this.inner = inner; } + /** See {@link WriteVoid#withAutoSharding()}. */ + public Write withAutoSharding() { + return new Write<>(inner.withAutoSharding()); + } + /** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */ public Write withDataSourceConfiguration(DataSourceConfiguration config) { return new Write<>(inner.withDataSourceConfiguration(config)); @@ -1393,6 +1404,7 @@ public WriteWithResults withWriteResults( .setPreparedStatementSetter(inner.getPreparedStatementSetter()) .setStatement(inner.getStatement()) .setTable(inner.getTable()) + .setAutoSharding(inner.getAutoSharding()) .build(); } @@ -1408,6 +1420,50 @@ public PDone expand(PCollection input) { } } + /* The maximum number of elements that will be included in a batch. */ + private static final Integer MAX_BUNDLE_SIZE = 5000; + + static PCollection> batchElements( + PCollection input, Boolean withAutoSharding) { + PCollection> iterables; + if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null && withAutoSharding) { + iterables = + input + .apply(WithKeys.of("")) + .apply( + GroupIntoBatches.ofSize(DEFAULT_BATCH_SIZE) + .withMaxBufferingDuration(Duration.millis(200)) + .withShardedKey()) + .apply(Values.create()); + } else { + iterables = + input.apply( + ParDo.of( + new DoFn>() { + List outputList; + + @ProcessElement + public void process(ProcessContext c) { + if (outputList == null) { + outputList = new ArrayList<>(); + } + outputList.add(c.element()); + if (outputList.size() > MAX_BUNDLE_SIZE) { + c.output(outputList); + outputList = null; + } + } + + @FinishBundle + public void finish(FinishBundleContext c) { + c.output(outputList, Instant.now(), GlobalWindow.INSTANCE); + outputList = null; + } + })); + } + return iterables; + } + /** Interface implemented by functions that sets prepared statement data. */ @FunctionalInterface interface PreparedStatementSetCaller extends Serializable { @@ -1430,6 +1486,8 @@ void set( @AutoValue public abstract static class WriteWithResults extends PTransform, PCollection> { + abstract @Nullable Boolean getAutoSharding(); + abstract @Nullable SerializableFunction getDataSourceProviderFn(); abstract @Nullable ValueProvider getStatement(); @@ -1451,6 +1509,8 @@ abstract static class Builder { abstract Builder setDataSourceProviderFn( SerializableFunction dataSourceProviderFn); + abstract Builder setAutoSharding(Boolean autoSharding); + abstract Builder setStatement(ValueProvider statement); abstract Builder setPreparedStatementSetter(PreparedStatementSetter setter); @@ -1487,6 +1547,11 @@ public WriteWithResults withPreparedStatementSetter(PreparedStatementSette return toBuilder().setPreparedStatementSetter(setter).build(); } + /** If true, enables using a dynamically determined number of shards to write. */ + public WriteWithResults withAutoSharding() { + return toBuilder().setAutoSharding(true).build(); + } + /** * When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it * will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true}, @@ -1549,8 +1614,14 @@ public PCollection expand(PCollection input) { checkArgument( (getDataSourceProviderFn() != null), "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); + checkArgument( + getAutoSharding() == null + || (getAutoSharding() && input.isBounded() != IsBounded.UNBOUNDED), + "Autosharding is only supported for streaming pipelines."); + ; - return input.apply( + PCollection> iterables = JdbcIO.batchElements(input, getAutoSharding()); + return iterables.apply( ParDo.of( new WriteFn( WriteFnSpec.builder() @@ -1573,6 +1644,8 @@ public PCollection expand(PCollection input) { @AutoValue public abstract static class WriteVoid extends PTransform, PCollection> { + abstract @Nullable Boolean getAutoSharding(); + abstract @Nullable SerializableFunction getDataSourceProviderFn(); abstract @Nullable ValueProvider getStatement(); @@ -1591,6 +1664,8 @@ public abstract static class WriteVoid extends PTransform, PCo @AutoValue.Builder abstract static class Builder { + abstract Builder setAutoSharding(Boolean autoSharding); + abstract Builder setDataSourceProviderFn( SerializableFunction dataSourceProviderFn); @@ -1609,6 +1684,11 @@ abstract Builder setDataSourceProviderFn( abstract WriteVoid build(); } + /** If true, enables using a dynamically determined number of shards to write. */ + public WriteVoid withAutoSharding() { + return toBuilder().setAutoSharding(true).build(); + } + public WriteVoid withDataSourceConfiguration(DataSourceConfiguration config) { return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config)); } @@ -1708,7 +1788,10 @@ public PCollection expand(PCollection input) { checkArgument( spec.getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required"); } - return input + + PCollection> iterables = JdbcIO.batchElements(input, getAutoSharding()); + + return iterables .apply( ParDo.of( new WriteFn( @@ -1955,7 +2038,7 @@ public void populateDisplayData(DisplayData.Builder builder) { * @param * @param */ - static class WriteFn extends DoFn { + static class WriteFn extends DoFn, V> { @AutoValue abstract static class WriteFnSpec implements Serializable, HasDisplayData { @@ -2045,7 +2128,6 @@ abstract static class Builder { private Connection connection; private PreparedStatement preparedStatement; private static FluentBackoff retryBackOff; - private final List records = new ArrayList<>(); public WriteFn(WriteFnSpec spec) { this.spec = spec; @@ -2085,17 +2167,12 @@ private Connection getConnection() throws SQLException { @ProcessElement public void processElement(ProcessContext context) throws Exception { - T record = context.element(); - records.add(record); - if (records.size() >= spec.getBatchSize()) { - executeBatch(context); - } + executeBatch(context, context.element()); } @FinishBundle public void finishBundle() throws Exception { // We pass a null context because we only execute a final batch for WriteVoid cases. - executeBatch(null); cleanUpStatementAndConnection(); } @@ -2124,11 +2201,8 @@ private void cleanUpStatementAndConnection() throws Exception { } } - private void executeBatch(ProcessContext context) + private void executeBatch(ProcessContext context, Iterable records) throws SQLException, IOException, InterruptedException { - if (records.isEmpty()) { - return; - } Long startTimeNs = System.nanoTime(); Sleeper sleeper = Sleeper.DEFAULT; BackOff backoff = retryBackOff.backoff(); @@ -2137,8 +2211,10 @@ private void executeBatch(ProcessContext context) getConnection().prepareStatement(spec.getStatement().get())) { try { // add each record in the statement batch + int recordsInBatch = 0; for (T record : records) { processRecord(record, preparedStatement, context); + recordsInBatch += 1; } if (!spec.getReturnResults()) { // execute the batch @@ -2146,7 +2222,7 @@ private void executeBatch(ProcessContext context) // commit the changes getConnection().commit(); } - RECORDS_PER_BATCH.update(records.size()); + RECORDS_PER_BATCH.update(recordsInBatch); MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)); break; } catch (SQLException exception) { @@ -2164,7 +2240,6 @@ private void executeBatch(ProcessContext context) } } } - records.clear(); } private void processRecord(T record, PreparedStatement preparedStatement, ProcessContext c) { diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java index 8cebbbd56b09..59bc7641acfb 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java @@ -32,6 +32,9 @@ import java.util.UUID; import java.util.function.Function; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.common.DatabaseTestHelper; import org.apache.beam.sdk.io.common.HashingFn; @@ -39,6 +42,7 @@ import org.apache.beam.sdk.io.common.TestRow; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testutils.NamedTestResult; import org.apache.beam.sdk.testutils.metrics.IOITMetrics; import org.apache.beam.sdk.testutils.metrics.MetricsReader; @@ -51,6 +55,7 @@ import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.joda.time.Instant; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; @@ -254,6 +259,40 @@ private PipelineResult runRead() { return pipelineRead.run(); } + @Test + public void testWriteWithAutosharding() throws Exception { + String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(dataSource, firstTableName); + try { + List> data = getTestDataToWrite(EXPECTED_ROW_COUNT); + TestStream.Builder> ts = + TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of())) + .advanceWatermarkTo(Instant.now()); + for (KV elm : data) { + ts.addElements(elm); + } + + PCollection> dataCollection = + pipelineWrite.apply(ts.advanceWatermarkToInfinity()); + dataCollection.apply( + JdbcIO.>write() + .withDataSourceProviderFn(voidInput -> dataSource) + .withStatement(String.format("insert into %s values(?, ?) returning *", tableName)) + .withAutoSharding() + .withPreparedStatementSetter( + (element, statement) -> { + statement.setInt(1, element.getKey()); + statement.setString(2, element.getValue()); + })); + + pipelineWrite.run().waitUntilFinish(); + + runRead(); + } finally { + DatabaseTestHelper.deleteTable(dataSource, firstTableName); + } + } + @Test public void testWriteWithWriteResults() throws Exception { String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 67cd1dbf0aaf..536026a6ae00 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -75,6 +75,7 @@ import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.SerializableFunction; @@ -89,6 +90,7 @@ import org.hamcrest.TypeSafeMatcher; import org.joda.time.DateTime; import org.joda.time.Duration; +import org.joda.time.Instant; import org.joda.time.LocalDate; import org.joda.time.chrono.ISOChronology; import org.junit.BeforeClass; @@ -528,6 +530,31 @@ public void testWrite() throws Exception { } } + @Test + public void testWriteWithAutosharding() throws Exception { + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(DATA_SOURCE, tableName); + TestStream.Builder> ts = + TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of())) + .advanceWatermarkTo(Instant.now()); + + try { + List> data = getDataToWrite(EXPECTED_ROW_COUNT); + for (KV elm : data) { + ts = ts.addElements(elm); + } + pipeline + .apply(ts.advanceWatermarkToInfinity()) + .apply(getJdbcWrite(tableName).withAutoSharding()); + + pipeline.run().waitUntilFinish(); + + assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); + } finally { + DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); + } + } + @Test public void testWriteWithWriteResults() throws Exception { String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); @@ -548,6 +575,9 @@ public void testWriteWithWriteResults() throws Exception { })); resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER); + PAssert.thatSingleton(resultSetCollection.apply(Count.globally())) + .isEqualTo((long) EXPECTED_ROW_COUNT); + List expectedResult = new ArrayList<>(); for (int i = 0; i < EXPECTED_ROW_COUNT; i++) { expectedResult.add(new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT)); From 9699c9ba5113b4654b426c2f2232c76852c0b8c1 Mon Sep 17 00:00:00 2001 From: Pablo Estrada Date: Thu, 20 Jan 2022 14:58:44 -0800 Subject: [PATCH 2/3] Using batchSize to define element batch size --- .../src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 5755754b3580..56af8556cc56 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -1421,7 +1421,6 @@ public PDone expand(PCollection input) { } /* The maximum number of elements that will be included in a batch. */ - private static final Integer MAX_BUNDLE_SIZE = 5000; static PCollection> batchElements( PCollection input, Boolean withAutoSharding) { @@ -1431,7 +1430,7 @@ static PCollection> batchElements( input .apply(WithKeys.of("")) .apply( - GroupIntoBatches.ofSize(DEFAULT_BATCH_SIZE) + GroupIntoBatches.ofSize(getBatchSize()) .withMaxBufferingDuration(Duration.millis(200)) .withShardedKey()) .apply(Values.create()); @@ -1448,7 +1447,7 @@ public void process(ProcessContext c) { outputList = new ArrayList<>(); } outputList.add(c.element()); - if (outputList.size() > MAX_BUNDLE_SIZE) { + if (outputList.size() > getBatchSize()) { c.output(outputList); outputList = null; } From b4bd6147268f438b3fe8ba8def583a25a8ee29c2 Mon Sep 17 00:00:00 2001 From: Pablo Estrada Date: Thu, 20 Jan 2022 17:58:49 -0800 Subject: [PATCH 3/3] Handle corner case for null list --- .../java/org/apache/beam/sdk/io/jdbc/JdbcIO.java | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 56af8556cc56..14f5e6935541 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -1423,14 +1423,14 @@ public PDone expand(PCollection input) { /* The maximum number of elements that will be included in a batch. */ static PCollection> batchElements( - PCollection input, Boolean withAutoSharding) { + PCollection input, Boolean withAutoSharding, long batchSize) { PCollection> iterables; if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null && withAutoSharding) { iterables = input .apply(WithKeys.of("")) .apply( - GroupIntoBatches.ofSize(getBatchSize()) + GroupIntoBatches.ofSize(batchSize) .withMaxBufferingDuration(Duration.millis(200)) .withShardedKey()) .apply(Values.create()); @@ -1447,7 +1447,7 @@ public void process(ProcessContext c) { outputList = new ArrayList<>(); } outputList.add(c.element()); - if (outputList.size() > getBatchSize()) { + if (outputList.size() > batchSize) { c.output(outputList); outputList = null; } @@ -1455,7 +1455,9 @@ public void process(ProcessContext c) { @FinishBundle public void finish(FinishBundleContext c) { - c.output(outputList, Instant.now(), GlobalWindow.INSTANCE); + if (outputList != null && outputList.size() > 0) { + c.output(outputList, Instant.now(), GlobalWindow.INSTANCE); + } outputList = null; } })); @@ -1619,7 +1621,8 @@ public PCollection expand(PCollection input) { "Autosharding is only supported for streaming pipelines."); ; - PCollection> iterables = JdbcIO.batchElements(input, getAutoSharding()); + PCollection> iterables = + JdbcIO.batchElements(input, getAutoSharding(), DEFAULT_BATCH_SIZE); return iterables.apply( ParDo.of( new WriteFn( @@ -1788,7 +1791,8 @@ public PCollection expand(PCollection input) { spec.getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required"); } - PCollection> iterables = JdbcIO.batchElements(input, getAutoSharding()); + PCollection> iterables = + JdbcIO.batchElements(input, getAutoSharding(), getBatchSize()); return iterables .apply(