DEFAULT_PUBSUB_CODER = StringUtf8Coder.of();
@@ -141,48 +132,6 @@ private static void validatePubsubName(String name) {
}
}
- /**
- * Returns the {@link Instant} that corresponds to the timestamp in the supplied
- * {@link PubsubMessage} under the specified {@code ink label}. See
- * {@link PubsubIO.Read#timestampLabel(String)} for details about how these messages are
- * parsed.
- *
- * The {@link Clock} parameter is used to virtualize time for testing.
- *
- * @throws IllegalArgumentException if the timestamp label is provided, but there is no
- * corresponding attribute in the message or the value provided is not a valid timestamp
- * string.
- * @see PubsubIO.Read#timestampLabel(String)
- */
- @VisibleForTesting
- protected static Instant assignMessageTimestamp(
- PubsubMessage message, @Nullable String label, Clock clock) {
- if (label == null) {
- return new Instant(clock.currentTimeMillis());
- }
-
- // Extract message attributes, defaulting to empty map if null.
- Map attributes = firstNonNull(
- message.getAttributes(), ImmutableMap.of());
-
- String timestampStr = attributes.get(label);
- checkArgument(timestampStr != null && !timestampStr.isEmpty(),
- "PubSub message is missing a timestamp in label: %s", label);
-
- long millisSinceEpoch;
- try {
- // Try parsing as milliseconds since epoch. Note there is no way to parse a string in
- // RFC 3339 format here.
- // Expected IllegalArgumentException if parsing fails; we use that to fall back to RFC 3339.
- millisSinceEpoch = Long.parseLong(timestampStr);
- } catch (IllegalArgumentException e) {
- // Try parsing as RFC3339 string. DateTime.parseRfc3339 will throw an IllegalArgumentException
- // if parsing fails, and the caller should handle.
- millisSinceEpoch = DateTime.parseRfc3339(timestampStr).getValue();
- }
- return new Instant(millisSinceEpoch);
- }
-
/**
* Populate common {@link DisplayData} between Pubsub source and sink.
*/
@@ -415,10 +364,9 @@ public String asPath() {
* the stream.
*
* When running with a {@link PipelineRunner} that only supports bounded
- * {@link PCollection PCollections} (such as {@link DirectPipelineRunner} or
- * {@link DataflowPipelineRunner} without {@code --streaming}), only a bounded portion of the
- * input Pub/Sub stream can be processed. As such, either {@link Bound#maxNumRecords(int)} or
- * {@link Bound#maxReadTime(Duration)} must be set.
+ * {@link PCollection PCollections} (such as {@link DirectPipelineRunner}),
+ * only a bounded portion of the input Pub/Sub stream can be processed. As such, either
+ * {@link Bound#maxNumRecords(int)} or {@link Bound#maxReadTime(Duration)} must be set.
*/
public static class Read {
/**
@@ -684,24 +632,34 @@ public Bound maxReadTime(Duration maxReadTime) {
@Override
public PCollection apply(PInput input) {
if (topic == null && subscription == null) {
- throw new IllegalStateException("need to set either the topic or the subscription for "
+ throw new IllegalStateException("Need to set either the topic or the subscription for "
+ "a PubsubIO.Read transform");
}
if (topic != null && subscription != null) {
- throw new IllegalStateException("Can't set both the topic and the subscription for a "
- + "PubsubIO.Read transform");
+ throw new IllegalStateException("Can't set both the topic and the subscription for "
+ + "a PubsubIO.Read transform");
}
boolean boundedOutput = getMaxNumRecords() > 0 || getMaxReadTime() != null;
if (boundedOutput) {
return input.getPipeline().begin()
- .apply(Create.of((Void) null)).setCoder(VoidCoder.of())
- .apply(ParDo.of(new PubsubReader())).setCoder(coder);
+ .apply(Create.of((Void) null)).setCoder(VoidCoder.of())
+ .apply(ParDo.of(new PubsubBoundedReader())).setCoder(coder);
} else {
- return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
- .setCoder(coder);
+ @Nullable ProjectPath projectPath =
+ topic == null ? null : PubsubClient.projectPathFromId(topic.project);
+ @Nullable TopicPath topicPath =
+ topic == null ? null : PubsubClient.topicPathFromName(topic.project, topic.topic);
+ @Nullable SubscriptionPath subscriptionPath =
+ subscription == null
+ ? null
+ : PubsubClient.subscriptionPathFromName(
+ subscription.project, subscription.subscription);
+ return input.getPipeline().begin()
+ .apply(new PubsubUnboundedSource(
+ FACTORY, projectPath, topicPath, subscriptionPath,
+ coder, timestampLabel, idLabel));
}
}
@@ -755,88 +713,106 @@ public Duration getMaxReadTime() {
return maxReadTime;
}
- private class PubsubReader extends DoFn {
+ /**
+ * Default reader when Pubsub subscription has some form of upper bound.
+ *
+ * TODO: Consider replacing with BoundedReadFromUnboundedSource on top
+ * of PubsubUnboundedSource.
+ *
+ *
NOTE: This is not the implementation used when running on the Google Cloud Dataflow
+ * service in streaming mode.
+ *
+ *
Public so can be suppressed by runners.
+ */
+ public class PubsubBoundedReader extends DoFn {
private static final int DEFAULT_PULL_SIZE = 100;
+ private static final int ACK_TIMEOUT_SEC = 60;
@Override
public void processElement(ProcessContext c) throws IOException {
- Pubsub pubsubClient =
- Transport.newPubsubClient(c.getPipelineOptions().as(DataflowPipelineOptions.class))
- .build();
-
- String subscription;
- if (getSubscription() == null) {
- String topic = getTopic().asPath();
- String[] split = topic.split("/");
- subscription =
- "projects/" + split[1] + "/subscriptions/" + split[3] + "_dataflow_"
- + new Random().nextLong();
- Subscription subInfo = new Subscription().setAckDeadlineSeconds(60).setTopic(topic);
- try {
- pubsubClient.projects().subscriptions().create(subscription, subInfo).execute();
- } catch (Exception e) {
- throw new RuntimeException("Failed to create subscription: ", e);
+ try (PubsubClient pubsubClient =
+ FACTORY.newClient(timestampLabel, idLabel,
+ c.getPipelineOptions().as(DataflowPipelineOptions.class))) {
+
+ PubsubClient.SubscriptionPath subscriptionPath;
+ if (getSubscription() == null) {
+ TopicPath topicPath =
+ PubsubClient.topicPathFromName(getTopic().project, getTopic().topic);
+ // The subscription will be registered under this pipeline's project if we know it.
+ // Otherwise we'll fall back to the topic's project.
+ // Note that they don't need to be the same.
+ String projectId =
+ c.getPipelineOptions().as(DataflowPipelineOptions.class).getProject();
+ if (Strings.isNullOrEmpty(projectId)) {
+ projectId = getTopic().project;
+ }
+ ProjectPath projectPath = PubsubClient.projectPathFromId(projectId);
+ try {
+ subscriptionPath =
+ pubsubClient.createRandomSubscription(projectPath, topicPath, ACK_TIMEOUT_SEC);
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to create subscription: ", e);
+ }
+ } else {
+ subscriptionPath =
+ PubsubClient.subscriptionPathFromName(getSubscription().project,
+ getSubscription().subscription);
}
- } else {
- subscription = getSubscription().asPath();
- }
- Instant endTime = (getMaxReadTime() == null)
- ? new Instant(Long.MAX_VALUE) : Instant.now().plus(getMaxReadTime());
+ Instant endTime = (getMaxReadTime() == null)
+ ? new Instant(Long.MAX_VALUE) : Instant.now().plus(getMaxReadTime());
- List messages = new ArrayList<>();
+ List messages = new ArrayList<>();
- Throwable finallyBlockException = null;
- try {
- while ((getMaxNumRecords() == 0 || messages.size() < getMaxNumRecords())
- && Instant.now().isBefore(endTime)) {
- PullRequest pullRequest = new PullRequest().setReturnImmediately(false);
- if (getMaxNumRecords() > 0) {
- pullRequest.setMaxMessages(getMaxNumRecords() - messages.size());
- } else {
- pullRequest.setMaxMessages(DEFAULT_PULL_SIZE);
- }
+ Throwable finallyBlockException = null;
+ try {
+ while ((getMaxNumRecords() == 0 || messages.size() < getMaxNumRecords())
+ && Instant.now().isBefore(endTime)) {
+ int batchSize = DEFAULT_PULL_SIZE;
+ if (getMaxNumRecords() > 0) {
+ batchSize = Math.min(batchSize, getMaxNumRecords() - messages.size());
+ }
- PullResponse pullResponse =
- pubsubClient.projects().subscriptions().pull(subscription, pullRequest).execute();
- List ackIds = new ArrayList<>();
- if (pullResponse.getReceivedMessages() != null) {
- for (ReceivedMessage received : pullResponse.getReceivedMessages()) {
- messages.add(received.getMessage());
- ackIds.add(received.getAckId());
+ List batchMessages =
+ pubsubClient.pull(System.currentTimeMillis(), subscriptionPath, batchSize,
+ false);
+ List ackIds = new ArrayList<>();
+ for (IncomingMessage message : batchMessages) {
+ messages.add(message);
+ ackIds.add(message.ackId);
+ }
+ if (ackIds.size() != 0) {
+ pubsubClient.acknowledge(subscriptionPath, ackIds);
}
}
-
- if (ackIds.size() != 0) {
- AcknowledgeRequest ackRequest = new AcknowledgeRequest().setAckIds(ackIds);
- pubsubClient.projects()
- .subscriptions()
- .acknowledge(subscription, ackRequest)
- .execute();
+ } catch (IOException e) {
+ throw new RuntimeException("Unexpected exception while reading from Pubsub: ", e);
+ } finally {
+ if (getSubscription() == null) {
+ try {
+ pubsubClient.deleteSubscription(subscriptionPath);
+ } catch (Exception e) {
+ finallyBlockException = e;
+ }
}
}
- } catch (IOException e) {
- throw new RuntimeException("Unexpected exception while reading from Pubsub: ", e);
- } finally {
- if (getTopic() != null) {
- try {
- pubsubClient.projects().subscriptions().delete(subscription).execute();
- } catch (IOException e) {
- finallyBlockException = new RuntimeException("Failed to delete subscription: ", e);
- LOG.error("Failed to delete subscription: ", e);
- }
+ if (finallyBlockException != null) {
+ throw new RuntimeException("Failed to delete subscription: ", finallyBlockException);
}
- }
- if (finallyBlockException != null) {
- throw new RuntimeException(finallyBlockException);
- }
- for (PubsubMessage message : messages) {
- c.outputWithTimestamp(
- CoderUtils.decodeFromByteArray(getCoder(), message.decodeData()),
- assignMessageTimestamp(message, getTimestampLabel(), Clock.SYSTEM));
+ for (IncomingMessage message : messages) {
+ c.outputWithTimestamp(
+ CoderUtils.decodeFromByteArray(getCoder(), message.elementBytes),
+ new Instant(message.timestampMsSinceEpoch));
+ }
}
}
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+ Bound.this.populateDisplayData(builder);
+ }
}
}
@@ -1003,8 +979,20 @@ public PDone apply(PCollection input) {
if (topic == null) {
throw new IllegalStateException("need to set the topic of a PubsubIO.Write transform");
}
- input.apply(ParDo.of(new PubsubWriter()));
- return PDone.in(input.getPipeline());
+ switch (input.isBounded()) {
+ case BOUNDED:
+ input.apply(ParDo.of(new PubsubBoundedWriter()));
+ return PDone.in(input.getPipeline());
+ case UNBOUNDED:
+ return input.apply(new PubsubUnboundedSink(
+ FACTORY,
+ PubsubClient.topicPathFromName(topic.project, topic.topic),
+ coder,
+ timestampLabel,
+ idLabel,
+ 100 /* numShards */));
+ }
+ throw new RuntimeException(); // cases are exhaustive.
}
@Override
@@ -1034,31 +1022,34 @@ public Coder getCoder() {
return coder;
}
- private class PubsubWriter extends DoFn {
+ /**
+ * Writer to Pubsub which batches messages from bounded collections.
+ *
+ * NOTE: This is not the implementation used when running on the Google Cloud Dataflow
+ * service in streaming mode.
+ *
+ *
Public so can be suppressed by runners.
+ */
+ public class PubsubBoundedWriter extends DoFn {
private static final int MAX_PUBLISH_BATCH_SIZE = 100;
- private transient List output;
- private transient Pubsub pubsubClient;
+ private transient List output;
+ private transient PubsubClient pubsubClient;
@Override
- public void startBundle(Context c) {
+ public void startBundle(Context c) throws IOException {
this.output = new ArrayList<>();
+ // NOTE: idLabel is ignored.
this.pubsubClient =
- Transport.newPubsubClient(c.getPipelineOptions().as(DataflowPipelineOptions.class))
- .build();
+ FACTORY.newClient(timestampLabel, null,
+ c.getPipelineOptions().as(DataflowPipelineOptions.class));
}
@Override
public void processElement(ProcessContext c) throws IOException {
- PubsubMessage message =
- new PubsubMessage().encodeData(CoderUtils.encodeToByteArray(getCoder(), c.element()));
- if (getTimestampLabel() != null) {
- Map attributes = message.getAttributes();
- if (attributes == null) {
- attributes = new HashMap<>();
- message.setAttributes(attributes);
- }
- attributes.put(getTimestampLabel(), String.valueOf(c.timestamp().getMillis()));
- }
+ // NOTE: The record id is always null.
+ OutgoingMessage message =
+ new OutgoingMessage(CoderUtils.encodeToByteArray(getCoder(), c.element()),
+ c.timestamp().getMillis(), null);
output.add(message);
if (output.size() >= MAX_PUBLISH_BATCH_SIZE) {
@@ -1071,18 +1062,22 @@ public void finishBundle(Context c) throws IOException {
if (!output.isEmpty()) {
publish();
}
+ output = null;
+ pubsubClient.close();
+ pubsubClient = null;
}
private void publish() throws IOException {
- PublishRequest publishRequest = new PublishRequest().setMessages(output);
- pubsubClient.projects().topics()
- .publish(getTopic().asPath(), publishRequest)
- .execute();
+ int n = pubsubClient.publish(
+ PubsubClient.topicPathFromName(getTopic().project, getTopic().topic),
+ output);
+ checkState(n == output.size());
output.clear();
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
Bound.this.populateDisplayData(builder);
}
}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSink.java
new file mode 100644
index 0000000000..98fc29ba6e
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSink.java
@@ -0,0 +1,445 @@
+/*
+ * Copyright (C) 2015 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.io;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder;
+import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.coders.CustomCoder;
+import com.google.cloud.dataflow.sdk.coders.KvCoder;
+import com.google.cloud.dataflow.sdk.coders.NullableCoder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.coders.VarIntCoder;
+import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
+import com.google.cloud.dataflow.sdk.transforms.Aggregator;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.transforms.Sum;
+import com.google.cloud.dataflow.sdk.transforms.display.DisplayData;
+import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder;
+import com.google.cloud.dataflow.sdk.transforms.windowing.AfterFirst;
+import com.google.cloud.dataflow.sdk.transforms.windowing.AfterPane;
+import com.google.cloud.dataflow.sdk.transforms.windowing.AfterProcessingTime;
+import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows;
+import com.google.cloud.dataflow.sdk.transforms.windowing.Repeatedly;
+import com.google.cloud.dataflow.sdk.transforms.windowing.Window;
+import com.google.cloud.dataflow.sdk.util.CoderUtils;
+import com.google.cloud.dataflow.sdk.util.PubsubClient;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.OutgoingMessage;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.PubsubClientFactory;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.hash.Hashing;
+
+import org.joda.time.Duration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+
+import javax.annotation.Nullable;
+
+/**
+ * A PTransform which streams messages to Pubsub.
+ *
+ * - The underlying implementation is just a {@link GroupByKey} followed by a {@link ParDo} which
+ * publishes as a side effect. (In the future we want to design and switch to a custom
+ * {@code UnboundedSink} implementation so as to gain access to system watermark and
+ * end-of-pipeline cleanup.)
+ *
- We try to send messages in batches while also limiting send latency.
+ *
- No stats are logged. Rather some counters are used to keep track of elements and batches.
+ *
- Though some background threads are used by the underlying netty system all actual Pubsub
+ * calls are blocking. We rely on the underlying runner to allow multiple {@link DoFn} instances
+ * to execute concurrently and hide latency.
+ *
- A failed bundle will cause messages to be resent. Thus we rely on the Pubsub consumer
+ * to dedup messages.
+ *
+ *
+ * NOTE: This is not the implementation used when running on the Google Cloud Dataflow service.
+ */
+public class PubsubUnboundedSink extends PTransform, PDone> {
+ private static final Logger LOG = LoggerFactory.getLogger(PubsubUnboundedSink.class);
+
+ /**
+ * Default maximum number of messages per publish.
+ */
+ private static final int DEFAULT_PUBLISH_BATCH_SIZE = 1000;
+
+ /**
+ * Default maximum size of a publish batch, in bytes.
+ */
+ private static final int DEFAULT_PUBLISH_BATCH_BYTES = 400000;
+
+ /**
+ * Default longest delay between receiving a message and pushing it to Pubsub.
+ */
+ private static final Duration DEFAULT_MAX_LATENCY = Duration.standardSeconds(2);
+
+ /**
+ * Coder for conveying outgoing messages between internal stages.
+ */
+ private static class OutgoingMessageCoder extends CustomCoder {
+ private static final NullableCoder RECORD_ID_CODER =
+ NullableCoder.of(StringUtf8Coder.of());
+
+ @Override
+ public void encode(
+ OutgoingMessage value, OutputStream outStream, Context context)
+ throws CoderException, IOException {
+ ByteArrayCoder.of().encode(value.elementBytes, outStream, Context.NESTED);
+ BigEndianLongCoder.of().encode(value.timestampMsSinceEpoch, outStream, Context.NESTED);
+ RECORD_ID_CODER.encode(value.recordId, outStream, Context.NESTED);
+ }
+
+ @Override
+ public OutgoingMessage decode(
+ InputStream inStream, Context context) throws CoderException, IOException {
+ byte[] elementBytes = ByteArrayCoder.of().decode(inStream, Context.NESTED);
+ long timestampMsSinceEpoch = BigEndianLongCoder.of().decode(inStream, Context.NESTED);
+ @Nullable String recordId = RECORD_ID_CODER.decode(inStream, Context.NESTED);
+ return new OutgoingMessage(elementBytes, timestampMsSinceEpoch, recordId);
+ }
+ }
+
+ @VisibleForTesting
+ static final Coder CODER = new OutgoingMessageCoder();
+
+ // ================================================================================
+ // RecordIdMethod
+ // ================================================================================
+
+ /**
+ * Specify how record ids are to be generated.
+ */
+ @VisibleForTesting
+ enum RecordIdMethod {
+ /** Leave null. */
+ NONE,
+ /** Generate randomly. */
+ RANDOM,
+ /** Generate deterministically. For testing only. */
+ DETERMINISTIC
+ }
+
+ // ================================================================================
+ // ShardFn
+ // ================================================================================
+
+ /**
+ * Convert elements to messages and shard them.
+ */
+ private static class ShardFn extends DoFn> {
+ private final Aggregator elementCounter =
+ createAggregator("elements", new Sum.SumLongFn());
+ private final Coder elementCoder;
+ private final int numShards;
+ private final RecordIdMethod recordIdMethod;
+
+ ShardFn(Coder elementCoder, int numShards, RecordIdMethod recordIdMethod) {
+ this.elementCoder = elementCoder;
+ this.numShards = numShards;
+ this.recordIdMethod = recordIdMethod;
+ }
+
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ elementCounter.addValue(1L);
+ byte[] elementBytes = CoderUtils.encodeToByteArray(elementCoder, c.element());
+ long timestampMsSinceEpoch = c.timestamp().getMillis();
+ @Nullable String recordId = null;
+ switch (recordIdMethod) {
+ case NONE:
+ break;
+ case DETERMINISTIC:
+ recordId = Hashing.murmur3_128().hashBytes(elementBytes).toString();
+ break;
+ case RANDOM:
+ // Since these elements go through a GroupByKey, any failures while sending to
+ // Pubsub will be retried without falling back and generating a new record id.
+ // Thus even though we may send the same message to Pubsub twice, it is guaranteed
+ // to have the same record id.
+ recordId = UUID.randomUUID().toString();
+ break;
+ }
+ c.output(KV.of(ThreadLocalRandom.current().nextInt(numShards),
+ new OutgoingMessage(elementBytes, timestampMsSinceEpoch, recordId)));
+ }
+
+ @Override
+ public void populateDisplayData(Builder builder) {
+ super.populateDisplayData(builder);
+ builder.add(DisplayData.item("numShards", numShards));
+ }
+ }
+
+ // ================================================================================
+ // WriterFn
+ // ================================================================================
+
+ /**
+ * Publish messages to Pubsub in batches.
+ */
+ private static class WriterFn
+ extends DoFn>, Void> {
+ private final PubsubClientFactory pubsubFactory;
+ private final TopicPath topic;
+ private final String timestampLabel;
+ private final String idLabel;
+ private final int publishBatchSize;
+ private final int publishBatchBytes;
+
+ /**
+ * Client on which to talk to Pubsub. Null until created by {@link #startBundle}.
+ */
+ @Nullable
+ private transient PubsubClient pubsubClient;
+
+ private final Aggregator batchCounter =
+ createAggregator("batches", new Sum.SumLongFn());
+ private final Aggregator elementCounter =
+ createAggregator("elements", new Sum.SumLongFn());
+ private final Aggregator byteCounter =
+ createAggregator("bytes", new Sum.SumLongFn());
+
+ WriterFn(
+ PubsubClientFactory pubsubFactory, TopicPath topic, String timestampLabel,
+ String idLabel, int publishBatchSize, int publishBatchBytes) {
+ this.pubsubFactory = pubsubFactory;
+ this.topic = topic;
+ this.timestampLabel = timestampLabel;
+ this.idLabel = idLabel;
+ this.publishBatchSize = publishBatchSize;
+ this.publishBatchBytes = publishBatchBytes;
+ }
+
+ /**
+ * BLOCKING
+ * Send {@code messages} as a batch to Pubsub.
+ */
+ private void publishBatch(List messages, int bytes)
+ throws IOException {
+ long nowMsSinceEpoch = System.currentTimeMillis();
+ int n = pubsubClient.publish(topic, messages);
+ checkState(n == messages.size(), "Attempted to publish %d messages but %d were successful",
+ messages.size(), n);
+ batchCounter.addValue(1L);
+ elementCounter.addValue((long) messages.size());
+ byteCounter.addValue((long) bytes);
+ }
+
+ @Override
+ public void startBundle(Context c) throws Exception {
+ checkState(pubsubClient == null, "startBundle invoked without prior finishBundle");
+ pubsubClient =
+ pubsubFactory.newClient(timestampLabel, idLabel,
+ c.getPipelineOptions().as(DataflowPipelineOptions.class));
+ }
+
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ List pubsubMessages = new ArrayList<>(publishBatchSize);
+ int bytes = 0;
+ for (OutgoingMessage message : c.element().getValue()) {
+ if (!pubsubMessages.isEmpty()
+ && bytes + message.elementBytes.length > publishBatchBytes) {
+ // Break large (in bytes) batches into smaller.
+ // (We've already broken by batch size using the trigger below, though that may
+ // run slightly over the actual PUBLISH_BATCH_SIZE. We'll consider that ok since
+ // the hard limit from Pubsub is by bytes rather than number of messages.)
+ // BLOCKS until published.
+ publishBatch(pubsubMessages, bytes);
+ pubsubMessages.clear();
+ bytes = 0;
+ }
+ pubsubMessages.add(message);
+ bytes += message.elementBytes.length;
+ }
+ if (!pubsubMessages.isEmpty()) {
+ // BLOCKS until published.
+ publishBatch(pubsubMessages, bytes);
+ }
+ }
+
+ @Override
+ public void finishBundle(Context c) throws Exception {
+ pubsubClient.close();
+ pubsubClient = null;
+ }
+
+ @Override
+ public void populateDisplayData(Builder builder) {
+ super.populateDisplayData(builder);
+ builder.add(DisplayData.item("topic", topic.getPath()));
+ builder.add(DisplayData.item("transport", pubsubFactory.getKind()));
+ builder.addIfNotNull(DisplayData.item("timestampLabel", timestampLabel));
+ builder.addIfNotNull(DisplayData.item("idLabel", idLabel));
+ }
+ }
+
+ // ================================================================================
+ // PubsubUnboundedSink
+ // ================================================================================
+
+ /**
+ * Which factory to use for creating Pubsub transport.
+ */
+ private final PubsubClientFactory pubsubFactory;
+
+ /**
+ * Pubsub topic to publish to.
+ */
+ private final TopicPath topic;
+
+ /**
+ * Coder for elements. It is the responsibility of the underlying Pubsub transport to
+ * re-encode element bytes if necessary, eg as Base64 strings.
+ */
+ private final Coder elementCoder;
+
+ /**
+ * Pubsub metadata field holding timestamp of each element, or {@literal null} if should use
+ * Pubsub message publish timestamp instead.
+ */
+ @Nullable
+ private final String timestampLabel;
+
+ /**
+ * Pubsub metadata field holding id for each element, or {@literal null} if need to generate
+ * a unique id ourselves.
+ */
+ @Nullable
+ private final String idLabel;
+
+ /**
+ * Number of 'shards' to use so that latency in Pubsub publish can be hidden. Generally this
+ * should be a small multiple of the number of available cores. Too smoll a number results
+ * in too much time lost to blocking Pubsub calls. To large a number results in too many
+ * single-element batches being sent to Pubsub with high per-batch overhead.
+ */
+ private final int numShards;
+
+ /**
+ * Maximum number of messages per publish.
+ */
+ private final int publishBatchSize;
+
+ /**
+ * Maximum size of a publish batch, in bytes.
+ */
+ private final int publishBatchBytes;
+
+ /**
+ * Longest delay between receiving a message and pushing it to Pubsub.
+ */
+ private final Duration maxLatency;
+
+ /**
+ * How record ids should be generated for each record (if {@link #idLabel} is non-{@literal
+ * null}).
+ */
+ private final RecordIdMethod recordIdMethod;
+
+ @VisibleForTesting
+ PubsubUnboundedSink(
+ PubsubClientFactory pubsubFactory,
+ TopicPath topic,
+ Coder elementCoder,
+ String timestampLabel,
+ String idLabel,
+ int numShards,
+ int publishBatchSize,
+ int publishBatchBytes,
+ Duration maxLatency,
+ RecordIdMethod recordIdMethod) {
+ this.pubsubFactory = pubsubFactory;
+ this.topic = topic;
+ this.elementCoder = elementCoder;
+ this.timestampLabel = timestampLabel;
+ this.idLabel = idLabel;
+ this.numShards = numShards;
+ this.publishBatchSize = publishBatchSize;
+ this.publishBatchBytes = publishBatchBytes;
+ this.maxLatency = maxLatency;
+ this.recordIdMethod = idLabel == null ? RecordIdMethod.NONE : recordIdMethod;
+ }
+
+ public PubsubUnboundedSink(
+ PubsubClientFactory pubsubFactory,
+ TopicPath topic,
+ Coder elementCoder,
+ String timestampLabel,
+ String idLabel,
+ int numShards) {
+ this(pubsubFactory, topic, elementCoder, timestampLabel, idLabel, numShards,
+ DEFAULT_PUBLISH_BATCH_SIZE, DEFAULT_PUBLISH_BATCH_BYTES, DEFAULT_MAX_LATENCY,
+ RecordIdMethod.RANDOM);
+ }
+
+ public TopicPath getTopic() {
+ return topic;
+ }
+
+ @Nullable
+ public String getTimestampLabel() {
+ return timestampLabel;
+ }
+
+ @Nullable
+ public String getIdLabel() {
+ return idLabel;
+ }
+
+ public Coder getElementCoder() {
+ return elementCoder;
+ }
+
+ @Override
+ public PDone apply(PCollection input) {
+ input.apply(
+ Window.named("PubsubUnboundedSink.Window")
+ .into(new GlobalWindows())
+ .triggering(
+ Repeatedly.forever(
+ AfterFirst.of(AfterPane.elementCountAtLeast(publishBatchSize),
+ AfterProcessingTime.pastFirstElementInPane()
+ .plusDelayOf(maxLatency))))
+ .discardingFiredPanes())
+ .apply(ParDo.named("PubsubUnboundedSink.Shard")
+ .of(new ShardFn(elementCoder, numShards, recordIdMethod)))
+ .setCoder(KvCoder.of(VarIntCoder.of(), CODER))
+ .apply(GroupByKey.create())
+ .apply(ParDo.named("PubsubUnboundedSink.Writer")
+ .of(new WriterFn(pubsubFactory, topic, timestampLabel, idLabel,
+ publishBatchSize, publishBatchBytes)));
+ return PDone.in(input.getPipeline());
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java
new file mode 100644
index 0000000000..6c60877811
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java
@@ -0,0 +1,1300 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.io;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.api.client.util.Clock;
+
+import com.google.cloud.dataflow.sdk.coders.AtomicCoder;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.coders.ListCoder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.transforms.Aggregator;
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.transforms.Sum;
+import com.google.cloud.dataflow.sdk.transforms.Sum.SumLongFn;
+import com.google.cloud.dataflow.sdk.transforms.display.DisplayData;
+import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder;
+import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow;
+import com.google.cloud.dataflow.sdk.util.BucketingFunction;
+import com.google.cloud.dataflow.sdk.util.CoderUtils;
+import com.google.cloud.dataflow.sdk.util.MovingFunction;
+import com.google.cloud.dataflow.sdk.util.PubsubClient;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.ProjectPath;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.PubsubClientFactory;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.SubscriptionPath;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath;
+import com.google.cloud.dataflow.sdk.values.PBegin;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Charsets;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.security.GeneralSecurityException;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.annotation.Nullable;
+
+/**
+ * A PTransform which streams messages from Pubsub.
+ *
+ * - The underlying implementation in an {@link UnboundedSource} which receives messages
+ * in batches and hands them out one at a time.
+ *
- The watermark (either in Pubsub processing time or custom timestamp time) is estimated
+ * by keeping track of the minimum of the last minutes worth of messages. This assumes Pubsub
+ * delivers the oldest (in Pubsub processing time) available message at least once a minute,
+ * and that custom timestamps are 'mostly' monotonic with Pubsub processing time. Unfortunately
+ * both of those assumptions are fragile. Thus the estimated watermark may get ahead of
+ * the 'true' watermark and cause some messages to be late.
+ *
- Checkpoints are used both to ACK received messages back to Pubsub (so that they may
+ * be retired on the Pubsub end), and to NACK already consumed messages should a checkpoint
+ * need to be restored (so that Pubsub will resend those messages promptly).
+ *
- The backlog is determined by each reader using the messages which have been pulled from
+ * Pubsub but not yet consumed downstream. The backlog does not take account of any messages queued
+ * by Pubsub for the subscription. Unfortunately there is currently no API to determine
+ * the size of the Pubsub queue's backlog.
+ *
- The subscription must already exist.
+ *
- The subscription timeout is read whenever a reader is started. However it is not
+ * checked thereafter despite the timeout being user-changeable on-the-fly.
+ *
- We log vital stats every 30 seconds.
+ *
- Though some background threads may be used by the underlying transport all Pubsub calls
+ * are blocking. We rely on the underlying runner to allow multiple
+ * {@link UnboundedSource.UnboundedReader} instances to execute concurrently and thus hide latency.
+ *
+ *
+ * NOTE: This is not the implementation used when running on the Google Cloud Dataflow service.
+ */
+public class PubsubUnboundedSource extends PTransform> {
+ private static final Logger LOG = LoggerFactory.getLogger(PubsubUnboundedSource.class);
+
+ /**
+ * Default ACK timeout for created subscriptions.
+ */
+ private static final int DEAULT_ACK_TIMEOUT_SEC = 60;
+
+ /**
+ * Coder for checkpoints.
+ */
+ private static final PubsubCheckpointCoder> CHECKPOINT_CODER = new PubsubCheckpointCoder<>();
+
+ /**
+ * Maximum number of messages per pull.
+ */
+ private static final int PULL_BATCH_SIZE = 1000;
+
+ /**
+ * Maximum number of ACK ids per ACK or ACK extension call.
+ */
+ private static final int ACK_BATCH_SIZE = 2000;
+
+ /**
+ * Maximum number of messages in flight.
+ */
+ private static final int MAX_IN_FLIGHT = 20000;
+
+ /**
+ * Timeout for round trip from receiving a message to finally ACKing it back to Pubsub.
+ */
+ private static final Duration PROCESSING_TIMEOUT = Duration.standardSeconds(120);
+
+ /**
+ * Percentage of ack timeout by which to extend acks when they are near timeout.
+ */
+ private static final int ACK_EXTENSION_PCT = 50;
+
+ /**
+ * Percentage of ack timeout we should use as a safety margin. We'll try to extend acks
+ * by this margin before the ack actually expires.
+ */
+ private static final int ACK_SAFETY_PCT = 20;
+
+ /**
+ * For stats only: How close we can get to an ack deadline before we risk it being already
+ * considered passed by Pubsub.
+ */
+ private static final Duration ACK_TOO_LATE = Duration.standardSeconds(2);
+
+ /**
+ * Period of samples to determine watermark and other stats.
+ */
+ private static final Duration SAMPLE_PERIOD = Duration.standardMinutes(1);
+
+ /**
+ * Period of updates to determine watermark and other stats.
+ */
+ private static final Duration SAMPLE_UPDATE = Duration.standardSeconds(5);
+
+ /**
+ * Period for logging stats.
+ */
+ private static final Duration LOG_PERIOD = Duration.standardSeconds(30);
+
+ /**
+ * Minimum number of unread messages required before considering updating watermark.
+ */
+ private static final int MIN_WATERMARK_MESSAGES = 10;
+
+ /**
+ * Minimum number of SAMPLE_UPDATE periods over which unread messages shoud be spread
+ * before considering updating watermark.
+ */
+ private static final int MIN_WATERMARK_SPREAD = 2;
+
+ /**
+ * Additional sharding so that we can hide read message latency.
+ */
+ private static final int SCALE_OUT = 4;
+
+ // TODO: Would prefer to use MinLongFn but it is a BinaryCombineFn rather
+ // than a BinaryCombineLongFn. [BEAM-285]
+ private static final Combine.BinaryCombineLongFn MIN =
+ new Combine.BinaryCombineLongFn() {
+ @Override
+ public long apply(long left, long right) {
+ return Math.min(left, right);
+ }
+
+ @Override
+ public long identity() {
+ return Long.MAX_VALUE;
+ }
+ };
+
+ private static final Combine.BinaryCombineLongFn MAX =
+ new Combine.BinaryCombineLongFn() {
+ @Override
+ public long apply(long left, long right) {
+ return Math.max(left, right);
+ }
+
+ @Override
+ public long identity() {
+ return Long.MIN_VALUE;
+ }
+ };
+
+ private static final Combine.BinaryCombineLongFn SUM = new SumLongFn();
+
+ // ================================================================================
+ // Checkpoint
+ // ================================================================================
+
+ /**
+ * Which messages have been durably committed and thus can now be ACKed.
+ * Which messages have been read but not yet committed, in which case they should be NACKed if
+ * we need to restore.
+ */
+ @VisibleForTesting
+ static class PubsubCheckpoint implements UnboundedSource.CheckpointMark {
+ /**
+ * If the checkpoint is for persisting: the reader who's snapshotted state we are persisting.
+ * If the checkpoint is for restoring: initially {@literal null}, then explicitly set.
+ * Not persisted in durable checkpoint.
+ * CAUTION: Between a checkpoint being taken and {@link #finalizeCheckpoint()} being called
+ * the 'true' active reader may have changed.
+ */
+ @Nullable
+ private PubsubReader reader;
+
+ /**
+ * If the checkpoint is for persisting: The ACK ids of messages which have been passed
+ * downstream since the last checkpoint.
+ * If the checkpoint is for restoring: {@literal null}.
+ * Not persisted in durable checkpoint.
+ */
+ @Nullable
+ private final List safeToAckIds;
+
+ /**
+ * If the checkpoint is for persisting: The ACK ids of messages which have been received
+ * from Pubsub but not yet passed downstream at the time of the snapshot.
+ * If the checkpoint is for restoring: Same, but recovered from durable storage.
+ */
+ @VisibleForTesting
+ final List notYetReadIds;
+
+ public PubsubCheckpoint(
+ @Nullable PubsubReader reader, @Nullable List safeToAckIds,
+ List notYetReadIds) {
+ this.reader = reader;
+ this.safeToAckIds = safeToAckIds;
+ this.notYetReadIds = notYetReadIds;
+ }
+
+ /**
+ * BLOCKING
+ * All messages which have been passed downstream have now been durably committed.
+ * We can ACK them upstream.
+ * CAUTION: This may never be called.
+ */
+ @Override
+ public void finalizeCheckpoint() throws IOException {
+ checkState(reader != null && safeToAckIds != null, "Cannot finalize a restored checkpoint");
+ // Even if the 'true' active reader has changed since the checkpoint was taken we are
+ // fine:
+ // - The underlying Pubsub topic will not have changed, so the following ACKs will still
+ // go to the right place.
+ // - We'll delete the ACK ids from the readers in-flight state, but that only effects
+ // flow control and stats, neither of which are relevant anymore.
+ try {
+ int n = safeToAckIds.size();
+ List batchSafeToAckIds = new ArrayList<>(Math.min(n, ACK_BATCH_SIZE));
+ for (String ackId : safeToAckIds) {
+ batchSafeToAckIds.add(ackId);
+ if (batchSafeToAckIds.size() >= ACK_BATCH_SIZE) {
+ reader.ackBatch(batchSafeToAckIds);
+ n -= batchSafeToAckIds.size();
+ // CAUTION: Don't reuse the same list since ackBatch holds on to it.
+ batchSafeToAckIds = new ArrayList<>(Math.min(n, ACK_BATCH_SIZE));
+ }
+ }
+ if (!batchSafeToAckIds.isEmpty()) {
+ reader.ackBatch(batchSafeToAckIds);
+ }
+ } finally {
+ checkState(reader.numInFlightCheckpoints.decrementAndGet() >= 0,
+ "Miscounted in-flight checkpoints");
+ }
+ }
+
+ /**
+ * Return current time according to {@code reader}.
+ */
+ private static long now(PubsubReader reader) {
+ if (reader.outer.outer.clock == null) {
+ return System.currentTimeMillis();
+ } else {
+ return reader.outer.outer.clock.currentTimeMillis();
+ }
+ }
+
+ /**
+ * BLOCKING
+ * NACK all messages which have been read from Pubsub but not passed downstream.
+ * This way Pubsub will send them again promptly.
+ */
+ public void nackAll(PubsubReader reader) throws IOException {
+ checkState(this.reader == null, "Cannot nackAll on persisting checkpoint");
+ List batchYetToAckIds =
+ new ArrayList<>(Math.min(notYetReadIds.size(), ACK_BATCH_SIZE));
+ for (String ackId : notYetReadIds) {
+ batchYetToAckIds.add(ackId);
+ if (batchYetToAckIds.size() >= ACK_BATCH_SIZE) {
+ long nowMsSinceEpoch = now(reader);
+ reader.nackBatch(nowMsSinceEpoch, batchYetToAckIds);
+ batchYetToAckIds.clear();
+ }
+ }
+ if (!batchYetToAckIds.isEmpty()) {
+ long nowMsSinceEpoch = now(reader);
+ reader.nackBatch(nowMsSinceEpoch, batchYetToAckIds);
+ }
+ }
+ }
+
+ /**
+ * The coder for our checkpoints.
+ */
+ private static class PubsubCheckpointCoder extends AtomicCoder> {
+ private static final Coder> LIST_CODER = ListCoder.of(StringUtf8Coder.of());
+
+ @Override
+ public void encode(PubsubCheckpoint value, OutputStream outStream, Context context)
+ throws IOException {
+ LIST_CODER.encode(value.notYetReadIds, outStream, context);
+ }
+
+ @Override
+ public PubsubCheckpoint decode(InputStream inStream, Context context) throws IOException {
+ List notYetReadIds = LIST_CODER.decode(inStream, context);
+ return new PubsubCheckpoint<>(null, null, notYetReadIds);
+ }
+ }
+
+ // ================================================================================
+ // Reader
+ // ================================================================================
+
+ /**
+ * A reader which keeps track of which messages have been received from Pubsub
+ * but not yet consumed downstream and/or ACKed back to Pubsub.
+ */
+ @VisibleForTesting
+ static class PubsubReader extends UnboundedSource.UnboundedReader {
+ /**
+ * For access to topic and checkpointCoder.
+ */
+ private final PubsubSource outer;
+
+ /**
+ * Client on which to talk to Pubsub. Null if closed.
+ */
+ @Nullable
+ private PubsubClient pubsubClient;
+
+ /**
+ * Ack timeout, in ms, as set on subscription when we first start reading. Not
+ * updated thereafter. -1 if not yet determined.
+ */
+ private int ackTimeoutMs;
+
+ /**
+ * ACK ids of messages we have delivered downstream but not yet ACKed.
+ */
+ private Set safeToAckIds;
+
+ /**
+ * Messages we have received from Pubsub and not yet delivered downstream.
+ * We preserve their order.
+ */
+ private final Queue notYetRead;
+
+ private static class InFlightState {
+ /**
+ * When request which yielded message was issues.
+ */
+ long requestTimeMsSinceEpoch;
+
+ /**
+ * When Pubsub will consider this message's ACK to timeout and thus it needs to be
+ * extended.
+ */
+ long ackDeadlineMsSinceEpoch;
+
+ public InFlightState(long requestTimeMsSinceEpoch, long ackDeadlineMsSinceEpoch) {
+ this.requestTimeMsSinceEpoch = requestTimeMsSinceEpoch;
+ this.ackDeadlineMsSinceEpoch = ackDeadlineMsSinceEpoch;
+ }
+ }
+
+ /**
+ * Map from ACK ids of messages we have received from Pubsub but not yet ACKed to their
+ * in flight state. Ordered from earliest to latest ACK deadline.
+ */
+ private final LinkedHashMap inFlight;
+
+ /**
+ * Batches of successfully ACKed ids which need to be pruned from the above.
+ * CAUTION: Accessed by both reader and checkpointing threads.
+ */
+ private final Queue> ackedIds;
+
+ /**
+ * Byte size of undecoded elements in {@link #notYetRead}.
+ */
+ private long notYetReadBytes;
+
+ /**
+ * Bucketed map from received time (as system time, ms since epoch) to message
+ * timestamps (mssince epoch) of all received but not-yet read messages.
+ * Used to estimate watermark.
+ */
+ private BucketingFunction minUnreadTimestampMsSinceEpoch;
+
+ /**
+ * Minimum of timestamps (ms since epoch) of all recently read messages.
+ * Used to estimate watermark.
+ */
+ private MovingFunction minReadTimestampMsSinceEpoch;
+
+ /**
+ * System time (ms since epoch) we last received a message from Pubsub, or -1 if
+ * not yet received any messages.
+ */
+ private long lastReceivedMsSinceEpoch;
+
+ /**
+ * The last reported watermark (ms since epoch), or beginning of time if none yet reported.
+ */
+ private long lastWatermarkMsSinceEpoch;
+
+ /**
+ * The current message, or {@literal null} if none.
+ */
+ @Nullable
+ private PubsubClient.IncomingMessage current;
+
+ /**
+ * Stats only: System time (ms since epoch) we last logs stats, or -1 if never.
+ */
+ private long lastLogTimestampMsSinceEpoch;
+
+ /**
+ * Stats only: Total number of messages received.
+ */
+ private long numReceived;
+
+ /**
+ * Stats only: Number of messages which have recently been received.
+ */
+ private MovingFunction numReceivedRecently;
+
+ /**
+ * Stats only: Number of messages which have recently had their deadline extended.
+ */
+ private MovingFunction numExtendedDeadlines;
+
+ /**
+ * Stats only: Number of messages which have recenttly had their deadline extended even
+ * though it may be too late to do so.
+ */
+ private MovingFunction numLateDeadlines;
+
+
+ /**
+ * Stats only: Number of messages which have recently been ACKed.
+ */
+ private MovingFunction numAcked;
+
+ /**
+ * Stats only: Number of messages which have recently expired (ACKs were extended for too
+ * long).
+ */
+ private MovingFunction numExpired;
+
+ /**
+ * Stats only: Number of messages which have recently been NACKed.
+ */
+ private MovingFunction numNacked;
+
+ /**
+ * Stats only: Number of message bytes which have recently been read by downstream consumer.
+ */
+ private MovingFunction numReadBytes;
+
+ /**
+ * Stats only: Minimum of timestamp (ms since epoch) of all recently received messages.
+ * Used to estimate timestamp skew. Does not contribute to watermark estimator.
+ */
+ private MovingFunction minReceivedTimestampMsSinceEpoch;
+
+ /**
+ * Stats only: Maximum of timestamp (ms since epoch) of all recently received messages.
+ * Used to estimate timestamp skew.
+ */
+ private MovingFunction maxReceivedTimestampMsSinceEpoch;
+
+ /**
+ * Stats only: Minimum of recent estimated watermarks (ms since epoch).
+ */
+ private MovingFunction minWatermarkMsSinceEpoch;
+
+ /**
+ * Stats ony: Maximum of recent estimated watermarks (ms since epoch).
+ */
+ private MovingFunction maxWatermarkMsSinceEpoch;
+
+ /**
+ * Stats only: Number of messages with timestamps strictly behind the estimated watermark
+ * at the time they are received. These may be considered 'late' by downstream computations.
+ */
+ private MovingFunction numLateMessages;
+
+ /**
+ * Stats only: Current number of checkpoints in flight.
+ * CAUTION: Accessed by both checkpointing and reader threads.
+ */
+ private AtomicInteger numInFlightCheckpoints;
+
+ /**
+ * Stats only: Maximum number of checkpoints in flight at any time.
+ */
+ private int maxInFlightCheckpoints;
+
+ private static MovingFunction newFun(Combine.BinaryCombineLongFn function) {
+ return new MovingFunction(SAMPLE_PERIOD.getMillis(),
+ SAMPLE_UPDATE.getMillis(),
+ MIN_WATERMARK_SPREAD,
+ MIN_WATERMARK_MESSAGES,
+ function);
+ }
+
+ /**
+ * Construct a reader.
+ */
+ public PubsubReader(DataflowPipelineOptions options, PubsubSource outer)
+ throws IOException, GeneralSecurityException {
+ this.outer = outer;
+ pubsubClient =
+ outer.outer.pubsubFactory.newClient(outer.outer.timestampLabel, outer.outer.idLabel,
+ options);
+ ackTimeoutMs = -1;
+ safeToAckIds = new HashSet<>();
+ notYetRead = new ArrayDeque<>();
+ inFlight = new LinkedHashMap<>();
+ ackedIds = new ConcurrentLinkedQueue<>();
+ notYetReadBytes = 0;
+ minUnreadTimestampMsSinceEpoch = new BucketingFunction(SAMPLE_UPDATE.getMillis(),
+ MIN_WATERMARK_SPREAD,
+ MIN_WATERMARK_MESSAGES,
+ MIN);
+ minReadTimestampMsSinceEpoch = newFun(MIN);
+ lastReceivedMsSinceEpoch = -1;
+ lastWatermarkMsSinceEpoch = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis();
+ current = null;
+ lastLogTimestampMsSinceEpoch = -1;
+ numReceived = 0L;
+ numReceivedRecently = newFun(SUM);
+ numExtendedDeadlines = newFun(SUM);
+ numLateDeadlines = newFun(SUM);
+ numAcked = newFun(SUM);
+ numExpired = newFun(SUM);
+ numNacked = newFun(SUM);
+ numReadBytes = newFun(SUM);
+ minReceivedTimestampMsSinceEpoch = newFun(MIN);
+ maxReceivedTimestampMsSinceEpoch = newFun(MAX);
+ minWatermarkMsSinceEpoch = newFun(MIN);
+ maxWatermarkMsSinceEpoch = newFun(MAX);
+ numLateMessages = newFun(SUM);
+ numInFlightCheckpoints = new AtomicInteger();
+ maxInFlightCheckpoints = 0;
+ }
+
+ @VisibleForTesting
+ PubsubClient getPubsubClient() {
+ return pubsubClient;
+ }
+
+ /**
+ * BLOCKING
+ * ACK {@code ackIds} back to Pubsub.
+ * CAUTION: May be invoked from a separate checkpointing thread.
+ * CAUTION: Retains {@code ackIds}.
+ */
+ void ackBatch(List ackIds) throws IOException {
+ pubsubClient.acknowledge(outer.outer.subscription, ackIds);
+ ackedIds.add(ackIds);
+ }
+
+ /**
+ * BLOCKING
+ * NACK (ie request deadline extension of 0) receipt of messages from Pubsub
+ * with the given {@code ockIds}. Does not retain {@code ackIds}.
+ */
+ public void nackBatch(long nowMsSinceEpoch, List ackIds) throws IOException {
+ pubsubClient.modifyAckDeadline(outer.outer.subscription, ackIds, 0);
+ numNacked.add(nowMsSinceEpoch, ackIds.size());
+ }
+
+ /**
+ * BLOCKING
+ * Extend the processing deadline for messages from Pubsub with the given {@code ackIds}.
+ * Does not retain {@code ackIds}.
+ */
+ private void extendBatch(long nowMsSinceEpoch, List ackIds) throws IOException {
+ int extensionSec = (ackTimeoutMs * ACK_EXTENSION_PCT) / (100 * 1000);
+ pubsubClient.modifyAckDeadline(outer.outer.subscription, ackIds, extensionSec);
+ numExtendedDeadlines.add(nowMsSinceEpoch, ackIds.size());
+ }
+
+ /**
+ * Return the current time, in ms since epoch.
+ */
+ private long now() {
+ if (outer.outer.clock == null) {
+ return System.currentTimeMillis();
+ } else {
+ return outer.outer.clock.currentTimeMillis();
+ }
+ }
+
+ /**
+ * Messages which have been ACKed (via the checkpoint finalize) are no longer in flight.
+ * This is only used for flow control and stats.
+ */
+ private void retire() throws IOException {
+ long nowMsSinceEpoch = now();
+ while (true) {
+ List ackIds = ackedIds.poll();
+ if (ackIds == null) {
+ return;
+ }
+ numAcked.add(nowMsSinceEpoch, ackIds.size());
+ for (String ackId : ackIds) {
+ inFlight.remove(ackId);
+ safeToAckIds.remove(ackId);
+ }
+ }
+ }
+
+ /**
+ * BLOCKING
+ * Extend deadline for all messages which need it.
+ * CAUTION: If extensions can't keep up with wallclock then we'll never return.
+ */
+ private void extend() throws IOException {
+ while (true) {
+ long nowMsSinceEpoch = now();
+ List assumeExpired = new ArrayList<>();
+ List toBeExtended = new ArrayList<>();
+ List toBeExpired = new ArrayList<>();
+ // Messages will be in increasing deadline order.
+ for (Map.Entry entry : inFlight.entrySet()) {
+ if (entry.getValue().ackDeadlineMsSinceEpoch - (ackTimeoutMs * ACK_SAFETY_PCT) / 100
+ > nowMsSinceEpoch) {
+ // All remaining messages don't need their ACKs to be extended.
+ break;
+ }
+
+ if (entry.getValue().ackDeadlineMsSinceEpoch - ACK_TOO_LATE.getMillis()
+ < nowMsSinceEpoch) {
+ // Pubsub may have already considered this message to have expired.
+ // If so it will (eventually) be made available on a future pull request.
+ // If this message ends up being committed then it will be considered a duplicate
+ // when re-pulled.
+ assumeExpired.add(entry.getKey());
+ continue;
+ }
+
+ if (entry.getValue().requestTimeMsSinceEpoch + PROCESSING_TIMEOUT.getMillis()
+ < nowMsSinceEpoch) {
+ // This message has been in-flight for too long.
+ // Give up on it, otherwise we risk extending its ACK indefinitely.
+ toBeExpired.add(entry.getKey());
+ continue;
+ }
+
+ // Extend the ACK for this message.
+ toBeExtended.add(entry.getKey());
+ if (toBeExtended.size() >= ACK_BATCH_SIZE) {
+ // Enough for one batch.
+ break;
+ }
+ }
+
+ if (assumeExpired.isEmpty() && toBeExtended.isEmpty() && toBeExpired.isEmpty()) {
+ // Nothing to be done.
+ return;
+ }
+
+ if (!assumeExpired.isEmpty()) {
+ // If we didn't make the ACK deadline assume expired and no longer in flight.
+ numLateDeadlines.add(nowMsSinceEpoch, assumeExpired.size());
+ for (String ackId : assumeExpired) {
+ inFlight.remove(ackId);
+ }
+ }
+
+ if (!toBeExpired.isEmpty()) {
+ // Expired messages are no longer considered in flight.
+ numExpired.add(nowMsSinceEpoch, toBeExpired.size());
+ for (String ackId : toBeExpired) {
+ inFlight.remove(ackId);
+ }
+ }
+
+ if (!toBeExtended.isEmpty()) {
+ // Pubsub extends acks from it's notion of current time.
+ // We'll try to track that on our side, but note the deadlines won't necessarily agree.
+ long newDeadlineMsSinceEpoch = nowMsSinceEpoch + (ackTimeoutMs * ACK_EXTENSION_PCT) / 100;
+ for (String ackId : toBeExtended) {
+ // Maintain increasing ack deadline order.
+ InFlightState state = inFlight.remove(ackId);
+ inFlight.put(ackId,
+ new InFlightState(state.requestTimeMsSinceEpoch, newDeadlineMsSinceEpoch));
+ }
+ // BLOCKs until extended.
+ extendBatch(nowMsSinceEpoch, toBeExtended);
+ }
+ }
+ }
+
+ /**
+ * BLOCKING
+ * Fetch another batch of messages from Pubsub.
+ */
+ private void pull() throws IOException {
+ if (inFlight.size() >= MAX_IN_FLIGHT) {
+ // Wait for checkpoint to be finalized before pulling anymore.
+ // There may be lag while checkpoints are persisted and the finalizeCheckpoint method
+ // is invoked. By limiting the in-flight messages we can ensure we don't end up consuming
+ // messages faster than we can checkpoint them.
+ return;
+ }
+
+ long requestTimeMsSinceEpoch = now();
+ long deadlineMsSinceEpoch = requestTimeMsSinceEpoch + ackTimeoutMs;
+
+ // Pull the next batch.
+ // BLOCKs until received.
+ Collection receivedMessages =
+ pubsubClient.pull(requestTimeMsSinceEpoch,
+ outer.outer.subscription,
+ PULL_BATCH_SIZE, true);
+ if (receivedMessages.isEmpty()) {
+ // Nothing available yet. Try again later.
+ return;
+ }
+
+ lastReceivedMsSinceEpoch = requestTimeMsSinceEpoch;
+
+ // Capture the received messages.
+ for (PubsubClient.IncomingMessage incomingMessage : receivedMessages) {
+ notYetRead.add(incomingMessage);
+ notYetReadBytes += incomingMessage.elementBytes.length;
+ inFlight.put(incomingMessage.ackId,
+ new InFlightState(requestTimeMsSinceEpoch, deadlineMsSinceEpoch));
+ numReceived++;
+ numReceivedRecently.add(requestTimeMsSinceEpoch, 1L);
+ minReceivedTimestampMsSinceEpoch.add(requestTimeMsSinceEpoch,
+ incomingMessage.timestampMsSinceEpoch);
+ maxReceivedTimestampMsSinceEpoch.add(requestTimeMsSinceEpoch,
+ incomingMessage.timestampMsSinceEpoch);
+ minUnreadTimestampMsSinceEpoch.add(requestTimeMsSinceEpoch,
+ incomingMessage.timestampMsSinceEpoch);
+ }
+ }
+
+ /**
+ * Log stats if time to do so.
+ */
+ private void stats() {
+ long nowMsSinceEpoch = now();
+ if (lastLogTimestampMsSinceEpoch < 0) {
+ lastLogTimestampMsSinceEpoch = nowMsSinceEpoch;
+ return;
+ }
+ long deltaMs = nowMsSinceEpoch - lastLogTimestampMsSinceEpoch;
+ if (deltaMs < LOG_PERIOD.getMillis()) {
+ return;
+ }
+
+ String messageSkew = "unknown";
+ long minTimestamp = minReceivedTimestampMsSinceEpoch.get(nowMsSinceEpoch);
+ long maxTimestamp = maxReceivedTimestampMsSinceEpoch.get(nowMsSinceEpoch);
+ if (minTimestamp < Long.MAX_VALUE && maxTimestamp > Long.MIN_VALUE) {
+ messageSkew = (maxTimestamp - minTimestamp) + "ms";
+ }
+
+ String watermarkSkew = "unknown";
+ long minWatermark = minWatermarkMsSinceEpoch.get(nowMsSinceEpoch);
+ long maxWatermark = maxWatermarkMsSinceEpoch.get(nowMsSinceEpoch);
+ if (minWatermark < Long.MAX_VALUE && maxWatermark > Long.MIN_VALUE) {
+ watermarkSkew = (maxWatermark - minWatermark) + "ms";
+ }
+
+ String oldestInFlight = "no";
+ String oldestAckId = Iterables.getFirst(inFlight.keySet(), null);
+ if (oldestAckId != null) {
+ oldestInFlight =
+ (nowMsSinceEpoch - inFlight.get(oldestAckId).requestTimeMsSinceEpoch) + "ms";
+ }
+
+ LOG.info("Pubsub {} has "
+ + "{} received messages, "
+ + "{} current unread messages, "
+ + "{} current unread bytes, "
+ + "{} current in-flight msgs, "
+ + "{} oldest in-flight, "
+ + "{} current in-flight checkpoints, "
+ + "{} max in-flight checkpoints, "
+ + "{}B/s recent read, "
+ + "{} recent received, "
+ + "{} recent extended, "
+ + "{} recent late extended, "
+ + "{} recent ACKed, "
+ + "{} recent NACKed, "
+ + "{} recent expired, "
+ + "{} recent message timestamp skew, "
+ + "{} recent watermark skew, "
+ + "{} recent late messages, "
+ + "{} last reported watermark",
+ outer.outer.subscription,
+ numReceived,
+ notYetRead.size(),
+ notYetReadBytes,
+ inFlight.size(),
+ oldestInFlight,
+ numInFlightCheckpoints.get(),
+ maxInFlightCheckpoints,
+ numReadBytes.get(nowMsSinceEpoch) / (SAMPLE_PERIOD.getMillis() / 1000L),
+ numReceivedRecently.get(nowMsSinceEpoch),
+ numExtendedDeadlines.get(nowMsSinceEpoch),
+ numLateDeadlines.get(nowMsSinceEpoch),
+ numAcked.get(nowMsSinceEpoch),
+ numNacked.get(nowMsSinceEpoch),
+ numExpired.get(nowMsSinceEpoch),
+ messageSkew,
+ watermarkSkew,
+ numLateMessages.get(nowMsSinceEpoch),
+ new Instant(lastWatermarkMsSinceEpoch));
+
+ lastLogTimestampMsSinceEpoch = nowMsSinceEpoch;
+ }
+
+ @Override
+ public boolean start() throws IOException {
+ // Determine the ack timeout.
+ ackTimeoutMs = pubsubClient.ackDeadlineSeconds(outer.outer.subscription) * 1000;
+ return advance();
+ }
+
+ /**
+ * BLOCKING
+ * Return {@literal true} if a Pubsub messaage is available, {@literal false} if
+ * none is available at this time or we are over-subscribed. May BLOCK while extending
+ * ACKs or fetching available messages. Will not block waiting for messages.
+ */
+ @Override
+ public boolean advance() throws IOException {
+ // Emit stats.
+ stats();
+
+ if (current != null) {
+ // Current is consumed. It can no longer contribute to holding back the watermark.
+ minUnreadTimestampMsSinceEpoch.remove(current.requestTimeMsSinceEpoch);
+ current = null;
+ }
+
+ // Retire state associated with ACKed messages.
+ retire();
+
+ // Extend all pressing deadlines.
+ // Will BLOCK until done.
+ // If the system is pulling messages only to let them sit in a downsteam queue then
+ // this will have the effect of slowing down the pull rate.
+ // However, if the system is genuinely taking longer to process each message then
+ // the work to extend ACKs would be better done in the background.
+ extend();
+
+ if (notYetRead.isEmpty()) {
+ // Pull another batch.
+ // Will BLOCK until fetch returns, but will not block until a message is available.
+ pull();
+ }
+
+ // Take one message from queue.
+ current = notYetRead.poll();
+ if (current == null) {
+ // Try again later.
+ return false;
+ }
+ notYetReadBytes -= current.elementBytes.length;
+ checkState(notYetReadBytes >= 0);
+ long nowMsSinceEpoch = now();
+ numReadBytes.add(nowMsSinceEpoch, current.elementBytes.length);
+ minReadTimestampMsSinceEpoch.add(nowMsSinceEpoch, current.timestampMsSinceEpoch);
+ if (current.timestampMsSinceEpoch < lastWatermarkMsSinceEpoch) {
+ numLateMessages.add(nowMsSinceEpoch, 1L);
+ }
+
+ // Current message can be considered 'read' and will be persisted by the next
+ // checkpoint. So it is now safe to ACK back to Pubsub.
+ safeToAckIds.add(current.ackId);
+ return true;
+ }
+
+ @Override
+ public T getCurrent() throws NoSuchElementException {
+ if (current == null) {
+ throw new NoSuchElementException();
+ }
+ try {
+ return CoderUtils.decodeFromByteArray(outer.outer.elementCoder, current.elementBytes);
+ } catch (CoderException e) {
+ throw new RuntimeException("Unable to decode element from Pubsub message: ", e);
+ }
+ }
+
+ @Override
+ public Instant getCurrentTimestamp() throws NoSuchElementException {
+ if (current == null) {
+ throw new NoSuchElementException();
+ }
+ return new Instant(current.timestampMsSinceEpoch);
+ }
+
+ @Override
+ public byte[] getCurrentRecordId() throws NoSuchElementException {
+ if (current == null) {
+ throw new NoSuchElementException();
+ }
+ return current.recordId.getBytes(Charsets.UTF_8);
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (pubsubClient != null) {
+ pubsubClient.close();
+ pubsubClient = null;
+ }
+ }
+
+ @Override
+ public PubsubSource getCurrentSource() {
+ return outer;
+ }
+
+ @Override
+ public Instant getWatermark() {
+ if (pubsubClient.isEOF() && notYetRead.isEmpty()) {
+ // For testing only: Advance the watermark to the end of time to signal
+ // the test is complete.
+ return BoundedWindow.TIMESTAMP_MAX_VALUE;
+ }
+
+ // NOTE: We'll allow the watermark to go backwards. The underlying runner is responsible
+ // for aggregating all reported watermarks and ensuring the aggregate is latched.
+ // If we attempt to latch locally then it is possible a temporary starvation of one reader
+ // could cause its estimated watermark to fast forward to current system time. Then when
+ // the reader resumes its watermark would be unable to resume tracking.
+ // By letting the underlying runner latch we avoid any problems due to localized starvation.
+ long nowMsSinceEpoch = now();
+ long readMin = minReadTimestampMsSinceEpoch.get(nowMsSinceEpoch);
+ long unreadMin = minUnreadTimestampMsSinceEpoch.get();
+ if (readMin == Long.MAX_VALUE
+ && unreadMin == Long.MAX_VALUE
+ && lastReceivedMsSinceEpoch >= 0
+ && nowMsSinceEpoch > lastReceivedMsSinceEpoch + SAMPLE_PERIOD.getMillis()) {
+ // We don't currently have any unread messages pending, we have not had any messages
+ // read for a while, and we have not received any new messages from Pubsub for a while.
+ // Advance watermark to current time.
+ // TODO: Estimate a timestamp lag.
+ lastWatermarkMsSinceEpoch = nowMsSinceEpoch;
+ } else if (minReadTimestampMsSinceEpoch.isSignificant()
+ || minUnreadTimestampMsSinceEpoch.isSignificant()) {
+ // Take minimum of the timestamps in all unread messages and recently read messages.
+ lastWatermarkMsSinceEpoch = Math.min(readMin, unreadMin);
+ }
+ // else: We're not confident enough to estimate a new watermark. Stick with the old one.
+ minWatermarkMsSinceEpoch.add(nowMsSinceEpoch, lastWatermarkMsSinceEpoch);
+ maxWatermarkMsSinceEpoch.add(nowMsSinceEpoch, lastWatermarkMsSinceEpoch);
+ return new Instant(lastWatermarkMsSinceEpoch);
+ }
+
+ @Override
+ public PubsubCheckpoint getCheckpointMark() {
+ int cur = numInFlightCheckpoints.incrementAndGet();
+ maxInFlightCheckpoints = Math.max(maxInFlightCheckpoints, cur);
+ // It's possible for a checkpoint to be taken but never finalized.
+ // So we simply copy whatever safeToAckIds we currently have.
+ // It's possible a later checkpoint will be taken before an earlier one is finalized,
+ // in which case we'll double ACK messages to Pubsub. However Pubsub is fine with that.
+ List snapshotSafeToAckIds = Lists.newArrayList(safeToAckIds);
+ List snapshotNotYetReadIds = new ArrayList<>(notYetRead.size());
+ for (PubsubClient.IncomingMessage incomingMessage : notYetRead) {
+ snapshotNotYetReadIds.add(incomingMessage.ackId);
+ }
+ return new PubsubCheckpoint<>(this, snapshotSafeToAckIds, snapshotNotYetReadIds);
+ }
+
+ @Override
+ public long getSplitBacklogBytes() {
+ return notYetReadBytes;
+ }
+ }
+
+ // ================================================================================
+ // Source
+ // ================================================================================
+
+ @VisibleForTesting
+ static class PubsubSource extends UnboundedSource> {
+ public final PubsubUnboundedSource outer;
+
+ public PubsubSource(PubsubUnboundedSource outer) {
+ this.outer = outer;
+ }
+
+ @Override
+ public List> generateInitialSplits(
+ int desiredNumSplits, PipelineOptions options) throws Exception {
+ List> result = new ArrayList<>(desiredNumSplits);
+ for (int i = 0; i < desiredNumSplits * SCALE_OUT; i++) {
+ // Since the source is immutable and Pubsub automatically shards we simply
+ // replicate ourselves the requested number of times
+ result.add(this);
+ }
+ return result;
+ }
+
+ @Override
+ public PubsubReader createReader(
+ PipelineOptions options,
+ @Nullable PubsubCheckpoint checkpoint) {
+ PubsubReader reader;
+ try {
+ reader = new PubsubReader<>(options.as(DataflowPipelineOptions.class), this);
+ } catch (GeneralSecurityException | IOException e) {
+ throw new RuntimeException("Unable to subscribe to " + outer.subscription + ": ", e);
+ }
+ if (checkpoint != null) {
+ // NACK all messages we may have lost.
+ try {
+ // Will BLOCK until NACKed.
+ checkpoint.nackAll(reader);
+ } catch (IOException e) {
+ LOG.error("Pubsub {} cannot have {} lost messages NACKed, ignoring: {}",
+ outer.subscription, checkpoint.notYetReadIds.size(), e);
+ }
+ }
+ return reader;
+ }
+
+ @Nullable
+ @Override
+ public Coder> getCheckpointMarkCoder() {
+ @SuppressWarnings("unchecked") PubsubCheckpointCoder typedCoder =
+ (PubsubCheckpointCoder) CHECKPOINT_CODER;
+ return typedCoder;
+ }
+
+ @Override
+ public Coder getDefaultOutputCoder() {
+ return outer.elementCoder;
+ }
+
+ @Override
+ public void validate() {
+ // Nothing to validate.
+ }
+
+ @Override
+ public boolean requiresDeduping() {
+ // We cannot prevent re-offering already read messages after a restore from checkpoint.
+ return true;
+ }
+ }
+
+ // ================================================================================
+ // StatsFn
+ // ================================================================================
+
+ private static class StatsFn extends DoFn {
+ private final Aggregator elementCounter =
+ createAggregator("elements", new Sum.SumLongFn());
+
+ private final PubsubClientFactory pubsubFactory;
+ private final SubscriptionPath subscription;
+ @Nullable
+ private final String timestampLabel;
+ @Nullable
+ private final String idLabel;
+
+ public StatsFn(
+ PubsubClientFactory pubsubFactory,
+ SubscriptionPath subscription,
+ @Nullable
+ String timestampLabel,
+ @Nullable
+ String idLabel) {
+ this.pubsubFactory = pubsubFactory;
+ this.subscription = subscription;
+ this.timestampLabel = timestampLabel;
+ this.idLabel = idLabel;
+ }
+
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ elementCounter.addValue(1L);
+ c.output(c.element());
+ }
+
+ @Override
+ public void populateDisplayData(Builder builder) {
+ super.populateDisplayData(builder);
+ builder.add(DisplayData.item("subscription", subscription.getPath()));
+ builder.add(DisplayData.item("transport", pubsubFactory.getKind()));
+ builder.addIfNotNull(DisplayData.item("timestampLabel", timestampLabel));
+ builder.addIfNotNull(DisplayData.item("idLabel", idLabel));
+ }
+ }
+
+ // ================================================================================
+ // PubsubUnboundedSource
+ // ================================================================================
+
+ /**
+ * For testing only: Clock to use for all timekeeping. If {@literal null} use system clock.
+ */
+ @Nullable
+ private Clock clock;
+
+ /**
+ * Factory for creating underlying Pubsub transport.
+ */
+ private final PubsubClientFactory pubsubFactory;
+
+ /**
+ * Project under which to create a subscription if only the {@link #topic} was given.
+ */
+ @Nullable
+ private final ProjectPath project;
+
+ /**
+ * Topic to read from. If {@literal null}, then {@link #subscription} must be given.
+ * Otherwise {@link #subscription} must be null.
+ */
+ @Nullable
+ private final TopicPath topic;
+
+ /**
+ * Subscription to read from. If {@literal null} then {@link #topic} must be given.
+ * Otherwise {@link #topic} must be null.
+ *
+ * If no subscription is given a random one will be created when the transorm is
+ * applied. This field will be update with that subscription's path. The created
+ * subscription is never deleted.
+ */
+ @Nullable
+ private SubscriptionPath subscription;
+
+ /**
+ * Coder for elements. Elements are effectively double-encoded: first to a byte array
+ * using this checkpointCoder, then to a base-64 string to conform to Pubsub's payload
+ * conventions.
+ */
+ private final Coder elementCoder;
+
+ /**
+ * Pubsub metadata field holding timestamp of each element, or {@literal null} if should use
+ * Pubsub message publish timestamp instead.
+ */
+ @Nullable
+ private final String timestampLabel;
+
+ /**
+ * Pubsub metadata field holding id for each element, or {@literal null} if need to generate
+ * a unique id ourselves.
+ */
+ @Nullable
+ private final String idLabel;
+
+ @VisibleForTesting
+ PubsubUnboundedSource(
+ Clock clock,
+ PubsubClientFactory pubsubFactory,
+ @Nullable ProjectPath project,
+ @Nullable TopicPath topic,
+ @Nullable SubscriptionPath subscription,
+ Coder elementCoder,
+ @Nullable String timestampLabel,
+ @Nullable String idLabel) {
+ checkArgument((topic == null) != (subscription == null),
+ "Exactly one of topic and subscription must be given");
+ checkArgument((topic == null) == (project == null),
+ "Project must be given if topic is given");
+ this.clock = clock;
+ this.pubsubFactory = checkNotNull(pubsubFactory);
+ this.project = project;
+ this.topic = topic;
+ this.subscription = subscription;
+ this.elementCoder = checkNotNull(elementCoder);
+ this.timestampLabel = timestampLabel;
+ this.idLabel = idLabel;
+ }
+
+ /**
+ * Construct an unbounded source to consume from the Pubsub {@code subscription}.
+ */
+ public PubsubUnboundedSource(
+ PubsubClientFactory pubsubFactory,
+ @Nullable ProjectPath project,
+ @Nullable TopicPath topic,
+ @Nullable SubscriptionPath subscription,
+ Coder elementCoder,
+ @Nullable String timestampLabel,
+ @Nullable String idLabel) {
+ this(null, pubsubFactory, project, topic, subscription, elementCoder, timestampLabel, idLabel);
+ }
+
+ public Coder getElementCoder() {
+ return elementCoder;
+ }
+
+ @Nullable
+ public ProjectPath getProject() {
+ return project;
+ }
+
+ @Nullable
+ public TopicPath getTopic() {
+ return topic;
+ }
+
+ @Nullable
+ public SubscriptionPath getSubscription() {
+ return subscription;
+ }
+
+ @Nullable
+ public String getTimestampLabel() {
+ return timestampLabel;
+ }
+
+ @Nullable
+ public String getIdLabel() {
+ return idLabel;
+ }
+
+ @Override
+ public PCollection apply(PBegin input) {
+ if (subscription == null) {
+ try {
+ try (PubsubClient pubsubClient =
+ pubsubFactory.newClient(timestampLabel, idLabel,
+ input.getPipeline()
+ .getOptions()
+ .as(DataflowPipelineOptions.class))) {
+ subscription =
+ pubsubClient.createRandomSubscription(project, topic, DEAULT_ACK_TIMEOUT_SEC);
+ LOG.warn("Created subscription {} to topic {}."
+ + " Note this subscription WILL NOT be deleted when the pipeline terminates",
+ subscription, topic);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to create subscription: ", e);
+ }
+ }
+
+ return input.getPipeline().begin()
+ .apply(Read.from(new PubsubSource(this)))
+ .apply(ParDo.named("PubsubUnboundedSource.Stats")
+ .of(new StatsFn(pubsubFactory, subscription,
+ timestampLabel, idLabel)));
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java
index 1477791538..61429b24b2 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java
@@ -59,6 +59,8 @@
import com.google.cloud.dataflow.sdk.io.BigQueryIO;
import com.google.cloud.dataflow.sdk.io.FileBasedSink;
import com.google.cloud.dataflow.sdk.io.PubsubIO;
+import com.google.cloud.dataflow.sdk.io.PubsubUnboundedSink;
+import com.google.cloud.dataflow.sdk.io.PubsubUnboundedSource;
import com.google.cloud.dataflow.sdk.io.Read;
import com.google.cloud.dataflow.sdk.io.ShardNameTemplate;
import com.google.cloud.dataflow.sdk.io.TextIO;
@@ -75,7 +77,6 @@
import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext;
import com.google.cloud.dataflow.sdk.runners.dataflow.AssignWindows;
import com.google.cloud.dataflow.sdk.runners.dataflow.DataflowAggregatorTransforms;
-import com.google.cloud.dataflow.sdk.runners.dataflow.PubsubIOTranslator;
import com.google.cloud.dataflow.sdk.runners.dataflow.ReadTranslator;
import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat;
import com.google.cloud.dataflow.sdk.runners.worker.IsmFormat.IsmRecord;
@@ -117,6 +118,7 @@
import com.google.cloud.dataflow.sdk.util.WindowedValue.FullWindowedValueCoder;
import com.google.cloud.dataflow.sdk.util.WindowingStrategy;
import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PBegin;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded;
import com.google.cloud.dataflow.sdk.values.PCollectionList;
@@ -176,6 +178,7 @@
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
+import javax.annotation.Nullable;
/**
* A {@link PipelineRunner} that executes the operations in the
@@ -338,33 +341,46 @@ public static DataflowPipelineRunner fromOptions(PipelineOptions options) {
this.pcollectionsRequiringIndexedFormat = new HashSet<>();
this.ptransformViewsWithNonDeterministicKeyCoders = new HashSet<>();
+ ImmutableMap.Builder, Class>> builder = ImmutableMap., Class>>builder();
if (options.isStreaming()) {
- overrides = ImmutableMap., Class>>builder()
- .put(Combine.GloballyAsSingletonView.class, StreamingCombineGloballyAsSingletonView.class)
- .put(Create.Values.class, StreamingCreate.class)
- .put(View.AsMap.class, StreamingViewAsMap.class)
- .put(View.AsMultimap.class, StreamingViewAsMultimap.class)
- .put(View.AsSingleton.class, StreamingViewAsSingleton.class)
- .put(View.AsList.class, StreamingViewAsList.class)
- .put(View.AsIterable.class, StreamingViewAsIterable.class)
- .put(Write.Bound.class, StreamingWrite.class)
- .put(PubsubIO.Write.Bound.class, StreamingPubsubIOWrite.class)
- .put(Read.Unbounded.class, StreamingUnboundedRead.class)
- .put(Read.Bounded.class, UnsupportedIO.class)
- .put(AvroIO.Read.Bound.class, UnsupportedIO.class)
- .put(AvroIO.Write.Bound.class, UnsupportedIO.class)
- .put(BigQueryIO.Read.Bound.class, UnsupportedIO.class)
- .put(TextIO.Read.Bound.class, UnsupportedIO.class)
- .put(TextIO.Write.Bound.class, UnsupportedIO.class)
- .put(Window.Bound.class, AssignWindows.class)
- .build();
+ builder.put(Combine.GloballyAsSingletonView.class,
+ StreamingCombineGloballyAsSingletonView.class);
+ builder.put(Create.Values.class, StreamingCreate.class);
+ builder.put(View.AsMap.class, StreamingViewAsMap.class);
+ builder.put(View.AsMultimap.class, StreamingViewAsMultimap.class);
+ builder.put(View.AsSingleton.class, StreamingViewAsSingleton.class);
+ builder.put(View.AsList.class, StreamingViewAsList.class);
+ builder.put(View.AsIterable.class, StreamingViewAsIterable.class);
+ builder.put(Write.Bound.class, StreamingWrite.class);
+ builder.put(Read.Unbounded.class, StreamingUnboundedRead.class);
+ builder.put(Read.Bounded.class, UnsupportedIO.class);
+ builder.put(AvroIO.Read.Bound.class, UnsupportedIO.class);
+ builder.put(AvroIO.Write.Bound.class, UnsupportedIO.class);
+ builder.put(BigQueryIO.Read.Bound.class, UnsupportedIO.class);
+ builder.put(TextIO.Read.Bound.class, UnsupportedIO.class);
+ builder.put(TextIO.Write.Bound.class, UnsupportedIO.class);
+ builder.put(Window.Bound.class, AssignWindows.class);
+ // In streaming mode must use either the custom Pubsub unbounded source/sink or
+ // defer to Windmill's built-in implementation.
+ builder.put(PubsubIO.Read.Bound.PubsubBoundedReader.class, UnsupportedIO.class);
+ builder.put(PubsubIO.Write.Bound.PubsubBoundedWriter.class, UnsupportedIO.class);
+ if (options.getExperiments() == null
+ || !options.getExperiments().contains("enable_custom_pubsub_source")) {
+ builder.put(PubsubUnboundedSource.class, StreamingPubsubIORead.class);
+ }
+ if (options.getExperiments() == null
+ || !options.getExperiments().contains("enable_custom_pubsub_sink")) {
+ builder.put(PubsubUnboundedSink.class, StreamingPubsubIOWrite.class);
+ }
} else {
- ImmutableMap.Builder, Class>> builder = ImmutableMap., Class>>builder();
builder.put(Read.Unbounded.class, UnsupportedIO.class);
builder.put(Window.Bound.class, AssignWindows.class);
builder.put(Write.Bound.class, BatchWrite.class);
builder.put(AvroIO.Write.Bound.class, BatchAvroIOWrite.class);
builder.put(TextIO.Write.Bound.class, BatchTextIOWrite.class);
+ // In batch mode must use the custom Pubsub bounded source/sink.
+ builder.put(PubsubUnboundedSource.class, UnsupportedIO.class);
+ builder.put(PubsubUnboundedSink.class, UnsupportedIO.class);
if (options.getExperiments() == null
|| !options.getExperiments().contains("disable_ism_side_input")) {
builder.put(View.AsMap.class, BatchViewAsMap.class);
@@ -381,8 +397,8 @@ public static DataflowPipelineRunner fromOptions(PipelineOptions options) {
|| !options.getExperiments().contains("enable_custom_bigquery_sink")) {
builder.put(BigQueryIO.Write.Bound.class, BatchBigQueryIOWrite.class);
}
- overrides = builder.build();
}
+ overrides = builder.build();
}
/**
@@ -2535,27 +2551,104 @@ protected String getKindString() {
}
}
+ // ================================================================================
+ // PubsubIO translations
+ // ================================================================================
+
/**
- * Specialized implementation for
- * {@link com.google.cloud.dataflow.sdk.io.PubsubIO.Write PubsubIO.Write} for the
- * Dataflow runner in streaming mode.
- *
- * For internal use only. Subject to change at any time.
- *
- *
Public so the {@link PubsubIOTranslator} can access.
+ * Suppress application of {@link PubsubUnboundedSource#apply} in streaming mode so that we
+ * can instead defer to Windmill's implementation.
+ */
+ private static class StreamingPubsubIORead extends PTransform> {
+ private final PubsubUnboundedSource transform;
+
+ /**
+ * Builds an instance of this class from the overridden transform.
+ */
+ public StreamingPubsubIORead(
+ DataflowPipelineRunner runner, PubsubUnboundedSource transform) {
+ this.transform = transform;
+ }
+
+ PubsubUnboundedSource getOverriddenTransform() {
+ return transform;
+ }
+
+ @Override
+ public PCollection apply(PBegin input) {
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
+ .setCoder(transform.getElementCoder());
+ }
+
+ @Override
+ protected String getKindString() {
+ return "StreamingPubsubIORead";
+ }
+
+ static {
+ DataflowPipelineTranslator.registerTransformTranslator(
+ StreamingPubsubIORead.class, new StreamingPubsubIOReadTranslator());
+ }
+ }
+
+ /**
+ * Rewrite {@link StreamingPubsubIORead} to the appropriate internal node.
+ */
+ private static class StreamingPubsubIOReadTranslator implements
+ TransformTranslator {
+ @Override
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void translate(
+ StreamingPubsubIORead transform,
+ TranslationContext context) {
+ translateTyped(transform, context);
+ }
+
+ private void translateTyped(
+ StreamingPubsubIORead transform,
+ TranslationContext context) {
+ Preconditions.checkState(context.getPipelineOptions().isStreaming(),
+ "StreamingPubsubIORead is only for streaming pipelines.");
+ PubsubUnboundedSource overriddenTransform = transform.getOverriddenTransform();
+ context.addStep(transform, "ParallelRead");
+ context.addInput(PropertyNames.FORMAT, "pubsub");
+ if (overriddenTransform.getTopic() != null) {
+ context.addInput(PropertyNames.PUBSUB_TOPIC,
+ overriddenTransform.getTopic().getV1Beta1Path());
+ }
+ if (overriddenTransform.getSubscription() != null) {
+ context.addInput(
+ PropertyNames.PUBSUB_SUBSCRIPTION,
+ overriddenTransform.getSubscription().getV1Beta1Path());
+ }
+ if (overriddenTransform.getTimestampLabel() != null) {
+ context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL,
+ overriddenTransform.getTimestampLabel());
+ }
+ if (overriddenTransform.getIdLabel() != null) {
+ context.addInput(PropertyNames.PUBSUB_ID_LABEL, overriddenTransform.getIdLabel());
+ }
+ context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform));
+ }
+ }
+
+ /**
+ * Suppress application of {@link PubsubUnboundedSink#apply} in streaming mode so that we
+ * can instead defer to Windmill's implementation.
*/
- public static class StreamingPubsubIOWrite extends PTransform, PDone> {
- private final PubsubIO.Write.Bound transform;
+ private static class StreamingPubsubIOWrite extends PTransform, PDone> {
+ private final PubsubUnboundedSink transform;
/**
* Builds an instance of this class from the overridden transform.
*/
public StreamingPubsubIOWrite(
- DataflowPipelineRunner runner, PubsubIO.Write.Bound transform) {
+ DataflowPipelineRunner runner, PubsubUnboundedSink transform) {
this.transform = transform;
}
- public PubsubIO.Write.Bound getOverriddenTransform() {
+ PubsubUnboundedSink getOverriddenTransform() {
return transform;
}
@@ -2568,8 +2661,51 @@ public PDone apply(PCollection input) {
protected String getKindString() {
return "StreamingPubsubIOWrite";
}
+
+ static {
+ DataflowPipelineTranslator.registerTransformTranslator(
+ StreamingPubsubIOWrite.class, new StreamingPubsubIOWriteTranslator());
+ }
}
+ /**
+ * Rewrite {@link StreamingPubsubIOWrite} to the appropriate internal node.
+ */
+ private static class StreamingPubsubIOWriteTranslator implements
+ TransformTranslator {
+
+ @Override
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void translate(
+ StreamingPubsubIOWrite transform,
+ TranslationContext context) {
+ translateTyped(transform, context);
+ }
+
+ private void translateTyped(
+ StreamingPubsubIOWrite transform,
+ TranslationContext context) {
+ Preconditions.checkState(context.getPipelineOptions().isStreaming(),
+ "StreamingPubsubIOWrite is only for streaming pipelines.");
+ PubsubUnboundedSink overriddenTransform = transform.getOverriddenTransform();
+ context.addStep(transform, "ParallelWrite");
+ context.addInput(PropertyNames.FORMAT, "pubsub");
+ context.addInput(PropertyNames.PUBSUB_TOPIC, overriddenTransform.getTopic().getV1Beta1Path());
+ if (overriddenTransform.getTimestampLabel() != null) {
+ context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL,
+ overriddenTransform.getTimestampLabel());
+ }
+ if (overriddenTransform.getIdLabel() != null) {
+ context.addInput(PropertyNames.PUBSUB_ID_LABEL, overriddenTransform.getIdLabel());
+ }
+ context.addEncodingInput(
+ WindowedValue.getValueOnlyCoder(overriddenTransform.getElementCoder()));
+ context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform));
+ }
+ }
+
+ // ================================================================================
+
/**
* Specialized implementation for
* {@link com.google.cloud.dataflow.sdk.io.Read.Unbounded Read.Unbounded} for the
@@ -3111,11 +3247,14 @@ public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inp
}
/**
- * Specialized expansion for unsupported IO transforms that throws an error.
+ * Specialized expansion for unsupported IO transforms and DoFns that throws an error.
*/
private static class UnsupportedIO
extends PTransform {
+ @Nullable
private PTransform, ?> transform;
+ @Nullable
+ private DoFn, ?> doFn;
/**
* Builds an instance of this class from the overridden transform.
@@ -3173,13 +3312,51 @@ public UnsupportedIO(DataflowPipelineRunner runner, TextIO.Write.Bound> transf
this.transform = transform;
}
+ /**
+ * Builds an instance of this class from the overridden doFn.
+ */
+ @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply()
+ public UnsupportedIO(DataflowPipelineRunner runner,
+ PubsubIO.Read.Bound>.PubsubBoundedReader doFn) {
+ this.doFn = doFn;
+ }
+
+ /**
+ * Builds an instance of this class from the overridden doFn.
+ */
+ @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply()
+ public UnsupportedIO(DataflowPipelineRunner runner,
+ PubsubIO.Write.Bound>.PubsubBoundedWriter doFn) {
+ this.doFn = doFn;
+ }
+
+ /**
+ * Builds an instance of this class from the overridden transform.
+ */
+ @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply()
+ public UnsupportedIO(DataflowPipelineRunner runner, PubsubUnboundedSource> transform) {
+ this.transform = transform;
+ }
+
+ /**
+ * Builds an instance of this class from the overridden transform.
+ */
+ @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply()
+ public UnsupportedIO(DataflowPipelineRunner runner, PubsubUnboundedSink> transform) {
+ this.transform = transform;
+ }
+
+
@Override
public OutputT apply(InputT input) {
String mode = input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming()
? "streaming" : "batch";
+ String name =
+ transform == null
+ ? approximateSimpleName(doFn.getClass())
+ : approximatePTransformName(transform.getClass());
throw new UnsupportedOperationException(
- String.format("The DataflowPipelineRunner in %s mode does not support %s.",
- mode, approximatePTransformName(transform.getClass())));
+ String.format("The DataflowPipelineRunner in %s mode does not support %s.", mode, name));
}
}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java
index fb477d04c6..28392a48da 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java
@@ -42,12 +42,10 @@
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.CoderException;
import com.google.cloud.dataflow.sdk.coders.IterableCoder;
-import com.google.cloud.dataflow.sdk.io.PubsubIO;
import com.google.cloud.dataflow.sdk.io.Read;
import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
import com.google.cloud.dataflow.sdk.options.StreamingOptions;
import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.GroupByKeyAndSortValuesOnly;
-import com.google.cloud.dataflow.sdk.runners.dataflow.PubsubIOTranslator;
import com.google.cloud.dataflow.sdk.runners.dataflow.ReadTranslator;
import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
import com.google.cloud.dataflow.sdk.transforms.Combine;
@@ -1050,12 +1048,6 @@ private void translateHelper(
///////////////////////////////////////////////////////////////////////////
// IO Translation.
- registerTransformTranslator(
- PubsubIO.Read.Bound.class, new PubsubIOTranslator.ReadTranslator());
- registerTransformTranslator(
- DataflowPipelineRunner.StreamingPubsubIOWrite.class,
- new PubsubIOTranslator.WriteTranslator());
-
registerTransformTranslator(Read.Bounded.class, new ReadTranslator());
}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java
deleted file mode 100644
index 8b066ab065..0000000000
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/PubsubIOTranslator.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (C) 2015 Google Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License"); you may not
- * use this file except in compliance with the License. You may obtain a copy of
- * the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations under
- * the License.
- */
-
-package com.google.cloud.dataflow.sdk.runners.dataflow;
-
-import com.google.cloud.dataflow.sdk.io.PubsubIO;
-import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner;
-import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator;
-import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext;
-import com.google.cloud.dataflow.sdk.util.PropertyNames;
-import com.google.cloud.dataflow.sdk.util.WindowedValue;
-
-/**
- * Pubsub transform support code for the Dataflow backend.
- */
-public class PubsubIOTranslator {
-
- /**
- * Implements PubsubIO Read translation for the Dataflow backend.
- */
- public static class ReadTranslator implements TransformTranslator> {
- @Override
- @SuppressWarnings({"rawtypes", "unchecked"})
- public void translate(
- PubsubIO.Read.Bound transform,
- TranslationContext context) {
- translateReadHelper(transform, context);
- }
-
- private void translateReadHelper(
- PubsubIO.Read.Bound transform,
- TranslationContext context) {
- if (!context.getPipelineOptions().isStreaming()) {
- throw new IllegalArgumentException(
- "PubsubIO.Read can only be used with the Dataflow streaming runner.");
- }
-
- context.addStep(transform, "ParallelRead");
- context.addInput(PropertyNames.FORMAT, "pubsub");
- if (transform.getTopic() != null) {
- context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic().asV1Beta1Path());
- }
- if (transform.getSubscription() != null) {
- context.addInput(
- PropertyNames.PUBSUB_SUBSCRIPTION, transform.getSubscription().asV1Beta1Path());
- }
- if (transform.getTimestampLabel() != null) {
- context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, transform.getTimestampLabel());
- }
- if (transform.getIdLabel() != null) {
- context.addInput(PropertyNames.PUBSUB_ID_LABEL, transform.getIdLabel());
- }
- context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform));
- }
- }
-
- /**
- * Implements PubsubIO Write translation for the Dataflow backend.
- */
- public static class WriteTranslator
- implements TransformTranslator> {
-
- @Override
- @SuppressWarnings({"rawtypes", "unchecked"})
- public void translate(
- DataflowPipelineRunner.StreamingPubsubIOWrite transform,
- TranslationContext context) {
- translateWriteHelper(transform, context);
- }
-
- private void translateWriteHelper(
- DataflowPipelineRunner.StreamingPubsubIOWrite customTransform,
- TranslationContext context) {
- if (!context.getPipelineOptions().isStreaming()) {
- throw new IllegalArgumentException(
- "PubsubIO.Write is non-primitive for the Dataflow batch runner.");
- }
-
- PubsubIO.Write.Bound transform = customTransform.getOverriddenTransform();
-
- context.addStep(customTransform, "ParallelWrite");
- context.addInput(PropertyNames.FORMAT, "pubsub");
- context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic().asV1Beta1Path());
- if (transform.getTimestampLabel() != null) {
- context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, transform.getTimestampLabel());
- }
- if (transform.getIdLabel() != null) {
- context.addInput(PropertyNames.PUBSUB_ID_LABEL, transform.getIdLabel());
- }
- context.addEncodingInput(WindowedValue.getValueOnlyCoder(transform.getCoder()));
- context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(customTransform));
- }
- }
-}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BucketingFunction.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BucketingFunction.java
new file mode 100644
index 0000000000..36efcce94c
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/BucketingFunction.java
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Keep track of the minimum/maximum/sum of a set of timestamped long values.
+ * For efficiency, bucket values by their timestamp.
+ */
+public class BucketingFunction {
+ private static class Bucket {
+ private int numSamples;
+ private long combinedValue;
+
+ public Bucket(BucketingFunction outer) {
+ numSamples = 0;
+ combinedValue = outer.function.identity();
+ }
+
+ public void add(BucketingFunction outer, long value) {
+ combinedValue = outer.function.apply(combinedValue, value);
+ numSamples++;
+ }
+
+ public boolean remove() {
+ numSamples--;
+ checkState(numSamples >= 0, "Lost count of samples");
+ return numSamples == 0;
+ }
+
+ public long get() {
+ return combinedValue;
+ }
+ }
+
+ /**
+ * How large a time interval to fit within each bucket.
+ */
+ private final long bucketWidthMs;
+
+ /**
+ * How many buckets are considered 'significant'?
+ */
+ private final int numSignificantBuckets;
+
+ /**
+ * How many samples are considered 'significant'?
+ */
+ private final int numSignificantSamples;
+
+ /**
+ * Function for combining sample values.
+ */
+ private final Combine.BinaryCombineLongFn function;
+
+ /**
+ * Active buckets.
+ */
+ private final Map buckets;
+
+ public BucketingFunction(
+ long bucketWidthMs,
+ int numSignificantBuckets,
+ int numSignificantSamples,
+ Combine.BinaryCombineLongFn function) {
+ this.bucketWidthMs = bucketWidthMs;
+ this.numSignificantBuckets = numSignificantBuckets;
+ this.numSignificantSamples = numSignificantSamples;
+ this.function = function;
+ this.buckets = new HashMap<>();
+ }
+
+ /**
+ * Which bucket key corresponds to {@code timeMsSinceEpoch}.
+ */
+ private long key(long timeMsSinceEpoch) {
+ return timeMsSinceEpoch - (timeMsSinceEpoch % bucketWidthMs);
+ }
+
+ /**
+ * Add one sample of {@code value} (to bucket) at {@code timeMsSinceEpoch}.
+ */
+ public void add(long timeMsSinceEpoch, long value) {
+ long key = key(timeMsSinceEpoch);
+ Bucket bucket = buckets.get(key);
+ if (bucket == null) {
+ bucket = new Bucket(this);
+ buckets.put(key, bucket);
+ }
+ bucket.add(this, value);
+ }
+
+ /**
+ * Remove one sample (from bucket) at {@code timeMsSinceEpoch}.
+ */
+ public void remove(long timeMsSinceEpoch) {
+ long key = key(timeMsSinceEpoch);
+ Bucket bucket = buckets.get(key);
+ if (bucket == null) {
+ return;
+ }
+ if (bucket.remove()) {
+ buckets.remove(key);
+ }
+ }
+
+ /**
+ * Return the (bucketized) combined value of all samples.
+ */
+ public long get() {
+ long result = function.identity();
+ for (Bucket bucket : buckets.values()) {
+ result = function.apply(result, bucket.get());
+ }
+ return result;
+ }
+
+ /**
+ * Is the current result 'significant'? Ie is it drawn from enough buckets
+ * or from enough samples?
+ */
+ public boolean isSignificant() {
+ if (buckets.size() >= numSignificantBuckets) {
+ return true;
+ }
+ int totalSamples = 0;
+ for (Bucket bucket : buckets.values()) {
+ totalSamples += bucket.numSamples;
+ }
+ return totalSamples >= numSignificantSamples;
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MovingFunction.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MovingFunction.java
new file mode 100644
index 0000000000..04ec8db6d3
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/MovingFunction.java
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import java.util.Arrays;
+
+/**
+ * Keep track of the moving minimum/maximum/sum of sampled long values. The minimum/maximum/sum
+ * is over at most the last {@link #samplePeriodMs}, and is updated every
+ * {@link #sampleUpdateMs}.
+ */
+public class MovingFunction {
+ /**
+ * How far back to retain samples, in ms.
+ */
+ private final long samplePeriodMs;
+
+ /**
+ * How frequently to update the moving function, in ms.
+ */
+ private final long sampleUpdateMs;
+
+ /**
+ * How many buckets are considered 'significant'?
+ */
+ private final int numSignificantBuckets;
+
+ /**
+ * How many samples are considered 'significant'?
+ */
+ private final int numSignificantSamples;
+
+ /**
+ * Function for combining sample values.
+ */
+ private final Combine.BinaryCombineLongFn function;
+
+ /**
+ * Minimum/maximum/sum of all values per bucket.
+ */
+ private final long[] buckets;
+
+ /**
+ * How many samples have been added to each bucket.
+ */
+ private final int[] numSamples;
+
+ /**
+ * Time of start of current bucket.
+ */
+ private long currentMsSinceEpoch;
+
+ /**
+ * Index of bucket corresponding to above timestamp, or -1 if no entries.
+ */
+ private int currentIndex;
+
+ public MovingFunction(long samplePeriodMs, long sampleUpdateMs,
+ int numSignificantBuckets, int numSignificantSamples,
+ Combine.BinaryCombineLongFn function) {
+ this.samplePeriodMs = samplePeriodMs;
+ this.sampleUpdateMs = sampleUpdateMs;
+ this.numSignificantBuckets = numSignificantBuckets;
+ this.numSignificantSamples = numSignificantSamples;
+ this.function = function;
+ int n = (int) (samplePeriodMs / sampleUpdateMs);
+ buckets = new long[n];
+ Arrays.fill(buckets, function.identity());
+ numSamples = new int[n];
+ Arrays.fill(numSamples, 0);
+ currentMsSinceEpoch = -1;
+ currentIndex = -1;
+ }
+
+ /**
+ * Flush stale values.
+ */
+ private void flush(long nowMsSinceEpoch) {
+ checkArgument(nowMsSinceEpoch >= 0, "Only positive timestamps supported");
+ if (currentIndex < 0) {
+ currentMsSinceEpoch = nowMsSinceEpoch - (nowMsSinceEpoch % sampleUpdateMs);
+ currentIndex = 0;
+ }
+ checkArgument(nowMsSinceEpoch >= currentMsSinceEpoch, "Attempting to move backwards");
+ int newBuckets =
+ Math.min((int) ((nowMsSinceEpoch - currentMsSinceEpoch) / sampleUpdateMs),
+ buckets.length);
+ while (newBuckets > 0) {
+ currentIndex = (currentIndex + 1) % buckets.length;
+ buckets[currentIndex] = function.identity();
+ numSamples[currentIndex] = 0;
+ newBuckets--;
+ currentMsSinceEpoch += sampleUpdateMs;
+ }
+ }
+
+ /**
+ * Add {@code value} at {@code nowMsSinceEpoch}.
+ */
+ public void add(long nowMsSinceEpoch, long value) {
+ flush(nowMsSinceEpoch);
+ buckets[currentIndex] = function.apply(buckets[currentIndex], value);
+ numSamples[currentIndex]++;
+ }
+
+ /**
+ * Return the minimum/maximum/sum of all retained values within {@link #samplePeriodMs}
+ * of {@code nowMsSinceEpoch}.
+ */
+ public long get(long nowMsSinceEpoch) {
+ flush(nowMsSinceEpoch);
+ long result = function.identity();
+ for (int i = 0; i < buckets.length; i++) {
+ result = function.apply(result, buckets[i]);
+ }
+ return result;
+ }
+
+ /**
+ * Is the current result 'significant'? Ie is it drawn from enough buckets
+ * or from enough samples?
+ */
+ public boolean isSignificant() {
+ int totalSamples = 0;
+ int activeBuckets = 0;
+ for (int i = 0; i < buckets.length; i++) {
+ totalSamples += numSamples[i];
+ if (numSamples[i] > 0) {
+ activeBuckets++;
+ }
+ }
+ return activeBuckets >= numSignificantBuckets || totalSamples >= numSignificantSamples;
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubApiaryClient.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubApiaryClient.java
new file mode 100644
index 0000000000..f7fa7832ee
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubApiaryClient.java
@@ -0,0 +1,301 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.api.services.pubsub.Pubsub;
+import com.google.api.services.pubsub.Pubsub.Builder;
+import com.google.api.services.pubsub.model.AcknowledgeRequest;
+import com.google.api.services.pubsub.model.ListSubscriptionsResponse;
+import com.google.api.services.pubsub.model.ListTopicsResponse;
+import com.google.api.services.pubsub.model.ModifyAckDeadlineRequest;
+import com.google.api.services.pubsub.model.PublishRequest;
+import com.google.api.services.pubsub.model.PublishResponse;
+import com.google.api.services.pubsub.model.PubsubMessage;
+import com.google.api.services.pubsub.model.PullRequest;
+import com.google.api.services.pubsub.model.PullResponse;
+import com.google.api.services.pubsub.model.ReceivedMessage;
+import com.google.api.services.pubsub.model.Subscription;
+import com.google.api.services.pubsub.model.Topic;
+import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
+import com.google.cloud.hadoop.util.ChainingHttpRequestInitializer;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableList;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import javax.annotation.Nullable;
+
+/**
+ * A Pubsub client using Apiary.
+ */
+public class PubsubApiaryClient extends PubsubClient {
+
+ private static class PubsubApiaryClientFactory implements PubsubClientFactory {
+ @Override
+ public PubsubClient newClient(
+ @Nullable String timestampLabel, @Nullable String idLabel, DataflowPipelineOptions options)
+ throws IOException {
+ Pubsub pubsub = new Builder(
+ Transport.getTransport(),
+ Transport.getJsonFactory(),
+ new ChainingHttpRequestInitializer(
+ options.getGcpCredential(),
+ // Do not log 404. It clutters the output and is possibly even required by the caller.
+ new RetryHttpRequestInitializer(ImmutableList.of(404))))
+ .setRootUrl(options.getPubsubRootUrl())
+ .setApplicationName(options.getAppName())
+ .setGoogleClientRequestInitializer(options.getGoogleApiTrace())
+ .build();
+ return new PubsubApiaryClient(timestampLabel, idLabel, pubsub);
+ }
+
+ @Override
+ public String getKind() {
+ return "Apiary";
+ }
+ }
+
+ /**
+ * Factory for creating Pubsub clients using Apiary transport.
+ */
+ public static final PubsubClientFactory FACTORY = new PubsubApiaryClientFactory();
+
+ /**
+ * Label to use for custom timestamps, or {@literal null} if should use Pubsub publish time
+ * instead.
+ */
+ @Nullable
+ private final String timestampLabel;
+
+ /**
+ * Label to use for custom ids, or {@literal null} if should use Pubsub provided ids.
+ */
+ @Nullable
+ private final String idLabel;
+
+ /**
+ * Underlying Apiary client.
+ */
+ private Pubsub pubsub;
+
+ @VisibleForTesting
+ PubsubApiaryClient(
+ @Nullable String timestampLabel,
+ @Nullable String idLabel,
+ Pubsub pubsub) {
+ this.timestampLabel = timestampLabel;
+ this.idLabel = idLabel;
+ this.pubsub = pubsub;
+ }
+
+ @Override
+ public void close() {
+ // Nothing to close.
+ }
+
+ @Override
+ public int publish(TopicPath topic, List outgoingMessages)
+ throws IOException {
+ List pubsubMessages = new ArrayList<>(outgoingMessages.size());
+ for (OutgoingMessage outgoingMessage : outgoingMessages) {
+ PubsubMessage pubsubMessage = new PubsubMessage().encodeData(outgoingMessage.elementBytes);
+
+ Map attributes = pubsubMessage.getAttributes();
+ if ((timestampLabel != null || idLabel != null) && attributes == null) {
+ attributes = new TreeMap<>();
+ pubsubMessage.setAttributes(attributes);
+ }
+
+ if (timestampLabel != null) {
+ attributes.put(timestampLabel, String.valueOf(outgoingMessage.timestampMsSinceEpoch));
+ }
+
+ if (idLabel != null && !Strings.isNullOrEmpty(outgoingMessage.recordId)) {
+ attributes.put(idLabel, outgoingMessage.recordId);
+ }
+
+ pubsubMessages.add(pubsubMessage);
+ }
+ PublishRequest request = new PublishRequest().setMessages(pubsubMessages);
+ PublishResponse response = pubsub.projects()
+ .topics()
+ .publish(topic.getPath(), request)
+ .execute();
+ return response.getMessageIds().size();
+ }
+
+ @Override
+ public List pull(
+ long requestTimeMsSinceEpoch,
+ SubscriptionPath subscription,
+ int batchSize,
+ boolean returnImmediately) throws IOException {
+ PullRequest request = new PullRequest()
+ .setReturnImmediately(returnImmediately)
+ .setMaxMessages(batchSize);
+ PullResponse response = pubsub.projects()
+ .subscriptions()
+ .pull(subscription.getPath(), request)
+ .execute();
+ if (response.getReceivedMessages() == null || response.getReceivedMessages().size() == 0) {
+ return ImmutableList.of();
+ }
+ List incomingMessages = new ArrayList<>(response.getReceivedMessages().size());
+ for (ReceivedMessage message : response.getReceivedMessages()) {
+ PubsubMessage pubsubMessage = message.getMessage();
+ @Nullable Map attributes = pubsubMessage.getAttributes();
+
+ // Payload.
+ byte[] elementBytes = pubsubMessage.decodeData();
+
+ // Timestamp.
+ long timestampMsSinceEpoch =
+ extractTimestamp(timestampLabel, message.getMessage().getPublishTime(), attributes);
+
+ // Ack id.
+ String ackId = message.getAckId();
+ checkState(!Strings.isNullOrEmpty(ackId));
+
+ // Record id, if any.
+ @Nullable String recordId = null;
+ if (idLabel != null && attributes != null) {
+ recordId = attributes.get(idLabel);
+ }
+ if (Strings.isNullOrEmpty(recordId)) {
+ // Fall back to the Pubsub provided message id.
+ recordId = pubsubMessage.getMessageId();
+ }
+
+ incomingMessages.add(new IncomingMessage(elementBytes, timestampMsSinceEpoch,
+ requestTimeMsSinceEpoch, ackId, recordId));
+ }
+
+ return incomingMessages;
+ }
+
+ @Override
+ public void acknowledge(SubscriptionPath subscription, List ackIds) throws IOException {
+ AcknowledgeRequest request = new AcknowledgeRequest().setAckIds(ackIds);
+ pubsub.projects()
+ .subscriptions()
+ .acknowledge(subscription.getPath(), request)
+ .execute(); // ignore Empty result.
+ }
+
+ @Override
+ public void modifyAckDeadline(
+ SubscriptionPath subscription, List ackIds, int deadlineSeconds)
+ throws IOException {
+ ModifyAckDeadlineRequest request =
+ new ModifyAckDeadlineRequest().setAckIds(ackIds)
+ .setAckDeadlineSeconds(deadlineSeconds);
+ pubsub.projects()
+ .subscriptions()
+ .modifyAckDeadline(subscription.getPath(), request)
+ .execute(); // ignore Empty result.
+ }
+
+ @Override
+ public void createTopic(TopicPath topic) throws IOException {
+ pubsub.projects()
+ .topics()
+ .create(topic.getPath(), new Topic())
+ .execute(); // ignore Topic result.
+ }
+
+ @Override
+ public void deleteTopic(TopicPath topic) throws IOException {
+ pubsub.projects()
+ .topics()
+ .delete(topic.getPath())
+ .execute(); // ignore Empty result.
+ }
+
+ @Override
+ public List listTopics(ProjectPath project) throws IOException {
+ ListTopicsResponse response = pubsub.projects()
+ .topics()
+ .list(project.getPath())
+ .execute();
+ if (response.getTopics() == null || response.getTopics().isEmpty()) {
+ return ImmutableList.of();
+ }
+ List topics = new ArrayList<>(response.getTopics().size());
+ for (Topic topic : response.getTopics()) {
+ topics.add(topicPathFromPath(topic.getName()));
+ }
+ return topics;
+ }
+
+ @Override
+ public void createSubscription(
+ TopicPath topic, SubscriptionPath subscription,
+ int ackDeadlineSeconds) throws IOException {
+ Subscription request = new Subscription()
+ .setTopic(topic.getPath())
+ .setAckDeadlineSeconds(ackDeadlineSeconds);
+ pubsub.projects()
+ .subscriptions()
+ .create(subscription.getPath(), request)
+ .execute(); // ignore Subscription result.
+ }
+
+ @Override
+ public void deleteSubscription(SubscriptionPath subscription) throws IOException {
+ pubsub.projects()
+ .subscriptions()
+ .delete(subscription.getPath())
+ .execute(); // ignore Empty result.
+ }
+
+ @Override
+ public List listSubscriptions(ProjectPath project, TopicPath topic)
+ throws IOException {
+ ListSubscriptionsResponse response = pubsub.projects()
+ .subscriptions()
+ .list(project.getPath())
+ .execute();
+ if (response.getSubscriptions() == null || response.getSubscriptions().isEmpty()) {
+ return ImmutableList.of();
+ }
+ List subscriptions = new ArrayList<>(response.getSubscriptions().size());
+ for (Subscription subscription : response.getSubscriptions()) {
+ if (subscription.getTopic().equals(topic.getPath())) {
+ subscriptions.add(subscriptionPathFromPath(subscription.getName()));
+ }
+ }
+ return subscriptions;
+ }
+
+ @Override
+ public int ackDeadlineSeconds(SubscriptionPath subscription) throws IOException {
+ Subscription response = pubsub.projects().subscriptions().get(subscription.getPath()).execute();
+ return response.getAckDeadlineSeconds();
+ }
+
+ @Override
+ public boolean isEOF() {
+ return false;
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubClient.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubClient.java
new file mode 100644
index 0000000000..46eaf006c1
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubClient.java
@@ -0,0 +1,544 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.api.client.util.DateTime;
+import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
+import com.google.common.base.Objects;
+import com.google.common.base.Strings;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import javax.annotation.Nullable;
+
+/**
+ * An (abstract) helper class for talking to Pubsub via an underlying transport.
+ */
+public abstract class PubsubClient implements Closeable {
+ /**
+ * Factory for creating clients.
+ */
+ public interface PubsubClientFactory extends Serializable {
+ /**
+ * Construct a new Pubsub client. It should be closed via {@link #close} in order
+ * to ensure tidy cleanup of underlying netty resources (or use the try-with-resources
+ * construct). Uses {@code options} to derive pubsub endpoints and application credentials.
+ * If non-{@literal null}, use {@code timestampLabel} and {@code idLabel} to store custom
+ * timestamps/ids within message metadata.
+ */
+ PubsubClient newClient(
+ @Nullable String timestampLabel,
+ @Nullable String idLabel,
+ DataflowPipelineOptions options) throws IOException;
+
+ /**
+ * Return the display name for this factory. Eg "Apiary", "gRPC".
+ */
+ String getKind();
+ }
+
+ /**
+ * Return timestamp as ms-since-unix-epoch corresponding to {@code timestamp}.
+ * Return {@literal null} if no timestamp could be found. Throw {@link IllegalArgumentException}
+ * if timestamp cannot be recognized.
+ */
+ @Nullable
+ private static Long asMsSinceEpoch(@Nullable String timestamp) {
+ if (Strings.isNullOrEmpty(timestamp)) {
+ return null;
+ }
+ try {
+ // Try parsing as milliseconds since epoch. Note there is no way to parse a
+ // string in RFC 3339 format here.
+ // Expected IllegalArgumentException if parsing fails; we use that to fall back
+ // to RFC 3339.
+ return Long.parseLong(timestamp);
+ } catch (IllegalArgumentException e1) {
+ // Try parsing as RFC3339 string. DateTime.parseRfc3339 will throw an
+ // IllegalArgumentException if parsing fails, and the caller should handle.
+ return DateTime.parseRfc3339(timestamp).getValue();
+ }
+ }
+
+ /**
+ * Return the timestamp (in ms since unix epoch) to use for a Pubsub message with {@code
+ * attributes} and {@code pubsubTimestamp}.
+ * If {@code timestampLabel} is non-{@literal null} then the message attributes must contain
+ * that label, and the value of that label will be taken as the timestamp.
+ * Otherwise the timestamp will be taken from the Pubsub publish timestamp {@code
+ * pubsubTimestamp}. Throw {@link IllegalArgumentException} if the timestamp cannot be
+ * recognized as a ms-since-unix-epoch or RFC3339 time.
+ *
+ * @throws IllegalArgumentException
+ */
+ protected static long extractTimestamp(
+ @Nullable String timestampLabel,
+ @Nullable String pubsubTimestamp,
+ @Nullable Map attributes) {
+ Long timestampMsSinceEpoch;
+ if (Strings.isNullOrEmpty(timestampLabel)) {
+ timestampMsSinceEpoch = asMsSinceEpoch(pubsubTimestamp);
+ checkArgument(timestampMsSinceEpoch != null,
+ "Cannot interpret PubSub publish timestamp: %s",
+ pubsubTimestamp);
+ } else {
+ String value = attributes == null ? null : attributes.get(timestampLabel);
+ checkArgument(value != null,
+ "PubSub message is missing a value for timestamp label %s",
+ timestampLabel);
+ timestampMsSinceEpoch = asMsSinceEpoch(value);
+ checkArgument(timestampMsSinceEpoch != null,
+ "Cannot interpret value of label %s as timestamp: %s",
+ timestampLabel, value);
+ }
+ return timestampMsSinceEpoch;
+ }
+
+ /**
+ * Path representing a cloud project id.
+ */
+ public static class ProjectPath implements Serializable {
+ private final String path;
+
+ ProjectPath(String path) {
+ this.path = path;
+ }
+
+ public String getPath() {
+ return path;
+ }
+
+ public String getId() {
+ String[] splits = path.split("/");
+ checkState(splits.length == 1, "Malformed project path %s", path);
+ return splits[1];
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ ProjectPath that = (ProjectPath) o;
+
+ return path.equals(that.path);
+
+ }
+
+ @Override
+ public int hashCode() {
+ return path.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return path;
+ }
+ }
+
+ public static ProjectPath projectPathFromPath(String path) {
+ return new ProjectPath(path);
+ }
+
+ public static ProjectPath projectPathFromId(String projectId) {
+ return new ProjectPath(String.format("projects/%s", projectId));
+ }
+
+ /**
+ * Path representing a Pubsub subscription.
+ */
+ public static class SubscriptionPath implements Serializable {
+ private final String path;
+
+ SubscriptionPath(String path) {
+ this.path = path;
+ }
+
+ public String getPath() {
+ return path;
+ }
+
+ public String getName() {
+ String[] splits = path.split("/");
+ checkState(splits.length == 4, "Malformed subscription path %s", path);
+ return splits[3];
+ }
+
+ public String getV1Beta1Path() {
+ String[] splits = path.split("/");
+ checkState(splits.length == 4, "Malformed subscription path %s", path);
+ return String.format("/subscriptions/%s/%s", splits[1], splits[3]);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ SubscriptionPath that = (SubscriptionPath) o;
+ return path.equals(that.path);
+ }
+
+ @Override
+ public int hashCode() {
+ return path.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return path;
+ }
+ }
+
+ public static SubscriptionPath subscriptionPathFromPath(String path) {
+ return new SubscriptionPath(path);
+ }
+
+ public static SubscriptionPath subscriptionPathFromName(
+ String projectId, String subscriptionName) {
+ return new SubscriptionPath(String.format("projects/%s/subscriptions/%s",
+ projectId, subscriptionName));
+ }
+
+ /**
+ * Path representing a Pubsub topic.
+ */
+ public static class TopicPath implements Serializable {
+ private final String path;
+
+ TopicPath(String path) {
+ this.path = path;
+ }
+
+ public String getPath() {
+ return path;
+ }
+
+ public String getName() {
+ String[] splits = path.split("/");
+ checkState(splits.length == 4, "Malformed topic path %s", path);
+ return splits[3];
+ }
+
+ public String getV1Beta1Path() {
+ String[] splits = path.split("/");
+ checkState(splits.length == 4, "Malformed topic path %s", path);
+ return String.format("/topics/%s/%s", splits[1], splits[3]);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ TopicPath topicPath = (TopicPath) o;
+ return path.equals(topicPath.path);
+ }
+
+ @Override
+ public int hashCode() {
+ return path.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return path;
+ }
+ }
+
+ public static TopicPath topicPathFromPath(String path) {
+ return new TopicPath(path);
+ }
+
+ public static TopicPath topicPathFromName(String projectId, String topicName) {
+ return new TopicPath(String.format("projects/%s/topics/%s", projectId, topicName));
+ }
+
+ /**
+ * A message to be sent to Pubsub.
+ * NOTE: This class is {@link Serializable} only to support the {@link PubsubTestClient}.
+ * Java serialization is never used for non-test clients.
+ */
+ public static class OutgoingMessage implements Serializable {
+ /**
+ * Underlying (encoded) element.
+ */
+ public final byte[] elementBytes;
+
+ /**
+ * Timestamp for element (ms since epoch).
+ */
+ public final long timestampMsSinceEpoch;
+
+ /**
+ * If using an id label, the record id to associate with this record's metadata so the receiver
+ * can reject duplicates. Otherwise {@literal null}.
+ */
+ @Nullable
+ public final String recordId;
+
+ public OutgoingMessage(
+ byte[] elementBytes, long timestampMsSinceEpoch, @Nullable String recordId) {
+ this.elementBytes = elementBytes;
+ this.timestampMsSinceEpoch = timestampMsSinceEpoch;
+ this.recordId = recordId;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("OutgoingMessage(%db, %dms)",
+ elementBytes.length, timestampMsSinceEpoch);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ OutgoingMessage that = (OutgoingMessage) o;
+
+ return timestampMsSinceEpoch == that.timestampMsSinceEpoch
+ && Arrays.equals(elementBytes, that.elementBytes)
+ && Objects.equal(recordId, that.recordId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(Arrays.hashCode(elementBytes), timestampMsSinceEpoch, recordId);
+ }
+ }
+
+ /**
+ * A message received from Pubsub.
+ *
NOTE: This class is {@link Serializable} only to support the {@link PubsubTestClient}.
+ * Java serialization is never used for non-test clients.
+ */
+ public static class IncomingMessage implements Serializable {
+ /**
+ * Underlying (encoded) element.
+ */
+ public final byte[] elementBytes;
+
+ /**
+ * Timestamp for element (ms since epoch). Either Pubsub's processing time,
+ * or the custom timestamp associated with the message.
+ */
+ public final long timestampMsSinceEpoch;
+
+ /**
+ * Timestamp (in system time) at which we requested the message (ms since epoch).
+ */
+ public final long requestTimeMsSinceEpoch;
+
+ /**
+ * Id to pass back to Pubsub to acknowledge receipt of this message.
+ */
+ public final String ackId;
+
+ /**
+ * Id to pass to the runner to distinguish this message from all others.
+ */
+ public final String recordId;
+
+ public IncomingMessage(
+ byte[] elementBytes,
+ long timestampMsSinceEpoch,
+ long requestTimeMsSinceEpoch,
+ String ackId,
+ String recordId) {
+ this.elementBytes = elementBytes;
+ this.timestampMsSinceEpoch = timestampMsSinceEpoch;
+ this.requestTimeMsSinceEpoch = requestTimeMsSinceEpoch;
+ this.ackId = ackId;
+ this.recordId = recordId;
+ }
+
+ public IncomingMessage withRequestTime(long requestTimeMsSinceEpoch) {
+ return new IncomingMessage(elementBytes, timestampMsSinceEpoch, requestTimeMsSinceEpoch,
+ ackId, recordId);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("IncomingMessage(%db, %dms)",
+ elementBytes.length, timestampMsSinceEpoch);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ IncomingMessage that = (IncomingMessage) o;
+
+ return timestampMsSinceEpoch == that.timestampMsSinceEpoch
+ && requestTimeMsSinceEpoch == that.requestTimeMsSinceEpoch
+ && ackId.equals(that.ackId)
+ && recordId.equals(that.recordId)
+ && Arrays.equals(elementBytes, that.elementBytes);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(Arrays.hashCode(elementBytes), timestampMsSinceEpoch,
+ requestTimeMsSinceEpoch,
+ ackId, recordId);
+ }
+ }
+
+ /**
+ * Publish {@code outgoingMessages} to Pubsub {@code topic}. Return number of messages
+ * published.
+ *
+ * @throws IOException
+ */
+ public abstract int publish(TopicPath topic, List outgoingMessages)
+ throws IOException;
+
+ /**
+ * Request the next batch of up to {@code batchSize} messages from {@code subscription}.
+ * Return the received messages, or empty collection if none were available. Does not
+ * wait for messages to arrive if {@code returnImmediately} is {@literal true}.
+ * Returned messages will record their request time as {@code requestTimeMsSinceEpoch}.
+ *
+ * @throws IOException
+ */
+ public abstract List pull(
+ long requestTimeMsSinceEpoch,
+ SubscriptionPath subscription,
+ int batchSize,
+ boolean returnImmediately)
+ throws IOException;
+
+ /**
+ * Acknowldege messages from {@code subscription} with {@code ackIds}.
+ *
+ * @throws IOException
+ */
+ public abstract void acknowledge(SubscriptionPath subscription, List ackIds)
+ throws IOException;
+
+ /**
+ * Modify the ack deadline for messages from {@code subscription} with {@code ackIds} to
+ * be {@code deadlineSeconds} from now.
+ *
+ * @throws IOException
+ */
+ public abstract void modifyAckDeadline(
+ SubscriptionPath subscription, List ackIds,
+ int deadlineSeconds) throws IOException;
+
+ /**
+ * Create {@code topic}.
+ *
+ * @throws IOException
+ */
+ public abstract void createTopic(TopicPath topic) throws IOException;
+
+ /*
+ * Delete {@code topic}.
+ *
+ * @throws IOException
+ */
+ public abstract void deleteTopic(TopicPath topic) throws IOException;
+
+ /**
+ * Return a list of topics for {@code project}.
+ *
+ * @throws IOException
+ */
+ public abstract List listTopics(ProjectPath project) throws IOException;
+
+ /**
+ * Create {@code subscription} to {@code topic}.
+ *
+ * @throws IOException
+ */
+ public abstract void createSubscription(
+ TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException;
+
+ /**
+ * Create a random subscription for {@code topic}. Return the {@link SubscriptionPath}. It
+ * is the responsibility of the caller to later delete the subscription.
+ *
+ * @throws IOException
+ */
+ public SubscriptionPath createRandomSubscription(
+ ProjectPath project, TopicPath topic, int ackDeadlineSeconds) throws IOException {
+ // Create a randomized subscription derived from the topic name.
+ String subscriptionName = topic.getName() + "_beam_" + ThreadLocalRandom.current().nextLong();
+ SubscriptionPath subscription =
+ PubsubClient.subscriptionPathFromName(project.getId(), subscriptionName);
+ createSubscription(topic, subscription, ackDeadlineSeconds);
+ return subscription;
+ }
+
+ /**
+ * Delete {@code subscription}.
+ *
+ * @throws IOException
+ */
+ public abstract void deleteSubscription(SubscriptionPath subscription) throws IOException;
+
+ /**
+ * Return a list of subscriptions for {@code topic} in {@code project}.
+ *
+ * @throws IOException
+ */
+ public abstract List listSubscriptions(ProjectPath project, TopicPath topic)
+ throws IOException;
+
+ /**
+ * Return the ack deadline, in seconds, for {@code subscription}.
+ *
+ * @throws IOException
+ */
+ public abstract int ackDeadlineSeconds(SubscriptionPath subscription) throws IOException;
+
+ /**
+ * Return {@literal true} if {@link #pull} will always return empty list. Actual clients
+ * will return {@literal false}. Test clients may return {@literal true} to signal that all
+ * expected messages have been pulled and the test may complete.
+ */
+ public abstract boolean isEOF();
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubGrpcClient.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubGrpcClient.java
new file mode 100644
index 0000000000..05ebfa3883
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubGrpcClient.java
@@ -0,0 +1,444 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.auth.oauth2.GoogleCredentials;
+import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.Timestamp;
+import com.google.pubsub.v1.AcknowledgeRequest;
+import com.google.pubsub.v1.DeleteSubscriptionRequest;
+import com.google.pubsub.v1.DeleteTopicRequest;
+import com.google.pubsub.v1.GetSubscriptionRequest;
+import com.google.pubsub.v1.ListSubscriptionsRequest;
+import com.google.pubsub.v1.ListSubscriptionsResponse;
+import com.google.pubsub.v1.ListTopicsRequest;
+import com.google.pubsub.v1.ListTopicsResponse;
+import com.google.pubsub.v1.ModifyAckDeadlineRequest;
+import com.google.pubsub.v1.PublishRequest;
+import com.google.pubsub.v1.PublishResponse;
+import com.google.pubsub.v1.PublisherGrpc;
+import com.google.pubsub.v1.PublisherGrpc.PublisherBlockingStub;
+import com.google.pubsub.v1.PubsubMessage;
+import com.google.pubsub.v1.PullRequest;
+import com.google.pubsub.v1.PullResponse;
+import com.google.pubsub.v1.ReceivedMessage;
+import com.google.pubsub.v1.SubscriberGrpc;
+import com.google.pubsub.v1.SubscriberGrpc.SubscriberBlockingStub;
+import com.google.pubsub.v1.Subscription;
+import com.google.pubsub.v1.Topic;
+
+import io.grpc.Channel;
+import io.grpc.ClientInterceptors;
+import io.grpc.ManagedChannel;
+import io.grpc.auth.ClientAuthInterceptor;
+import io.grpc.netty.GrpcSslContexts;
+import io.grpc.netty.NegotiationType;
+import io.grpc.netty.NettyChannelBuilder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import javax.annotation.Nullable;
+
+/**
+ * A helper class for talking to Pubsub via grpc.
+ *
+ * CAUTION: Currently uses the application default credentials and does not respect any
+ * credentials-related arguments in {@link DataflowPipelineOptions}.
+ */
+public class PubsubGrpcClient extends PubsubClient {
+ private static final String PUBSUB_ADDRESS = "pubsub.googleapis.com";
+ private static final int PUBSUB_PORT = 443;
+ // Will be needed when credentials are correctly constructed and scoped.
+ @SuppressWarnings("unused")
+ private static final List PUBSUB_SCOPES =
+ Collections.singletonList("https://www.googleapis.com/auth/pubsub");
+ private static final int LIST_BATCH_SIZE = 1000;
+
+ private static final int DEFAULT_TIMEOUT_S = 15;
+
+ private static class PubsubGrpcClientFactory implements PubsubClientFactory {
+ @Override
+ public PubsubClient newClient(
+ @Nullable String timestampLabel, @Nullable String idLabel, DataflowPipelineOptions options)
+ throws IOException {
+ ManagedChannel channel = NettyChannelBuilder
+ .forAddress(PUBSUB_ADDRESS, PUBSUB_PORT)
+ .negotiationType(NegotiationType.TLS)
+ .sslContext(GrpcSslContexts.forClient().ciphers(null).build())
+ .build();
+ // TODO: GcpOptions needs to support building com.google.auth.oauth2.Credentials from the
+ // various command line options. It currently only supports the older
+ // com.google.api.client.auth.oauth2.Credentials.
+ GoogleCredentials credentials = GoogleCredentials.getApplicationDefault();
+ return new PubsubGrpcClient(timestampLabel,
+ idLabel,
+ DEFAULT_TIMEOUT_S,
+ channel,
+ credentials,
+ null /* publisher stub */,
+ null /* subscriber stub */);
+ }
+
+ @Override
+ public String getKind() {
+ return "Grpc";
+ }
+ }
+
+ /**
+ * Factory for creating Pubsub clients using gRCP transport.
+ */
+ public static final PubsubClientFactory FACTORY = new PubsubGrpcClientFactory();
+
+ /**
+ * Timeout for grpc calls (in s).
+ */
+ private final int timeoutSec;
+
+ /**
+ * Underlying netty channel, or {@literal null} if closed.
+ */
+ @Nullable
+ private ManagedChannel publisherChannel;
+
+ /**
+ * Credentials determined from options and environment.
+ */
+ private final GoogleCredentials credentials;
+
+ /**
+ * Label to use for custom timestamps, or {@literal null} if should use Pubsub publish time
+ * instead.
+ */
+ @Nullable
+ private final String timestampLabel;
+
+ /**
+ * Label to use for custom ids, or {@literal null} if should use Pubsub provided ids.
+ */
+ @Nullable
+ private final String idLabel;
+
+
+ /**
+ * Cached stubs, or null if not cached.
+ */
+ @Nullable
+ private PublisherGrpc.PublisherBlockingStub cachedPublisherStub;
+ private SubscriberGrpc.SubscriberBlockingStub cachedSubscriberStub;
+
+ @VisibleForTesting
+ PubsubGrpcClient(
+ @Nullable String timestampLabel,
+ @Nullable String idLabel,
+ int timeoutSec,
+ ManagedChannel publisherChannel,
+ GoogleCredentials credentials,
+ PublisherGrpc.PublisherBlockingStub cachedPublisherStub,
+ SubscriberGrpc.SubscriberBlockingStub cachedSubscriberStub) {
+ this.timestampLabel = timestampLabel;
+ this.idLabel = idLabel;
+ this.timeoutSec = timeoutSec;
+ this.publisherChannel = publisherChannel;
+ this.credentials = credentials;
+ this.cachedPublisherStub = cachedPublisherStub;
+ this.cachedSubscriberStub = cachedSubscriberStub;
+ }
+
+ /**
+ * Gracefully close the underlying netty channel.
+ */
+ @Override
+ public void close() {
+ if (publisherChannel == null) {
+ // Already closed.
+ return;
+ }
+ // Can gc the underlying stubs.
+ cachedPublisherStub = null;
+ cachedSubscriberStub = null;
+ // Mark the client as having been closed before going further
+ // in case we have an exception from the channel.
+ ManagedChannel publisherChannel = this.publisherChannel;
+ this.publisherChannel = null;
+ // Gracefully shutdown the channel.
+ publisherChannel.shutdown();
+ if (timeoutSec > 0) {
+ try {
+ publisherChannel.awaitTermination(timeoutSec, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ // Ignore.
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
+
+ /**
+ * Return channel with interceptor for returning credentials.
+ */
+ private Channel newChannel() throws IOException {
+ checkState(publisherChannel != null, "PubsubGrpcClient has been closed");
+ ClientAuthInterceptor interceptor =
+ new ClientAuthInterceptor(credentials, Executors.newSingleThreadExecutor());
+ return ClientInterceptors.intercept(publisherChannel, interceptor);
+ }
+
+ /**
+ * Return a stub for making a publish request with a timeout.
+ */
+ private PublisherBlockingStub publisherStub() throws IOException {
+ if (cachedPublisherStub == null) {
+ cachedPublisherStub = PublisherGrpc.newBlockingStub(newChannel());
+ }
+ if (timeoutSec > 0) {
+ return cachedPublisherStub.withDeadlineAfter(timeoutSec, TimeUnit.SECONDS);
+ } else {
+ return cachedPublisherStub;
+ }
+ }
+
+ /**
+ * Return a stub for making a subscribe request with a timeout.
+ */
+ private SubscriberBlockingStub subscriberStub() throws IOException {
+ if (cachedSubscriberStub == null) {
+ cachedSubscriberStub = SubscriberGrpc.newBlockingStub(newChannel());
+ }
+ if (timeoutSec > 0) {
+ return cachedSubscriberStub.withDeadlineAfter(timeoutSec, TimeUnit.SECONDS);
+ } else {
+ return cachedSubscriberStub;
+ }
+ }
+
+ @Override
+ public int publish(TopicPath topic, List outgoingMessages)
+ throws IOException {
+ PublishRequest.Builder request = PublishRequest.newBuilder()
+ .setTopic(topic.getPath());
+ for (OutgoingMessage outgoingMessage : outgoingMessages) {
+ PubsubMessage.Builder message =
+ PubsubMessage.newBuilder()
+ .setData(ByteString.copyFrom(outgoingMessage.elementBytes));
+
+ if (timestampLabel != null) {
+ message.getMutableAttributes()
+ .put(timestampLabel, String.valueOf(outgoingMessage.timestampMsSinceEpoch));
+ }
+
+ if (idLabel != null && !Strings.isNullOrEmpty(outgoingMessage.recordId)) {
+ message.getMutableAttributes().put(idLabel, outgoingMessage.recordId);
+ }
+
+ request.addMessages(message);
+ }
+
+ PublishResponse response = publisherStub().publish(request.build());
+ return response.getMessageIdsCount();
+ }
+
+ @Override
+ public List pull(
+ long requestTimeMsSinceEpoch,
+ SubscriptionPath subscription,
+ int batchSize,
+ boolean returnImmediately) throws IOException {
+ PullRequest request = PullRequest.newBuilder()
+ .setSubscription(subscription.getPath())
+ .setReturnImmediately(returnImmediately)
+ .setMaxMessages(batchSize)
+ .build();
+ PullResponse response = subscriberStub().pull(request);
+ if (response.getReceivedMessagesCount() == 0) {
+ return ImmutableList.of();
+ }
+ List incomingMessages = new ArrayList<>(response.getReceivedMessagesCount());
+ for (ReceivedMessage message : response.getReceivedMessagesList()) {
+ PubsubMessage pubsubMessage = message.getMessage();
+ @Nullable Map attributes = pubsubMessage.getAttributes();
+
+ // Payload.
+ byte[] elementBytes = pubsubMessage.getData().toByteArray();
+
+ // Timestamp.
+ String pubsubTimestampString = null;
+ Timestamp timestampProto = pubsubMessage.getPublishTime();
+ if (timestampProto != null) {
+ pubsubTimestampString = String.valueOf(timestampProto.getSeconds()
+ + timestampProto.getNanos() / 1000L);
+ }
+ long timestampMsSinceEpoch =
+ extractTimestamp(timestampLabel, pubsubTimestampString, attributes);
+
+ // Ack id.
+ String ackId = message.getAckId();
+ checkState(!Strings.isNullOrEmpty(ackId));
+
+ // Record id, if any.
+ @Nullable String recordId = null;
+ if (idLabel != null && attributes != null) {
+ recordId = attributes.get(idLabel);
+ }
+ if (Strings.isNullOrEmpty(recordId)) {
+ // Fall back to the Pubsub provided message id.
+ recordId = pubsubMessage.getMessageId();
+ }
+
+ incomingMessages.add(new IncomingMessage(elementBytes, timestampMsSinceEpoch,
+ requestTimeMsSinceEpoch, ackId, recordId));
+ }
+ return incomingMessages;
+ }
+
+ @Override
+ public void acknowledge(SubscriptionPath subscription, List ackIds)
+ throws IOException {
+ AcknowledgeRequest request = AcknowledgeRequest.newBuilder()
+ .setSubscription(subscription.getPath())
+ .addAllAckIds(ackIds)
+ .build();
+ subscriberStub().acknowledge(request); // ignore Empty result.
+ }
+
+ @Override
+ public void modifyAckDeadline(
+ SubscriptionPath subscription, List ackIds, int deadlineSeconds)
+ throws IOException {
+ ModifyAckDeadlineRequest request =
+ ModifyAckDeadlineRequest.newBuilder()
+ .setSubscription(subscription.getPath())
+ .addAllAckIds(ackIds)
+ .setAckDeadlineSeconds(deadlineSeconds)
+ .build();
+ subscriberStub().modifyAckDeadline(request); // ignore Empty result.
+ }
+
+ @Override
+ public void createTopic(TopicPath topic) throws IOException {
+ Topic request = Topic.newBuilder()
+ .setName(topic.getPath())
+ .build();
+ publisherStub().createTopic(request); // ignore Topic result.
+ }
+
+ @Override
+ public void deleteTopic(TopicPath topic) throws IOException {
+ DeleteTopicRequest request = DeleteTopicRequest.newBuilder()
+ .setTopic(topic.getPath())
+ .build();
+ publisherStub().deleteTopic(request); // ignore Empty result.
+ }
+
+ @Override
+ public List listTopics(ProjectPath project) throws IOException {
+ ListTopicsRequest.Builder request =
+ ListTopicsRequest.newBuilder()
+ .setProject(project.getPath())
+ .setPageSize(LIST_BATCH_SIZE);
+ ListTopicsResponse response = publisherStub().listTopics(request.build());
+ if (response.getTopicsCount() == 0) {
+ return ImmutableList.of();
+ }
+ List topics = new ArrayList<>(response.getTopicsCount());
+ while (true) {
+ for (Topic topic : response.getTopicsList()) {
+ topics.add(topicPathFromPath(topic.getName()));
+ }
+ if (response.getNextPageToken().isEmpty()) {
+ break;
+ }
+ request.setPageToken(response.getNextPageToken());
+ response = publisherStub().listTopics(request.build());
+ }
+ return topics;
+ }
+
+ @Override
+ public void createSubscription(
+ TopicPath topic, SubscriptionPath subscription,
+ int ackDeadlineSeconds) throws IOException {
+ Subscription request = Subscription.newBuilder()
+ .setTopic(topic.getPath())
+ .setName(subscription.getPath())
+ .setAckDeadlineSeconds(ackDeadlineSeconds)
+ .build();
+ subscriberStub().createSubscription(request); // ignore Subscription result.
+ }
+
+ @Override
+ public void deleteSubscription(SubscriptionPath subscription) throws IOException {
+ DeleteSubscriptionRequest request =
+ DeleteSubscriptionRequest.newBuilder()
+ .setSubscription(subscription.getPath())
+ .build();
+ subscriberStub().deleteSubscription(request); // ignore Empty result.
+ }
+
+ @Override
+ public List listSubscriptions(ProjectPath project, TopicPath topic)
+ throws IOException {
+ ListSubscriptionsRequest.Builder request =
+ ListSubscriptionsRequest.newBuilder()
+ .setProject(project.getPath())
+ .setPageSize(LIST_BATCH_SIZE);
+ ListSubscriptionsResponse response = subscriberStub().listSubscriptions(request.build());
+ if (response.getSubscriptionsCount() == 0) {
+ return ImmutableList.of();
+ }
+ List subscriptions = new ArrayList<>(response.getSubscriptionsCount());
+ while (true) {
+ for (Subscription subscription : response.getSubscriptionsList()) {
+ if (subscription.getTopic().equals(topic.getPath())) {
+ subscriptions.add(subscriptionPathFromPath(subscription.getName()));
+ }
+ }
+ if (response.getNextPageToken().isEmpty()) {
+ break;
+ }
+ request.setPageToken(response.getNextPageToken());
+ response = subscriberStub().listSubscriptions(request.build());
+ }
+ return subscriptions;
+ }
+
+ @Override
+ public int ackDeadlineSeconds(SubscriptionPath subscription) throws IOException {
+ GetSubscriptionRequest request =
+ GetSubscriptionRequest.newBuilder()
+ .setSubscription(subscription.getPath())
+ .build();
+ Subscription response = subscriberStub().getSubscription(request);
+ return response.getAckDeadlineSeconds();
+ }
+
+ @Override
+ public boolean isEOF() {
+ return false;
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java
new file mode 100644
index 0000000000..df8c1a72e6
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java
@@ -0,0 +1,403 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.api.client.util.Clock;
+import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
+
+/**
+ * A (partial) implementation of {@link PubsubClient} for use by unit tests. Only suitable for
+ * testing {@link #publish}, {@link #pull}, {@link #acknowledge} and {@link #modifyAckDeadline}
+ * methods. Relies on statics to mimic the Pubsub service, though we try to hide that.
+ */
+public class PubsubTestClient extends PubsubClient {
+ /**
+ * Mimic the state of the simulated Pubsub 'service'.
+ *
+ * Note that the {@link PubsubTestClientFactory} is serialized/deserialized even when running
+ * test pipelines. Meanwhile it is valid for multiple {@link PubsubTestClient}s to be created
+ * from the same client factory and run in parallel. Thus we can't enforce aliasing of the
+ * following data structures over all clients and must resort to a static.
+ */
+ private static class State {
+ /**
+ * True if has been primed for a test but not yet validated.
+ */
+ boolean isActive;
+
+ /**
+ * Publish mode only: Only publish calls for this topic are allowed.
+ */
+ @Nullable
+ TopicPath expectedTopic;
+
+ /**
+ * Publish mode only: Messages yet to seen in a {@link #publish} call.
+ */
+ @Nullable
+ Set remainingExpectedOutgoingMessages;
+
+ /**
+ * Publish mode only: Messages which should throw when first sent to simulate transient publish
+ * failure.
+ */
+ @Nullable
+ Set remainingFailingOutgoingMessages;
+
+ /**
+ * Pull mode only: Clock from which to get current time.
+ */
+ @Nullable
+ Clock clock;
+
+ /**
+ * Pull mode only: Only pull calls for this subscription are allowed.
+ */
+ @Nullable
+ SubscriptionPath expectedSubscription;
+
+ /**
+ * Pull mode only: Timeout to simulate.
+ */
+ int ackTimeoutSec;
+
+ /**
+ * Pull mode only: Messages waiting to be received by a {@link #pull} call.
+ */
+ @Nullable
+ List remainingPendingIncomingMessages;
+
+ /**
+ * Pull mode only: Messages which have been returned from a {@link #pull} call and
+ * not yet ACKed by an {@link #acknowledge} call.
+ */
+ @Nullable
+ Map pendingAckIncomingMessages;
+
+ /**
+ * Pull mode only: When above messages are due to have their ACK deadlines expire.
+ */
+ @Nullable
+ Map ackDeadline;
+ }
+
+ private static final State STATE = new State();
+
+ /** Closing the factory will validate all expected messages were processed. */
+ public interface PubsubTestClientFactory extends PubsubClientFactory, Closeable {
+ }
+
+ /**
+ * Return a factory for testing publishers. Only one factory may be in-flight at a time.
+ * The factory must be closed when the test is complete, at which point final validation will
+ * occur.
+ */
+ public static PubsubTestClientFactory createFactoryForPublish(
+ final TopicPath expectedTopic,
+ final Iterable expectedOutgoingMessages,
+ final Iterable failingOutgoingMessages) {
+ synchronized (STATE) {
+ checkState(!STATE.isActive, "Test still in flight");
+ STATE.expectedTopic = expectedTopic;
+ STATE.remainingExpectedOutgoingMessages = Sets.newHashSet(expectedOutgoingMessages);
+ STATE.remainingFailingOutgoingMessages = Sets.newHashSet(failingOutgoingMessages);
+ STATE.isActive = true;
+ }
+ return new PubsubTestClientFactory() {
+ @Override
+ public PubsubClient newClient(
+ @Nullable String timestampLabel, @Nullable String idLabel,
+ DataflowPipelineOptions options)
+ throws IOException {
+ return new PubsubTestClient();
+ }
+
+ @Override
+ public String getKind() {
+ return "PublishTest";
+ }
+
+ @Override
+ public void close() {
+ synchronized (STATE) {
+ checkState(STATE.isActive, "No test still in flight");
+ checkState(STATE.remainingExpectedOutgoingMessages.isEmpty(),
+ "Still waiting for %s messages to be published",
+ STATE.remainingExpectedOutgoingMessages.size());
+ STATE.isActive = false;
+ STATE.remainingExpectedOutgoingMessages = null;
+ }
+ }
+ };
+ }
+
+ /**
+ * Return a factory for testing subscribers. Only one factory may be in-flight at a time.
+ * The factory must be closed when the test in complete
+ */
+ public static PubsubTestClientFactory createFactoryForPull(
+ final Clock clock,
+ final SubscriptionPath expectedSubscription,
+ final int ackTimeoutSec,
+ final Iterable expectedIncomingMessages) {
+ synchronized (STATE) {
+ checkState(!STATE.isActive, "Test still in flight");
+ STATE.clock = clock;
+ STATE.expectedSubscription = expectedSubscription;
+ STATE.ackTimeoutSec = ackTimeoutSec;
+ STATE.remainingPendingIncomingMessages = Lists.newArrayList(expectedIncomingMessages);
+ STATE.pendingAckIncomingMessages = new HashMap<>();
+ STATE.ackDeadline = new HashMap<>();
+ STATE.isActive = true;
+ }
+ return new PubsubTestClientFactory() {
+ @Override
+ public PubsubClient newClient(
+ @Nullable String timestampLabel, @Nullable String idLabel,
+ DataflowPipelineOptions options)
+ throws IOException {
+ return new PubsubTestClient();
+ }
+
+ @Override
+ public String getKind() {
+ return "PullTest";
+ }
+
+ @Override
+ public void close() {
+ synchronized (STATE) {
+ checkState(STATE.isActive, "No test still in flight");
+ checkState(STATE.remainingPendingIncomingMessages.isEmpty(),
+ "Still waiting for %s messages to be pulled",
+ STATE.remainingPendingIncomingMessages.size());
+ checkState(STATE.pendingAckIncomingMessages.isEmpty(),
+ "Still waiting for %s messages to be ACKed",
+ STATE.pendingAckIncomingMessages.size());
+ checkState(STATE.ackDeadline.isEmpty(),
+ "Still waiting for %s messages to be ACKed",
+ STATE.ackDeadline.size());
+ STATE.isActive = false;
+ STATE.remainingPendingIncomingMessages = null;
+ STATE.pendingAckIncomingMessages = null;
+ STATE.ackDeadline = null;
+ }
+ }
+ };
+ }
+
+ /**
+ * Return true if in pull mode.
+ */
+ private boolean inPullMode() {
+ checkState(STATE.isActive, "No test is active");
+ return STATE.expectedSubscription != null;
+ }
+
+ /**
+ * Return true if in publish mode.
+ */
+ private boolean inPublishMode() {
+ checkState(STATE.isActive, "No test is active");
+ return STATE.expectedTopic != null;
+ }
+
+ /**
+ * For subscription mode only:
+ * Track progression of time according to the {@link Clock} passed . This will simulate Pubsub
+ * expiring
+ * outstanding ACKs.
+ */
+ public void advance() {
+ synchronized (STATE) {
+ checkState(inPullMode(), "Can only advance in pull mode");
+ // Any messages who's ACKs timed out are available for re-pulling.
+ Iterator> deadlineItr = STATE.ackDeadline.entrySet().iterator();
+ while (deadlineItr.hasNext()) {
+ Map.Entry entry = deadlineItr.next();
+ if (entry.getValue() <= STATE.clock.currentTimeMillis()) {
+ STATE.remainingPendingIncomingMessages.add(
+ STATE.pendingAckIncomingMessages.remove(entry.getKey()));
+ deadlineItr.remove();
+ }
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public int publish(
+ TopicPath topic, List outgoingMessages) throws IOException {
+ synchronized (STATE) {
+ checkState(inPublishMode(), "Can only publish in publish mode");
+ checkState(topic.equals(STATE.expectedTopic), "Topic %s does not match expected %s", topic,
+ STATE.expectedTopic);
+ for (OutgoingMessage outgoingMessage : outgoingMessages) {
+ if (STATE.remainingFailingOutgoingMessages.remove(outgoingMessage)) {
+ throw new RuntimeException("Simulating failure for " + outgoingMessage);
+ }
+ checkState(STATE.remainingExpectedOutgoingMessages.remove(outgoingMessage),
+ "Unexpected outgoing message %s", outgoingMessage);
+ }
+ return outgoingMessages.size();
+ }
+ }
+
+ @Override
+ public List pull(
+ long requestTimeMsSinceEpoch, SubscriptionPath subscription, int batchSize,
+ boolean returnImmediately) throws IOException {
+ synchronized (STATE) {
+ checkState(inPullMode(), "Can only pull in pull mode");
+ long now = STATE.clock.currentTimeMillis();
+ checkState(requestTimeMsSinceEpoch == now,
+ "Simulated time %s does not match request time %s", now, requestTimeMsSinceEpoch);
+ checkState(subscription.equals(STATE.expectedSubscription),
+ "Subscription %s does not match expected %s", subscription,
+ STATE.expectedSubscription);
+ checkState(returnImmediately, "Pull only supported if returning immediately");
+
+ List incomingMessages = new ArrayList<>();
+ Iterator pendItr = STATE.remainingPendingIncomingMessages.iterator();
+ while (pendItr.hasNext()) {
+ IncomingMessage incomingMessage = pendItr.next();
+ pendItr.remove();
+ IncomingMessage incomingMessageWithRequestTime =
+ incomingMessage.withRequestTime(requestTimeMsSinceEpoch);
+ incomingMessages.add(incomingMessageWithRequestTime);
+ STATE.pendingAckIncomingMessages.put(incomingMessageWithRequestTime.ackId,
+ incomingMessageWithRequestTime);
+ STATE.ackDeadline.put(incomingMessageWithRequestTime.ackId,
+ requestTimeMsSinceEpoch + STATE.ackTimeoutSec * 1000);
+ if (incomingMessages.size() >= batchSize) {
+ break;
+ }
+ }
+ return incomingMessages;
+ }
+ }
+
+ @Override
+ public void acknowledge(
+ SubscriptionPath subscription,
+ List ackIds) throws IOException {
+ synchronized (STATE) {
+ checkState(inPullMode(), "Can only acknowledge in pull mode");
+ checkState(subscription.equals(STATE.expectedSubscription),
+ "Subscription %s does not match expected %s", subscription,
+ STATE.expectedSubscription);
+
+ for (String ackId : ackIds) {
+ checkState(STATE.ackDeadline.remove(ackId) != null,
+ "No message with ACK id %s is waiting for an ACK", ackId);
+ checkState(STATE.pendingAckIncomingMessages.remove(ackId) != null,
+ "No message with ACK id %s is waiting for an ACK", ackId);
+ }
+ }
+ }
+
+ @Override
+ public void modifyAckDeadline(
+ SubscriptionPath subscription, List ackIds, int deadlineSeconds) throws IOException {
+ synchronized (STATE) {
+ checkState(inPullMode(), "Can only modify ack deadline in pull mode");
+ checkState(subscription.equals(STATE.expectedSubscription),
+ "Subscription %s does not match expected %s", subscription,
+ STATE.expectedSubscription);
+
+ for (String ackId : ackIds) {
+ if (deadlineSeconds > 0) {
+ checkState(STATE.ackDeadline.remove(ackId) != null,
+ "No message with ACK id %s is waiting for an ACK", ackId);
+ checkState(STATE.pendingAckIncomingMessages.containsKey(ackId),
+ "No message with ACK id %s is waiting for an ACK", ackId);
+ STATE.ackDeadline.put(ackId, STATE.clock.currentTimeMillis() + deadlineSeconds * 1000);
+ } else {
+ checkState(STATE.ackDeadline.remove(ackId) != null,
+ "No message with ACK id %s is waiting for an ACK", ackId);
+ IncomingMessage message = STATE.pendingAckIncomingMessages.remove(ackId);
+ checkState(message != null, "No message with ACK id %s is waiting for an ACK", ackId);
+ STATE.remainingPendingIncomingMessages.add(message);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void createTopic(TopicPath topic) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void deleteTopic(TopicPath topic) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List listTopics(ProjectPath project) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void createSubscription(
+ TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void deleteSubscription(SubscriptionPath subscription) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List listSubscriptions(
+ ProjectPath project, TopicPath topic) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int ackDeadlineSeconds(SubscriptionPath subscription) throws IOException {
+ synchronized (STATE) {
+ return STATE.ackTimeoutSec;
+ }
+ }
+
+ @Override
+ public boolean isEOF() {
+ synchronized (STATE) {
+ checkState(inPullMode(), "Can only check EOF in pull mode");
+ return STATE.remainingPendingIncomingMessages.isEmpty();
+ }
+ }
+}
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java
index d792753524..75a93b3589 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/Transport.java
@@ -112,7 +112,11 @@ private static ApiComponents apiComponentsFromUrl(String urlString) {
*
* Note: this client's endpoint is not modified by the
* {@link DataflowPipelineDebugOptions#getApiRootUrl()} option.
+ *
+ * @deprecated Use an appropriate
+ * {@link com.google.cloud.dataflow.sdk.util.PubsubClient.PubsubClientFactory}
*/
+ @Deprecated
public static Pubsub.Builder
newPubsubClient(DataflowPipelineOptions options) {
return new Pubsub.Builder(getTransport(), getJsonFactory(),
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java
index dfe5da3457..bbba56187d 100644
--- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubIOTest.java
@@ -22,26 +22,19 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
-import com.google.api.client.testing.http.FixedClock;
-import com.google.api.client.util.Clock;
-import com.google.api.services.pubsub.model.PubsubMessage;
import com.google.cloud.dataflow.sdk.transforms.display.DataflowDisplayDataEvaluator;
import com.google.cloud.dataflow.sdk.transforms.display.DisplayData;
import com.google.cloud.dataflow.sdk.transforms.display.DisplayDataEvaluator;
import org.joda.time.Duration;
-import org.joda.time.Instant;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-import java.util.HashMap;
import java.util.Set;
-import javax.annotation.Nullable;
-
/**
* Tests for PubsubIO Read and Write transforms.
*/
@@ -92,154 +85,6 @@ public void testTopicValidationTooLong() throws Exception {
.toString());
}
- /**
- * Helper function that creates a {@link PubsubMessage} with the given timestamp registered as
- * an attribute with the specified label.
- *
- *
If {@code label} is {@code null}, then the attributes are {@code null}.
- *
- *
Else, if {@code timestamp} is {@code null}, then attributes are present but have no key for
- * the label.
- */
- private static PubsubMessage messageWithTimestamp(
- @Nullable String label, @Nullable String timestamp) {
- PubsubMessage message = new PubsubMessage();
- if (label == null) {
- message.setAttributes(null);
- return message;
- }
-
- message.setAttributes(new HashMap());
-
- if (timestamp == null) {
- return message;
- }
-
- message.getAttributes().put(label, timestamp);
- return message;
- }
-
- /**
- * Helper function that parses the given string to a timestamp through the PubSubIO plumbing.
- */
- private static Instant parseTimestamp(@Nullable String timestamp) {
- PubsubMessage message = messageWithTimestamp("mylabel", timestamp);
- return PubsubIO.assignMessageTimestamp(message, "mylabel", Clock.SYSTEM);
- }
-
- @Test
- public void noTimestampLabelReturnsNow() {
- final long time = 987654321L;
- Instant timestamp = PubsubIO.assignMessageTimestamp(
- messageWithTimestamp(null, null), null, new FixedClock(time));
-
- assertEquals(new Instant(time), timestamp);
- }
-
- @Test
- public void timestampLabelWithNullAttributesThrowsError() {
- PubsubMessage message = messageWithTimestamp(null, null);
- thrown.expect(RuntimeException.class);
- thrown.expectMessage("PubSub message is missing a timestamp in label: myLabel");
-
- PubsubIO.assignMessageTimestamp(message, "myLabel", Clock.SYSTEM);
- }
-
- @Test
- public void timestampLabelSetWithMissingAttributeThrowsError() {
- PubsubMessage message = messageWithTimestamp("notMyLabel", "ignored");
- thrown.expect(RuntimeException.class);
- thrown.expectMessage("PubSub message is missing a timestamp in label: myLabel");
-
- PubsubIO.assignMessageTimestamp(message, "myLabel", Clock.SYSTEM);
- }
-
- @Test
- public void timestampLabelParsesMillisecondsSinceEpoch() {
- Long millis = 1446162101123L;
- assertEquals(new Instant(millis), parseTimestamp(millis.toString()));
- }
-
- @Test
- public void timestampLabelParsesRfc3339Seconds() {
- String rfc3339 = "2015-10-29T23:41:41Z";
- assertEquals(Instant.parse(rfc3339), parseTimestamp(rfc3339));
- }
-
- @Test
- public void timestampLabelParsesRfc3339Tenths() {
- String rfc3339tenths = "2015-10-29T23:41:41.1Z";
- assertEquals(Instant.parse(rfc3339tenths), parseTimestamp(rfc3339tenths));
- }
-
- @Test
- public void timestampLabelParsesRfc3339Hundredths() {
- String rfc3339hundredths = "2015-10-29T23:41:41.12Z";
- assertEquals(Instant.parse(rfc3339hundredths), parseTimestamp(rfc3339hundredths));
- }
-
- @Test
- public void timestampLabelParsesRfc3339Millis() {
- String rfc3339millis = "2015-10-29T23:41:41.123Z";
- assertEquals(Instant.parse(rfc3339millis), parseTimestamp(rfc3339millis));
- }
-
- @Test
- public void timestampLabelParsesRfc3339Micros() {
- String rfc3339micros = "2015-10-29T23:41:41.123456Z";
- assertEquals(Instant.parse(rfc3339micros), parseTimestamp(rfc3339micros));
- // Note: micros part 456/1000 is dropped.
- assertEquals(Instant.parse("2015-10-29T23:41:41.123Z"), parseTimestamp(rfc3339micros));
- }
-
- @Test
- public void timestampLabelParsesRfc3339MicrosRounding() {
- String rfc3339micros = "2015-10-29T23:41:41.123999Z";
- assertEquals(Instant.parse(rfc3339micros), parseTimestamp(rfc3339micros));
- // Note: micros part 999/1000 is dropped, not rounded up.
- assertEquals(Instant.parse("2015-10-29T23:41:41.123Z"), parseTimestamp(rfc3339micros));
- }
-
- @Test
- public void timestampLabelWithInvalidFormatThrowsError() {
- thrown.expect(NumberFormatException.class);
- parseTimestamp("not-a-timestamp");
- }
-
- @Test
- public void timestampLabelWithInvalidFormat2ThrowsError() {
- thrown.expect(NumberFormatException.class);
- parseTimestamp("null");
- }
-
- @Test
- public void timestampLabelWithInvalidFormat3ThrowsError() {
- thrown.expect(NumberFormatException.class);
- parseTimestamp("2015-10");
- }
-
- @Test
- public void timestampLabelParsesRfc3339WithSmallYear() {
- // Google and JodaTime agree on dates after 1582-10-15, when the Gregorian Calendar was adopted
- // This is therefore a "small year" until this difference is reconciled.
- String rfc3339SmallYear = "1582-10-15T01:23:45.123Z";
- assertEquals(Instant.parse(rfc3339SmallYear), parseTimestamp(rfc3339SmallYear));
- }
-
- @Test
- public void timestampLabelParsesRfc3339WithLargeYear() {
- // Year 9999 in range.
- String rfc3339LargeYear = "9999-10-29T23:41:41.123999Z";
- assertEquals(Instant.parse(rfc3339LargeYear), parseTimestamp(rfc3339LargeYear));
- }
-
- @Test
- public void timestampLabelRfc3339WithTooLargeYearThrowsError() {
- thrown.expect(NumberFormatException.class);
- // Year 10000 out of range.
- parseTimestamp("10000-10-29T23:41:41.123999Z");
- }
-
@Test
public void testReadDisplayData() {
String topic = "projects/project/topics/topic";
@@ -292,7 +137,8 @@ public void testPrimitiveWriteDisplayData() {
@Test
public void testPrimitiveReadDisplayData() {
DisplayDataEvaluator evaluator = DataflowDisplayDataEvaluator.create();
- PubsubIO.Read.Bound read = PubsubIO.Read.topic("projects/project/topics/topic");
+ PubsubIO.Read.Bound read = PubsubIO.Read.topic("projects/project/topics/topic")
+ .maxNumRecords(1);
Set displayData = evaluator.displayDataForPrimitiveTransforms(read);
assertThat("PubsubIO.Read should include the topic in its primitive display data",
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSinkTest.java
new file mode 100644
index 0000000000..ef95a643f9
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSinkTest.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.io;
+
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.io.PubsubUnboundedSink.RecordIdMethod;
+import com.google.cloud.dataflow.sdk.testing.CoderProperties;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.util.PubsubClient;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.OutgoingMessage;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath;
+import com.google.cloud.dataflow.sdk.util.PubsubTestClient;
+import com.google.cloud.dataflow.sdk.util.PubsubTestClient.PubsubTestClientFactory;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.hash.Hashing;
+
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Test PubsubUnboundedSink.
+ */
+@RunWith(JUnit4.class)
+public class PubsubUnboundedSinkTest {
+ private static final TopicPath TOPIC = PubsubClient.topicPathFromName("testProject", "testTopic");
+ private static final String DATA = "testData";
+ private static final long TIMESTAMP = 1234L;
+ private static final String TIMESTAMP_LABEL = "timestamp";
+ private static final String ID_LABEL = "id";
+ private static final int NUM_SHARDS = 10;
+
+ private static class Stamp extends DoFn {
+ @Override
+ public void processElement(ProcessContext c) {
+ c.outputWithTimestamp(c.element(), new Instant(TIMESTAMP));
+ }
+ }
+
+ private String getRecordId(String data) {
+ return Hashing.murmur3_128().hashBytes(data.getBytes()).toString();
+ }
+
+ @Test
+ public void saneCoder() throws Exception {
+ OutgoingMessage message = new OutgoingMessage(DATA.getBytes(), TIMESTAMP, getRecordId(DATA));
+ CoderProperties.coderDecodeEncodeEqual(PubsubUnboundedSink.CODER, message);
+ CoderProperties.coderSerializable(PubsubUnboundedSink.CODER);
+ }
+
+ @Test
+ public void sendOneMessage() throws IOException {
+ List outgoing =
+ ImmutableList.of(new OutgoingMessage(DATA.getBytes(), TIMESTAMP, getRecordId(DATA)));
+ int batchSize = 1;
+ int batchBytes = 1;
+ try (PubsubTestClientFactory factory =
+ PubsubTestClient.createFactoryForPublish(TOPIC, outgoing,
+ ImmutableList.of())) {
+ PubsubUnboundedSink sink =
+ new PubsubUnboundedSink<>(factory, TOPIC, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL,
+ NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2),
+ RecordIdMethod.DETERMINISTIC);
+ TestPipeline p = TestPipeline.create();
+ p.apply(Create.of(ImmutableList.of(DATA)))
+ .apply(ParDo.of(new Stamp()))
+ .apply(sink);
+ p.run();
+ }
+ // The PubsubTestClientFactory will assert fail on close if the actual published
+ // message does not match the expected publish message.
+ }
+
+ @Test
+ public void sendMoreThanOneBatchByNumMessages() throws IOException {
+ List outgoing = new ArrayList<>();
+ List data = new ArrayList<>();
+ int batchSize = 2;
+ int batchBytes = 1000;
+ for (int i = 0; i < batchSize * 10; i++) {
+ String str = String.valueOf(i);
+ outgoing.add(new OutgoingMessage(str.getBytes(), TIMESTAMP, getRecordId(str)));
+ data.add(str);
+ }
+ try (PubsubTestClientFactory factory =
+ PubsubTestClient.createFactoryForPublish(TOPIC, outgoing,
+ ImmutableList.of())) {
+ PubsubUnboundedSink sink =
+ new PubsubUnboundedSink<>(factory, TOPIC, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL,
+ NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2),
+ RecordIdMethod.DETERMINISTIC);
+ TestPipeline p = TestPipeline.create();
+ p.apply(Create.of(data))
+ .apply(ParDo.of(new Stamp()))
+ .apply(sink);
+ p.run();
+ }
+ // The PubsubTestClientFactory will assert fail on close if the actual published
+ // message does not match the expected publish message.
+ }
+
+ @Test
+ public void sendMoreThanOneBatchByByteSize() throws IOException {
+ List outgoing = new ArrayList<>();
+ List data = new ArrayList<>();
+ int batchSize = 100;
+ int batchBytes = 10;
+ int n = 0;
+ while (n < batchBytes * 10) {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < batchBytes; i++) {
+ sb.append(String.valueOf(n));
+ }
+ String str = sb.toString();
+ outgoing.add(new OutgoingMessage(str.getBytes(), TIMESTAMP, getRecordId(str)));
+ data.add(str);
+ n += str.length();
+ }
+ try (PubsubTestClientFactory factory =
+ PubsubTestClient.createFactoryForPublish(TOPIC, outgoing,
+ ImmutableList.of())) {
+ PubsubUnboundedSink sink =
+ new PubsubUnboundedSink<>(factory, TOPIC, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL,
+ NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2),
+ RecordIdMethod.DETERMINISTIC);
+ TestPipeline p = TestPipeline.create();
+ p.apply(Create.of(data))
+ .apply(ParDo.of(new Stamp()))
+ .apply(sink);
+ p.run();
+ }
+ // The PubsubTestClientFactory will assert fail on close if the actual published
+ // message does not match the expected publish message.
+ }
+
+ // TODO: We would like to test that failed Pubsub publish calls cause the already assigned
+ // (and random) record ids to be reused. However that can't be done without the test runnner
+ // supporting retrying bundles.
+}
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java
new file mode 100644
index 0000000000..a0b05a5357
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java
@@ -0,0 +1,323 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.io;
+
+import static junit.framework.TestCase.assertFalse;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.google.api.client.util.Clock;
+
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.io.PubsubUnboundedSource.PubsubCheckpoint;
+import com.google.cloud.dataflow.sdk.io.PubsubUnboundedSource.PubsubReader;
+import com.google.cloud.dataflow.sdk.io.PubsubUnboundedSource.PubsubSource;
+import com.google.cloud.dataflow.sdk.testing.CoderProperties;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.util.CoderUtils;
+import com.google.cloud.dataflow.sdk.util.PubsubClient;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.IncomingMessage;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.SubscriptionPath;
+import com.google.cloud.dataflow.sdk.util.PubsubTestClient;
+import com.google.cloud.dataflow.sdk.util.PubsubTestClient.PubsubTestClientFactory;
+
+import com.google.common.collect.ImmutableList;
+
+import org.joda.time.Instant;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicLong;
+
+/**
+ * Test PubsubUnboundedSource.
+ */
+@RunWith(JUnit4.class)
+public class PubsubUnboundedSourceTest {
+ private static final SubscriptionPath SUBSCRIPTION =
+ PubsubClient.subscriptionPathFromName("testProject", "testSubscription");
+ private static final String DATA = "testData";
+ private static final long TIMESTAMP = 1234L;
+ private static final long REQ_TIME = 6373L;
+ private static final String TIMESTAMP_LABEL = "timestamp";
+ private static final String ID_LABEL = "id";
+ private static final String ACK_ID = "testAckId";
+ private static final String RECORD_ID = "testRecordId";
+ private static final int ACK_TIMEOUT_S = 60;
+
+ private AtomicLong now;
+ private Clock clock;
+ private PubsubTestClientFactory factory;
+ private PubsubSource primSource;
+
+ private void setupOneMessage(Iterable incoming) {
+ now = new AtomicLong(REQ_TIME);
+ clock = new Clock() {
+ @Override
+ public long currentTimeMillis() {
+ return now.get();
+ }
+ };
+ factory = PubsubTestClient.createFactoryForPull(clock, SUBSCRIPTION, ACK_TIMEOUT_S, incoming);
+ PubsubUnboundedSource source =
+ new PubsubUnboundedSource<>(clock, factory, null, null, SUBSCRIPTION, StringUtf8Coder.of(),
+ TIMESTAMP_LABEL, ID_LABEL);
+ primSource = new PubsubSource<>(source);
+ }
+
+ private void setupOneMessage() {
+ setupOneMessage(ImmutableList.of(
+ new IncomingMessage(DATA.getBytes(), TIMESTAMP, 0, ACK_ID, RECORD_ID)));
+ }
+
+ @After
+ public void after() throws IOException {
+ factory.close();
+ now = null;
+ clock = null;
+ primSource = null;
+ factory = null;
+ }
+
+ @Test
+ public void checkpointCoderIsSane() throws Exception {
+ setupOneMessage(ImmutableList.of());
+ CoderProperties.coderSerializable(primSource.getCheckpointMarkCoder());
+ // Since we only serialize/deserialize the 'notYetReadIds', and we don't want to make
+ // equals on checkpoints ignore those fields, we'll test serialization and deserialization
+ // of checkpoints in multipleReaders below.
+ }
+
+ @Test
+ public void readOneMessage() throws IOException {
+ setupOneMessage();
+ TestPipeline p = TestPipeline.create();
+ PubsubReader reader = primSource.createReader(p.getOptions(), null);
+ PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient();
+ // Read one message.
+ assertTrue(reader.start());
+ assertEquals(DATA, reader.getCurrent());
+ assertFalse(reader.advance());
+ // ACK the message.
+ PubsubCheckpoint checkpoint = reader.getCheckpointMark();
+ checkpoint.finalizeCheckpoint();
+ reader.close();
+ }
+
+ @Test
+ public void timeoutAckAndRereadOneMessage() throws IOException {
+ setupOneMessage();
+ TestPipeline p = TestPipeline.create();
+ PubsubReader reader = primSource.createReader(p.getOptions(), null);
+ PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient();
+ assertTrue(reader.start());
+ assertEquals(DATA, reader.getCurrent());
+ // Let the ACK deadline for the above expire.
+ now.addAndGet(65 * 1000);
+ pubsubClient.advance();
+ // We'll now receive the same message again.
+ assertTrue(reader.advance());
+ assertEquals(DATA, reader.getCurrent());
+ assertFalse(reader.advance());
+ // Now ACK the message.
+ PubsubCheckpoint checkpoint = reader.getCheckpointMark();
+ checkpoint.finalizeCheckpoint();
+ reader.close();
+ }
+
+ @Test
+ public void extendAck() throws IOException {
+ setupOneMessage();
+ TestPipeline p = TestPipeline.create();
+ PubsubReader reader = primSource.createReader(p.getOptions(), null);
+ PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient();
+ // Pull the first message but don't take a checkpoint for it.
+ assertTrue(reader.start());
+ assertEquals(DATA, reader.getCurrent());
+ // Extend the ack
+ now.addAndGet(55 * 1000);
+ pubsubClient.advance();
+ assertFalse(reader.advance());
+ // Extend the ack again
+ now.addAndGet(25 * 1000);
+ pubsubClient.advance();
+ assertFalse(reader.advance());
+ // Now ACK the message.
+ PubsubCheckpoint checkpoint = reader.getCheckpointMark();
+ checkpoint.finalizeCheckpoint();
+ reader.close();
+ }
+
+ @Test
+ public void timeoutAckExtensions() throws IOException {
+ setupOneMessage();
+ TestPipeline p = TestPipeline.create();
+ PubsubReader reader = primSource.createReader(p.getOptions(), null);
+ PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient();
+ // Pull the first message but don't take a checkpoint for it.
+ assertTrue(reader.start());
+ assertEquals(DATA, reader.getCurrent());
+ // Extend the ack.
+ now.addAndGet(55 * 1000);
+ pubsubClient.advance();
+ assertFalse(reader.advance());
+ // Let the ack expire.
+ for (int i = 0; i < 3; i++) {
+ now.addAndGet(25 * 1000);
+ pubsubClient.advance();
+ assertFalse(reader.advance());
+ }
+ // Wait for resend.
+ now.addAndGet(25 * 1000);
+ pubsubClient.advance();
+ // Reread the same message.
+ assertTrue(reader.advance());
+ assertEquals(DATA, reader.getCurrent());
+ // Now ACK the message.
+ PubsubCheckpoint checkpoint = reader.getCheckpointMark();
+ checkpoint.finalizeCheckpoint();
+ reader.close();
+ }
+
+ @Test
+ public void multipleReaders() throws IOException {
+ List incoming = new ArrayList<>();
+ for (int i = 0; i < 2; i++) {
+ String data = String.format("data_%d", i);
+ String ackid = String.format("ackid_%d", i);
+ incoming.add(new IncomingMessage(data.getBytes(), TIMESTAMP, 0, ackid, RECORD_ID));
+ }
+ setupOneMessage(incoming);
+ TestPipeline p = TestPipeline.create();
+ PubsubReader reader = primSource.createReader(p.getOptions(), null);
+ PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient();
+ // Consume two messages, only read one.
+ assertTrue(reader.start());
+ assertEquals("data_0", reader.getCurrent());
+
+ // Grab checkpoint.
+ PubsubCheckpoint checkpoint = reader.getCheckpointMark();
+ checkpoint.finalizeCheckpoint();
+ assertEquals(1, checkpoint.notYetReadIds.size());
+ assertEquals("ackid_1", checkpoint.notYetReadIds.get(0));
+
+ // Read second message.
+ assertTrue(reader.advance());
+ assertEquals("data_1", reader.getCurrent());
+
+ // Restore from checkpoint.
+ byte[] checkpointBytes =
+ CoderUtils.encodeToByteArray(primSource.getCheckpointMarkCoder(), checkpoint);
+ checkpoint = CoderUtils.decodeFromByteArray(primSource.getCheckpointMarkCoder(),
+ checkpointBytes);
+ assertEquals(1, checkpoint.notYetReadIds.size());
+ assertEquals("ackid_1", checkpoint.notYetReadIds.get(0));
+
+ // Re-read second message.
+ reader = primSource.createReader(p.getOptions(), checkpoint);
+ assertTrue(reader.start());
+ assertEquals("data_1", reader.getCurrent());
+
+ // We are done.
+ assertFalse(reader.advance());
+
+ // ACK final message.
+ checkpoint = reader.getCheckpointMark();
+ checkpoint.finalizeCheckpoint();
+ reader.close();
+ }
+
+ private long messageNumToTimestamp(int messageNum) {
+ return TIMESTAMP + messageNum * 100;
+ }
+
+ @Test
+ public void readManyMessages() throws IOException {
+ Map dataToMessageNum = new HashMap<>();
+
+ final int m = 97;
+ final int n = 10000;
+ List incoming = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ // Make the messages timestamps slightly out of order.
+ int messageNum = ((i / m) * m) + (m - 1) - (i % m);
+ String data = String.format("data_%d", messageNum);
+ dataToMessageNum.put(data, messageNum);
+ String recid = String.format("recordid_%d", messageNum);
+ String ackId = String.format("ackid_%d", messageNum);
+ incoming.add(new IncomingMessage(data.getBytes(), messageNumToTimestamp(messageNum), 0,
+ ackId, recid));
+ }
+ setupOneMessage(incoming);
+
+ TestPipeline p = TestPipeline.create();
+ PubsubReader reader = primSource.createReader(p.getOptions(), null);
+ PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient();
+
+ for (int i = 0; i < n; i++) {
+ if (i == 0) {
+ assertTrue(reader.start());
+ } else {
+ assertTrue(reader.advance());
+ }
+ // We'll checkpoint and ack within the 2min limit.
+ now.addAndGet(30);
+ pubsubClient.advance();
+ String data = reader.getCurrent();
+ Integer messageNum = dataToMessageNum.remove(data);
+ // No duplicate messages.
+ assertNotNull(messageNum);
+ // Preserve timestamp.
+ assertEquals(new Instant(messageNumToTimestamp(messageNum)), reader.getCurrentTimestamp());
+ // Preserve record id.
+ String recid = String.format("recordid_%d", messageNum);
+ assertArrayEquals(recid.getBytes(), reader.getCurrentRecordId());
+
+ if (i % 1000 == 999) {
+ // Estimated watermark can never get ahead of actual outstanding messages.
+ long watermark = reader.getWatermark().getMillis();
+ long minOutstandingTimestamp = Long.MAX_VALUE;
+ for (Integer outstandingMessageNum : dataToMessageNum.values()) {
+ minOutstandingTimestamp =
+ Math.min(minOutstandingTimestamp, messageNumToTimestamp(outstandingMessageNum));
+ }
+ assertThat(watermark, lessThanOrEqualTo(minOutstandingTimestamp));
+ // Ack messages, but only every other finalization.
+ PubsubCheckpoint checkpoint = reader.getCheckpointMark();
+ if (i % 2000 == 1999) {
+ checkpoint.finalizeCheckpoint();
+ }
+ }
+ }
+ // We are done.
+ assertFalse(reader.advance());
+ // We saw each message exactly once.
+ assertTrue(dataToMessageNum.isEmpty());
+ reader.close();
+ }
+}
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BucketingFunctionTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BucketingFunctionTest.java
new file mode 100644
index 0000000000..130b0bd39b
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/BucketingFunctionTest.java
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests {@link BucketingFunction}.
+ */
+@RunWith(JUnit4.class)
+public class BucketingFunctionTest {
+
+ private static final long BUCKET_WIDTH = 10;
+ private static final int SIGNIFICANT_BUCKETS = 2;
+ private static final int SIGNIFICANT_SAMPLES = 10;
+
+ private static final Combine.BinaryCombineLongFn SUM =
+ new Combine.BinaryCombineLongFn() {
+ @Override
+ public long apply(long left, long right) {
+ return left + right;
+ }
+
+ @Override
+ public long identity() {
+ return 0;
+ }
+ };
+
+ private BucketingFunction newFunc() {
+ return new
+ BucketingFunction(BUCKET_WIDTH, SIGNIFICANT_BUCKETS,
+ SIGNIFICANT_SAMPLES, SUM);
+ }
+
+ @Test
+ public void significantSamples() {
+ BucketingFunction f = newFunc();
+ assertFalse(f.isSignificant());
+ for (int i = 0; i < SIGNIFICANT_SAMPLES - 1; i++) {
+ f.add(0, 0);
+ assertFalse(f.isSignificant());
+ }
+ f.add(0, 0);
+ assertTrue(f.isSignificant());
+ }
+
+ @Test
+ public void significantBuckets() {
+ BucketingFunction f = newFunc();
+ assertFalse(f.isSignificant());
+ f.add(0, 0);
+ assertFalse(f.isSignificant());
+ f.add(BUCKET_WIDTH, 0);
+ assertTrue(f.isSignificant());
+ }
+
+ @Test
+ public void sum() {
+ BucketingFunction f = newFunc();
+ for (int i = 0; i < 100; i++) {
+ f.add(i, i);
+ assertEquals(((i + 1) * i) / 2, f.get());
+ }
+ }
+
+ @Test
+ public void movingSum() {
+ BucketingFunction f = newFunc();
+ int lost = 0;
+ for (int i = 0; i < 200; i++) {
+ f.add(i, 1);
+ if (i >= 100) {
+ f.remove(i - 100);
+ if (i % BUCKET_WIDTH == BUCKET_WIDTH - 1) {
+ lost += BUCKET_WIDTH;
+ }
+ }
+ assertEquals(i + 1 - lost, f.get());
+ }
+ }
+}
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MovingFunctionTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MovingFunctionTest.java
new file mode 100644
index 0000000000..998d2ba0dc
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/MovingFunctionTest.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import com.google.cloud.dataflow.sdk.transforms.Combine;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests {@link MovingFunction}.
+ */
+@RunWith(JUnit4.class)
+public class MovingFunctionTest {
+
+ private static final long SAMPLE_PERIOD = 100;
+ private static final long SAMPLE_UPDATE = 10;
+ private static final int SIGNIFICANT_BUCKETS = 2;
+ private static final int SIGNIFICANT_SAMPLES = 10;
+
+ private static final Combine.BinaryCombineLongFn SUM =
+ new Combine.BinaryCombineLongFn() {
+ @Override
+ public long apply(long left, long right) {
+ return left + right;
+ }
+
+ @Override
+ public long identity() {
+ return 0;
+ }
+ };
+
+ private MovingFunction newFunc() {
+ return new
+ MovingFunction(SAMPLE_PERIOD, SAMPLE_UPDATE, SIGNIFICANT_BUCKETS,
+ SIGNIFICANT_SAMPLES, SUM);
+
+ }
+
+ @Test
+ public void significantSamples() {
+ MovingFunction f = newFunc();
+ assertFalse(f.isSignificant());
+ for (int i = 0; i < SIGNIFICANT_SAMPLES - 1; i++) {
+ f.add(0, 0);
+ assertFalse(f.isSignificant());
+ }
+ f.add(0, 0);
+ assertTrue(f.isSignificant());
+ }
+
+ @Test
+ public void significantBuckets() {
+ MovingFunction f = newFunc();
+ assertFalse(f.isSignificant());
+ f.add(0, 0);
+ assertFalse(f.isSignificant());
+ f.add(SAMPLE_UPDATE, 0);
+ assertTrue(f.isSignificant());
+ }
+
+ @Test
+ public void sum() {
+ MovingFunction f = newFunc();
+ for (int i = 0; i < SAMPLE_PERIOD; i++) {
+ f.add(i, i);
+ assertEquals(((i + 1) * i) / 2, f.get(i));
+ }
+ }
+
+ @Test
+ public void movingSum() {
+ MovingFunction f = newFunc();
+ int lost = 0;
+ for (int i = 0; i < SAMPLE_PERIOD * 2; i++) {
+ f.add(i , 1);
+ if (i >= SAMPLE_PERIOD) {
+ if (i % SAMPLE_UPDATE == 0) {
+ lost += SAMPLE_UPDATE;
+ }
+ }
+ assertEquals(i + 1 - lost, f.get(i));
+ }
+ }
+
+ @Test
+ public void jumpingSum() {
+ MovingFunction f = newFunc();
+ f.add(0, 1);
+ f.add(SAMPLE_PERIOD - 1, 1);
+ assertEquals(2, f.get(SAMPLE_PERIOD - 1));
+ assertEquals(1, f.get(SAMPLE_PERIOD + 3 * SAMPLE_UPDATE));
+ assertEquals(0, f.get(SAMPLE_PERIOD * 2));
+ }
+}
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PubsubApiaryClientTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PubsubApiaryClientTest.java
new file mode 100644
index 0000000000..f329347684
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PubsubApiaryClientTest.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.api.services.pubsub.Pubsub;
+import com.google.api.services.pubsub.model.PublishRequest;
+import com.google.api.services.pubsub.model.PublishResponse;
+import com.google.api.services.pubsub.model.PubsubMessage;
+import com.google.api.services.pubsub.model.PullRequest;
+import com.google.api.services.pubsub.model.PullResponse;
+import com.google.api.services.pubsub.model.ReceivedMessage;
+
+import com.google.cloud.dataflow.sdk.util.PubsubClient.IncomingMessage;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.OutgoingMessage;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.SubscriptionPath;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Tests for PubsubApiaryClient.
+ */
+public class PubsubApiaryClientTest {
+ private Pubsub mockPubsub;
+ private PubsubClient client;
+
+ private static final TopicPath TOPIC = PubsubClient.topicPathFromName("testProject", "testTopic");
+ private static final SubscriptionPath SUBSCRIPTION =
+ PubsubClient.subscriptionPathFromName("testProject", "testSubscription");
+ private static final long REQ_TIME = 1234L;
+ private static final long PUB_TIME = 3456L;
+ private static final long MESSAGE_TIME = 6789L;
+ private static final String TIMESTAMP_LABEL = "timestamp";
+ private static final String ID_LABEL = "id";
+ private static final String MESSAGE_ID = "testMessageId";
+ private static final String DATA = "testData";
+ private static final String RECORD_ID = "testRecordId";
+ private static final String ACK_ID = "testAckId";
+
+ @Before
+ public void setup() throws IOException {
+ mockPubsub = Mockito.mock(Pubsub.class, Mockito.RETURNS_DEEP_STUBS);
+ client = new PubsubApiaryClient(TIMESTAMP_LABEL, ID_LABEL, mockPubsub);
+ }
+
+ @After
+ public void teardown() throws IOException {
+ client.close();
+ client = null;
+ mockPubsub = null;
+ }
+
+ @Test
+ public void pullOneMessage() throws IOException {
+ String expectedSubscription = SUBSCRIPTION.getPath();
+ PullRequest expectedRequest =
+ new PullRequest().setReturnImmediately(true).setMaxMessages(10);
+ PubsubMessage expectedPubsubMessage = new PubsubMessage()
+ .setMessageId(MESSAGE_ID)
+ .encodeData(DATA.getBytes())
+ .setPublishTime(String.valueOf(PUB_TIME))
+ .setAttributes(
+ ImmutableMap.of(TIMESTAMP_LABEL, String.valueOf(MESSAGE_TIME),
+ ID_LABEL, RECORD_ID));
+ ReceivedMessage expectedReceivedMessage =
+ new ReceivedMessage().setMessage(expectedPubsubMessage)
+ .setAckId(ACK_ID);
+ PullResponse expectedResponse =
+ new PullResponse().setReceivedMessages(ImmutableList.of(expectedReceivedMessage));
+ Mockito.when(mockPubsub.projects()
+ .subscriptions()
+ .pull(expectedSubscription, expectedRequest)
+ .execute())
+ .thenReturn(expectedResponse);
+ List acutalMessages = client.pull(REQ_TIME, SUBSCRIPTION, 10, true);
+ assertEquals(1, acutalMessages.size());
+ IncomingMessage actualMessage = acutalMessages.get(0);
+ assertEquals(ACK_ID, actualMessage.ackId);
+ assertEquals(DATA, new String(actualMessage.elementBytes));
+ assertEquals(RECORD_ID, actualMessage.recordId);
+ assertEquals(REQ_TIME, actualMessage.requestTimeMsSinceEpoch);
+ assertEquals(MESSAGE_TIME, actualMessage.timestampMsSinceEpoch);
+ }
+
+ @Test
+ public void publishOneMessage() throws IOException {
+ String expectedTopic = TOPIC.getPath();
+ PubsubMessage expectedPubsubMessage = new PubsubMessage()
+ .encodeData(DATA.getBytes())
+ .setAttributes(
+ ImmutableMap.of(TIMESTAMP_LABEL, String.valueOf(MESSAGE_TIME),
+ ID_LABEL, RECORD_ID));
+ PublishRequest expectedRequest = new PublishRequest()
+ .setMessages(ImmutableList.of(expectedPubsubMessage));
+ PublishResponse expectedResponse = new PublishResponse()
+ .setMessageIds(ImmutableList.of(MESSAGE_ID));
+ Mockito.when(mockPubsub.projects()
+ .topics()
+ .publish(expectedTopic, expectedRequest)
+ .execute())
+ .thenReturn(expectedResponse);
+ OutgoingMessage actualMessage = new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME, RECORD_ID);
+ int n = client.publish(TOPIC, ImmutableList.of(actualMessage));
+ assertEquals(1, n);
+ }
+}
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PubsubClientTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PubsubClientTest.java
new file mode 100644
index 0000000000..44ed022f36
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/PubsubClientTest.java
@@ -0,0 +1,187 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.cloud.dataflow.sdk.util;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.cloud.dataflow.sdk.util.PubsubClient.ProjectPath;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.SubscriptionPath;
+import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath;
+
+import com.google.common.collect.ImmutableMap;
+
+import org.joda.time.Instant;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.util.Map;
+
+/**
+ * Tests for helper classes and methods in PubsubClient.
+ */
+public class PubsubClientTest {
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ //
+ // Timestamp handling
+ //
+
+ private long parse(String timestamp) {
+ Map map = ImmutableMap.of("myLabel", timestamp);
+ return PubsubClient.extractTimestamp("myLabel", null, map);
+ }
+
+ private void roundTripRfc339(String timestamp) {
+ assertEquals(Instant.parse(timestamp).getMillis(), parse(timestamp));
+ }
+
+ private void truncatedRfc339(String timestamp, String truncatedTimestmap) {
+ assertEquals(Instant.parse(truncatedTimestmap).getMillis(), parse(timestamp));
+ }
+
+ @Test
+ public void noTimestampLabelReturnsPubsubPublish() {
+ final long time = 987654321L;
+ long timestamp = PubsubClient.extractTimestamp(null, String.valueOf(time), null);
+ assertEquals(time, timestamp);
+ }
+
+ @Test
+ public void noTimestampLabelAndInvalidPubsubPublishThrowsError() {
+ thrown.expect(NumberFormatException.class);
+ PubsubClient.extractTimestamp(null, "not-a-date", null);
+ }
+
+ @Test
+ public void timestampLabelWithNullAttributesThrowsError() {
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("PubSub message is missing a value for timestamp label myLabel");
+ PubsubClient.extractTimestamp("myLabel", null, null);
+ }
+
+ @Test
+ public void timestampLabelSetWithMissingAttributeThrowsError() {
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("PubSub message is missing a value for timestamp label myLabel");
+ Map