Skip to content
This repository was archived by the owner on Nov 11, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class TestCountingSource
private final int shardNumber;
private final boolean dedup;
private final boolean throwOnFirstSnapshot;
private final boolean allowSplitting;

/**
* We only allow an exception to be thrown from getCheckpointMark
Expand All @@ -68,27 +69,36 @@ public static void setFinalizeTracker(List<Integer> finalizeTracker) {
}

public TestCountingSource(int numMessagesPerShard) {
this(numMessagesPerShard, 0, false, false);
this(numMessagesPerShard, 0, false, false, true);
}

public TestCountingSource withDedup() {
return new TestCountingSource(numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot);
return new TestCountingSource(
numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot, true);
}

private TestCountingSource withShardNumber(int shardNumber) {
return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot);
return new TestCountingSource(
numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true);
}

public TestCountingSource withThrowOnFirstSnapshot(boolean throwOnFirstSnapshot) {
return new TestCountingSource(numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot);
return new TestCountingSource(
numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true);
}

private TestCountingSource(
int numMessagesPerShard, int shardNumber, boolean dedup, boolean throwOnFirstSnapshot) {
public TestCountingSource withoutSplitting() {
return new TestCountingSource(
numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, false);
}

private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup,
boolean throwOnFirstSnapshot, boolean allowSplitting) {
this.numMessagesPerShard = numMessagesPerShard;
this.shardNumber = shardNumber;
this.dedup = dedup;
this.throwOnFirstSnapshot = throwOnFirstSnapshot;
this.allowSplitting = allowSplitting;
}

public int getShardNumber() {
Expand All @@ -99,7 +109,8 @@ public int getShardNumber() {
public List<TestCountingSource> generateInitialSplits(
int desiredNumSplits, PipelineOptions options) {
List<TestCountingSource> splits = new ArrayList<>();
for (int i = 0; i < desiredNumSplits; i++) {
int numSplits = allowSplitting ? desiredNumSplits : 1;
for (int i = 0; i < numSplits; i++) {
splits.add(withShardNumber(i));
}
return splits;
Expand Down Expand Up @@ -143,7 +154,11 @@ public boolean requiresDeduping() {
return dedup;
}

private class CountingSourceReader extends UnboundedReader<KV<Integer, Integer>> {
/**
* Public only so that the checkpoint can be conveyed from {@link #getCheckpointMark()} to
* {@link TestCountingSource#createReader(PipelineOptions, CounterMark)} without cast.
*/
public class CountingSourceReader extends UnboundedReader<KV<Integer, Integer>> {
private int current;

public CountingSourceReader(int startingPoint) {
Expand All @@ -152,21 +167,20 @@ public CountingSourceReader(int startingPoint) {

@Override
public boolean start() {
return true;
return advance();
}

@Override
public boolean advance() {
if (current < numMessagesPerShard - 1) {
// If testing dedup, occasionally insert a duplicate value;
if (dedup && ThreadLocalRandom.current().nextInt(5) == 0) {
return true;
}
current++;
return true;
} else {
if (current >= numMessagesPerShard - 1) {
return false;
}
// If testing dedup, occasionally insert a duplicate value;
if (current >= 0 && dedup && ThreadLocalRandom.current().nextInt(5) == 0) {
return true;
}
current++;
return true;
}

@Override
Expand Down Expand Up @@ -204,12 +218,14 @@ public Instant getWatermark() {
}

@Override
public CheckpointMark getCheckpointMark() {
public CounterMark getCheckpointMark() {
if (throwOnFirstSnapshot && !thrown) {
thrown = true;
LOG.error("Throwing exception while checkpointing counter");
throw new RuntimeException("failed during checkpoint");
}
// The checkpoint can assume all records read, including the current, have
// been commited.
return new CounterMark(current);
}

Expand All @@ -222,7 +238,12 @@ public long getSplitBacklogBytes() {
@Override
public CountingSourceReader createReader(
PipelineOptions options, @Nullable CounterMark checkpointMark) {
return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : 0);
if (checkpointMark == null) {
LOG.debug("creating reader");
} else {
LOG.debug("restoring reader from checkpoint with current = {}", checkpointMark.current);
}
return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : -1);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (C) 2015 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package com.google.cloud.dataflow.sdk.runners.dataflow;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

import java.io.IOException;

/**
* Test the TestCountingSource.
*/
@RunWith(JUnit4.class)
public class TestCountingSourceTest {
@Test
public void testRespectsCheckpointContract() throws IOException {
TestCountingSource source = new TestCountingSource(3);
PipelineOptions options = PipelineOptionsFactory.create();
TestCountingSource.CountingSourceReader reader =
source.createReader(options, null /* no checkpoint */);
assertTrue(reader.start());
assertEquals(0L, (long) reader.getCurrent().getValue());
assertTrue(reader.advance());
assertEquals(1L, (long) reader.getCurrent().getValue());
TestCountingSource.CounterMark checkpoint = reader.getCheckpointMark();
checkpoint.finalizeCheckpoint();
reader = source.createReader(options, checkpoint);
assertTrue(reader.start());
assertEquals(2L, (long) reader.getCurrent().getValue());
assertFalse(reader.advance());
}

@Test
public void testCanResumeWithExpandedCount() throws IOException {
TestCountingSource source = new TestCountingSource(1);
PipelineOptions options = PipelineOptionsFactory.create();
TestCountingSource.CountingSourceReader reader =
source.createReader(options, null /* no checkpoint */);
assertTrue(reader.start());
assertEquals(0L, (long) reader.getCurrent().getValue());
assertFalse(reader.advance());
TestCountingSource.CounterMark checkpoint = reader.getCheckpointMark();
checkpoint.finalizeCheckpoint();
source = new TestCountingSource(2);
reader = source.createReader(options, checkpoint);
assertTrue(reader.start());
assertEquals(1L, (long) reader.getCurrent().getValue());
assertFalse(reader.advance());
}
}