diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e018af35cb18d..d1961eb004041 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -387,6 +387,14 @@ abstract class RDD[T: ClassTag]( preservesPartitioning = true) } + /** + * Return a new RDD containing only the elements that satisfy a predicate. + * This is an alias for filter so that RDDs can be used in for comprehensions without causing the + * compiler to complain. + */ + @inline + final def withFilter(f: T => Boolean): RDD[T] = filter(f) + /** * Return a new RDD containing the distinct elements in this RDD. */ diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c85..b838cfb903e6d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -70,6 +70,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(!nums.isEmpty()) assert(nums.max() === 4) assert(nums.min() === 1) + assert((for (n <- nums if n > 2) yield n).collect().toList === List(3, 4)) val partitionSums = nums.mapPartitions(iter => Iterator(iter.sum)) assert(partitionSums.collect().toList === List(3, 7))