diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 0f6b9c39b12d..14f5e6935541 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -66,15 +66,19 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SerializableFunctions; +import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.Wait; +import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.HasDisplayData; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.BackOffUtils; import org.apache.beam.sdk.util.FluentBackoff; @@ -82,6 +86,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.Row; @@ -96,6 +101,7 @@ import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -1318,6 +1324,11 @@ public static class Write extends PTransform, PDone> { this.inner = inner; } + /** See {@link WriteVoid#withAutoSharding()}. */ + public Write withAutoSharding() { + return new Write<>(inner.withAutoSharding()); + } + /** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */ public Write withDataSourceConfiguration(DataSourceConfiguration config) { return new Write<>(inner.withDataSourceConfiguration(config)); @@ -1393,6 +1404,7 @@ public WriteWithResults withWriteResults( .setPreparedStatementSetter(inner.getPreparedStatementSetter()) .setStatement(inner.getStatement()) .setTable(inner.getTable()) + .setAutoSharding(inner.getAutoSharding()) .build(); } @@ -1408,6 +1420,51 @@ public PDone expand(PCollection input) { } } + /* The maximum number of elements that will be included in a batch. */ + + static PCollection> batchElements( + PCollection input, Boolean withAutoSharding, long batchSize) { + PCollection> iterables; + if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null && withAutoSharding) { + iterables = + input + .apply(WithKeys.of("")) + .apply( + GroupIntoBatches.ofSize(batchSize) + .withMaxBufferingDuration(Duration.millis(200)) + .withShardedKey()) + .apply(Values.create()); + } else { + iterables = + input.apply( + ParDo.of( + new DoFn>() { + List outputList; + + @ProcessElement + public void process(ProcessContext c) { + if (outputList == null) { + outputList = new ArrayList<>(); + } + outputList.add(c.element()); + if (outputList.size() > batchSize) { + c.output(outputList); + outputList = null; + } + } + + @FinishBundle + public void finish(FinishBundleContext c) { + if (outputList != null && outputList.size() > 0) { + c.output(outputList, Instant.now(), GlobalWindow.INSTANCE); + } + outputList = null; + } + })); + } + return iterables; + } + /** Interface implemented by functions that sets prepared statement data. */ @FunctionalInterface interface PreparedStatementSetCaller extends Serializable { @@ -1430,6 +1487,8 @@ void set( @AutoValue public abstract static class WriteWithResults extends PTransform, PCollection> { + abstract @Nullable Boolean getAutoSharding(); + abstract @Nullable SerializableFunction getDataSourceProviderFn(); abstract @Nullable ValueProvider getStatement(); @@ -1451,6 +1510,8 @@ abstract static class Builder { abstract Builder setDataSourceProviderFn( SerializableFunction dataSourceProviderFn); + abstract Builder setAutoSharding(Boolean autoSharding); + abstract Builder setStatement(ValueProvider statement); abstract Builder setPreparedStatementSetter(PreparedStatementSetter setter); @@ -1487,6 +1548,11 @@ public WriteWithResults withPreparedStatementSetter(PreparedStatementSette return toBuilder().setPreparedStatementSetter(setter).build(); } + /** If true, enables using a dynamically determined number of shards to write. */ + public WriteWithResults withAutoSharding() { + return toBuilder().setAutoSharding(true).build(); + } + /** * When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it * will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true}, @@ -1549,8 +1615,15 @@ public PCollection expand(PCollection input) { checkArgument( (getDataSourceProviderFn() != null), "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); - - return input.apply( + checkArgument( + getAutoSharding() == null + || (getAutoSharding() && input.isBounded() != IsBounded.UNBOUNDED), + "Autosharding is only supported for streaming pipelines."); + ; + + PCollection> iterables = + JdbcIO.batchElements(input, getAutoSharding(), DEFAULT_BATCH_SIZE); + return iterables.apply( ParDo.of( new WriteFn( WriteFnSpec.builder() @@ -1573,6 +1646,8 @@ public PCollection expand(PCollection input) { @AutoValue public abstract static class WriteVoid extends PTransform, PCollection> { + abstract @Nullable Boolean getAutoSharding(); + abstract @Nullable SerializableFunction getDataSourceProviderFn(); abstract @Nullable ValueProvider getStatement(); @@ -1591,6 +1666,8 @@ public abstract static class WriteVoid extends PTransform, PCo @AutoValue.Builder abstract static class Builder { + abstract Builder setAutoSharding(Boolean autoSharding); + abstract Builder setDataSourceProviderFn( SerializableFunction dataSourceProviderFn); @@ -1609,6 +1686,11 @@ abstract Builder setDataSourceProviderFn( abstract WriteVoid build(); } + /** If true, enables using a dynamically determined number of shards to write. */ + public WriteVoid withAutoSharding() { + return toBuilder().setAutoSharding(true).build(); + } + public WriteVoid withDataSourceConfiguration(DataSourceConfiguration config) { return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config)); } @@ -1708,7 +1790,11 @@ public PCollection expand(PCollection input) { checkArgument( spec.getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required"); } - return input + + PCollection> iterables = + JdbcIO.batchElements(input, getAutoSharding(), getBatchSize()); + + return iterables .apply( ParDo.of( new WriteFn( @@ -1955,7 +2041,7 @@ public void populateDisplayData(DisplayData.Builder builder) { * @param * @param */ - static class WriteFn extends DoFn { + static class WriteFn extends DoFn, V> { @AutoValue abstract static class WriteFnSpec implements Serializable, HasDisplayData { @@ -2045,7 +2131,6 @@ abstract static class Builder { private Connection connection; private PreparedStatement preparedStatement; private static FluentBackoff retryBackOff; - private final List records = new ArrayList<>(); public WriteFn(WriteFnSpec spec) { this.spec = spec; @@ -2085,17 +2170,12 @@ private Connection getConnection() throws SQLException { @ProcessElement public void processElement(ProcessContext context) throws Exception { - T record = context.element(); - records.add(record); - if (records.size() >= spec.getBatchSize()) { - executeBatch(context); - } + executeBatch(context, context.element()); } @FinishBundle public void finishBundle() throws Exception { // We pass a null context because we only execute a final batch for WriteVoid cases. - executeBatch(null); cleanUpStatementAndConnection(); } @@ -2124,11 +2204,8 @@ private void cleanUpStatementAndConnection() throws Exception { } } - private void executeBatch(ProcessContext context) + private void executeBatch(ProcessContext context, Iterable records) throws SQLException, IOException, InterruptedException { - if (records.isEmpty()) { - return; - } Long startTimeNs = System.nanoTime(); Sleeper sleeper = Sleeper.DEFAULT; BackOff backoff = retryBackOff.backoff(); @@ -2137,8 +2214,10 @@ private void executeBatch(ProcessContext context) getConnection().prepareStatement(spec.getStatement().get())) { try { // add each record in the statement batch + int recordsInBatch = 0; for (T record : records) { processRecord(record, preparedStatement, context); + recordsInBatch += 1; } if (!spec.getReturnResults()) { // execute the batch @@ -2146,7 +2225,7 @@ private void executeBatch(ProcessContext context) // commit the changes getConnection().commit(); } - RECORDS_PER_BATCH.update(records.size()); + RECORDS_PER_BATCH.update(recordsInBatch); MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)); break; } catch (SQLException exception) { @@ -2164,7 +2243,6 @@ private void executeBatch(ProcessContext context) } } } - records.clear(); } private void processRecord(T record, PreparedStatement preparedStatement, ProcessContext c) { diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java index 8cebbbd56b09..59bc7641acfb 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java @@ -32,6 +32,9 @@ import java.util.UUID; import java.util.function.Function; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.common.DatabaseTestHelper; import org.apache.beam.sdk.io.common.HashingFn; @@ -39,6 +42,7 @@ import org.apache.beam.sdk.io.common.TestRow; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testutils.NamedTestResult; import org.apache.beam.sdk.testutils.metrics.IOITMetrics; import org.apache.beam.sdk.testutils.metrics.MetricsReader; @@ -51,6 +55,7 @@ import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.joda.time.Instant; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; @@ -254,6 +259,40 @@ private PipelineResult runRead() { return pipelineRead.run(); } + @Test + public void testWriteWithAutosharding() throws Exception { + String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(dataSource, firstTableName); + try { + List> data = getTestDataToWrite(EXPECTED_ROW_COUNT); + TestStream.Builder> ts = + TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of())) + .advanceWatermarkTo(Instant.now()); + for (KV elm : data) { + ts.addElements(elm); + } + + PCollection> dataCollection = + pipelineWrite.apply(ts.advanceWatermarkToInfinity()); + dataCollection.apply( + JdbcIO.>write() + .withDataSourceProviderFn(voidInput -> dataSource) + .withStatement(String.format("insert into %s values(?, ?) returning *", tableName)) + .withAutoSharding() + .withPreparedStatementSetter( + (element, statement) -> { + statement.setInt(1, element.getKey()); + statement.setString(2, element.getValue()); + })); + + pipelineWrite.run().waitUntilFinish(); + + runRead(); + } finally { + DatabaseTestHelper.deleteTable(dataSource, firstTableName); + } + } + @Test public void testWriteWithWriteResults() throws Exception { String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 67cd1dbf0aaf..536026a6ae00 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -75,6 +75,7 @@ import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.SerializableFunction; @@ -89,6 +90,7 @@ import org.hamcrest.TypeSafeMatcher; import org.joda.time.DateTime; import org.joda.time.Duration; +import org.joda.time.Instant; import org.joda.time.LocalDate; import org.joda.time.chrono.ISOChronology; import org.junit.BeforeClass; @@ -528,6 +530,31 @@ public void testWrite() throws Exception { } } + @Test + public void testWriteWithAutosharding() throws Exception { + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(DATA_SOURCE, tableName); + TestStream.Builder> ts = + TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of())) + .advanceWatermarkTo(Instant.now()); + + try { + List> data = getDataToWrite(EXPECTED_ROW_COUNT); + for (KV elm : data) { + ts = ts.addElements(elm); + } + pipeline + .apply(ts.advanceWatermarkToInfinity()) + .apply(getJdbcWrite(tableName).withAutoSharding()); + + pipeline.run().waitUntilFinish(); + + assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); + } finally { + DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); + } + } + @Test public void testWriteWithWriteResults() throws Exception { String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); @@ -548,6 +575,9 @@ public void testWriteWithWriteResults() throws Exception { })); resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER); + PAssert.thatSingleton(resultSetCollection.apply(Count.globally())) + .isEqualTo((long) EXPECTED_ROW_COUNT); + List expectedResult = new ArrayList<>(); for (int i = 0; i < EXPECTED_ROW_COUNT; i++) { expectedResult.add(new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT));