Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,22 @@
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Wait;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.Row;
Expand All @@ -101,7 +96,6 @@
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -1324,11 +1318,6 @@ public static class Write<T> extends PTransform<PCollection<T>, PDone> {
this.inner = inner;
}

/** See {@link WriteVoid#withAutoSharding()}. */
public Write<T> withAutoSharding() {
return new Write<>(inner.withAutoSharding());
}

/** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */
public Write<T> withDataSourceConfiguration(DataSourceConfiguration config) {
return new Write<>(inner.withDataSourceConfiguration(config));
Expand Down Expand Up @@ -1404,7 +1393,6 @@ public <V extends JdbcWriteResult> WriteWithResults<T, V> withWriteResults(
.setPreparedStatementSetter(inner.getPreparedStatementSetter())
.setStatement(inner.getStatement())
.setTable(inner.getTable())
.setAutoSharding(inner.getAutoSharding())
.build();
}

Expand All @@ -1420,50 +1408,6 @@ public PDone expand(PCollection<T> input) {
}
}

/* The maximum number of elements that will be included in a batch. */
private static final Integer MAX_BUNDLE_SIZE = 5000;

static <T> PCollection<Iterable<T>> batchElements(
PCollection<T> input, Boolean withAutoSharding) {
PCollection<Iterable<T>> iterables;
if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null && withAutoSharding) {
iterables =
input
.apply(WithKeys.<String, T>of(""))
.apply(
GroupIntoBatches.<String, T>ofSize(DEFAULT_BATCH_SIZE)
.withMaxBufferingDuration(Duration.millis(200))
.withShardedKey())
.apply(Values.create());
} else {
iterables =
input.apply(
ParDo.of(
new DoFn<T, Iterable<T>>() {
List<T> outputList;

@ProcessElement
public void process(ProcessContext c) {
if (outputList == null) {
outputList = new ArrayList<>();
}
outputList.add(c.element());
if (outputList.size() > MAX_BUNDLE_SIZE) {
c.output(outputList);
outputList = null;
}
}

@FinishBundle
public void finish(FinishBundleContext c) {
c.output(outputList, Instant.now(), GlobalWindow.INSTANCE);
outputList = null;
}
}));
}
return iterables;
}

