-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[BEAM-13184] Autosharding for JdbcIO.write* transforms #15863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
89d1650
fe8f980
983b960
ab260c5
1edce47
b4dae93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,22 +66,27 @@ | |
| import org.apache.beam.sdk.transforms.DoFn; | ||
| import org.apache.beam.sdk.transforms.Filter; | ||
| import org.apache.beam.sdk.transforms.GroupByKey; | ||
| import org.apache.beam.sdk.transforms.GroupIntoBatches; | ||
| import org.apache.beam.sdk.transforms.PTransform; | ||
| import org.apache.beam.sdk.transforms.ParDo; | ||
| import org.apache.beam.sdk.transforms.Reshuffle; | ||
| import org.apache.beam.sdk.transforms.SerializableFunction; | ||
| import org.apache.beam.sdk.transforms.SerializableFunctions; | ||
| import org.apache.beam.sdk.transforms.Values; | ||
| import org.apache.beam.sdk.transforms.View; | ||
| import org.apache.beam.sdk.transforms.Wait; | ||
| import org.apache.beam.sdk.transforms.WithKeys; | ||
| import org.apache.beam.sdk.transforms.display.DisplayData; | ||
| import org.apache.beam.sdk.transforms.display.HasDisplayData; | ||
| import org.apache.beam.sdk.transforms.windowing.GlobalWindow; | ||
| import org.apache.beam.sdk.util.BackOff; | ||
| import org.apache.beam.sdk.util.BackOffUtils; | ||
| import org.apache.beam.sdk.util.FluentBackoff; | ||
| import org.apache.beam.sdk.util.Sleeper; | ||
| import org.apache.beam.sdk.values.KV; | ||
| import org.apache.beam.sdk.values.PBegin; | ||
| import org.apache.beam.sdk.values.PCollection; | ||
| import org.apache.beam.sdk.values.PCollection.IsBounded; | ||
| import org.apache.beam.sdk.values.PCollectionView; | ||
| import org.apache.beam.sdk.values.PDone; | ||
| import org.apache.beam.sdk.values.Row; | ||
|
|
@@ -96,6 +101,7 @@ | |
| import org.apache.commons.pool2.impl.GenericObjectPoolConfig; | ||
| import org.checkerframework.checker.nullness.qual.Nullable; | ||
| import org.joda.time.Duration; | ||
| import org.joda.time.Instant; | ||
| import org.slf4j.Logger; | ||
| import org.slf4j.LoggerFactory; | ||
|
|
||
|
|
@@ -1318,6 +1324,11 @@ public static class Write<T> extends PTransform<PCollection<T>, PDone> { | |
| this.inner = inner; | ||
| } | ||
|
|
||
| /** See {@link WriteVoid#withAutoSharding()}. */ | ||
| public Write<T> withAutoSharding() { | ||
| return new Write<>(inner.withAutoSharding()); | ||
| } | ||
|
|
||
| /** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */ | ||
| public Write<T> withDataSourceConfiguration(DataSourceConfiguration config) { | ||
| return new Write<>(inner.withDataSourceConfiguration(config)); | ||
|
|
@@ -1393,6 +1404,7 @@ public <V extends JdbcWriteResult> WriteWithResults<T, V> withWriteResults( | |
| .setPreparedStatementSetter(inner.getPreparedStatementSetter()) | ||
| .setStatement(inner.getStatement()) | ||
| .setTable(inner.getTable()) | ||
| .setAutoSharding(inner.getAutoSharding()) | ||
| .build(); | ||
| } | ||
|
|
||
|
|
@@ -1408,6 +1420,50 @@ public PDone expand(PCollection<T> input) { | |
| } | ||
| } | ||
|
|
||
| /* The maximum number of elements that will be included in a batch. */ | ||
| private static final Integer MAX_BUNDLE_SIZE = 5000; | ||
|
|
||
| static <T> PCollection<Iterable<T>> batchElements( | ||
| PCollection<T> input, Boolean withAutoSharding) { | ||
| PCollection<Iterable<T>> iterables; | ||
| if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null && withAutoSharding) { | ||
| iterables = | ||
| input | ||
| .apply(WithKeys.<String, T>of("")) | ||
| .apply( | ||
| GroupIntoBatches.<String, T>ofSize(DEFAULT_BATCH_SIZE) | ||
| .withMaxBufferingDuration(Duration.millis(200)) | ||
| .withShardedKey()) | ||
| .apply(Values.create()); | ||
| } else { | ||
| iterables = | ||
| input.apply( | ||
| ParDo.of( | ||
| new DoFn<T, Iterable<T>>() { | ||
| List<T> outputList; | ||
|
|
||
| @ProcessElement | ||
| public void process(ProcessContext c) { | ||
| if (outputList == null) { | ||
| outputList = new ArrayList<>(); | ||
| } | ||
| outputList.add(c.element()); | ||
| if (outputList.size() > MAX_BUNDLE_SIZE) { | ||
| c.output(outputList); | ||
| outputList = null; | ||
| } | ||
| } | ||
|
|
||
| @FinishBundle | ||
| public void finish(FinishBundleContext c) { | ||
| c.output(outputList, Instant.now(), GlobalWindow.INSTANCE); | ||
| outputList = null; | ||
| } | ||
| })); | ||
| } | ||
| return iterables; | ||
| } | ||
|
|
||
| /** Interface implemented by functions that sets prepared statement data. */ | ||
| @FunctionalInterface | ||
| interface PreparedStatementSetCaller extends Serializable { | ||
|
|
@@ -1430,6 +1486,8 @@ void set( | |
| @AutoValue | ||
| public abstract static class WriteWithResults<T, V extends JdbcWriteResult> | ||
| extends PTransform<PCollection<T>, PCollection<V>> { | ||
| abstract @Nullable Boolean getAutoSharding(); | ||
|
|
||
| abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn(); | ||
|
|
||
| abstract @Nullable ValueProvider<String> getStatement(); | ||
|
|
@@ -1451,6 +1509,8 @@ abstract static class Builder<T, V extends JdbcWriteResult> { | |
| abstract Builder<T, V> setDataSourceProviderFn( | ||
| SerializableFunction<Void, DataSource> dataSourceProviderFn); | ||
|
|
||
| abstract Builder<T, V> setAutoSharding(Boolean autoSharding); | ||
|
|
||
| abstract Builder<T, V> setStatement(ValueProvider<String> statement); | ||
|
|
||
| abstract Builder<T, V> setPreparedStatementSetter(PreparedStatementSetter<T> setter); | ||
|
|
@@ -1487,6 +1547,11 @@ public WriteWithResults<T, V> withPreparedStatementSetter(PreparedStatementSette | |
| return toBuilder().setPreparedStatementSetter(setter).build(); | ||
| } | ||
|
|
||
| /** If true, enables using a dynamically determined number of shards to write. */ | ||
| public WriteWithResults<T, V> withAutoSharding() { | ||
| return toBuilder().setAutoSharding(true).build(); | ||
| } | ||
|
|
||
| /** | ||
| * When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it | ||
| * will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true}, | ||
|
|
@@ -1549,8 +1614,14 @@ public PCollection<V> expand(PCollection<T> input) { | |
| checkArgument( | ||
| (getDataSourceProviderFn() != null), | ||
| "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); | ||
| checkArgument( | ||
| getAutoSharding() == null | ||
| || (getAutoSharding() && input.isBounded() != IsBounded.UNBOUNDED), | ||
| "Autosharding is only supported for streaming pipelines."); | ||
| ; | ||
|
|
||
| return input.apply( | ||
| PCollection<Iterable<T>> iterables = JdbcIO.<T>batchElements(input, getAutoSharding()); | ||
| return iterables.apply( | ||
| ParDo.of( | ||
| new WriteFn<T, V>( | ||
| WriteFnSpec.builder() | ||
|
|
@@ -1573,6 +1644,8 @@ public PCollection<V> expand(PCollection<T> input) { | |
| @AutoValue | ||
| public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCollection<Void>> { | ||
|
|
||
| abstract @Nullable Boolean getAutoSharding(); | ||
|
|
||
| abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn(); | ||
|
|
||
| abstract @Nullable ValueProvider<String> getStatement(); | ||
|
|
@@ -1591,6 +1664,8 @@ public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCo | |
|
|
||
| @AutoValue.Builder | ||
| abstract static class Builder<T> { | ||
| abstract Builder<T> setAutoSharding(Boolean autoSharding); | ||
|
|
||
| abstract Builder<T> setDataSourceProviderFn( | ||
| SerializableFunction<Void, DataSource> dataSourceProviderFn); | ||
|
|
||
|
|
@@ -1609,6 +1684,11 @@ abstract Builder<T> setDataSourceProviderFn( | |
| abstract WriteVoid<T> build(); | ||
| } | ||
|
|
||
| /** If true, enables using a dynamically determined number of shards to write. */ | ||
| public WriteVoid<T> withAutoSharding() { | ||
| return toBuilder().setAutoSharding(true).build(); | ||
| } | ||
|
|
||
| public WriteVoid<T> withDataSourceConfiguration(DataSourceConfiguration config) { | ||
| return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config)); | ||
| } | ||
|
|
@@ -1708,7 +1788,10 @@ public PCollection<Void> expand(PCollection<T> input) { | |
| checkArgument( | ||
| spec.getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required"); | ||
| } | ||
| return input | ||
|
|
||
| PCollection<Iterable<T>> iterables = JdbcIO.<T>batchElements(input, getAutoSharding()); | ||
|
|
||
| return iterables | ||
| .apply( | ||
| ParDo.of( | ||
| new WriteFn<T, Void>( | ||
|
|
@@ -1955,7 +2038,7 @@ public void populateDisplayData(DisplayData.Builder builder) { | |
| * @param <T> | ||
| * @param <V> | ||
| */ | ||
| static class WriteFn<T, V> extends DoFn<T, V> { | ||
| static class WriteFn<T, V> extends DoFn<Iterable<T>, V> { | ||
|
|
||
| @AutoValue | ||
| abstract static class WriteFnSpec<T, V> implements Serializable, HasDisplayData { | ||
|
|
@@ -2045,7 +2128,6 @@ abstract static class Builder<T, V> { | |
| private Connection connection; | ||
| private PreparedStatement preparedStatement; | ||
| private static FluentBackoff retryBackOff; | ||
| private final List<T> records = new ArrayList<>(); | ||
|
|
||
| public WriteFn(WriteFnSpec<T, V> spec) { | ||
| this.spec = spec; | ||
|
|
@@ -2085,17 +2167,12 @@ private Connection getConnection() throws SQLException { | |
|
|
||
| @ProcessElement | ||
| public void processElement(ProcessContext context) throws Exception { | ||
| T record = context.element(); | ||
| records.add(record); | ||
| if (records.size() >= spec.getBatchSize()) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pabloem Hi Pablo. Not sure if this is the right way to ask this, but I was wondering if this removal removes the configurability of batch size? In batchElements I see that it utilizes a constant of MAX_BUNDLE_SIZE. Seems like it may not be utilizing batchSize anymore but I may be missing or misunderstanding something though.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch. I'll fix this before the release. (2.36.0), so it'll only be available on the next one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great. Thank you! |
||
| executeBatch(context); | ||
| } | ||
| executeBatch(context, context.element()); | ||
| } | ||
|
|
||
| @FinishBundle | ||
| public void finishBundle() throws Exception { | ||
| // We pass a null context because we only execute a final batch for WriteVoid cases. | ||
| executeBatch(null); | ||
| cleanUpStatementAndConnection(); | ||
| } | ||
|
|
||
|
|
@@ -2124,11 +2201,8 @@ private void cleanUpStatementAndConnection() throws Exception { | |
| } | ||
| } | ||
|
|
||
| private void executeBatch(ProcessContext context) | ||
| private void executeBatch(ProcessContext context, Iterable<T> records) | ||
| throws SQLException, IOException, InterruptedException { | ||
| if (records.isEmpty()) { | ||
| return; | ||
| } | ||
| Long startTimeNs = System.nanoTime(); | ||
| Sleeper sleeper = Sleeper.DEFAULT; | ||
| BackOff backoff = retryBackOff.backoff(); | ||
|
|
@@ -2137,16 +2211,18 @@ private void executeBatch(ProcessContext context) | |
| getConnection().prepareStatement(spec.getStatement().get())) { | ||
| try { | ||
| // add each record in the statement batch | ||
| int recordsInBatch = 0; | ||
| for (T record : records) { | ||
| processRecord(record, preparedStatement, context); | ||
| recordsInBatch += 1; | ||
| } | ||
| if (!spec.getReturnResults()) { | ||
| // execute the batch | ||
| preparedStatement.executeBatch(); | ||
| // commit the changes | ||
| getConnection().commit(); | ||
| } | ||
| RECORDS_PER_BATCH.update(records.size()); | ||
| RECORDS_PER_BATCH.update(recordsInBatch); | ||
| MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)); | ||
| break; | ||
| } catch (SQLException exception) { | ||
|
|
@@ -2164,7 +2240,6 @@ private void executeBatch(ProcessContext context) | |
| } | ||
| } | ||
| } | ||
| records.clear(); | ||
| } | ||
|
|
||
| private void processRecord(T record, PreparedStatement preparedStatement, ProcessContext c) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,13 +32,17 @@ | |
| import java.util.UUID; | ||
| import java.util.function.Function; | ||
| import org.apache.beam.sdk.PipelineResult; | ||
| import org.apache.beam.sdk.coders.KvCoder; | ||
| import org.apache.beam.sdk.coders.StringUtf8Coder; | ||
| import org.apache.beam.sdk.coders.VarIntCoder; | ||
| import org.apache.beam.sdk.io.GenerateSequence; | ||
| import org.apache.beam.sdk.io.common.DatabaseTestHelper; | ||
| import org.apache.beam.sdk.io.common.HashingFn; | ||
| import org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions; | ||
| import org.apache.beam.sdk.io.common.TestRow; | ||
| import org.apache.beam.sdk.testing.PAssert; | ||
| import org.apache.beam.sdk.testing.TestPipeline; | ||
| import org.apache.beam.sdk.testing.TestStream; | ||
| import org.apache.beam.sdk.testutils.NamedTestResult; | ||
| import org.apache.beam.sdk.testutils.metrics.IOITMetrics; | ||
| import org.apache.beam.sdk.testutils.metrics.MetricsReader; | ||
|
|
@@ -51,6 +55,7 @@ | |
| import org.apache.beam.sdk.transforms.Top; | ||
| import org.apache.beam.sdk.values.KV; | ||
| import org.apache.beam.sdk.values.PCollection; | ||
| import org.joda.time.Instant; | ||
| import org.junit.AfterClass; | ||
| import org.junit.BeforeClass; | ||
| import org.junit.Rule; | ||
|
|
@@ -258,6 +263,40 @@ private PipelineResult runRead() { | |
| return pipelineRead.run(); | ||
| } | ||
|
|
||
| @Test | ||
| public void testWriteWithAutosharding() throws Exception { | ||
| String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); | ||
| DatabaseTestHelper.createTable(dataSource, firstTableName); | ||
| try { | ||
| List<KV<Integer, String>> data = getTestDataToWrite(EXPECTED_ROW_COUNT); | ||
| TestStream.Builder<KV<Integer, String>> ts = | ||
| TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of())) | ||
| .advanceWatermarkTo(Instant.now()); | ||
| for (KV<Integer, String> elm : data) { | ||
| ts.addElements(elm); | ||
| } | ||
|
|
||
| PCollection<KV<Integer, String>> dataCollection = | ||
| pipelineWrite.apply(ts.advanceWatermarkToInfinity()); | ||
| dataCollection.apply( | ||
| JdbcIO.<KV<Integer, String>>write() | ||
| .withDataSourceProviderFn(voidInput -> dataSource) | ||
| .withStatement(String.format("insert into %s values(?, ?) returning *", tableName)) | ||
| .withAutoSharding() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we actually able to test that auto-sharding worked somehow ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm not really the way things are now - perhaps analyze the graph and see that the GIB transform is in it - but is that worth it? |
||
| .withPreparedStatementSetter( | ||
| (element, statement) -> { | ||
| statement.setInt(1, element.getKey()); | ||
| statement.setString(2, element.getValue()); | ||
| })); | ||
|
|
||
| pipelineWrite.run().waitUntilFinish(); | ||
|
|
||
| runRead(); | ||
| } finally { | ||
| DatabaseTestHelper.deleteTable(dataSource, firstTableName); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| public void testWriteWithWriteResults() throws Exception { | ||
| String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.