diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java index d0863a46f5..16752a9641 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java @@ -55,6 +55,7 @@ public class TestCountingSource private final int shardNumber; private final boolean dedup; private final boolean throwOnFirstSnapshot; + private final boolean allowSplitting; /** * We only allow an exception to be thrown from getCheckpointMark @@ -68,27 +69,36 @@ public static void setFinalizeTracker(List finalizeTracker) { } public TestCountingSource(int numMessagesPerShard) { - this(numMessagesPerShard, 0, false, false); + this(numMessagesPerShard, 0, false, false, true); } public TestCountingSource withDedup() { - return new TestCountingSource(numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot, true); } private TestCountingSource withShardNumber(int shardNumber) { - return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true); } public TestCountingSource withThrowOnFirstSnapshot(boolean throwOnFirstSnapshot) { - return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true); } - private TestCountingSource( - int numMessagesPerShard, int shardNumber, boolean dedup, boolean throwOnFirstSnapshot) { + public TestCountingSource withoutSplitting() { + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, false); + } + + private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup, + boolean throwOnFirstSnapshot, boolean allowSplitting) { this.numMessagesPerShard = numMessagesPerShard; this.shardNumber = shardNumber; this.dedup = dedup; this.throwOnFirstSnapshot = throwOnFirstSnapshot; + this.allowSplitting = allowSplitting; } public int getShardNumber() { @@ -99,7 +109,8 @@ public int getShardNumber() { public List generateInitialSplits( int desiredNumSplits, PipelineOptions options) { List splits = new ArrayList<>(); - for (int i = 0; i < desiredNumSplits; i++) { + int numSplits = allowSplitting ? desiredNumSplits : 1; + for (int i = 0; i < numSplits; i++) { splits.add(withShardNumber(i)); } return splits; @@ -143,7 +154,11 @@ public boolean requiresDeduping() { return dedup; } - private class CountingSourceReader extends UnboundedReader> { + /** + * Public only so that the checkpoint can be conveyed from {@link #getCheckpointMark()} to + * {@link TestCountingSource#createReader(PipelineOptions, CounterMark)} without cast. + */ + public class CountingSourceReader extends UnboundedReader> { private int current; public CountingSourceReader(int startingPoint) { @@ -152,21 +167,20 @@ public CountingSourceReader(int startingPoint) { @Override public boolean start() { - return true; + return advance(); } @Override public boolean advance() { - if (current < numMessagesPerShard - 1) { - // If testing dedup, occasionally insert a duplicate value; - if (dedup && ThreadLocalRandom.current().nextInt(5) == 0) { - return true; - } - current++; - return true; - } else { + if (current >= numMessagesPerShard - 1) { return false; } + // If testing dedup, occasionally insert a duplicate value; + if (current >= 0 && dedup && ThreadLocalRandom.current().nextInt(5) == 0) { + return true; + } + current++; + return true; } @Override @@ -204,12 +218,14 @@ public Instant getWatermark() { } @Override - public CheckpointMark getCheckpointMark() { + public CounterMark getCheckpointMark() { if (throwOnFirstSnapshot && !thrown) { thrown = true; LOG.error("Throwing exception while checkpointing counter"); throw new RuntimeException("failed during checkpoint"); } + // The checkpoint can assume all records read, including the current, have + // been commited. return new CounterMark(current); } @@ -222,7 +238,12 @@ public long getSplitBacklogBytes() { @Override public CountingSourceReader createReader( PipelineOptions options, @Nullable CounterMark checkpointMark) { - return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : 0); + if (checkpointMark == null) { + LOG.debug("creating reader"); + } else { + LOG.debug("restoring reader from checkpoint with current = {}", checkpointMark.current); + } + return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : -1); } @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java new file mode 100644 index 0000000000..b03398f479 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +/** + * Test the TestCountingSource. + */ +@RunWith(JUnit4.class) +public class TestCountingSourceTest { + @Test + public void testRespectsCheckpointContract() throws IOException { + TestCountingSource source = new TestCountingSource(3); + PipelineOptions options = PipelineOptionsFactory.create(); + TestCountingSource.CountingSourceReader reader = + source.createReader(options, null /* no checkpoint */); + assertTrue(reader.start()); + assertEquals(0L, (long) reader.getCurrent().getValue()); + assertTrue(reader.advance()); + assertEquals(1L, (long) reader.getCurrent().getValue()); + TestCountingSource.CounterMark checkpoint = reader.getCheckpointMark(); + checkpoint.finalizeCheckpoint(); + reader = source.createReader(options, checkpoint); + assertTrue(reader.start()); + assertEquals(2L, (long) reader.getCurrent().getValue()); + assertFalse(reader.advance()); + } + + @Test + public void testCanResumeWithExpandedCount() throws IOException { + TestCountingSource source = new TestCountingSource(1); + PipelineOptions options = PipelineOptionsFactory.create(); + TestCountingSource.CountingSourceReader reader = + source.createReader(options, null /* no checkpoint */); + assertTrue(reader.start()); + assertEquals(0L, (long) reader.getCurrent().getValue()); + assertFalse(reader.advance()); + TestCountingSource.CounterMark checkpoint = reader.getCheckpointMark(); + checkpoint.finalizeCheckpoint(); + source = new TestCountingSource(2); + reader = source.createReader(options, checkpoint); + assertTrue(reader.start()); + assertEquals(1L, (long) reader.getCurrent().getValue()); + assertFalse(reader.advance()); + } +}