perElement());
+
+ return wordCounts;
+ }
+ }
+
+ /**
+ * Options supported by {@link WordCount}.
+ *
+ * Concept #4: Defining your own configuration options. Here, you can add your own arguments
+ * to be processed by the command-line parser, and specify default values for them. You can then
+ * access the options values in your pipeline code.
+ *
+ *
Inherits standard configuration options.
+ */
+ public static interface WordCountOptions extends PipelineOptions {
+ @Description("Path of the file to read from")
+ @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt")
+ String getInputFile();
+ void setInputFile(String value);
+
+ @Description("Path of the file to write to")
+ String getOutput();
+ void setOutput(String value);
+ }
+
+ public static void main(String[] args) {
+ WordCountOptions options = PipelineOptionsFactory.fromArgs(args).withValidation()
+ .as(WordCountOptions.class);
+ Pipeline p = Pipeline.create(options);
+
+ // Concepts #2 and #3: Our pipeline applies the composite CountWords transform, and passes the
+ // static FormatAsTextFn() to the ParDo transform.
+ //TODO: remove withoutValidation once possible
+ p.apply("ReadLines", TextIO.Read.from(options.getInputFile()).withoutValidation())
+ .apply(new CountWords())
+ .apply(MapElements.via(new FormatAsTextFn()))
+ .apply("WriteCounts", TextIO.Write.to(options.getOutput()));
+
+ p.run();
+ }
+}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
index 6a3edd7e3cd5..6f5ce5e49bfb 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
@@ -21,19 +21,14 @@
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertThat;
+import org.apache.beam.runners.spark.examples.WordCount;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.transforms.Aggregator;
-import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.Sum;
-import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.values.PCollection;
import com.google.common.collect.ImmutableSet;
@@ -48,7 +43,6 @@
import java.util.Arrays;
import java.util.List;
import java.util.Set;
-import java.util.regex.Pattern;
/**
* Simple word count test.
@@ -68,7 +62,8 @@ public void testInMem() throws Exception {
Pipeline p = Pipeline.create(options);
PCollection inputWords = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder
.of()));
- PCollection output = inputWords.apply(new CountWords());
+ PCollection output = inputWords.apply(new WordCount.CountWords())
+ .apply(MapElements.via(new WordCount.FormatAsTextFn()));
PAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET);
@@ -86,7 +81,8 @@ public void testOutputFile() throws Exception {
Pipeline p = Pipeline.create(options);
PCollection inputWords = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder
.of()));
- PCollection output = inputWords.apply(new CountWords());
+ PCollection output = inputWords.apply(new WordCount.CountWords())
+ .apply(MapElements.via(new WordCount.FormatAsTextFn()));
File outputFile = testFolder.newFile();
output.apply("WriteCounts", TextIO.Write.to(outputFile.getAbsolutePath()).withoutSharding());
@@ -97,64 +93,4 @@ public void testOutputFile() throws Exception {
assertThat(Sets.newHashSet(FileUtils.readLines(outputFile)),
containsInAnyOrder(EXPECTED_COUNT_SET.toArray()));
}
-
- /**
- * A DoFn that tokenizes lines of text into individual words.
- */
- static class ExtractWordsFn extends DoFn {
- private static final Pattern WORD_BOUNDARY = Pattern.compile("[^a-zA-Z']+");
- private final Aggregator emptyLines =
- createAggregator("emptyLines", new Sum.SumLongFn());
-
- @Override
- public void processElement(ProcessContext c) {
- // Split the line into words.
- String[] words = WORD_BOUNDARY.split(c.element());
-
- // Keep track of the number of lines without any words encountered while tokenizing.
- // This aggregator is visible in the monitoring UI when run using DataflowRunner.
- if (words.length == 0) {
- emptyLines.addValue(1L);
- }
-
- // Output each word encountered into the output PCollection.
- for (String word : words) {
- if (!word.isEmpty()) {
- c.output(word);
- }
- }
- }
- }
-
- /**
- * A DoFn that converts a Word and Count into a printable string.
- */
- private static class FormatCountsFn extends DoFn, String> {
- @Override
- public void processElement(ProcessContext c) {
- c.output(c.element().getKey() + ": " + c.element().getValue());
- }
- }
-
- /**
- * A {@link PTransform} counting words.
- */
- public static class CountWords extends PTransform, PCollection> {
- @Override
- public PCollection apply(PCollection lines) {
-
- // Convert lines of text into individual words.
- PCollection words = lines.apply(
- ParDo.of(new ExtractWordsFn()));
-
- // Count the number of times each word occurs.
- PCollection> wordCounts =
- words.apply(Count.perElement());
-
- // Format each word and count into a printable string.
-
- return wordCounts.apply(ParDo.of(new FormatCountsFn()));
- }
-
- }
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/TfIdfTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/TfIdfTest.java
index df78338d4269..d1f8d125bdad 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/TfIdfTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/TfIdfTest.java
@@ -18,18 +18,30 @@
package org.apache.beam.runners.spark;
-import org.apache.beam.examples.complete.TfIdf;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringDelegateCoder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Keys;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.RemoveDuplicates;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.join.CoGbkResult;
+import org.apache.beam.sdk.transforms.join.CoGroupByKey;
+import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.net.URI;
import java.util.Arrays;
@@ -52,7 +64,7 @@ public void testTfIdf() throws Exception {
KV.of(new URI("x"), "a b c d"),
KV.of(new URI("y"), "a b c"),
KV.of(new URI("z"), "a m n")))
- .apply(new TfIdf.ComputeTfIdf());
+ .apply(new ComputeTfIdf());
PCollection words = wordToUriAndTfIdf
.apply(Keys.create())
@@ -64,4 +76,187 @@ public void testTfIdf() throws Exception {
res.close();
}
+ /**
+ * Duplicated to avoid dependency on beam-examlpes.
+ */
+ public static class ComputeTfIdf
+ extends PTransform>, PCollection>>> {
+ public ComputeTfIdf() { }
+
+ @Override
+ public PCollection>> apply(
+ PCollection> uriToContent) {
+
+ // Compute the total number of documents, and
+ // prepare this singleton PCollectionView for
+ // use as a side input.
+ final PCollectionView totalDocuments =
+ uriToContent
+ .apply("GetURIs", Keys.create())
+ .apply("RemoveDuplicateDocs", RemoveDuplicates.create())
+ .apply(Count.globally())
+ .apply(View.asSingleton());
+
+ // Create a collection of pairs mapping a URI to each
+ // of the words in the document associated with that that URI.
+ PCollection> uriToWords = uriToContent
+ .apply("SplitWords", ParDo.of(
+ new DoFn, KV>() {
+ @Override
+ public void processElement(ProcessContext c) {
+ URI uri = c.element().getKey();
+ String line = c.element().getValue();
+ for (String word : line.split("\\W+")) {
+ // Log INFO messages when the word “love” is found.
+ if (word.toLowerCase().equals("love")) {
+ LOG.info("Found {}", word.toLowerCase());
+ }
+
+ if (!word.isEmpty()) {
+ c.output(KV.of(uri, word.toLowerCase()));
+ }
+ }
+ }
+ }));
+
+ // Compute a mapping from each word to the total
+ // number of documents in which it appears.
+ PCollection> wordToDocCount = uriToWords
+ .apply("RemoveDuplicateWords", RemoveDuplicates.>create())
+ .apply(Values.create())
+ .apply("CountDocs", Count.perElement());
+
+ // Compute a mapping from each URI to the total
+ // number of words in the document associated with that URI.
+ PCollection> uriToWordTotal = uriToWords
+ .apply("GetURIs2", Keys.create())
+ .apply("CountWords", Count.perElement());
+
+ // Count, for each (URI, word) pair, the number of
+ // occurrences of that word in the document associated
+ // with the URI.
+ PCollection, Long>> uriAndWordToCount = uriToWords
+ .apply("CountWordDocPairs", Count.>perElement());
+
+ // Adjust the above collection to a mapping from
+ // (URI, word) pairs to counts into an isomorphic mapping
+ // from URI to (word, count) pairs, to prepare for a join
+ // by the URI key.
+ PCollection>> uriToWordAndCount = uriAndWordToCount
+ .apply("ShiftKeys", ParDo.of(
+ new DoFn, Long>, KV>>() {
+ @Override
+ public void processElement(ProcessContext c) {
+ URI uri = c.element().getKey().getKey();
+ String word = c.element().getKey().getValue();
+ Long occurrences = c.element().getValue();
+ c.output(KV.of(uri, KV.of(word, occurrences)));
+ }
+ }));
+
+ // Prepare to join the mapping of URI to (word, count) pairs with
+ // the mapping of URI to total word counts, by associating
+ // each of the input PCollection> with
+ // a tuple tag. Each input must have the same key type, URI
+ // in this case. The type parameter of the tuple tag matches
+ // the types of the values for each collection.
+ final TupleTag wordTotalsTag = new TupleTag();
+ final TupleTag> wordCountsTag = new TupleTag>();
+ KeyedPCollectionTuple coGbkInput = KeyedPCollectionTuple
+ .of(wordTotalsTag, uriToWordTotal)
+ .and(wordCountsTag, uriToWordAndCount);
+
+ // Perform a CoGroupByKey (a sort of pre-join) on the prepared
+ // inputs. This yields a mapping from URI to a CoGbkResult
+ // (CoGroupByKey Result). The CoGbkResult is a mapping
+ // from the above tuple tags to the values in each input
+ // associated with a particular URI. In this case, each
+ // KV group a URI with the total number of
+ // words in that document as well as all the (word, count)
+ // pairs for particular words.
+ PCollection> uriToWordAndCountAndTotal = coGbkInput
+ .apply("CoGroupByUri", CoGroupByKey.create());
+
+ // Compute a mapping from each word to a (URI, term frequency)
+ // pair for each URI. A word's term frequency for a document
+ // is simply the number of times that word occurs in the document
+ // divided by the total number of words in the document.
+ PCollection>> wordToUriAndTf = uriToWordAndCountAndTotal
+ .apply("ComputeTermFrequencies", ParDo.of(
+ new DoFn, KV>>() {
+ @Override
+ public void processElement(ProcessContext c) {
+ URI uri = c.element().getKey();
+ Long wordTotal = c.element().getValue().getOnly(wordTotalsTag);
+
+ for (KV wordAndCount
+ : c.element().getValue().getAll(wordCountsTag)) {
+ String word = wordAndCount.getKey();
+ Long wordCount = wordAndCount.getValue();
+ Double termFrequency = wordCount.doubleValue() / wordTotal.doubleValue();
+ c.output(KV.of(word, KV.of(uri, termFrequency)));
+ }
+ }
+ }));
+
+ // Compute a mapping from each word to its document frequency.
+ // A word's document frequency in a corpus is the number of
+ // documents in which the word appears divided by the total
+ // number of documents in the corpus. Note how the total number of
+ // documents is passed as a side input; the same value is
+ // presented to each invocation of the DoFn.
+ PCollection> wordToDf = wordToDocCount
+ .apply("ComputeDocFrequencies", ParDo
+ .withSideInputs(totalDocuments)
+ .of(new DoFn, KV>() {
+ @Override
+ public void processElement(ProcessContext c) {
+ String word = c.element().getKey();
+ Long documentCount = c.element().getValue();
+ Long documentTotal = c.sideInput(totalDocuments);
+ Double documentFrequency = documentCount.doubleValue()
+ / documentTotal.doubleValue();
+
+ c.output(KV.of(word, documentFrequency));
+ }
+ }));
+
+ // Join the term frequency and document frequency
+ // collections, each keyed on the word.
+ final TupleTag> tfTag = new TupleTag>();
+ final TupleTag dfTag = new TupleTag();
+ PCollection> wordToUriAndTfAndDf = KeyedPCollectionTuple
+ .of(tfTag, wordToUriAndTf)
+ .and(dfTag, wordToDf)
+ .apply(CoGroupByKey.create());
+
+ // Compute a mapping from each word to a (URI, TF-IDF) score
+ // for each URI. There are a variety of definitions of TF-IDF
+ // ("term frequency - inverse document frequency") score;
+ // here we use a basic version that is the term frequency
+ // divided by the log of the document frequency.
+ return wordToUriAndTfAndDf
+ .apply("ComputeTfIdf", ParDo.of(
+ new DoFn, KV>>() {
+ @Override
+ public void processElement(ProcessContext c) {
+ String word = c.element().getKey();
+ Double df = c.element().getValue().getOnly(dfTag);
+
+ for (KV uriAndTf : c.element().getValue().getAll(tfTag)) {
+ URI uri = uriAndTf.getKey();
+ Double tf = uriAndTf.getValue();
+ Double tfIdf = tf * Math.log(1 / df);
+ c.output(KV.of(word, KV.of(uri, tfIdf)));
+ }
+ }
+ }));
+ }
+
+ // Instantiate Logger.
+ // It is suggested that the user specify the class name of the containing class
+ // (in this case ComputeTfIdf).
+ private static final Logger LOG = LoggerFactory.getLogger(ComputeTfIdf.class);
+ }
+
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
index b4268d6127c1..36d8b67fb725 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
@@ -21,10 +21,10 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import org.apache.beam.examples.WordCount;
import org.apache.beam.runners.spark.EvaluationResult;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.runners.spark.examples.WordCount;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.TextIO;
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java
index 043d506d9247..b70e090ddae7 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java
@@ -19,14 +19,15 @@
package org.apache.beam.runners.spark.translation;
import org.apache.beam.runners.spark.EvaluationResult;
-import org.apache.beam.runners.spark.SimpleWordCountTest;
import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.runners.spark.examples.WordCount;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
@@ -64,7 +65,8 @@ public void testFixed() throws Exception {
PCollection windowedWords =
inputWords.apply(Window.into(FixedWindows.of(Duration.standardMinutes(1))));
- PCollection output = windowedWords.apply(new SimpleWordCountTest.CountWords());
+ PCollection output = windowedWords.apply(new WordCount.CountWords())
+ .apply(MapElements.via(new WordCount.FormatAsTextFn()));
PAssert.that(output).containsInAnyOrder(EXPECTED_FIXED_SEPARATE_COUNT_SET);
@@ -85,7 +87,8 @@ public void testFixed2() throws Exception {
PCollection windowedWords = inputWords
.apply(Window.into(FixedWindows.of(Duration.standardMinutes(5))));
- PCollection output = windowedWords.apply(new SimpleWordCountTest.CountWords());
+ PCollection output = windowedWords.apply(new WordCount.CountWords())
+ .apply(MapElements.via(new WordCount.FormatAsTextFn()));
PAssert.that(output).containsInAnyOrder(EXPECTED_FIXED_SAME_COUNT_SET);
@@ -108,7 +111,8 @@ public void testSliding() throws Exception {
.apply(Window.into(SlidingWindows.of(Duration.standardMinutes(2))
.every(Duration.standardMinutes(1))));
- PCollection output = windowedWords.apply(new SimpleWordCountTest.CountWords());
+ PCollection output = windowedWords.apply(new WordCount.CountWords())
+ .apply(MapElements.via(new WordCount.FormatAsTextFn()));
PAssert.that(output).containsInAnyOrder(EXPECTED_SLIDING_COUNT_SET);
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
index 75a702b87791..75ab2745807d 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
@@ -19,14 +19,15 @@
import org.apache.beam.runners.spark.EvaluationResult;
-import org.apache.beam.runners.spark.SimpleWordCountTest;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.SparkStreamingPipelineOptions;
+import org.apache.beam.runners.spark.examples.WordCount;
import org.apache.beam.runners.spark.io.CreateStream;
import org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.PCollection;
@@ -65,7 +66,8 @@ public void testRun() throws Exception {
PCollection windowedWords = inputWords
.apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))));
- PCollection output = windowedWords.apply(new SimpleWordCountTest.CountWords());
+ PCollection output = windowedWords.apply(new WordCount.CountWords())
+ .apply(MapElements.via(new WordCount.FormatAsTextFn()));
PAssertStreaming.assertContents(output, EXPECTED_COUNTS);
EvaluationResult res = SparkRunner.create(options).run(p);