diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 69f438055da9..07f227ad250a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -46,6 +46,7 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.OptionalInt; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; @@ -64,6 +65,7 @@ import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.Wait; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; @@ -73,6 +75,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; @@ -352,7 +355,6 @@ public static Write write() { .setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES) .setMaxNumMutations(DEFAULT_MAX_NUM_MUTATIONS) .setMaxNumRows(DEFAULT_MAX_NUM_ROWS) - .setGroupingFactor(DEFAULT_GROUPING_FACTOR) .setFailureMode(FailureMode.FAIL_FAST) .build(); } @@ -783,7 +785,7 @@ public abstract static class Write extends PTransform, Spa @Nullable abstract PCollection getSchemaReadySignal(); - abstract int getGroupingFactor(); + abstract OptionalInt getGroupingFactor(); abstract Builder toBuilder(); @@ -967,8 +969,14 @@ private void populateDisplayDataWithParamaters(DisplayData.Builder builder) { builder.add( DisplayData.item("maxNumRows", getMaxNumRows()) .withLabel("Max number of rows in each batch")); + // Grouping factor default value depends on whether it is a batch or streaming pipeline. + // This function is not aware of that state, so use 'DEFAULT' if unset. builder.add( - DisplayData.item("groupingFactor", getGroupingFactor()) + DisplayData.item( + "groupingFactor", + (getGroupingFactor().isPresent() + ? Integer.toString(getGroupingFactor().getAsInt()) + : "DEFAULT")) .withLabel("Number of batches to sort over")); } } @@ -1033,7 +1041,11 @@ public SpannerWriteResult expand(PCollection input) { // Filter out mutation groups too big to be batched. PCollectionTuple filteredMutations = input - .apply("To Global Window", Window.into(new GlobalWindows())) + .apply( + "RewindowIntoGlobal", + Window.into(new GlobalWindows()) + .triggering(DefaultTrigger.of()) + .discardingFiredPanes()) .apply( "Filter Unbatchable Mutations", ParDo.of( @@ -1059,7 +1071,12 @@ public SpannerWriteResult expand(PCollection input) { spec.getBatchSizeBytes(), spec.getMaxNumMutations(), spec.getMaxNumRows(), - spec.getGroupingFactor(), + // Do not group on streaming unless explicitly set. + spec.getGroupingFactor() + .orElse( + input.isBounded() == IsBounded.BOUNDED + ? DEFAULT_GROUPING_FACTOR + : 1), schemaView)) .withSideInputs(schemaView)) .apply( diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java index aaca469105c4..ffcbaefff9a5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java @@ -300,6 +300,54 @@ public void streamingWrites() throws Exception { verifyBatches(batch(m(1L), m(2L)), batch(m(3L), m(4L)), batch(m(5L), m(6L))); } + @Test + public void streamingWritesWithGrouping() throws Exception { + + // verify that grouping/sorting occurs when set. + TestStream testStream = + TestStream.create(SerializableCoder.of(Mutation.class)) + .addElements(m(1L), m(5L), m(2L), m(4L), m(3L), m(6L)) + .advanceWatermarkToInfinity(); + pipeline + .apply(testStream) + .apply( + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory) + .withGroupingFactor(40) + .withMaxNumRows(2)); + pipeline.run(); + + // Output should be batches of sorted mutations. + verifyBatches(batch(m(1L), m(2L)), batch(m(3L), m(4L)), batch(m(5L), m(6L))); + } + + @Test + public void streamingWritesNoGrouping() throws Exception { + + // verify that grouping/sorting does not occur - batches should be created in received order. + TestStream testStream = + TestStream.create(SerializableCoder.of(Mutation.class)) + .addElements(m(1L), m(5L), m(2L), m(4L), m(3L), m(6L)) + .advanceWatermarkToInfinity(); + + // verify that grouping/sorting does not occur when notset. + pipeline + .apply(testStream) + .apply( + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory) + .withMaxNumRows(2)); + pipeline.run(); + + verifyBatches(batch(m(1L), m(5L)), batch(m(2L), m(4L)), batch(m(3L), m(6L))); + } + @Test public void reportFailures() throws Exception { @@ -608,7 +656,18 @@ public void displayDataWrite() throws Exception { assertThat(data, hasDisplayItem("batchSizeBytes", 123)); assertThat(data, hasDisplayItem("maxNumMutations", 456)); assertThat(data, hasDisplayItem("maxNumRows", 789)); - assertThat(data, hasDisplayItem("groupingFactor", 100)); + assertThat(data, hasDisplayItem("groupingFactor", "100")); + + // check for default grouping value + write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database"); + + data = DisplayData.from(write); + assertThat(data.items(), hasSize(7)); + assertThat(data, hasDisplayItem("groupingFactor", "DEFAULT")); } @Test @@ -632,7 +691,19 @@ public void displayDataWriteGrouped() throws Exception { assertThat(data, hasDisplayItem("batchSizeBytes", 123)); assertThat(data, hasDisplayItem("maxNumMutations", 456)); assertThat(data, hasDisplayItem("maxNumRows", 789)); - assertThat(data, hasDisplayItem("groupingFactor", 100)); + assertThat(data, hasDisplayItem("groupingFactor", "100")); + + // check for default grouping value + writeGrouped = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .grouped(); + + data = DisplayData.from(writeGrouped); + assertThat(data.items(), hasSize(7)); + assertThat(data, hasDisplayItem("groupingFactor", "DEFAULT")); } @Test