From 74f08eacea64f872cd95932e2e7f372862d929d3 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 29 Feb 2016 13:48:09 -0800 Subject: [PATCH] Add UnboundedCountingInput#withRate The rate controls the speed at which UnboundedCountingInput outputs elements. This is an aggregate rate across all instances of the source, and thus elements will not necessarily be output "smoothly", or within the first period. The aggregate rate, however, will be approximately equal to the provided rate. Add package-private CountingSource#createUnbounded() to expose the UnboundedCountingSource type. Make UnboundedCountingSource package-private. --- .../cloud/dataflow/sdk/io/CountingInput.java | 42 +++++++- .../cloud/dataflow/sdk/io/CountingSource.java | 99 +++++++++++++++++-- .../dataflow/sdk/io/CountingInputTest.java | 24 ++++- .../dataflow/sdk/io/CountingSourceTest.java | 73 ++++++++++++++ 4 files changed, 224 insertions(+), 14 deletions(-) diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java index 91e3f13113..c743146e10 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java @@ -90,7 +90,11 @@ public static BoundedCountingInput upTo(long numElements) { */ public static UnboundedCountingInput unbounded() { return new UnboundedCountingInput( - new NowTimestampFn(), Optional.absent(), Optional.absent()); + new NowTimestampFn(), + 1L /* Elements per period */, + Duration.ZERO /* Period length */, + Optional.absent() /* Maximum number of records */, + Optional.absent() /* Maximum read duration */); } /** @@ -133,14 +137,20 @@ public void populateDisplayData(DisplayData.Builder builder) { */ public static class UnboundedCountingInput extends PTransform> { private final SerializableFunction timestampFn; + private final long elementsPerPeriod; + private final Duration period; private final Optional maxNumRecords; private final Optional maxReadTime; private UnboundedCountingInput( SerializableFunction timestampFn, + long elementsPerPeriod, + Duration period, Optional maxNumRecords, Optional maxReadTime) { this.timestampFn = timestampFn; + this.elementsPerPeriod = elementsPerPeriod; + this.period = period; this.maxNumRecords = maxNumRecords; this.maxReadTime = maxReadTime; } @@ -152,7 +162,8 @@ private UnboundedCountingInput( *

