Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] (
private[streaming] val mustCheckpoint = false
private[streaming] var checkpointDuration: Duration = null
private[streaming] val checkpointData = new DStreamCheckpointData(this)
@transient
private var restoredFromCheckpointData = false

// Reference to whole DStream graph
private[streaming] var graph: DStreamGraph = null
Expand Down Expand Up @@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] (
* override the updateCheckpointData() method would also need to override this method.
*/
private[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
logInfo("Restoring checkpoint data")
checkpointData.restore()
dependencies.foreach(_.restoreCheckpointData())
logInfo("Restored checkpoint data")
if (!restoredFromCheckpointData) {
// Create RDDs from the checkpoint data
logInfo("Restoring checkpoint data")
checkpointData.restore()
dependencies.foreach(_.restoreCheckpointData())
restoredFromCheckpointData = true
logInfo("Restored checkpoint data")
}
}

@throws(classOf[IOException])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.streaming

import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream}

import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
Expand All @@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils}

/**
* A input stream that records the times of restore() invoked
*/
private[streaming]
class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) {
protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
override def start(): Unit = { }
override def stop(): Unit = { }
override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1)))
private[streaming]
class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
@transient
var restoredTimes = 0
override def restore() {
restoredTimes += 1
super.restore()
}
}
}

/**
* A trait of that can be mixed in to get methods for testing DStream operations under
Expand Down Expand Up @@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
}

private def generateOutput[V: ClassTag](
protected def generateOutput[V: ClassTag](
ssc: StreamingContext,
targetBatchTime: Time,
checkpointDir: String,
Expand Down Expand Up @@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
}
}

test("DStreamCheckpointData.restore invoking times") {
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
ssc.checkpoint(checkpointDir)
val inputDStream = new CheckpointInputDStream(ssc)
val checkpointData = inputDStream.checkpointData
val mappedDStream = inputDStream.map(_ + 100)
val outputStream = new TestOutputStreamWithPartitions(mappedDStream)
outputStream.register()
// do two more times output
mappedDStream.foreachRDD(rdd => rdd.count())
mappedDStream.foreachRDD(rdd => rdd.count())
assert(checkpointData.restoredTimes === 0)
val batchDurationMillis = ssc.progressListener.batchDuration
generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true)
assert(checkpointData.restoredTimes === 0)
}
logInfo("*********** RESTARTING ************")
withStreamingContext(new StreamingContext(checkpointDir)) { ssc =>
val checkpointData =
ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData
assert(checkpointData.restoredTimes === 1)
ssc.start()
ssc.stop()
assert(checkpointData.restoredTimes === 1)
}
}

// This tests whether spark can deserialize array object
// refer to SPARK-5569
test("recovery from checkpoint contains array object") {
Expand Down