From c7fc09691557e19811fa96164766e6c53a4eb8ae Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 25 Apr 2016 14:26:12 -0700 Subject: [PATCH 1/2] Backport BEAM pull/235 --- .../runners/dataflow/TestCountingSource.java | 59 +++++++++++++------ .../dataflow/TestCountingSourceTest.java | 54 +++++++++++++++++ 2 files changed, 94 insertions(+), 19 deletions(-) create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java index d0863a46f5..612b117827 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java @@ -55,6 +55,7 @@ public class TestCountingSource private final int shardNumber; private final boolean dedup; private final boolean throwOnFirstSnapshot; + private final boolean allowSplitting; /** * We only allow an exception to be thrown from getCheckpointMark @@ -68,27 +69,36 @@ public static void setFinalizeTracker(List finalizeTracker) { } public TestCountingSource(int numMessagesPerShard) { - this(numMessagesPerShard, 0, false, false); + this(numMessagesPerShard, 0, false, false, true); } public TestCountingSource withDedup() { - return new TestCountingSource(numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot, true); } private TestCountingSource withShardNumber(int shardNumber) { - return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true); } public TestCountingSource withThrowOnFirstSnapshot(boolean throwOnFirstSnapshot) { - return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot); + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true); } - private TestCountingSource( - int numMessagesPerShard, int shardNumber, boolean dedup, boolean throwOnFirstSnapshot) { + public TestCountingSource withoutSplitting() { + return new TestCountingSource( + numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, false); + } + + private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup, + boolean throwOnFirstSnapshot, boolean allowSplitting) { this.numMessagesPerShard = numMessagesPerShard; this.shardNumber = shardNumber; this.dedup = dedup; this.throwOnFirstSnapshot = throwOnFirstSnapshot; + this.allowSplitting = allowSplitting; } public int getShardNumber() { @@ -99,7 +109,8 @@ public int getShardNumber() { public List generateInitialSplits( int desiredNumSplits, PipelineOptions options) { List splits = new ArrayList<>(); - for (int i = 0; i < desiredNumSplits; i++) { + int numSplits = allowSplitting ? desiredNumSplits : 1; + for (int i = 0; i < numSplits; i++) { splits.add(withShardNumber(i)); } return splits; @@ -143,7 +154,11 @@ public boolean requiresDeduping() { return dedup; } - private class CountingSourceReader extends UnboundedReader> { + /** + * Public only so that the checkpoint can be conveyed from {@link #getCheckpointMark()} to + * {@link TestCountingSource#createReader(PipelineOptions, CounterMark)} without cast. + */ + public class CountingSourceReader extends UnboundedReader> { private int current; public CountingSourceReader(int startingPoint) { @@ -152,21 +167,20 @@ public CountingSourceReader(int startingPoint) { @Override public boolean start() { - return true; + return advance(); } @Override public boolean advance() { - if (current < numMessagesPerShard - 1) { - // If testing dedup, occasionally insert a duplicate value; - if (dedup && ThreadLocalRandom.current().nextInt(5) == 0) { - return true; - } - current++; - return true; - } else { + if (current >= numMessagesPerShard) { return false; } + // If testing dedup, occasionally insert a duplicate value; + if (current >= 0 && dedup && ThreadLocalRandom.current().nextInt(5) == 0) { + return true; + } + current++; + return current < numMessagesPerShard; } @Override @@ -204,12 +218,14 @@ public Instant getWatermark() { } @Override - public CheckpointMark getCheckpointMark() { + public CounterMark getCheckpointMark() { if (throwOnFirstSnapshot && !thrown) { thrown = true; LOG.error("Throwing exception while checkpointing counter"); throw new RuntimeException("failed during checkpoint"); } + // The checkpoint can assume all records read, including the current, have + // been commited. return new CounterMark(current); } @@ -222,7 +238,12 @@ public long getSplitBacklogBytes() { @Override public CountingSourceReader createReader( PipelineOptions options, @Nullable CounterMark checkpointMark) { - return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : 0); + if (checkpointMark == null) { + LOG.debug("creating reader"); + } else { + LOG.debug("restoring reader from checkpoint with current = {}", checkpointMark.current); + } + return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : -1); } @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java new file mode 100644 index 0000000000..3c5144a111 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.runners.dataflow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; + +/** + * Test the TestCountingSource. + */ +@RunWith(JUnit4.class) +public class TestCountingSourceTest { + @Test + public void testRespectsCheckpointContract() throws IOException { + TestCountingSource source = new TestCountingSource(3); + PipelineOptions options = PipelineOptionsFactory.create(); + TestCountingSource.CountingSourceReader reader = + source.createReader(options, null /* no checkpoint */); + assertTrue(reader.start()); + assertEquals(0L, (long) reader.getCurrent().getValue()); + assertTrue(reader.advance()); + assertEquals(1L, (long) reader.getCurrent().getValue()); + TestCountingSource.CounterMark checkpoint = reader.getCheckpointMark(); + checkpoint.finalizeCheckpoint(); + reader = source.createReader(options, checkpoint); + assertTrue(reader.start()); + assertEquals(2L, (long) reader.getCurrent().getValue()); + assertFalse(reader.advance()); + } +} From 0bffaab75b60c36da6e82d55b58974d0247bb262 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Tue, 26 Apr 2016 16:41:12 -0700 Subject: [PATCH 2/2] Mirror BEAM --- .../runners/dataflow/TestCountingSource.java | 4 ++-- .../dataflow/TestCountingSourceTest.java | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java index 612b117827..16752a9641 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSource.java @@ -172,7 +172,7 @@ public boolean start() { @Override public boolean advance() { - if (current >= numMessagesPerShard) { + if (current >= numMessagesPerShard - 1) { return false; } // If testing dedup, occasionally insert a duplicate value; @@ -180,7 +180,7 @@ public boolean advance() { return true; } current++; - return current < numMessagesPerShard; + return true; } @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java index 3c5144a111..b03398f479 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/dataflow/TestCountingSourceTest.java @@ -51,4 +51,22 @@ public void testRespectsCheckpointContract() throws IOException { assertEquals(2L, (long) reader.getCurrent().getValue()); assertFalse(reader.advance()); } + + @Test + public void testCanResumeWithExpandedCount() throws IOException { + TestCountingSource source = new TestCountingSource(1); + PipelineOptions options = PipelineOptionsFactory.create(); + TestCountingSource.CountingSourceReader reader = + source.createReader(options, null /* no checkpoint */); + assertTrue(reader.start()); + assertEquals(0L, (long) reader.getCurrent().getValue()); + assertFalse(reader.advance()); + TestCountingSource.CounterMark checkpoint = reader.getCheckpointMark(); + checkpoint.finalizeCheckpoint(); + source = new TestCountingSource(2); + reader = source.createReader(options, checkpoint); + assertTrue(reader.start()); + assertEquals(1L, (long) reader.getCurrent().getValue()); + assertFalse(reader.advance()); + } }