Note that the timestamps produced by {@code timestampFn} may not decrease. */ public UnboundedCountingInput withTimestampFn(SerializableFunction timestampFn) { - return new UnboundedCountingInput(timestampFn, maxNumRecords, maxReadTime); + return new UnboundedCountingInput( + timestampFn, elementsPerPeriod, period, maxNumRecords, maxReadTime); } /** @@ -165,7 +176,23 @@ public UnboundedCountingInput withTimestampFn(SerializableFunction 0, "MaxRecords must be a positive (nonzero) value. Got %s", maxRecords); - return new UnboundedCountingInput(timestampFn, Optional.of(maxRecords), maxReadTime); + return new UnboundedCountingInput( + timestampFn, elementsPerPeriod, period, Optional.of(maxRecords), maxReadTime); + } + + /** + * Returns an {@link UnboundedCountingInput} like this one, but with output production limited + * to an aggregate rate of no more than the number of elements per the period length. + * + *

Note that when there are multiple splits, each split outputs independently. This may lead + * to elements not being produced evenly across time, though the aggregate rate will still + * approach the specified rate. + * + *

A duration of {@link Duration#ZERO} will produce output as fast as possible. + */ + public UnboundedCountingInput withRate(long numElements, Duration periodLength) { + return new UnboundedCountingInput( + timestampFn, numElements, periodLength, maxNumRecords, maxReadTime); } /** @@ -177,13 +204,18 @@ public UnboundedCountingInput withMaxNumRecords(long maxRecords) { */ public UnboundedCountingInput withMaxReadTime(Duration readTime) { checkNotNull(readTime, "ReadTime cannot be null"); - return new UnboundedCountingInput(timestampFn, maxNumRecords, Optional.of(readTime)); + return new UnboundedCountingInput( + timestampFn, elementsPerPeriod, period, maxNumRecords, Optional.of(readTime)); } @SuppressWarnings("deprecation") @Override public PCollection apply(PBegin begin) { - Unbounded read = Read.from(CountingSource.unboundedWithTimestampFn(timestampFn)); + Unbounded read = + Read.from( + CountingSource.createUnbounded() + .withTimestampFn(timestampFn) + .withRate(elementsPerPeriod, period)); if (!maxNumRecords.isPresent() && !maxReadTime.isPresent()) { return begin.apply(read); } else if (maxNumRecords.isPresent() && !maxReadTime.isPresent()) { diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java index 8cca7e2179..20568441ed 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java @@ -17,6 +17,7 @@ package com.google.cloud.dataflow.sdk.io; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.dataflow.sdk.coders.AvroCoder; import com.google.cloud.dataflow.sdk.coders.Coder; @@ -29,6 +30,7 @@ import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.common.collect.ImmutableList; +import org.joda.time.Duration; import org.joda.time.Instant; import java.io.IOException; @@ -81,6 +83,14 @@ public static BoundedSource upTo(long numElements) { return new BoundedCountingSource(0, numElements); } + /** + * Create a new {@link UnboundedCountingSource}. + */ + // package-private to return a typed UnboundedCountingSource rather than the UnboundedSource type. + static UnboundedCountingSource createUnbounded() { + return new UnboundedCountingSource(0, 1, 1L, Duration.ZERO, new NowTimestampFn()); + } + /** * Creates an {@link UnboundedSource} that will produce numbers starting from {@code 0} up to * {@link Long#MAX_VALUE}. @@ -113,7 +123,7 @@ public static UnboundedSource unbounded() { @Deprecated public static UnboundedSource unboundedWithTimestampFn( SerializableFunction timestampFn) { - return new UnboundedCountingSource(0, 1, timestampFn); + return new UnboundedCountingSource(0, 1, 1L, Duration.ZERO, timestampFn); } ///////////////////////////////////////////////////////////////////////////////////////////// @@ -231,11 +241,15 @@ public void close() throws IOException {} /** * An implementation of {@link CountingSource} that produces an unbounded {@link PCollection}. */ - private static class UnboundedCountingSource extends UnboundedSource { + static class UnboundedCountingSource extends UnboundedSource { /** The first number (>= 0) generated by this {@link UnboundedCountingSource}. */ private final long start; /** The interval between numbers generated by this {@link UnboundedCountingSource}. */ private final long stride; + /** The number of elements to produce each period. */ + private final long elementsPerPeriod; + /** The time between producing numbers from this {@link UnboundedCountingSource}. */ + private final Duration period; /** The function used to produce timestamps for the generated elements. */ private final SerializableFunction timestampFn; @@ -248,13 +262,45 @@ private static class UnboundedCountingSource extends UnboundedSourceNote that the timestamps produced by {@code timestampFn} may not decrease. */ - public UnboundedCountingSource( - long start, long stride, SerializableFunction timestampFn) { + private UnboundedCountingSource( + long start, + long stride, + long elementsPerPeriod, + Duration period, + SerializableFunction timestampFn) { this.start = start; this.stride = stride; + checkArgument( + elementsPerPeriod > 0L, + "Must produce at least one element per period, got %s", + elementsPerPeriod); + this.elementsPerPeriod = elementsPerPeriod; + checkArgument( + period.getMillis() >= 0L, "Must have a non-negative period length, got %s", period); + this.period = period; this.timestampFn = timestampFn; } + /** + * Returns an {@link UnboundedCountingSource} like this one with the specified period. Elements + * will be produced with an interval between them equal to the period. + */ + public UnboundedCountingSource withRate(long elementsPerPeriod, Duration period) { + return new UnboundedCountingSource(start, stride, elementsPerPeriod, period, timestampFn); + } + + /** + * Returns an {@link UnboundedCountingSource} like this one where the timestamp of output + * elements are supplied by the specified function. + * + *

Note that timestamps produced by {@code timestampFn} may not decrease. + */ + public UnboundedCountingSource withTimestampFn( + SerializableFunction timestampFn) { + checkNotNull(timestampFn); + return new UnboundedCountingSource(start, stride, elementsPerPeriod, period, timestampFn); + } + /** * Splits an unbounded source {@code desiredNumSplits} ways by giving each split every * {@code desiredNumSplits}th element that this {@link UnboundedCountingSource} @@ -275,7 +321,9 @@ public List> generat for (int i = 0; i < desiredNumSplits; ++i) { // Starts offset by the original stride. Using Javadoc example, this generates starts of // 0, 2, and 4. - splits.add(new UnboundedCountingSource(start + i * stride, newStride, timestampFn)); + splits.add( + new UnboundedCountingSource( + start + i * stride, newStride, elementsPerPeriod, period, timestampFn)); } return splits.build(); } @@ -309,6 +357,7 @@ private static class UnboundedCountingReader extends UnboundedReader { private UnboundedCountingSource source; private long current; private Instant currentTimestamp; + private Instant firstStarted; public UnboundedCountingReader(UnboundedCountingSource source, CounterMark mark) { this.source = source; @@ -318,11 +367,15 @@ public UnboundedCountingReader(UnboundedCountingSource source, CounterMark mark) this.current = source.start - source.stride; } else { this.current = mark.getLastEmitted(); + this.firstStarted = mark.getStartTime(); } } @Override public boolean start() throws IOException { + if (firstStarted == null) { + this.firstStarted = Instant.now(); + } return advance(); } @@ -332,11 +385,25 @@ public boolean advance() throws IOException { if (Long.MAX_VALUE - source.stride < current) { return false; } - current += source.stride; + long nextValue = current + source.stride; + if (expectedValue() < nextValue) { + return false; + } + current = nextValue; currentTimestamp = source.timestampFn.apply(current); return true; } + private long expectedValue() { + if (source.period.getMillis() == 0L) { + return Long.MAX_VALUE; + } + double periodsElapsed = + (Instant.now().getMillis() - firstStarted.getMillis()) + / (double) source.period.getMillis(); + return (long) (source.elementsPerPeriod * periodsElapsed); + } + @Override public Instant getWatermark() { return source.timestampFn.apply(current); @@ -344,7 +411,7 @@ public Instant getWatermark() { @Override public CounterMark getCheckpointMark() { - return new CounterMark(current); + return new CounterMark(current, firstStarted); } @Override @@ -364,6 +431,12 @@ public Instant getCurrentTimestamp() throws NoSuchElementException { @Override public void close() throws IOException {} + + @Override + public long getSplitBacklogBytes() { + long expected = expectedValue(); + return Math.max(0L, 8 * (expected - current) / source.stride); + } } /** @@ -374,12 +447,14 @@ public void close() throws IOException {} public static class CounterMark implements UnboundedSource.CheckpointMark { /** The last value emitted. */ private final long lastEmitted; + private final Instant startTime; /** * Creates a checkpoint mark reflecting the last emitted value. */ - public CounterMark(long lastEmitted) { + public CounterMark(long lastEmitted, Instant startTime) { this.lastEmitted = lastEmitted; + this.startTime = startTime; } /** @@ -389,11 +464,19 @@ public long getLastEmitted() { return lastEmitted; } + /** + * Returns the time the reader was started. + */ + public Instant getStartTime() { + return startTime; + } + ///////////////////////////////////////////////////////////////////////////////////// @SuppressWarnings("unused") // For AvroCoder private CounterMark() { this.lastEmitted = 0L; + this.startTime = Instant.now(); } @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java index d06688a29c..529536e7a5 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java @@ -19,7 +19,8 @@ import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; -import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.io.CountingInput.UnboundedCountingInput; @@ -96,6 +97,27 @@ public void testUnboundedInput() { p.run(); } + @Test + public void testUnboundedInputRate() { + Pipeline p = TestPipeline.create(); + long numElements = 5000; + + long elemsPerPeriod = 10L; + Duration periodLength = Duration.millis(8); + PCollection input = + p.apply( + CountingInput.unbounded() + .withRate(elemsPerPeriod, periodLength) + .withMaxNumRecords(numElements)); + + addCountingAsserts(input, numElements); + long expectedRuntimeMillis = (periodLength.getMillis() * numElements) / elemsPerPeriod; + Instant startTime = Instant.now(); + p.run(); + Instant endTime = Instant.now(); + assertThat(endTime.isAfter(startTime.plus(expectedRuntimeMillis)), is(true)); + } + private static class ElementValueDiff extends DoFn { @Override public void processElement(ProcessContext c) throws Exception { diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java index cc01cc8576..3051535d5d 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java @@ -16,12 +16,16 @@ package com.google.cloud.dataflow.sdk.io; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; import com.google.cloud.dataflow.sdk.io.CountingSource.CounterMark; +import com.google.cloud.dataflow.sdk.io.CountingSource.UnboundedCountingSource; import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; import com.google.cloud.dataflow.sdk.testing.DataflowAssert; @@ -39,6 +43,7 @@ import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionList; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -179,6 +184,40 @@ public void testUnboundedSourceTimestamps() { p.run(); } + @Test + public void testUnboundedSourceWithRate() { + Pipeline p = TestPipeline.create(); + + Duration period = Duration.millis(5); + long numElements = 1000L; + + PCollection input = + p.apply( + Read.from( + CountingSource.createUnbounded() + .withTimestampFn(new ValueAsTimestampFn()) + .withRate(1, period)) + .withMaxNumRecords(numElements)); + addCountingAsserts(input, numElements); + + PCollection diffs = + input + .apply("TimestampDiff", ParDo.of(new ElementValueDiff())) + .apply("RemoveDuplicateTimestamps", RemoveDuplicates.create()); + // This assert also confirms that diffs only has one unique value. + DataflowAssert.thatSingleton(diffs).isEqualTo(0L); + + Instant started = Instant.now(); + p.run(); + Instant finished = Instant.now(); + Duration expectedDuration = period.multipliedBy((int) numElements); + assertThat( + started + .plus(expectedDuration) + .isBefore(finished), + is(true)); + } + @Test @Category(RunnableOnService.class) public void testUnboundedSourceSplits() throws Exception { @@ -204,6 +243,40 @@ public void testUnboundedSourceSplits() throws Exception { p.run(); } + @Test + public void testUnboundedSourceRateSplits() throws Exception { + Pipeline p = TestPipeline.create(); + int elementsPerPeriod = 10; + Duration period = Duration.millis(5); + + long numElements = 1000; + int numSplits = 10; + + UnboundedCountingSource initial = + CountingSource.createUnbounded().withRate(elementsPerPeriod, period); + List> splits = + initial.generateInitialSplits(numSplits, p.getOptions()); + assertEquals("Expected exact splitting", numSplits, splits.size()); + + long elementsPerSplit = numElements / numSplits; + assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits); + PCollectionList pcollections = PCollectionList.empty(p); + for (int i = 0; i < splits.size(); ++i) { + pcollections = + pcollections.and( + p.apply("split" + i, Read.from(splits.get(i)).withMaxNumRecords(elementsPerSplit))); + } + PCollection input = pcollections.apply(Flatten.pCollections()); + + addCountingAsserts(input, numElements); + Instant startTime = Instant.now(); + p.run(); + Instant endTime = Instant.now(); + // 500 ms if the readers are all initialized in parallel; 5000 ms if they are evaluated serially + long expectedMinimumMillis = (numElements * period.getMillis()) / elementsPerPeriod; + assertThat(expectedMinimumMillis, lessThan(endTime.getMillis() - startTime.getMillis())); + } + /** * A timestamp function that uses the given value as the timestamp. Because the input values will * not wrap, this function is non-decreasing and meets the timestamp function criteria laid out