From 28cb0fe78f38bb2aa794166fe5ae4f82b925b52d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 27 Mar 2014 13:58:41 +0800 Subject: [PATCH 1/8] add whole text files reader --- .../org/apache/spark/mllib/MLContext.scala | 68 ++++++++++++ .../input/WholeTextFileInputFormat.scala | 47 ++++++++ .../input/WholeTextFileRecordReader.scala | 72 ++++++++++++ .../WholeTextFileRecordReaderSuite.scala | 103 ++++++++++++++++++ 4 files changed, 290 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala new file mode 100644 index 0000000000000..d6c6910c0ed84 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib + +import org.apache.spark.mllib.input.WholeTextFileInputFormat +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext + +/** + * Extra functions available on SparkContext of mllib through an implicit conversion. Import + * `org.apache.spark.mllib.MLContext._` at the top of your program to use these functions. + */ +class MLContext(self: SparkContext) { + + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + *

For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, + * + *

then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + */ + def wholeTextFile(path: String): RDD[(String, String)] = { + self.newAPIHadoopFile( + path, + classOf[WholeTextFileInputFormat], + classOf[String], + classOf[String]) + } +} + +/** + * The MLContext object contains a number of implicit conversions and parameters for use with + * various mllib features. + */ +object MLContext { + implicit def sparkContextToMLContext(sc: SparkContext) = new MLContext(sc) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala new file mode 100644 index 0000000000000..28133618e3c10 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.input + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit + +/** + * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for + * reading whole text files. Each file is read as key-value pair, where the key is the file path and + * the value is the entire content of file. + */ + +private[mllib] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[String, String] = { + + new CombineFileRecordReader[String, String]( + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala new file mode 100644 index 0000000000000..1fc668810332b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.input + +import com.google.common.io.{ByteStreams, Closeables} + +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext + +/** + * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file + * out in a key-value pair, where the key is the file path and the value is the entire content of + * the file. + */ +private[mllib] class WholeTextFileRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, String] { + + private val path = split.getPath(index) + private val fs = path.getFileSystem(context.getConfiguration) + + // True means the current file has been processed, then skip it. + private var processed = false + + private val key = path.toString + private var value: String = null + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = fs.open(path) + val innerBuffer = ByteStreams.toByteArray(fileIn) + + value = new Text(innerBuffer).toString + Closeables.close(fileIn, false) + + processed = true + true + } else { + false + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala new file mode 100644 index 0000000000000..c79355fd26c6f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.input + +import java.io.DataOutputStream +import java.io.File +import java.io.FileOutputStream + +import scala.collection.immutable.IndexedSeq + +import com.google.common.io.Files + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.hadoop.io.Text + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.MLContext._ + +/** + * Tests the correctness of + * [[org.apache.spark.mllib.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary + * directory is created as fake input. Temporal storage would be deleted in the end. + */ +class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { + private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = { + val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName")) + out.write(contents, 0, contents.length) + out.close() + } + + /** + * This code will test the behaviors of WholeTextFileRecordReader based on local disk. There are + * three aspects to check: + * 1) Whether all files are read; + * 2) Whether paths are read correctly; + * 3) Does the contents be the same. + */ + test("Correctness of WholeTextFileRecordReader.") { + + val dir = Files.createTempDir() + println(s"Local disk address is ${dir.toString}.") + + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents) + } + + val res = sc.wholeTextFile(dir.toString).collect() + + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") + + for ((filename, contents) <- res) { + val shortName = filename.split('/').last + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } + + dir.delete() + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileRecordReaderSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} From a1f1e7eb58b927e6c3ab98c12ea344c95b157eb6 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 27 Mar 2014 15:07:16 +0800 Subject: [PATCH 2/8] add two extra spaces --- .../apache/spark/mllib/input/WholeTextFileInputFormat.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala index 28133618e3c10..825a004cb088b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala @@ -40,8 +40,8 @@ private[mllib] class WholeTextFileInputFormat extends CombineFileInputFormat[Str context: TaskAttemptContext): RecordReader[String, String] = { new CombineFileRecordReader[String, String]( - split.asInstanceOf[CombineFileSplit], - context, - classOf[WholeTextFileRecordReader]) + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader]) } } From 6bdf2c2ecc44390ab7e909fe583f69f24a5a1fa2 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 1 Apr 2014 02:10:42 +0800 Subject: [PATCH 3/8] test for small local file system block size --- .../spark/mllib/input/WholeTextFileRecordReaderSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala index c79355fd26c6f..9954d3fbb2680 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala @@ -43,6 +43,9 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { sc = new SparkContext("local", "test") + + // Set the block size of local file system to test whether files are split right or not. ++ sc.hadoopConfiguration.setLong("fs.local.block.size", 32) } override def afterAll() { From d792cee57008b60bc80bdb996418fdb06927af83 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 1 Apr 2014 07:04:20 +0800 Subject: [PATCH 4/8] remove the typo character "+" --- .../spark/mllib/input/WholeTextFileRecordReaderSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala index 9954d3fbb2680..25622ddcd1fd8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala @@ -45,7 +45,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { sc = new SparkContext("local", "test") // Set the block size of local file system to test whether files are split right or not. -+ sc.hadoopConfiguration.setLong("fs.local.block.size", 32) + sc.hadoopConfiguration.setLong("fs.local.block.size", 32) } override def afterAll() { From cc97dca3658607ad95f5425bfd7b7488ea002176 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 2 Apr 2014 11:54:35 +0800 Subject: [PATCH 5/8] move whole text file API to Spark core --- .../scala/org/apache/spark/SparkContext.scala | 32 ++++ .../spark/api/java/JavaSparkContext.scala | 25 +++ .../input/WholeTextFileInputFormat.scala | 4 +- .../input/WholeTextFileRecordReader.scala | 4 +- .../WholeTextFileRecordReaderSuite.scala | 7 +- .../org/apache/spark/mllib/util/MLUtils.scala | 162 ------------------ 6 files changed, 64 insertions(+), 170 deletions(-) rename {mllib/src/main/scala/org/apache/spark/mllib => core/src/main/scala/org/apache/spark}/input/WholeTextFileInputFormat.scala (94%) rename {mllib/src/main/scala/org/apache/spark/mllib => core/src/main/scala/org/apache/spark}/input/WholeTextFileRecordReader.scala (96%) rename {mllib/src/test/scala/org/apache/spark/mllib => core/src/test/scala/org/apache/spark}/input/WholeTextFileRecordReaderSuite.scala (93%) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b23accbbb9410..cdda667ebb74c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -37,6 +37,7 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.input.WholeTextFileInputFormat import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -371,6 +372,37 @@ class SparkContext( minSplits).map(pair => pair._2.toString) } + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + *

For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, + * + *

then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + */ + def wholeTextFiles(path: String): RDD[(String, String)] = { + newAPIHadoopFile( + path, + classOf[WholeTextFileInputFormat], + classOf[String], + classOf[String]) + } + /** * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e531a57aced31..42345149b1b8e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -154,6 +154,31 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork */ def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits) + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + *

For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, + * + *

then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + */ + def wholeTextFiles(path: String): JavaRDD[(String, String)] = sc.wholeTextFiles(path) + /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala similarity index 94% rename from mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala rename to core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 825a004cb088b..4887fb6b84eb2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.input +package org.apache.spark.input import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit @@ -32,7 +32,7 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit * the value is the entire content of file. */ -private[mllib] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { +private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { override protected def isSplitable(context: JobContext, file: Path): Boolean = false override def createRecordReader( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala similarity index 96% rename from mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala rename to core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 1fc668810332b..c3dabd2e79995 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.input +package org.apache.spark.input import com.google.common.io.{ByteStreams, Closeables} @@ -30,7 +30,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext * out in a key-value pair, where the key is the file path and the value is the entire content of * the file. */ -private[mllib] class WholeTextFileRecordReader( +private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala similarity index 93% rename from mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala rename to core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 25622ddcd1fd8..09e35bfc8f85f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.input +package org.apache.spark.input import java.io.DataOutputStream import java.io.File @@ -31,11 +31,10 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.Text import org.apache.spark.SparkContext -import org.apache.spark.mllib.MLContext._ /** * Tests the correctness of - * [[org.apache.spark.mllib.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary + * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { @@ -74,7 +73,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { createNativeFile(dir, filename, contents) } - val res = sc.wholeTextFile(dir.toString).collect() + val res = sc.wholeTextFiles(dir.toString).collect() assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, "Number of files read out does not fit with the actual value.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala deleted file mode 100644 index 08cd9ab05547b..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ - -import org.jblas.DoubleMatrix - -import org.apache.spark.mllib.regression.LabeledPoint - -import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} - -/** - * Helper methods to load, save and pre-process data used in ML Lib. - */ -object MLUtils { - - private[util] lazy val EPSILON = { - var eps = 1.0 - while ((1.0 + (eps / 2.0)) != 1.0) { - eps /= 2.0 - } - eps - } - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.split(',') - val label = parts(0).toDouble - val features = parts(1).trim().split(' ').map(_.toDouble) - LabeledPoint(label, features) - } - } - - /** - * Save labeled data to a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param data An RDD of LabeledPoints containing data to be saved. - * @param dir Directory to save the data. - */ - def saveLabeledData(data: RDD[LabeledPoint], dir: String) { - val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) - dataStr.saveAsTextFile(dir) - } - - /** - * Utility function to compute mean and standard deviation on a given dataset. - * - * @param data - input data set whose statistics are computed - * @param nfeatures - number of features - * @param nexamples - number of examples in input dataset - * - * @return (yMean, xColMean, xColSd) - Tuple consisting of - * yMean - mean of the labels - * xColMean - Row vector with mean for every column (or feature) of the input data - * xColSd - Row vector standard deviation for every column (or feature) of the input data. - */ - def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): - (Double, DoubleMatrix, DoubleMatrix) = { - val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples - - // NOTE: We shuffle X by column here to compute column sum and sum of squares. - val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => - val nCols = labeledPoint.features.length - // Traverse over every column and emit (col, value, value^2) - Iterator.tabulate(nCols) { i => - (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) - } - }.reduceByKey { case(x1, x2) => - (x1._1 + x2._1, x1._2 + x2._2) - } - val xColSumsMap = xColSumSq.collectAsMap() - - val xColMean = DoubleMatrix.zeros(nfeatures, 1) - val xColSd = DoubleMatrix.zeros(nfeatures, 1) - - // Compute mean and unbiased variance using column sums - var col = 0 - while (col < nfeatures) { - xColMean.put(col, xColSumsMap(col)._1 / nexamples) - val variance = - (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples - xColSd.put(col, math.sqrt(variance)) - col += 1 - } - - (yMean, xColMean, xColSd) - } - - /** - * Returns the squared Euclidean distance between two vectors. The following formula will be used - * if it does not introduce too much numerical error: - *

-   *   \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.
-   * 
- * When both vector norms are given, this is faster than computing the squared distance directly, - * especially when one of the vectors is a sparse vector. - * - * @param v1 the first vector - * @param norm1 the norm of the first vector, non-negative - * @param v2 the second vector - * @param norm2 the norm of the second vector, non-negative - * @param precision desired relative precision for the squared distance - * @return squared distance between v1 and v2 within the specified precision - */ - private[mllib] def fastSquaredDistance( - v1: BV[Double], - norm1: Double, - v2: BV[Double], - norm2: Double, - precision: Double = 1e-6): Double = { - val n = v1.size - require(v2.size == n) - require(norm1 >= 0.0 && norm2 >= 0.0) - val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 - val normDiff = norm1 - norm2 - var sqDist = 0.0 - val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) - if (precisionBound1 < precision) { - sqDist = sumSquaredNorm - 2.0 * v1.dot(v2) - } else if (v1.isInstanceOf[BSV[Double]] || v2.isInstanceOf[BSV[Double]]) { - val dot = v1.dot(v2) - sqDist = math.max(sumSquaredNorm - 2.0 * dot, 0.0) - val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dot)) / (sqDist + EPSILON) - if (precisionBound2 > precision) { - sqDist = breezeSquaredDistance(v1, v2) - } - } else { - sqDist = breezeSquaredDistance(v1, v2) - } - sqDist - } -} From 01745eefdd98ec1afbfa730f8892b466789e38f1 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 2 Apr 2014 11:59:55 +0800 Subject: [PATCH 6/8] fix deletion error --- .../org/apache/spark/mllib/MLContext.scala | 68 -------- .../org/apache/spark/mllib/util/MLUtils.scala | 162 ++++++++++++++++++ 2 files changed, 162 insertions(+), 68 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala deleted file mode 100644 index d6c6910c0ed84..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib - -import org.apache.spark.mllib.input.WholeTextFileInputFormat -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext - -/** - * Extra functions available on SparkContext of mllib through an implicit conversion. Import - * `org.apache.spark.mllib.MLContext._` at the top of your program to use these functions. - */ -class MLContext(self: SparkContext) { - - /** - * Read a directory of text files from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI. Each file is read as a single record and returned in a - * key-value pair, where the key is the path of each file, the value is the content of each file. - * - *

For example, if you have the following files: - * {{{ - * hdfs://a-hdfs-path/part-00000 - * hdfs://a-hdfs-path/part-00001 - * ... - * hdfs://a-hdfs-path/part-nnnnn - * }}} - * - * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, - * - *

then `rdd` contains - * {{{ - * (a-hdfs-path/part-00000, its content) - * (a-hdfs-path/part-00001, its content) - * ... - * (a-hdfs-path/part-nnnnn, its content) - * }}} - */ - def wholeTextFile(path: String): RDD[(String, String)] = { - self.newAPIHadoopFile( - path, - classOf[WholeTextFileInputFormat], - classOf[String], - classOf[String]) - } -} - -/** - * The MLContext object contains a number of implicit conversions and parameters for use with - * various mllib features. - */ -object MLContext { - implicit def sparkContextToMLContext(sc: SparkContext) = new MLContext(sc) -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala new file mode 100644 index 0000000000000..08cd9ab05547b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ + +import org.jblas.DoubleMatrix + +import org.apache.spark.mllib.regression.LabeledPoint + +import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} + +/** + * Helper methods to load, save and pre-process data used in ML Lib. + */ +object MLUtils { + + private[util] lazy val EPSILON = { + var eps = 1.0 + while ((1.0 + (eps / 2.0)) != 1.0) { + eps /= 2.0 + } + eps + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.split(',') + val label = parts(0).toDouble + val features = parts(1).trim().split(' ').map(_.toDouble) + LabeledPoint(label, features) + } + } + + /** + * Save labeled data to a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param data An RDD of LabeledPoints containing data to be saved. + * @param dir Directory to save the data. + */ + def saveLabeledData(data: RDD[LabeledPoint], dir: String) { + val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) + dataStr.saveAsTextFile(dir) + } + + /** + * Utility function to compute mean and standard deviation on a given dataset. + * + * @param data - input data set whose statistics are computed + * @param nfeatures - number of features + * @param nexamples - number of examples in input dataset + * + * @return (yMean, xColMean, xColSd) - Tuple consisting of + * yMean - mean of the labels + * xColMean - Row vector with mean for every column (or feature) of the input data + * xColSd - Row vector standard deviation for every column (or feature) of the input data. + */ + def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): + (Double, DoubleMatrix, DoubleMatrix) = { + val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples + + // NOTE: We shuffle X by column here to compute column sum and sum of squares. + val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => + val nCols = labeledPoint.features.length + // Traverse over every column and emit (col, value, value^2) + Iterator.tabulate(nCols) { i => + (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) + } + }.reduceByKey { case(x1, x2) => + (x1._1 + x2._1, x1._2 + x2._2) + } + val xColSumsMap = xColSumSq.collectAsMap() + + val xColMean = DoubleMatrix.zeros(nfeatures, 1) + val xColSd = DoubleMatrix.zeros(nfeatures, 1) + + // Compute mean and unbiased variance using column sums + var col = 0 + while (col < nfeatures) { + xColMean.put(col, xColSumsMap(col)._1 / nexamples) + val variance = + (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples + xColSd.put(col, math.sqrt(variance)) + col += 1 + } + + (yMean, xColMean, xColSd) + } + + /** + * Returns the squared Euclidean distance between two vectors. The following formula will be used + * if it does not introduce too much numerical error: + *

+   *   \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.
+   * 
+ * When both vector norms are given, this is faster than computing the squared distance directly, + * especially when one of the vectors is a sparse vector. + * + * @param v1 the first vector + * @param norm1 the norm of the first vector, non-negative + * @param v2 the second vector + * @param norm2 the norm of the second vector, non-negative + * @param precision desired relative precision for the squared distance + * @return squared distance between v1 and v2 within the specified precision + */ + private[mllib] def fastSquaredDistance( + v1: BV[Double], + norm1: Double, + v2: BV[Double], + norm2: Double, + precision: Double = 1e-6): Double = { + val n = v1.size + require(v2.size == n) + require(norm1 >= 0.0 && norm2 >= 0.0) + val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 + val normDiff = norm1 - norm2 + var sqDist = 0.0 + val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) + if (precisionBound1 < precision) { + sqDist = sumSquaredNorm - 2.0 * v1.dot(v2) + } else if (v1.isInstanceOf[BSV[Double]] || v2.isInstanceOf[BSV[Double]]) { + val dot = v1.dot(v2) + sqDist = math.max(sumSquaredNorm - 2.0 * dot, 0.0) + val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dot)) / (sqDist + EPSILON) + if (precisionBound2 > precision) { + sqDist = breezeSquaredDistance(v1, v2) + } + } else { + sqDist = breezeSquaredDistance(v1, v2) + } + sqDist + } +} From 0af3faf0dd02925556657bbe395fcea1d37cb928 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 3 Apr 2014 10:57:47 +0800 Subject: [PATCH 7/8] add JavaAPI test --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/api/java/JavaSparkContext.scala | 5 ++-- .../java/org/apache/spark/JavaAPISuite.java | 30 +++++++++++++++++-- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cdda667ebb74c..3e66d8c4d8367 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -385,7 +385,7 @@ class SparkContext( * hdfs://a-hdfs-path/part-nnnnn * }}} * - * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, + * Do `val rdd = sparkContext.wholeTextFile("hdfs://a-hdfs-path")`, * *

then `rdd` contains * {{{ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 42345149b1b8e..a34ccb3781803 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -167,7 +167,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * hdfs://a-hdfs-path/part-nnnnn * }}} * - * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, + * Do `JavaPairRDD rdd = context.wholeTextFiles("hdfs://a-hdfs-path")`, * *

then `rdd` contains * {{{ @@ -177,7 +177,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * (a-hdfs-path/part-nnnnn, its content) * }}} */ - def wholeTextFiles(path: String): JavaRDD[(String, String)] = sc.wholeTextFiles(path) + def wholeTextFiles(path: String): JavaPairRDD[String, String] = + new JavaPairRDD(sc.wholeTextFiles(path)) /** Get an RDD for a Hadoop SequenceFile with given key and value types. * diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c6b65c7348ae0..2372f2d9924a1 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -17,9 +17,7 @@ package org.apache.spark; -import java.io.File; -import java.io.IOException; -import java.io.Serializable; +import java.io.*; import java.util.*; import scala.Tuple2; @@ -599,6 +597,32 @@ public void textFiles() throws IOException { Assert.assertEquals(expected, readRDD.collect()); } + @Test + public void wholeTextFiles() throws IOException { + byte[] content1 = "spark is easy to use.\n".getBytes(); + byte[] content2 = "spark is also easy to use.\n".getBytes(); + + File tempDir = Files.createTempDir(); + String tempDirName = tempDir.getAbsolutePath(); + DataOutputStream ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00000")); + ds.write(content1); + ds.close(); + ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00001")); + ds.write(content2); + ds.close(); + + HashMap container = new HashMap(); + container.put(tempDirName+"/part-00000", new Text(content1).toString()); + container.put(tempDirName+"/part-00001", new Text(content2).toString()); + + JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName); + List> result = readRDD.collect(); + + for (Tuple2 res : result) { + Assert.assertEquals(res._2(), container.get(res._1())); + } + } + @Test public void textFilesCompressed() throws IOException { File tempDir = Files.createTempDir(); From 7191be602b1108a4e00d72ab7d2b5b34b8c6c508 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 3 Apr 2014 11:05:52 +0800 Subject: [PATCH 8/8] refine comments --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 ++ .../scala/org/apache/spark/api/java/JavaSparkContext.scala | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e66d8c4d8367..28a865c0ad3b5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -394,6 +394,8 @@ class SparkContext( * ... * (a-hdfs-path/part-nnnnn, its content) * }}} + * + * @note Small files are perferred, large file is also allowable, but may cause bad performance. */ def wholeTextFiles(path: String): RDD[(String, String)] = { newAPIHadoopFile( diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index a34ccb3781803..6cbdeac58d5e2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -167,7 +167,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * hdfs://a-hdfs-path/part-nnnnn * }}} * - * Do `JavaPairRDD rdd = context.wholeTextFiles("hdfs://a-hdfs-path")`, + * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`, * *

then `rdd` contains * {{{ @@ -176,6 +176,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * ... * (a-hdfs-path/part-nnnnn, its content) * }}} + * + * @note Small files are perferred, large file is also allowable, but may cause bad performance. */ def wholeTextFiles(path: String): JavaPairRDD[String, String] = new JavaPairRDD(sc.wholeTextFiles(path))