From ee7272fe77d78240c6f8e0b79cdc93b460171b8c Mon Sep 17 00:00:00 2001 From: Daniel Mills Date: Tue, 5 Apr 2016 16:45:33 -0700 Subject: [PATCH] Make BoundedReadFromUnboundedSourceTest work across runners --- .../BoundedReadFromUnboundedSourceTest.java | 4 +-- .../runners/dataflow/TestCountingSource.java | 25 +++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java index 7cac67a20069..a98402d559ed 100644 --- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java +++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/BoundedReadFromUnboundedSourceTest.java @@ -95,7 +95,7 @@ public Void apply(Iterable> input) { Collections.sort(values); for (int i = 0; i < values.size(); i++) { assertEquals(i, (int) values.get(i)); - } + } if (finalizeTracker != null) { assertThat(finalizeTracker, containsInAnyOrder(values.size() - 1)); } @@ -110,7 +110,7 @@ private void test(boolean dedup, boolean timeBound) throws Exception { finalizeTracker = new ArrayList<>(); TestCountingSource.setFinalizeTracker(finalizeTracker); } - TestCountingSource source = new TestCountingSource(Integer.MAX_VALUE); + TestCountingSource source = new TestCountingSource(Integer.MAX_VALUE).withoutSplitting(); if (dedup) { source = source.withDedup(); } diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java index 778060f30fa2..207734a00f1e 100644 --- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java +++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java @@ -56,6 +56,7 @@ public class TestCountingSource private final int shardNumber; private final boolean dedup; private final boolean throwOnFirstSnapshot; + private final boolean allowSplitting; /** * We only allow an exception to be thrown from getCheckpointMark @@ -69,27 +70,36 @@ public static void setFinalizeTracker(List finalizeTracker) { } public TestCountingSource(int numMessagesPerShard) { - this(numMessagesPerShard, 0, false, false); + this(numMessagesPerShard, 0, false, false, true); } public TestCountingSource withDedup() { - return new TestCountingSource(numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot, true); } private TestCountingSource withShardNumber(int shardNumber) { - return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true); } public TestCountingSource withThrowOnFirstSnapshot(boolean throwOnFirstSnapshot) { - return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true); } - private TestCountingSource( - int numMessagesPerShard, int shardNumber, boolean dedup, boolean throwOnFirstSnapshot) { + public TestCountingSource withoutSplitting() { + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, false); + } + + private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup, + boolean throwOnFirstSnapshot, boolean allowSplitting) { this.numMessagesPerShard = numMessagesPerShard; this.shardNumber = shardNumber; this.dedup = dedup; this.throwOnFirstSnapshot = throwOnFirstSnapshot; + this.allowSplitting = allowSplitting; } public int getShardNumber() { @@ -100,7 +110,8 @@ public int getShardNumber() { public List generateInitialSplits( int desiredNumSplits, PipelineOptions options) { List splits = new ArrayList<>(); - for (int i = 0; i < desiredNumSplits; i++) { + int numSplits = allowSplitting ? desiredNumSplits : 1; + for (int i = 0; i < numSplits; i++) { splits.add(withShardNumber(i)); } return splits;