diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
index 94ad154ff7b0..dae7ae674c7c 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
@@ -18,8 +18,6 @@
package org.apache.beam.sdk.io.gcp.spanner;
import static org.apache.beam.sdk.io.gcp.spanner.MutationUtils.isPointDelete;
-import static org.apache.beam.sdk.io.gcp.spanner.SpannerIO.WriteGrouped.decode;
-import static org.apache.beam.sdk.io.gcp.spanner.SpannerIO.WriteGrouped.encode;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
@@ -41,13 +39,13 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
@@ -73,7 +71,6 @@
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
@@ -181,17 +178,17 @@
*
*
SpannerWriteResult
*
- * The {@link SpannerWriteResult SpannerWriteResult} object contains the results of the transform,
- * including a {@link PCollection} of MutationGroups that failed to write, and a {@link PCollection}
- * that can be used in batch pipelines as a completion signal to {@link
+ * The {@link SpannerWriteResult SpannerWriteResult} object contains the results of the
+ * transform, including a {@link PCollection} of MutationGroups that failed to write, and a {@link
+ * PCollection} that can be used in batch pipelines as a completion signal to {@link
* org.apache.beam.sdk.transforms.Wait Wait.OnSignal} to indicate when all input has been written.
* Note that in streaming pipelines, this signal will never be triggered as the input is unbounded
* and this {@link PCollection} is using the {@link GlobalWindow}.
*
- *
Batching
+ * Batching and Grouping
*
* To reduce the number of transactions sent to Spanner, the {@link Mutation Mutations} are
- * grouped into batches The default maximum size of the batch is set to 1MB or 5000 mutated cells,
+ * grouped into batches. The default maximum size of the batch is set to 1MB or 5000 mutated cells,
* or 500 rows (whichever is reached first). To override this use {@link
* Write#withBatchSizeBytes(long) withBatchSizeBytes()}, {@link Write#withMaxNumMutations(long)
* withMaxNumMutations()} or {@link Write#withMaxNumMutations(long) withMaxNumRows()}. Setting
@@ -205,16 +202,42 @@
* MaxNumMutations}.
*
*
The batches written are obtained from by grouping enough {@link Mutation Mutations} from the
- * Bundle provided by Beam to form (by default) 1000 batches. This group of {@link Mutation
- * Mutations} is then sorted by table and primary key, and the batches are created from the sorted
- * group. Each batch will then have rows with keys that are 'close' to each other to optimise write
- * performance. This grouping factor (number of batches) is controlled by the parameter {@link
+ * Bundle provided by Beam to form several batches. This group of {@link Mutation Mutations} is then
+ * sorted by table and primary key, and the batches are created from the sorted group. Each batch
+ * will then have rows for the same table, with keys that are 'close' to each other, thus optimising
+ * write efficiency by each batch affecting as few table splits as possible performance.
+ *
+ *
This grouping factor (number of batches) is controlled by the parameter {@link
* Write#withGroupingFactor(int) withGroupingFactor()}.
*
*
Note that each worker will need enough memory to hold {@code GroupingFactor x
* MaxBatchSizeBytes} Mutations, so if you have a large {@code MaxBatchSize} you may need to reduce
* {@code GroupingFactor}
*
+ *
While Grouping and Batching increases write efficiency, it dramatically increases the latency
+ * between when a Mutation is received by the transform, and when it is actually written to the
+ * database. This is because enough Mutations need to be received to fill the grouped batches. In
+ * Batch pipelines (bounded sources), this is not normally an issue, but in Streaming (unbounded)
+ * pipelines, this latency is often seen as unacceptable.
+ *
+ *
There are therefore 3 different ways that this transform can be configured:
+ *
+ *
+ * - With Grouping and Batching.
+ * This is the default for Batch pipelines, where sorted batches of Mutations are created and
+ * written. This is the most efficient way to ingest large amounts of data, but the highest
+ * latency before writing
+ * - With Batching but no Grouping
+ * If {@link Write#withGroupingFactor(int) .withGroupingFactor(1)}, is set, grouping is
+ * disabled. This is the default for Streaming pipelines. Unsorted batches are created and
+ * written as soon as enough mutations to fill a batch are received. This reflects a
+ * compromise where a small amount of additional latency enables more efficient writes
+ * - Without any Batching
+ * If {@link Write#withBatchSizeBytes(long) .withBatchSizeBytes(0)} is set, no batching is
+ * performed and the Mutations are written to the database as soon as they are received.
+ * ensuring the lowest latency before Mutations are written.
+ *
+ *
* Monitoring
*
* Several counters are provided for monitoring purpooses:
@@ -983,20 +1006,6 @@ private void populateDisplayDataWithParamaters(DisplayData.Builder builder) {
}
}
- /**
- * A singleton to compare encoded MutationGroups by encoded Key that wraps {@code
- * UnsignedBytes#lexicographicalComparator} which unfortunately is not serializable.
- */
- private enum EncodedKvMutationGroupComparator
- implements Comparator>, Serializable {
- INSTANCE {
- @Override
- public int compare(KV a, KV b) {
- return UnsignedBytes.lexicographicalComparator().compare(a.getKey(), b.getKey());
- }
- }
- }
-
/** Same as {@link Write} but supports grouped mutations. */
public static class WriteGrouped
extends PTransform, SpannerWriteResult> {
@@ -1078,9 +1087,9 @@ public SpannerWriteResult expand(PCollection input) {
filteredMutations
.get(BATCHABLE_MUTATIONS_TAG)
.apply(
- "Gather And Sort",
+ "Gather Sort And Create Batches",
ParDo.of(
- new GatherBundleAndSortFn(
+ new GatherSortCreateBatchesFn(
spec.getBatchSizeBytes(),
spec.getMaxNumMutations(),
spec.getMaxNumRows(),
@@ -1091,15 +1100,6 @@ public SpannerWriteResult expand(PCollection input) {
? DEFAULT_GROUPING_FACTOR
: 1),
schemaView))
- .withSideInputs(schemaView))
- .apply(
- "Create Batches",
- ParDo.of(
- new BatchFn(
- spec.getBatchSizeBytes(),
- spec.getMaxNumMutations(),
- spec.getMaxNumRows(),
- schemaView))
.withSideInputs(schemaView));
// Merge the batched and unbatchable mutation PCollections and write to Spanner.
@@ -1163,70 +1163,125 @@ public void processElement(ProcessContext c) {
* occur, Therefore this DoFn has to be tested in isolation.
*/
@VisibleForTesting
- static class GatherBundleAndSortFn extends DoFn>> {
- private final long maxBatchSizeBytes;
- private final long maxNumMutations;
- private final long maxNumRows;
-
- // total size of the current batch.
- private long batchSizeBytes;
- // total number of mutated cells.
- private long batchCells;
- // total number of rows mutated.
- private long batchRows;
+ static class GatherSortCreateBatchesFn extends DoFn> {
+ private final long maxBatchSizeBytes;
+ private final long maxBatchNumMutations;
+ private final long maxBatchNumRows;
+ private final long maxSortableSizeBytes;
+ private final long maxSortableNumMutations;
+ private final long maxSortableNumRows;
private final PCollectionView schemaView;
+ private final ArrayList mutationsToSort = new ArrayList<>();
- private transient ArrayList> mutationsToSort = null;
+ // total size of MutationGroups in mutationsToSort.
+ private long sortableSizeBytes = 0;
+ // total number of mutated cells in mutationsToSort
+ private long sortableNumCells = 0;
+ // total number of rows mutated in mutationsToSort
+ private long sortableNumRows = 0;
- GatherBundleAndSortFn(
+ GatherSortCreateBatchesFn(
long maxBatchSizeBytes,
long maxNumMutations,
long maxNumRows,
long groupingFactor,
PCollectionView schemaView) {
- this.maxBatchSizeBytes = maxBatchSizeBytes * groupingFactor;
- this.maxNumMutations = maxNumMutations * groupingFactor;
- this.maxNumRows = maxNumRows * groupingFactor;
- this.schemaView = schemaView;
- }
+ this.maxBatchSizeBytes = maxBatchSizeBytes;
+ this.maxBatchNumMutations = maxNumMutations;
+ this.maxBatchNumRows = maxNumRows;
- @StartBundle
- public synchronized void startBundle() throws Exception {
- if (mutationsToSort == null) {
- initSorter();
- } else {
- throw new IllegalStateException("Sorter should be null here");
+ if (groupingFactor <= 0) {
+ groupingFactor = 1;
}
+
+ this.maxSortableSizeBytes = maxBatchSizeBytes * groupingFactor;
+ this.maxSortableNumMutations = maxNumMutations * groupingFactor;
+ this.maxSortableNumRows = maxNumRows * groupingFactor;
+ this.schemaView = schemaView;
+
+ initSorter();
}
- private void initSorter() {
- mutationsToSort = new ArrayList>((int) maxNumMutations);
- batchSizeBytes = 0;
- batchCells = 0;
- batchRows = 0;
+ private synchronized void initSorter() {
+ mutationsToSort.clear();
+ sortableSizeBytes = 0;
+ sortableNumCells = 0;
+ sortableNumRows = 0;
}
@FinishBundle
public synchronized void finishBundle(FinishBundleContext c) throws Exception {
- // Only output when there is something in the batch.
- if (mutationsToSort.isEmpty()) {
- mutationsToSort = null;
- } else {
- c.output(sortAndGetList(), Instant.now(), GlobalWindow.INSTANCE);
+ sortAndOutputBatches(new OutputReceiverForFinishBundle(c));
+ }
+
+ private synchronized void sortAndOutputBatches(OutputReceiver> out)
+ throws IOException {
+ try {
+ if (mutationsToSort.isEmpty()) {
+ // nothing to output.
+ return;
+ }
+
+ if (maxSortableNumMutations == maxBatchNumMutations) {
+ // no grouping is occurring, no need to sort and make batches, just output what we have.
+ outputBatch(out, 0, mutationsToSort.size());
+ return;
+ }
+
+ // Sort then split the sorted mutations into batches.
+ mutationsToSort.sort(Comparator.naturalOrder());
+ int batchStart = 0;
+ int batchEnd = 0;
+
+ // total size of the current batch.
+ long batchSizeBytes = 0;
+ // total number of mutated cells.
+ long batchCells = 0;
+ // total number of rows mutated.
+ long batchRows = 0;
+
+ // collect and output batches.
+ while (batchEnd < mutationsToSort.size()) {
+ MutationGroupContainer mg = mutationsToSort.get(batchEnd);
+
+ if (((batchCells + mg.numCells) > maxBatchNumMutations)
+ || ((batchSizeBytes + mg.sizeBytes) > maxBatchSizeBytes
+ || (batchRows + mg.numRows > maxBatchNumRows))) {
+ // Cannot add new element, current batch is full; output.
+ outputBatch(out, batchStart, batchEnd);
+ batchStart = batchEnd;
+ batchSizeBytes = 0;
+ batchCells = 0;
+ batchRows = 0;
+ }
+
+ batchEnd++;
+ batchSizeBytes += mg.sizeBytes;
+ batchCells += mg.numCells;
+ batchRows += mg.numRows;
+ }
+
+ if (batchStart < batchEnd) {
+ // output remaining elements
+ outputBatch(out, batchStart, mutationsToSort.size());
+ }
+ } finally {
+ initSorter();
}
}
- private Iterable> sortAndGetList() throws IOException {
- mutationsToSort.sort(EncodedKvMutationGroupComparator.INSTANCE);
- ArrayList> tmp = mutationsToSort;
- // Ensure no more mutations can be added.
- mutationsToSort = null;
- return tmp;
+ private void outputBatch(
+ OutputReceiver> out, int batchStart, int batchEnd) {
+ out.output(
+ mutationsToSort.subList(batchStart, batchEnd).stream()
+ .map(o -> o.mutationGroup)
+ .collect(Collectors.toList()));
}
@ProcessElement
- public void processElement(ProcessContext c) throws Exception {
+ public synchronized void processElement(
+ ProcessContext c, OutputReceiver> out) throws Exception {
SpannerSchema spannerSchema = c.sideInput(schemaView);
MutationKeyEncoder encoder = new MutationKeyEncoder(spannerSchema);
MutationGroup mg = c.element();
@@ -1235,79 +1290,69 @@ public void processElement(ProcessContext c) throws Exception {
long groupRows = mg.size();
synchronized (this) {
- if (((batchCells + groupCells) > maxNumMutations)
- || (batchSizeBytes + groupSize) > maxBatchSizeBytes
- || (batchRows + groupRows) > maxNumRows) {
- c.output(sortAndGetList());
- initSorter();
+ if (((sortableNumCells + groupCells) > maxSortableNumMutations)
+ || (sortableSizeBytes + groupSize) > maxSortableSizeBytes
+ || (sortableNumRows + groupRows) > maxSortableNumRows) {
+ sortAndOutputBatches(out);
}
- mutationsToSort.add(KV.of(encoder.encodeTableNameAndKey(mg.primary()), encode(mg)));
- batchSizeBytes += groupSize;
- batchCells += groupCells;
- batchRows += groupRows;
+ mutationsToSort.add(
+ new MutationGroupContainer(
+ mg, groupSize, groupCells, groupRows, encoder.encodeTableNameAndKey(mg.primary())));
+ sortableSizeBytes += groupSize;
+ sortableNumCells += groupCells;
+ sortableNumRows += groupRows;
}
}
- }
-
- /** Batches mutations together. */
- @VisibleForTesting
- static class BatchFn extends DoFn>, Iterable> {
- private final long maxBatchSizeBytes;
- private final long maxNumMutations;
- private final long maxNumRows;
- private final PCollectionView schemaView;
+ // Container class to store a MutationGroup, its sortable encoded key and its statistics.
+ private static final class MutationGroupContainer
+ implements Comparable {
+
+ final MutationGroup mutationGroup;
+ final long sizeBytes;
+ final long numCells;
+ final long numRows;
+ final byte[] encodedKey;
+
+ MutationGroupContainer(
+ MutationGroup mutationGroup,
+ long sizeBytes,
+ long numCells,
+ long numRows,
+ byte[] encodedKey) {
+ this.mutationGroup = mutationGroup;
+ this.sizeBytes = sizeBytes;
+ this.numCells = numCells;
+ this.numRows = numRows;
+ this.encodedKey = encodedKey;
+ }
- BatchFn(
- long maxBatchSizeBytes,
- long maxNumMutations,
- long maxNumRows,
- PCollectionView schemaView) {
- this.maxBatchSizeBytes = maxBatchSizeBytes;
- this.maxNumMutations = maxNumMutations;
- this.maxNumRows = maxNumRows;
- this.schemaView = schemaView;
+ @Override
+ public int compareTo(MutationGroupContainer o) {
+ return UnsignedBytes.lexicographicalComparator().compare(this.encodedKey, o.encodedKey);
+ }
}
- @ProcessElement
- public void processElement(ProcessContext c) throws Exception {
- SpannerSchema spannerSchema = c.sideInput(schemaView);
- // Current batch of mutations to be written.
- ImmutableList.Builder batch = ImmutableList.builder();
- // total size of the current batch.
- long batchSizeBytes = 0;
- // total number of mutated cells.
- long batchCells = 0;
- // total number of rows mutated.
- long batchRows = 0;
-
- // Iterate through list, outputting whenever a batch is complete.
- for (KV kv : c.element()) {
- MutationGroup mg = decode(kv.getValue());
-
- long groupSize = MutationSizeEstimator.sizeOf(mg);
- long groupCells = MutationCellCounter.countOf(spannerSchema, mg);
- long groupRows = mg.size();
-
- if (((batchCells + groupCells) > maxNumMutations)
- || ((batchSizeBytes + groupSize) > maxBatchSizeBytes
- || (batchRows + groupRows > maxNumRows))) {
- // Batch is full: output and reset.
- c.output(batch.build());
- batch = ImmutableList.builder();
- batchSizeBytes = 0;
- batchCells = 0;
- batchRows = 0;
- }
- batch.add(mg);
- batchSizeBytes += groupSize;
- batchCells += groupCells;
- batchRows += groupRows;
+ // TODO(BEAM-1287): Remove this when FinishBundle has added support for an {@link
+ // OutputReceiver}
+ private static class OutputReceiverForFinishBundle
+ implements OutputReceiver> {
+
+ private final FinishBundleContext c;
+
+ OutputReceiverForFinishBundle(FinishBundleContext c) {
+ this.c = c;
}
- // End of list, output what is left.
- if (batchCells > 0) {
- c.output(batch.build());
+
+ @Override
+ public void output(Iterable output) {
+ outputWithTimestamp(output, Instant.now());
+ }
+
+ @Override
+ public void outputWithTimestamp(Iterable output, Instant timestamp) {
+ c.output(output, timestamp, GlobalWindow.INSTANCE);
}
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
index 816556cf1bc3..3c9847bd5b29 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
@@ -49,24 +49,21 @@
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
-import java.util.stream.Collectors;
import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
import org.apache.beam.sdk.coders.SerializableCoder;
-import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.BatchFn;
import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.BatchableMutationFilterFn;
import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.FailureMode;
-import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.GatherBundleAndSortFn;
-import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.WriteGrouped;
+import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.GatherSortCreateBatchesFn;
import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.WriteToSpannerFn;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
+import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.Sleeper;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
@@ -100,7 +97,6 @@ public class SpannerIOWriteTest implements Serializable {
@Captor public transient ArgumentCaptor> mutationBatchesCaptor;
@Captor public transient ArgumentCaptor> mutationGroupListCaptor;
@Captor public transient ArgumentCaptor mutationGroupCaptor;
- @Captor public transient ArgumentCaptor>> byteArrayKvListCaptor;
private FakeServiceFactory serviceFactory;
@@ -912,147 +908,150 @@ public void testBatchableMutationFilterFn_batchingDisabled() {
}
@Test
- public void testGatherBundleAndSortFn() throws Exception {
- GatherBundleAndSortFn testFn = new GatherBundleAndSortFn(10000000, 10, 1000, 100, null);
+ public void testGatherSortAndBatchFn() throws Exception {
+
+ GatherSortCreateBatchesFn testFn =
+ new GatherSortCreateBatchesFn(
+ 10000000, // batch bytes
+ 100, // batch up to 35 mutated cells.
+ 5, // batch rows
+ 100, // groupingFactor
+ null);
ProcessContext mockProcessContext = Mockito.mock(ProcessContext.class);
FinishBundleContext mockFinishBundleContext = Mockito.mock(FinishBundleContext.class);
when(mockProcessContext.sideInput(any())).thenReturn(getSchema());
// Capture the outputs.
- doNothing().when(mockProcessContext).output(byteArrayKvListCaptor.capture());
- // Capture the outputs.
- doNothing().when(mockFinishBundleContext).output(byteArrayKvListCaptor.capture(), any(), any());
+ doNothing()
+ .when(mockFinishBundleContext)
+ .output(mutationGroupListCaptor.capture(), any(), any());
MutationGroup[] mutationGroups =
new MutationGroup[] {
- g(m(4L)), g(m(1L)), g(m(5L), m(6L), m(7L), m(8L), m(9L)), g(del(2L)), g(m(3L))
+ // Unsorted group of 12 mutations.
+ // each mutation is considered 7 cells,
+ // should be sorted and output as 2 lists of 5, then 1 list of 2
+ // with mutations sorted in order.
+ g(m(4L)),
+ g(m(1L)),
+ g(m(7L)),
+ g(m(12L)),
+ g(m(10L)),
+ g(m(11L)),
+ g(m(2L)),
+ g(del(8L)),
+ g(m(3L)),
+ g(m(6L)),
+ g(m(9L)),
+ g(m(5L))
};
// Process all elements as one bundle.
- testFn.startBundle();
for (MutationGroup m : mutationGroups) {
when(mockProcessContext.element()).thenReturn(m);
- testFn.processElement(mockProcessContext);
+ // outputReceiver should not be called until end of bundle.
+ testFn.processElement(mockProcessContext, null);
}
testFn.finishBundle(mockFinishBundleContext);
verify(mockProcessContext, never()).output(any());
- verify(mockFinishBundleContext, times(1)).output(any(), any(), any());
+ verify(mockFinishBundleContext, times(3)).output(any(), any(), any());
- // Verify sorted output... first decode it...
- List sorted =
- byteArrayKvListCaptor.getValue().stream()
- .map(kv -> WriteGrouped.decode(kv.getValue()))
- .collect(Collectors.toList());
+ // Verify output are 3 batches of sorted values
assertThat(
- sorted,
- contains(g(m(1L)), g(del(2L)), g(m(3L)), g(m(4L)), g(m(5L), m(6L), m(7L), m(8L), m(9L))));
+ mutationGroupListCaptor.getAllValues(),
+ contains(
+ Arrays.asList(g(m(1L)), g(m(2L)), g(m(3L)), g(m(4L)), g(m(5L))),
+ Arrays.asList(g(m(6L)), g(m(7L)), g(del(8L)), g(m(9L)), g(m(10L))),
+ Arrays.asList(g(m(11L)), g(m(12L)))));
}
@Test
public void testGatherBundleAndSortFn_flushOversizedBundle() throws Exception {
- // Setup class to bundle every 3 mutations
- GatherBundleAndSortFn testFn =
- new GatherBundleAndSortFn(10000000, CELLS_PER_KEY, 1000, 3, null);
+ // Setup class to bundle every 6 rows and create batches of 2.
+ GatherSortCreateBatchesFn testFn =
+ new GatherSortCreateBatchesFn(
+ 10000000, // batch bytes
+ 100, // batch up to 14 mutated cells.
+ 2, // batch rows
+ 3, // groupingFactor
+ null);
ProcessContext mockProcessContext = Mockito.mock(ProcessContext.class);
FinishBundleContext mockFinishBundleContext = Mockito.mock(FinishBundleContext.class);
when(mockProcessContext.sideInput(any())).thenReturn(getSchema());
+ OutputReceiver> mockOutputReceiver = mock(OutputReceiver.class);
// Capture the outputs.
- doNothing().when(mockProcessContext).output(byteArrayKvListCaptor.capture());
+ doNothing().when(mockOutputReceiver).output(mutationGroupListCaptor.capture());
// Capture the outputs.
- doNothing().when(mockFinishBundleContext).output(byteArrayKvListCaptor.capture(), any(), any());
+ doNothing()
+ .when(mockFinishBundleContext)
+ .output(mutationGroupListCaptor.capture(), any(), any());
MutationGroup[] mutationGroups =
new MutationGroup[] {
+ // Unsorted group of 12 mutations.
+ // each mutation is considered 7 cells,
+ // should be sorted and output as 2 lists of 5, then 1 list of 2
+ // with mutations sorted in order.
g(m(4L)),
g(m(1L)),
- // end group
- g(m(5L), m(6L), m(7L), m(8L), m(9L)),
- // end group
+ g(m(7L)),
+ g(m(9L)),
g(m(10L)),
- g(m(3L)),
g(m(11L)),
- // end group.
- g(m(2L))
+ // end group
+ g(m(2L)),
+ g(del(8L)), // end batch
+ g(m(3L)),
+ g(m(6L)), // end batch
+ g(m(5L))
+ // end bundle, so end group and end batch.
};
// Process all elements as one bundle.
- testFn.startBundle();
for (MutationGroup m : mutationGroups) {
when(mockProcessContext.element()).thenReturn(m);
- testFn.processElement(mockProcessContext);
+ testFn.processElement(mockProcessContext, mockOutputReceiver);
}
testFn.finishBundle(mockFinishBundleContext);
- verify(mockProcessContext, times(3)).output(any());
- verify(mockFinishBundleContext, times(1)).output(any(), any(), any());
+ // processElement ouput receiver should have been called 3 times when the 1st group was full.
+ verify(mockOutputReceiver, times(3)).output(any());
+ // finsihBundleContext output should be called 3 times when the bundle was finished.
+ verify(mockFinishBundleContext, times(3)).output(any(), any(), any());
- // verify sorted output... needs decoding...
- List>> kvGroups = byteArrayKvListCaptor.getAllValues();
- assertEquals(4, kvGroups.size());
+ List> mgListGroups = mutationGroupListCaptor.getAllValues();
- // decode list of lists of KV to a list of lists of MutationGroup.
- List> mgListGroups =
- kvGroups.stream()
- .map(
- l ->
- l.stream()
- .map(kv -> WriteGrouped.decode(kv.getValue()))
- .collect(Collectors.toList()))
- .collect(Collectors.toList());
-
- // verify contents of 4 sorted groups.
+ assertEquals(6, mgListGroups.size());
+ // verify contents of 6 sorted groups.
+ // first group should be 1,3,4,7,9,11
assertThat(mgListGroups.get(0), contains(g(m(1L)), g(m(4L))));
- assertThat(mgListGroups.get(1), contains(g(m(5L), m(6L), m(7L), m(8L), m(9L))));
- assertThat(mgListGroups.get(2), contains(g(m(3L)), g(m(10L)), g(m(11L))));
- assertThat(mgListGroups.get(3), contains(g(m(2L))));
+ assertThat(mgListGroups.get(1), contains(g(m(7L)), g(m(9L))));
+ assertThat(mgListGroups.get(2), contains(g(m(10L)), g(m(11L))));
+
+ // second group at finishBundle should be 2,3,5,6,8
+ assertThat(mgListGroups.get(3), contains(g(m(2L)), g(m(3L))));
+ assertThat(mgListGroups.get(4), contains(g(m(5L)), g(m(6L))));
+ assertThat(mgListGroups.get(5), contains(g(del(8L))));
}
@Test
public void testBatchFn_cells() throws Exception {
- // Setup class to bundle every 3 mutations (3xCELLS_PER_KEY cell mutations)
- BatchFn testFn = new BatchFn(10000000, 3 * CELLS_PER_KEY, 1000, null);
-
- ProcessContext mockProcessContext = Mockito.mock(ProcessContext.class);
- when(mockProcessContext.sideInput(any())).thenReturn(getSchema());
-
- // Capture the outputs.
- doNothing().when(mockProcessContext).output(mutationGroupListCaptor.capture());
-
- List mutationGroups =
- Arrays.asList(
- g(m(1L)),
- g(m(4L)),
- g(m(5L), m(6L), m(7L), m(8L), m(9L)),
- g(m(3L)),
- g(m(10L)),
- g(m(11L)),
- g(m(2L)));
-
- List> encodedInput =
- mutationGroups.stream()
- .map(mg -> KV.of((byte[]) null, WriteGrouped.encode(mg)))
- .collect(Collectors.toList());
-
- // Process elements.
- when(mockProcessContext.element()).thenReturn(encodedInput);
- testFn.processElement(mockProcessContext);
-
- verify(mockProcessContext, times(4)).output(any());
-
- List> batches = mutationGroupListCaptor.getAllValues();
- assertEquals(4, batches.size());
+ // Setup class to batch every 3 mutations (3xCELLS_PER_KEY cell mutations)
+ GatherSortCreateBatchesFn testFn =
+ new GatherSortCreateBatchesFn(
+ 10000000, // batch bytes
+ 3 * CELLS_PER_KEY, // batch up to 21 mutated cells - 3 mutations.
+ 100, // batch rows
+ 100, // groupingFactor
+ null);
- // verify contents of 4 batches.
- assertThat(batches.get(0), contains(g(m(1L)), g(m(4L))));
- assertThat(batches.get(1), contains(g(m(5L), m(6L), m(7L), m(8L), m(9L))));
- assertThat(batches.get(2), contains(g(m(3L)), g(m(10L)), g(m(11L))));
- assertThat(batches.get(3), contains(g(m(2L))));
+ testAndVerifyBatches(testFn);
}
@Test
@@ -1061,56 +1060,41 @@ public void testBatchFn_size() throws Exception {
long mutationSize = MutationSizeEstimator.sizeOf(m(1L));
// Setup class to bundle every 3 mutations by size)
- BatchFn testFn = new BatchFn(mutationSize * 3, 1000, 1000, null);
-
- ProcessContext mockProcessContext = Mockito.mock(ProcessContext.class);
- when(mockProcessContext.sideInput(any())).thenReturn(getSchema());
-
- // Capture the outputs.
- doNothing().when(mockProcessContext).output(mutationGroupListCaptor.capture());
-
- List mutationGroups =
- Arrays.asList(
- g(m(1L)),
- g(m(4L)),
- g(m(5L), m(6L), m(7L), m(8L), m(9L)),
- g(m(3L)),
- g(m(10L)),
- g(m(11L)),
- g(m(2L)));
-
- List> encodedInput =
- mutationGroups.stream()
- .map(mg -> KV.of((byte[]) null, WriteGrouped.encode(mg)))
- .collect(Collectors.toList());
-
- // Process elements.
- when(mockProcessContext.element()).thenReturn(encodedInput);
- testFn.processElement(mockProcessContext);
-
- verify(mockProcessContext, times(4)).output(any());
-
- List> batches = mutationGroupListCaptor.getAllValues();
- assertEquals(4, batches.size());
-
- // verify contents of 4 batches.
- assertThat(batches.get(0), contains(g(m(1L)), g(m(4L))));
- assertThat(batches.get(1), contains(g(m(5L), m(6L), m(7L), m(8L), m(9L))));
- assertThat(batches.get(2), contains(g(m(3L)), g(m(10L)), g(m(11L))));
- assertThat(batches.get(3), contains(g(m(2L))));
+ GatherSortCreateBatchesFn testFn =
+ new GatherSortCreateBatchesFn(
+ mutationSize * 3, // batch bytes = 3 mutations.
+ 100, // batch cells
+ 100, // batch rows
+ 100, // groupingFactor
+ null);
+
+ testAndVerifyBatches(testFn);
}
@Test
public void testBatchFn_rows() throws Exception {
- // Setup class to bundle every 3 mutations (3xCELLS_PER_KEY cell mutations)
- BatchFn testFn = new BatchFn(10000000, 1000, 3, null);
+ // Setup class to bundle every 3 rows
+ GatherSortCreateBatchesFn testFn =
+ new GatherSortCreateBatchesFn(
+ 10000, // batch bytes = 3 mutations.
+ 100, // batch cells
+ 3, // batch rows
+ 100, // groupingFactor
+ null);
+ testAndVerifyBatches(testFn);
+ }
+
+ private void testAndVerifyBatches(GatherSortCreateBatchesFn testFn) throws Exception {
ProcessContext mockProcessContext = Mockito.mock(ProcessContext.class);
+ FinishBundleContext mockFinishBundleContext = Mockito.mock(FinishBundleContext.class);
when(mockProcessContext.sideInput(any())).thenReturn(getSchema());
- // Capture the outputs.
- doNothing().when(mockProcessContext).output(mutationGroupListCaptor.capture());
+ // Capture the output at finish bundle..
+ doNothing()
+ .when(mockFinishBundleContext)
+ .output(mutationGroupListCaptor.capture(), any(), any());
List mutationGroups =
Arrays.asList(
@@ -1122,25 +1106,23 @@ public void testBatchFn_rows() throws Exception {
g(m(11L)),
g(m(2L)));
- List> encodedInput =
- mutationGroups.stream()
- .map(mg -> KV.of((byte[]) null, WriteGrouped.encode(mg)))
- .collect(Collectors.toList());
-
// Process elements.
- when(mockProcessContext.element()).thenReturn(encodedInput);
- testFn.processElement(mockProcessContext);
+ for (MutationGroup m : mutationGroups) {
+ when(mockProcessContext.element()).thenReturn(m);
+ testFn.processElement(mockProcessContext, null);
+ }
+ testFn.finishBundle(mockFinishBundleContext);
- verify(mockProcessContext, times(4)).output(any());
+ verify(mockFinishBundleContext, times(4)).output(any(), any(), any());
List> batches = mutationGroupListCaptor.getAllValues();
assertEquals(4, batches.size());
// verify contents of 4 batches.
- assertThat(batches.get(0), contains(g(m(1L)), g(m(4L))));
- assertThat(batches.get(1), contains(g(m(5L), m(6L), m(7L), m(8L), m(9L))));
- assertThat(batches.get(2), contains(g(m(3L)), g(m(10L)), g(m(11L))));
- assertThat(batches.get(3), contains(g(m(2L))));
+ assertThat(batches.get(0), contains(g(m(1L)), g(m(2L)), g(m(3L))));
+ assertThat(batches.get(1), contains(g(m(4L)))); // small batch : next mutation group is too big.
+ assertThat(batches.get(2), contains(g(m(5L), m(6L), m(7L), m(8L), m(9L))));
+ assertThat(batches.get(3), contains(g(m(10L)), g(m(11L))));
}
private static MutationGroup g(Mutation m, Mutation... other) {