/** Interface implemented by functions that sets prepared statement data. */
@FunctionalInterface
interface PreparedStatementSetCaller extends Serializable {
Expand All @@ -1486,8 +1430,6 @@ void set(
@AutoValue
public abstract static class WriteWithResults<T, V extends JdbcWriteResult>
extends PTransform<PCollection<T>, PCollection<V>> {
abstract @Nullable Boolean getAutoSharding();

abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn();

abstract @Nullable ValueProvider<String> getStatement();
Expand All @@ -1509,8 +1451,6 @@ abstract static class Builder<T, V extends JdbcWriteResult> {
abstract Builder<T, V> setDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn);

abstract Builder<T, V> setAutoSharding(Boolean autoSharding);

abstract Builder<T, V> setStatement(ValueProvider<String> statement);

abstract Builder<T, V> setPreparedStatementSetter(PreparedStatementSetter<T> setter);
Expand Down Expand Up @@ -1547,11 +1487,6 @@ public WriteWithResults<T, V> withPreparedStatementSetter(PreparedStatementSette
return toBuilder().setPreparedStatementSetter(setter).build();
}

/** If true, enables using a dynamically determined number of shards to write. */
public WriteWithResults<T, V> withAutoSharding() {
return toBuilder().setAutoSharding(true).build();
}

/**
* When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it
* will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true},
Expand Down Expand Up @@ -1614,14 +1549,8 @@ public PCollection<V> expand(PCollection<T> input) {
checkArgument(
(getDataSourceProviderFn() != null),
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");
checkArgument(
getAutoSharding() == null
|| (getAutoSharding() && input.isBounded() != IsBounded.UNBOUNDED),
"Autosharding is only supported for streaming pipelines.");
;

PCollection<Iterable<T>> iterables = JdbcIO.<T>batchElements(input, getAutoSharding());
return iterables.apply(
return input.apply(
ParDo.of(
new WriteFn<T, V>(
WriteFnSpec.builder()
Expand All @@ -1644,8 +1573,6 @@ public PCollection<V> expand(PCollection<T> input) {
@AutoValue
public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCollection<Void>> {

abstract @Nullable Boolean getAutoSharding();

abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn();

abstract @Nullable ValueProvider<String> getStatement();
Expand All @@ -1664,8 +1591,6 @@ public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCo

@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setAutoSharding(Boolean autoSharding);

abstract Builder<T> setDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn);

Expand All @@ -1684,11 +1609,6 @@ abstract Builder<T> setDataSourceProviderFn(
abstract WriteVoid<T> build();
}

/** If true, enables using a dynamically determined number of shards to write. */
public WriteVoid<T> withAutoSharding() {
return toBuilder().setAutoSharding(true).build();
}

public WriteVoid<T> withDataSourceConfiguration(DataSourceConfiguration config) {
return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
}
Expand Down Expand Up @@ -1788,10 +1708,7 @@ public PCollection<Void> expand(PCollection<T> input) {
checkArgument(
spec.getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required");
}

PCollection<Iterable<T>> iterables = JdbcIO.<T>batchElements(input, getAutoSharding());

return iterables
return input
.apply(
ParDo.of(
new WriteFn<T, Void>(
Expand Down Expand Up @@ -2038,7 +1955,7 @@ public void populateDisplayData(DisplayData.Builder builder) {
* @param <T>
* @param <V>
*/
static class WriteFn<T, V> extends DoFn<Iterable<T>, V> {
static class WriteFn<T, V> extends DoFn<T, V> {

@AutoValue
abstract static class WriteFnSpec<T, V> implements Serializable, HasDisplayData {
Expand Down Expand Up @@ -2128,6 +2045,7 @@ abstract static class Builder<T, V> {
private Connection connection;
private PreparedStatement preparedStatement;
private static FluentBackoff retryBackOff;
private final List<T> records = new ArrayList<>();

public WriteFn(WriteFnSpec<T, V> spec) {
this.spec = spec;
Expand Down Expand Up @@ -2167,12 +2085,17 @@ private Connection getConnection() throws SQLException {

@ProcessElement
public void processElement(ProcessContext context) throws Exception {
executeBatch(context, context.element());
T record = context.element();
records.add(record);
if (records.size() >= spec.getBatchSize()) {
executeBatch(context);
}
}

@FinishBundle
public void finishBundle() throws Exception {
// We pass a null context because we only execute a final batch for WriteVoid cases.
executeBatch(null);
cleanUpStatementAndConnection();
}

Expand Down Expand Up @@ -2201,8 +2124,11 @@ private void cleanUpStatementAndConnection() throws Exception {
}
}

private void executeBatch(ProcessContext context, Iterable<T> records)
private void executeBatch(ProcessContext context)
throws SQLException, IOException, InterruptedException {
if (records.isEmpty()) {
return;
}
Long startTimeNs = System.nanoTime();
Sleeper sleeper = Sleeper.DEFAULT;
BackOff backoff = retryBackOff.backoff();
Expand All @@ -2211,18 +2137,16 @@ private void executeBatch(ProcessContext context, Iterable<T> records)
getConnection().prepareStatement(spec.getStatement().get())) {
try {
// add each record in the statement batch
int recordsInBatch = 0;
for (T record : records) {
processRecord(record, preparedStatement, context);
recordsInBatch += 1;
}
if (!spec.getReturnResults()) {
// execute the batch
preparedStatement.executeBatch();
// commit the changes
getConnection().commit();
}
RECORDS_PER_BATCH.update(recordsInBatch);
RECORDS_PER_BATCH.update(records.size());
MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs));
break;
} catch (SQLException exception) {
Expand All @@ -2240,6 +2164,7 @@ private void executeBatch(ProcessContext context, Iterable<T> records)
}
}
}
records.clear();
}

private void processRecord(T record, PreparedStatement preparedStatement, ProcessContext c) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,13 @@
import java.util.UUID;
import java.util.function.Function;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.common.DatabaseTestHelper;
import org.apache.beam.sdk.io.common.HashingFn;
import org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.testutils.NamedTestResult;
import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
import org.apache.beam.sdk.testutils.metrics.MetricsReader;
Expand All @@ -55,7 +51,6 @@
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.joda.time.Instant;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
Expand Down Expand Up @@ -259,40 +254,6 @@ private PipelineResult runRead() {
return pipelineRead.run();
}

@Test
public void testWriteWithAutosharding() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
DatabaseTestHelper.createTable(dataSource, firstTableName);
try {
List<KV<Integer, String>> data = getTestDataToWrite(EXPECTED_ROW_COUNT);
TestStream.Builder<KV<Integer, String>> ts =
TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
.advanceWatermarkTo(Instant.now());
for (KV<Integer, String> elm : data) {
ts.addElements(elm);
}

PCollection<KV<Integer, String>> dataCollection =
pipelineWrite.apply(ts.advanceWatermarkToInfinity());
dataCollection.apply(
JdbcIO.<KV<Integer, String>>write()
.withDataSourceProviderFn(voidInput -> dataSource)
.withStatement(String.format("insert into %s values(?, ?) returning *", tableName))
.withAutoSharding()
.withPreparedStatementSetter(
(element, statement) -> {
statement.setInt(1, element.getKey());
statement.setString(2, element.getValue());
}));

pipelineWrite.run().waitUntilFinish();

runRead();
} finally {
DatabaseTestHelper.deleteTable(dataSource, firstTableName);
}
}

@Test
public void testWriteWithWriteResults() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
import org.apache.beam.sdk.testing.ExpectedLogs;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
Expand All @@ -90,7 +89,6 @@
import org.hamcrest.TypeSafeMatcher;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.joda.time.LocalDate;
import org.joda.time.chrono.ISOChronology;
import org.junit.BeforeClass;
Expand Down Expand Up @@ -530,31 +528,6 @@ public void testWrite() throws Exception {
}
}

@Test
public void testWriteWithAutosharding() throws Exception {
String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
DatabaseTestHelper.createTable(DATA_SOURCE, tableName);
TestStream.Builder<KV<Integer, String>> ts =
TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
.advanceWatermarkTo(Instant.now());

try {
List<KV<Integer, String>> data = getDataToWrite(EXPECTED_ROW_COUNT);
for (KV<Integer, String> elm : data) {
ts = ts.addElements(elm);
}
pipeline
.apply(ts.advanceWatermarkToInfinity())
.apply(getJdbcWrite(tableName).withAutoSharding());

pipeline.run().waitUntilFinish();

assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT);
} finally {
DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
}
}

@Test
public void testWriteWithWriteResults() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
Expand All @@ -575,9 +548,6 @@ public void testWriteWithWriteResults() throws Exception {
}));
resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER);

PAssert.thatSingleton(resultSetCollection.apply(Count.globally()))
.isEqualTo((long) EXPECTED_ROW_COUNT);

List<JdbcTestHelper.TestDto> expectedResult = new ArrayList<>();
for (int i = 0; i < EXPECTED_ROW_COUNT; i++) {
expectedResult.add(new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT));
Expand Down
Loading