diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java index 91e3f13113..c743146e10 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingInput.java @@ -90,7 +90,11 @@ public static BoundedCountingInput upTo(long numElements) { */ public static UnboundedCountingInput unbounded() { return new UnboundedCountingInput( - new NowTimestampFn(), Optional.absent(), Optional.absent()); + new NowTimestampFn(), + 1L /* Elements per period */, + Duration.ZERO /* Period length */, + Optional.absent() /* Maximum number of records */, + Optional.absent() /* Maximum read duration */); } /** @@ -133,14 +137,20 @@ public void populateDisplayData(DisplayData.Builder builder) { */ public static class UnboundedCountingInput extends PTransform> { private final SerializableFunction timestampFn; + private final long elementsPerPeriod; + private final Duration period; private final Optional maxNumRecords; private final Optional maxReadTime; private UnboundedCountingInput( SerializableFunction timestampFn, + long elementsPerPeriod, + Duration period, Optional maxNumRecords, Optional maxReadTime) { this.timestampFn = timestampFn; + this.elementsPerPeriod = elementsPerPeriod; + this.period = period; this.maxNumRecords = maxNumRecords; this.maxReadTime = maxReadTime; } @@ -152,7 +162,8 @@ private UnboundedCountingInput( *

Note that the timestamps produced by {@code timestampFn} may not decrease. */ public UnboundedCountingInput withTimestampFn(SerializableFunction timestampFn) { - return new UnboundedCountingInput(timestampFn, maxNumRecords, maxReadTime); + return new UnboundedCountingInput( + timestampFn, elementsPerPeriod, period, maxNumRecords, maxReadTime); } /** @@ -165,7 +176,23 @@ public UnboundedCountingInput withTimestampFn(SerializableFunction 0, "MaxRecords must be a positive (nonzero) value. Got %s", maxRecords); - return new UnboundedCountingInput(timestampFn, Optional.of(maxRecords), maxReadTime); + return new UnboundedCountingInput( + timestampFn, elementsPerPeriod, period, Optional.of(maxRecords), maxReadTime); + } + + /** + * Returns an {@link UnboundedCountingInput} like this one, but with output production limited + * to an aggregate rate of no more than the number of elements per the period length. + * + *

Note that when there are multiple splits, each split outputs independently. This may lead + * to elements not being produced evenly across time, though the aggregate rate will still + * approach the specified rate. + * + *

A duration of {@link Duration#ZERO} will produce output as fast as possible. + */ + public UnboundedCountingInput withRate(long numElements, Duration periodLength) { + return new UnboundedCountingInput( + timestampFn, numElements, periodLength, maxNumRecords, maxReadTime); } /** @@ -177,13 +204,18 @@ public UnboundedCountingInput withMaxNumRecords(long maxRecords) { */ public UnboundedCountingInput withMaxReadTime(Duration readTime) { checkNotNull(readTime, "ReadTime cannot be null"); - return new UnboundedCountingInput(timestampFn, maxNumRecords, Optional.of(readTime)); + return new UnboundedCountingInput( + timestampFn, elementsPerPeriod, period, maxNumRecords, Optional.of(readTime)); } @SuppressWarnings("deprecation") @Override public PCollection apply(PBegin begin) { - Unbounded read = Read.from(CountingSource.unboundedWithTimestampFn(timestampFn)); + Unbounded read = + Read.from( + CountingSource.createUnbounded() + .withTimestampFn(timestampFn) + .withRate(elementsPerPeriod, period)); if (!maxNumRecords.isPresent() && !maxReadTime.isPresent()) { return begin.apply(read); } else if (maxNumRecords.isPresent() && !maxReadTime.isPresent()) { diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java index 8cca7e2179..20568441ed 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/CountingSource.java @@ -17,6 +17,7 @@ package com.google.cloud.dataflow.sdk.io; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.dataflow.sdk.coders.AvroCoder; import com.google.cloud.dataflow.sdk.coders.Coder; @@ -29,6 +30,7 @@ import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.common.collect.ImmutableList; +import org.joda.time.Duration; import org.joda.time.Instant; import java.io.IOException; @@ -81,6 +83,14 @@ public static BoundedSource upTo(long numElements) { return new BoundedCountingSource(0, numElements); } + /** + * Create a new {@link UnboundedCountingSource}. + */ + // package-private to return a typed UnboundedCountingSource rather than the UnboundedSource type. + static UnboundedCountingSource createUnbounded() { + return new UnboundedCountingSource(0, 1, 1L, Duration.ZERO, new NowTimestampFn()); + } + /** * Creates an {@link UnboundedSource} that will produce numbers starting from {@code 0} up to * {@link Long#MAX_VALUE}. @@ -113,7 +123,7 @@ public static UnboundedSource unbounded() { @Deprecated public static UnboundedSource unboundedWithTimestampFn( SerializableFunction timestampFn) { - return new UnboundedCountingSource(0, 1, timestampFn); + return new UnboundedCountingSource(0, 1, 1L, Duration.ZERO, timestampFn); } ///////////////////////////////////////////////////////////////////////////////////////////// @@ -231,11 +241,15 @@ public void close() throws IOException {} /** * An implementation of {@link CountingSource} that produces an unbounded {@link PCollection}. */ - private static class UnboundedCountingSource extends UnboundedSource { + static class UnboundedCountingSource extends UnboundedSource { /** The first number (>= 0) generated by this {@link UnboundedCountingSource}. */ private final long start; /** The interval between numbers generated by this {@link UnboundedCountingSource}. */ private final long stride; + /** The number of elements to produce each period. */ + private final long elementsPerPeriod; + /** The time between producing numbers from this {@link UnboundedCountingSource}. */ + private final Duration period; /** The function used to produce timestamps for the generated elements. */ private final SerializableFunction timestampFn; @@ -248,13 +262,45 @@ private static class UnboundedCountingSource extends UnboundedSourceNote that the timestamps produced by {@code timestampFn} may not decrease. */ - public UnboundedCountingSource( - long start, long stride, SerializableFunction timestampFn) { + private UnboundedCountingSource( + long start, + long stride, + long elementsPerPeriod, + Duration period, + SerializableFunction timestampFn) { this.start = start; this.stride = stride; + checkArgument( + elementsPerPeriod > 0L, + "Must produce at least one element per period, got %s", + elementsPerPeriod); + this.elementsPerPeriod = elementsPerPeriod; + checkArgument( + period.getMillis() >= 0L, "Must have a non-negative period length, got %s", period); + this.period = period; this.timestampFn = timestampFn; } + /** + * Returns an {@link UnboundedCountingSource} like this one with the specified period. Elements + * will be produced with an interval between them equal to the period. + */ + public UnboundedCountingSource withRate(long elementsPerPeriod, Duration period) { + return new UnboundedCountingSource(start, stride, elementsPerPeriod, period, timestampFn); + } + + /** + * Returns an {@link UnboundedCountingSource} like this one where the timestamp of output + * elements are supplied by the specified function. + * + *

Note that timestamps produced by {@code timestampFn} may not decrease. + */ + public UnboundedCountingSource withTimestampFn( + SerializableFunction timestampFn) { + checkNotNull(timestampFn); + return new UnboundedCountingSource(start, stride, elementsPerPeriod, period, timestampFn); + } + /** * Splits an unbounded source {@code desiredNumSplits} ways by giving each split every * {@code desiredNumSplits}th element that this {@link UnboundedCountingSource} @@ -275,7 +321,9 @@ public List> generat for (int i = 0; i < desiredNumSplits; ++i) { // Starts offset by the original stride. Using Javadoc example, this generates starts of // 0, 2, and 4. - splits.add(new UnboundedCountingSource(start + i * stride, newStride, timestampFn)); + splits.add( + new UnboundedCountingSource( + start + i * stride, newStride, elementsPerPeriod, period, timestampFn)); } return splits.build(); } @@ -309,6 +357,7 @@ private static class UnboundedCountingReader extends UnboundedReader { private UnboundedCountingSource source; private long current; private Instant currentTimestamp; + private Instant firstStarted; public UnboundedCountingReader(UnboundedCountingSource source, CounterMark mark) { this.source = source; @@ -318,11 +367,15 @@ public UnboundedCountingReader(UnboundedCountingSource source, CounterMark mark) this.current = source.start - source.stride; } else { this.current = mark.getLastEmitted(); + this.firstStarted = mark.getStartTime(); } } @Override public boolean start() throws IOException { + if (firstStarted == null) { + this.firstStarted = Instant.now(); + } return advance(); } @@ -332,11 +385,25 @@ public boolean advance() throws IOException { if (Long.MAX_VALUE - source.stride < current) { return false; } - current += source.stride; + long nextValue = current + source.stride; + if (expectedValue() < nextValue) { + return false; + } + current = nextValue; currentTimestamp = source.timestampFn.apply(current); return true; } + private long expectedValue() { + if (source.period.getMillis() == 0L) { + return Long.MAX_VALUE; + } + double periodsElapsed = + (Instant.now().getMillis() - firstStarted.getMillis()) + / (double) source.period.getMillis(); + return (long) (source.elementsPerPeriod * periodsElapsed); + } + @Override public Instant getWatermark() { return source.timestampFn.apply(current); @@ -344,7 +411,7 @@ public Instant getWatermark() { @Override public CounterMark getCheckpointMark() { - return new CounterMark(current); + return new CounterMark(current, firstStarted); } @Override @@ -364,6 +431,12 @@ public Instant getCurrentTimestamp() throws NoSuchElementException { @Override public void close() throws IOException {} + + @Override + public long getSplitBacklogBytes() { + long expected = expectedValue(); + return Math.max(0L, 8 * (expected - current) / source.stride); + } } /** @@ -374,12 +447,14 @@ public void close() throws IOException {} public static class CounterMark implements UnboundedSource.CheckpointMark { /** The last value emitted. */ private final long lastEmitted; + private final Instant startTime; /** * Creates a checkpoint mark reflecting the last emitted value. */ - public CounterMark(long lastEmitted) { + public CounterMark(long lastEmitted, Instant startTime) { this.lastEmitted = lastEmitted; + this.startTime = startTime; } /** @@ -389,11 +464,19 @@ public long getLastEmitted() { return lastEmitted; } + /** + * Returns the time the reader was started. + */ + public Instant getStartTime() { + return startTime; + } + ///////////////////////////////////////////////////////////////////////////////////// @SuppressWarnings("unused") // For AvroCoder private CounterMark() { this.lastEmitted = 0L; + this.startTime = Instant.now(); } @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java index d06688a29c..529536e7a5 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingInputTest.java @@ -19,7 +19,8 @@ import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; -import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.io.CountingInput.UnboundedCountingInput; @@ -96,6 +97,27 @@ public void testUnboundedInput() { p.run(); } + @Test + public void testUnboundedInputRate() { + Pipeline p = TestPipeline.create(); + long numElements = 5000; + + long elemsPerPeriod = 10L; + Duration periodLength = Duration.millis(8); + PCollection input = + p.apply( + CountingInput.unbounded() + .withRate(elemsPerPeriod, periodLength) + .withMaxNumRecords(numElements)); + + addCountingAsserts(input, numElements); + long expectedRuntimeMillis = (periodLength.getMillis() * numElements) / elemsPerPeriod; + Instant startTime = Instant.now(); + p.run(); + Instant endTime = Instant.now(); + assertThat(endTime.isAfter(startTime.plus(expectedRuntimeMillis)), is(true)); + } + private static class ElementValueDiff extends DoFn { @Override public void processElement(ProcessContext c) throws Exception { diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java index cc01cc8576..3051535d5d 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/CountingSourceTest.java @@ -16,12 +16,16 @@ package com.google.cloud.dataflow.sdk.io; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; import com.google.cloud.dataflow.sdk.io.CountingSource.CounterMark; +import com.google.cloud.dataflow.sdk.io.CountingSource.UnboundedCountingSource; import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; import com.google.cloud.dataflow.sdk.testing.DataflowAssert; @@ -39,6 +43,7 @@ import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionList; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -179,6 +184,40 @@ public void testUnboundedSourceTimestamps() { p.run(); } + @Test + public void testUnboundedSourceWithRate() { + Pipeline p = TestPipeline.create(); + + Duration period = Duration.millis(5); + long numElements = 1000L; + + PCollection input = + p.apply( + Read.from( + CountingSource.createUnbounded() + .withTimestampFn(new ValueAsTimestampFn()) + .withRate(1, period)) + .withMaxNumRecords(numElements)); + addCountingAsserts(input, numElements); + + PCollection diffs = + input + .apply("TimestampDiff", ParDo.of(new ElementValueDiff())) + .apply("RemoveDuplicateTimestamps", RemoveDuplicates.create()); + // This assert also confirms that diffs only has one unique value. + DataflowAssert.thatSingleton(diffs).isEqualTo(0L); + + Instant started = Instant.now(); + p.run(); + Instant finished = Instant.now(); + Duration expectedDuration = period.multipliedBy((int) numElements); + assertThat( + started + .plus(expectedDuration) + .isBefore(finished), + is(true)); + } + @Test @Category(RunnableOnService.class) public void testUnboundedSourceSplits() throws Exception { @@ -204,6 +243,40 @@ public void testUnboundedSourceSplits() throws Exception { p.run(); } + @Test + public void testUnboundedSourceRateSplits() throws Exception { + Pipeline p = TestPipeline.create(); + int elementsPerPeriod = 10; + Duration period = Duration.millis(5); + + long numElements = 1000; + int numSplits = 10; + + UnboundedCountingSource initial = + CountingSource.createUnbounded().withRate(elementsPerPeriod, period); + List> splits = + initial.generateInitialSplits(numSplits, p.getOptions()); + assertEquals("Expected exact splitting", numSplits, splits.size()); + + long elementsPerSplit = numElements / numSplits; + assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits); + PCollectionList pcollections = PCollectionList.empty(p); + for (int i = 0; i < splits.size(); ++i) { + pcollections = + pcollections.and( + p.apply("split" + i, Read.from(splits.get(i)).withMaxNumRecords(elementsPerSplit))); + } + PCollection input = pcollections.apply(Flatten.pCollections()); + + addCountingAsserts(input, numElements); + Instant startTime = Instant.now(); + p.run(); + Instant endTime = Instant.now(); + // 500 ms if the readers are all initialized in parallel; 5000 ms if they are evaluated serially + long expectedMinimumMillis = (numElements * period.getMillis()) / elementsPerPeriod; + assertThat(expectedMinimumMillis, lessThan(endTime.getMillis() - startTime.getMillis())); + } + /** * A timestamp function that uses the given value as the timestamp. Because the input values will * not wrap, this function is non-decreasing and meets the timestamp function criteria laid out