From c0194744378c2b1827aa6bb3a6109b3549710a44 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 10 Mar 2016 15:02:37 -0800 Subject: [PATCH 01/46] First draft of StateStore --- .../streaming/state/StateStore.scala | 451 ++++++++++++++++++ .../state/StateStoreCoordinator.scala | 84 ++++ .../streaming/state/StateStoreRDD.scala | 76 +++ .../streaming/state/StateStoreSuite.scala | 232 +++++++++ 4 files changed, 843 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala new file mode 100644 index 0000000000000..feed7e97eb63c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.{Timer, TimerTask} + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.Random +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.util.{RpcUtils, CompletionIterator, Utils} +import org.apache.spark.{SparkEnv, Logging, SparkConf} + +case class StateStoreId(operatorId: Long, partitionId: Int) + +private[state] object StateStore extends Logging { + + sealed trait Update + case class ValueUpdated(key: InternalRow, value: InternalRow) extends Update + case class KeyRemoved(key: InternalRow) extends Update + + private val loadedStores = new mutable.HashMap[StateStoreId, StateStore]() + private val managementTimer = new Timer("StateStore Timer", true) + @volatile private var managementTask: TimerTask = null + + def get(storeId: StateStoreId, directory: String): StateStore = { + val store = loadedStores.synchronized { + startIfNeeded() + loadedStores.getOrElseUpdate(storeId, new StateStore(storeId, directory)) + } + reportActiveInstance(storeId) + store + } + + def clearAll(): Unit = loadedStores.synchronized { + loadedStores.clear() + if (managementTask != null) { + managementTask.cancel() + managementTask = null + } + } + + private def remove(storeId: StateStoreId): Unit = { + loadedStores.remove(storeId) + } + + private def reportActiveInstance(storeId: StateStoreId): Unit = { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + askCoordinator[Boolean](ReportActiveInstance(storeId, host, executorId)) + } + + private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + askCoordinator[Boolean](VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) + } + + private def askCoordinator[T: ClassTag](message: StateStoreCoordinatorMessage): Option[T] = { + try { + val env = SparkEnv.get + if (env != null) { + val coordinatorRef = RpcUtils.makeDriverRef("StateStoreCoordinator", env.conf, env.rpcEnv) + Some(coordinatorRef.askWithRetry[T](message)) + } else { + None + } + } catch { + case NonFatal(e) => + clearAll() + None + } + } + + private def startIfNeeded(): Unit = loadedStores.synchronized { + if (managementTask == null) { + managementTask = new TimerTask { + override def run(): Unit = { manageFiles() } + } + managementTimer.schedule(managementTask, 10000, 10000) + } + } + + private def manageFiles(): Unit = { + loadedStores.synchronized { loadedStores.values.toSeq }.foreach { store => + try { + store.manageFiles() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up store ${store.id}") + } + } + } +} + +private[sql] class StateStore( + val id: StateStoreId, + val directory: String, + numBatchesToRetain: Int = 2 + ) extends Logging { + type MapType = mutable.HashMap[InternalRow, InternalRow] + + import StateStore._ + + private val storeMaps = new mutable.HashMap[Long, MapType] + private val baseDir = new Path(directory, s"${id.operatorId}/${id.partitionId.toString}") + private val fs = baseDir.getFileSystem(new Configuration()) + private val serializer = new KryoSerializer(new SparkConf) + + @volatile private var uncommittedDelta: UncommittedDelta = null + + initialize() + + private[state] def startUpdates(version: Long): Unit = synchronized { + require(version >= 0) + if (uncommittedDelta != null) { + cancelUpdates() + } + val newMap = new MapType() + if (version > 0) { + val oldMap = loadMap(version - 1) + newMap ++= oldMap + } + uncommittedDelta = new UncommittedDelta(version, newMap) + } + + def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { + verify(uncommittedDelta != null, "Cannot update data before calling newVersion()") + uncommittedDelta.update(key, updateFunc) + } + + def remove(condition: InternalRow => Boolean): Unit = { + verify(uncommittedDelta != null, "Cannot remove data before calling newVersion()") + uncommittedDelta.remove(condition) + } + + def commitUpdates(): Unit = { + verify(uncommittedDelta != null, "Cannot commit data before calling newVersion()") + uncommittedDelta.commit() + uncommittedDelta = null + } + + def cancelUpdates(): Unit = { + verify(uncommittedDelta != null, "Cannot commit data before calling newVersion()") + uncommittedDelta.cancel() + uncommittedDelta = null + } + + def hasCommitted: Boolean = { + uncommittedDelta == null + } + + def getAll(): Iterator[InternalRow] = synchronized { + verify(uncommittedDelta == null, "Cannot getAll() before committing") + val lastVersion = fetchFiles().lastOption.map(_.version) + lastVersion.map(loadMap) match { + case Some(map) => + map.iterator.map { case (key, value) => new JoinedRow(key, value) } + case None => + Iterator.empty + } + } + + override def toString(): String = { + s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + } + + // Private methods + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + private class UncommittedDelta(val version: Long, val map: MapType) { + private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + private val tempDeltaFileStream = + serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) + + def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { + verify(uncommittedDelta != null, "Cannot call update() before calling startUpdates()") + val value = updateFunc(uncommittedDelta.map.get(key)) + uncommittedDelta.map.put(key, value) + tempDeltaFileStream.writeObject(ValueUpdated(key, value)) + } + + def remove(condition: InternalRow => Boolean): Unit = { + verify(uncommittedDelta != null, "Cannot call remove() before calling startUpdates()") + val keyIter = uncommittedDelta.map.keysIterator + while (keyIter.hasNext) { + val key = keyIter.next + if (condition(key)) { + uncommittedDelta.map.remove(key) + tempDeltaFileStream.writeObject(KeyRemoved(key)) + } + } + } + + def commit(): Unit = { + try { + tempDeltaFileStream.close() + val deltaFile = new Path(baseDir, s"${uncommittedDelta.version}.delta") + StateStore.this.synchronized { + fs.rename(tempDeltaFile, deltaFile) + println("Written " + deltaFile) + storeMaps.put(version, map) + } + } catch { + case NonFatal(e) => + throw new IllegalStateException( + s"Error committing version ${uncommittedDelta.version} into $this", e) + } + } + + def cancel(): Unit = { + tempDeltaFileStream.close() + fs.delete(tempDeltaFile, true) + } + } + + + private[state] def getInternalMap(version: Long): Option[MapType] = synchronized { + storeMaps.get(version) + } + + private def initialize(): Unit = { + if (!fs.exists(baseDir)) { + fs.mkdirs(baseDir) + } else { + if (!fs.isDirectory(baseDir)) { + throw new IllegalStateException(s"Cannot use $directory for storing state data as" + + s"$baseDir already exists and is not a directory") + } + } + } + + private def loadMap(version: Long): MapType = { + if (version < 0) return new MapType + println(s"Loading version $version") + synchronized { storeMaps.get(version) }.getOrElse { + val mapFromFile = readSnapshotFile(version).getOrElse { + val prevMap = loadMap(version - 1) + val deltaUpdates = readDeltaFile(version) + println(s"Reading delta for $version") + val newMap = new MapType() + newMap ++= prevMap + newMap.sizeHint(prevMap.size) + while (deltaUpdates.hasNext) { + deltaUpdates.next match { + case ValueUpdated(key, value) => newMap.put(key, value) + case KeyRemoved(key) => newMap.remove(key) + } + } + println("Map = " + newMap.toSeq.mkString(", ")) + newMap + } + storeMaps.put(version, mapFromFile) + mapFromFile + } + } + + private def readDeltaFile(version: Long): Iterator[Update] = { + val fileToRead = deltaFile(version) + if (!fs.exists(fileToRead)) { + throw new IllegalStateException( + s"Cannot read delta file for version $version of $this: $fileToRead does not exist") + } + val deser = serializer.newInstance() + var deserStream: DeserializationStream = null + deserStream = deser.deserializeStream(fs.open(fileToRead)) + val iter = deserStream.asIterator.asInstanceOf[Iterator[Update]] + CompletionIterator[Update, Iterator[Update]](iter, { deserStream.close() }) + } + + private def writeSnapshotFile(version: Long, map: MapType): Unit = { + val fileToWrite = snapshotFile(version) + val ser = serializer.newInstance() + var outputStream: SerializationStream = null + Utils.tryWithSafeFinally { + outputStream = ser.serializeStream(fs.create(fileToWrite, false)) + outputStream.writeAll(map.iterator) + } { + if (outputStream != null) outputStream.close() + } + } + + private def readSnapshotFile(version: Long): Option[MapType] = { + val fileToRead = snapshotFile(version) + if (!fs.exists(fileToRead)) return None + + val deser = serializer.newInstance() + val map = new MapType() + var deserStream: DeserializationStream = null + + try { + deserStream = deser.deserializeStream(fs.open(fileToRead)) + val iter = deserStream.asIterator.asInstanceOf[Iterator[(InternalRow, InternalRow)]] + while(iter.hasNext) { + map += iter.next() + } + Some(map) + } finally { + if (deserStream != null) deserStream.close() + } + } + + private[state] def manageFiles(): Unit = { + doSnapshot() + cleanup() + } + + private def doSnapshot(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + synchronized { + storeMaps.get(lastVersion) + } match { + case Some(map) => + if (deltaFilesForLastVersion.size > 10) { + writeSnapshotFile(lastVersion, map) + } + case None => + // The last map is not loaded, probably some other instance is incharge + } + + } + } catch { + case NonFatal(e) => + logWarning(s"Error doing snapshots for $this") + } + } + + private[state] def cleanup(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - numBatchesToRetain + if (earliestVersionToRetain >= 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + synchronized { + storeMaps.keys.filter(_ < earliestVersionToRetain).foreach(storeMaps.remove) + } + files.filter(_.version < earliestFileToRetain.version).foreach { f => + fs.delete(f.path, true) + } + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this") + } + } + + private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + val deltaBatchIds = (snapshotFile.version + 1) to version + + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + } + verify( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version: ${deltaFiles.mkString(",")}" + ) + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + private def fetchFiles(): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fs.listStatus(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + println(s"Fetching files in $baseDir") + files.foreach { status => + val path = status.getPath + //println(s"\tTesting file $path") + val nameParts = path.getName.split("\\.") + println(s"\t${path.getName}, ${nameParts.mkString(",")}") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + //println(s"\tFound file $path") + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + logWarning(s"Could not identify file $path") + } + } else { + println("\tIgnoring") + } + } + versionToFiles.values.toSeq.sortBy(_.version) + } + + private def deltaFile(version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + private def snapshotFile(version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala new file mode 100644 index 0000000000000..39380f4d46065 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint + +private sealed trait StateStoreCoordinatorMessage extends Serializable +private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) + extends StateStoreCoordinatorMessage +private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) + extends StateStoreCoordinatorMessage +private object StopCoordinator extends StateStoreCoordinatorMessage + + +class StateStoreCoordinator(rpcEnv: RpcEnv) { + private val coordinatorRef = rpcEnv.setupEndpoint( + "StateStoreCoordinator", new StateStoreCoordinatorEndpoint(rpcEnv, this)) + private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + + def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Unit = { + instances.synchronized { instances.put(storeId, ExecutorCacheTaskLocation(host, executorId)) } + } + + def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { + instances.synchronized { + instances.get(storeId).forall(_.executorId == executorId) + } + } + + def getLocation(storeId: StateStoreId): Option[String] = { + instances.synchronized { instances.get(storeId).map(_.toString) } + } + + def makeInstancesInactive(operatorIds: Set[Long]): Unit = { + instances.synchronized { + val instancesToRemove = + instances.keys.filter(id => operatorIds.contains(id.operatorId)).toSeq + instances --= instancesToRemove + } + } +} + +private[spark] object StateStoreCoordinator { + + private[spark] class StateStoreCoordinatorEndpoint( + override val rpcEnv: RpcEnv, coordinator: StateStoreCoordinator) + extends RpcEndpoint with Logging { + + override def receive: PartialFunction[Any, Unit] = { + case StopCoordinator => + logInfo("StateStoreCoordinator stopped!") + stop() + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case ReportActiveInstance(id, host, executorId) => + coordinator.reportActiveInstance(id, host, executorId) + case VerifyIfInstanceActive(id, executor) => + context.reply(coordinator.verifyIfInstanceActive(id, executor)) + } + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala new file mode 100644 index 0000000000000..8c5e408ed45e3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils +import org.apache.spark.{Partition, TaskContext} + +/** + * Created by tdas on 3/9/16. + */ +class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( + dataRDD: RDD[INPUT], + storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], + operatorId: Long, + newStoreVersion: Long, + storeDirectory: String, + storeCoordinator: StateStoreCoordinator + ) + extends RDD[OUTPUT](dataRDD) { + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + override def getPreferredLocations(partition: Partition): Seq[String] = { + storeCoordinator.getLocation( + StateStoreId(operatorId, partition.index)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[OUTPUT] = { + var store: StateStore = null + + Utils.tryWithSafeFinally { + StateStore.get( + StateStoreId(operatorId, partition.index), + storeDirectory + ) + val inputIter = dataRDD.compute(partition, ctxt) + store.startUpdates(newStoreVersion) + val outputIter = storeUpdateFunction(store, inputIter) + assert(store.hasCommitted) + outputIter + } { + if (store != null) store.cancelUpdates() + } + } +} + +object StateStoreRDD { + implicit def withStateStores[INPUT: ClassTag, OUTPUT: ClassTag]( + dataRDD: RDD[INPUT], + storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], + operatorId: Long, + newStoreVersion: Long, + storeDirectory: String, + storeCoordinator: StateStoreCoordinator + ): RDD[OUTPUT] = { + new StateStoreRDD( + dataRDD, storeUpdateFunction, operatorId, newStoreVersion, storeDirectory, storeCoordinator) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala new file mode 100644 index 0000000000000..3d3680f9bee76 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import scala.collection.mutable +import scala.util.Random + +import org.apache.hadoop.fs.Path +import org.scalatest.{PrivateMethodTester, BeforeAndAfter} + +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.util.Utils + +class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { + type MapType = mutable.HashMap[InternalRow, InternalRow] + + private val tempDir = Utils.createTempDir().toString + + after { + StateStore.clearAll() + } + + test("startUpdates, update, remove, commitUpdates") { + val store = newStore() + + // Verify state before starting a new set of updates + assert(store.getAll().isEmpty) + assert(store.getInternalMap(0).isEmpty) + intercept[IllegalStateException] { + store.update(null, null) + } + intercept[IllegalStateException] { + store.remove(_ => true) + } + intercept[IllegalStateException] { + store.commitUpdates() + } + intercept[IllegalStateException] { + store.cancelUpdates() + } + + // Verify states after starting updates + store.startUpdates(0) + intercept[IllegalStateException] { + store.getAll() + } + update(store, "a", 1) + intercept[IllegalStateException] { + store.getAll() + } + + // Make updates and commit + update(store, "b", 2) + update(store, "aa", 3) + remove(store, _.startsWith("a")) + store.commitUpdates() + + // Very state after committing + assert(getData(store) === Set("b" -> 2)) + assertMap(store.getInternalMap(0), Map("b" -> 2)) + assert(fileExists(store, 0, isSnapshot = false)) + assert(store.getInternalMap(1).isEmpty) + + // Reload store from the directory + val reloadedStore = new StateStore(store.id, store.directory) + assert(getData(reloadedStore) === Set("b" -> 2)) + + // New updates to the reload store with new version, and does not change old version + reloadedStore.startUpdates(1) + update(reloadedStore, "c", 4) + reloadedStore.commitUpdates() + assert(getData(reloadedStore) === Set("b" -> 2, "c" -> 4)) + assertMap(reloadedStore.getInternalMap(0), Map("b" -> 2)) + assertMap(reloadedStore.getInternalMap(1), Map("b" -> 2, "c" -> 4)) + assert(fileExists(reloadedStore, 1, isSnapshot = false)) + } + + test("cancelUpdates") { + val store = newStore() + store.startUpdates(0) + update(store, "a", 1) + store.commitUpdates() + assert(getData(store) === Set("a" -> 1)) + + // cancelUpdates should not change the data + store.startUpdates(1) + update(store, "b", 1) + store.cancelUpdates() + assert(getData(store) === Set("a" -> 1)) + + // Calling startUpdates again should cancel previous updates + store.startUpdates(1) + update(store, "b", 1) + store.startUpdates(1) + update(store, "c", 1) + store.commitUpdates() + assert(getData(store) === Set("a" -> 1, "c" -> 1)) + } + + test("startUpdates with unexpected versions") { + val store = newStore() + + intercept[IllegalArgumentException] { + store.startUpdates(-1) + } + + // Prepare some data in the stoer + store.startUpdates(0) + update(store, "a", 1) + store.commitUpdates() + assert(getData(store) === Set("a" -> 1)) + + intercept[IllegalStateException] { + store.startUpdates(2) + } + + // Update store version with some data + println("here") + store.startUpdates(1) + update(store, "b", 1) + store.commitUpdates() + println("xyz") + assert(getData(store) === Set("a" -> 1, "b" -> 1)) + println("bla") + + assert(getData(new StateStore(store.id, store.directory)) === Set("a" -> 1, "b" -> 1)) + + // Overwrite the version with other data + store.startUpdates(1) + update(store, "c", 1) + store.commitUpdates() + assert(getData(store) === Set("a" -> 1, "c" -> 1)) + assert(getData(new StateStore(store.id, store.directory)) === Set("a" -> 1, "c" -> 1)) + } + + def getData(store: StateStore): Set[(String, Int)] = { + store.getAll.map(unwrapKeyValue).toSet + } + + def assertMap( + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { + assert(testMapOption.nonEmpty, "no map present") + val convertedMap = testMapOption.get.map(unwrapKeyValue) + assert(convertedMap === expectedMap) + } + + def fileExists(store: StateStore, version: Long, isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = store invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def storeLoaded(storeId: StateStoreId): Boolean = { + val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) + val loadedStores = StateStore invokePrivate method() + loadedStores.contains(storeId) + } + + def unloadStore(storeId: StateStoreId): Boolean = { + val method = PrivateMethod('remove) + StateStore invokePrivate method(storeId) + } + + def newStore(opId: Long = Random.nextLong, partition: Int = 0): StateStore = { + new StateStore( + StateStoreId(opId, partition), + Utils.createDirectory(tempDir, Random.nextString(5)).toString) + } + + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.remove(row => condition(unwrapKey(row))) + } + + private def update(store: StateStore, key: String, value: Int): Unit = { + store.update(wrapKey(key), _ => wrapValue(value)) + } + + private def increment(store: StateStore, key: String): Unit = { + val keyRow = new GenericInternalRow(Array(key).asInstanceOf[Array[Any]]) + store.update(keyRow, oldRow => { + val oldValue = oldRow.map(unwrapValue).getOrElse(0) + wrapValue(oldValue + 1) + }) + } + + private def wrapValue(i: Int): InternalRow = { + new GenericInternalRow(Array[Any](i)) + } + + private def wrapKey(s: String): InternalRow = { + new GenericInternalRow(Array[Any](UTF8String.fromString(s))) + } + + private def unwrapKey(row: InternalRow): String = { + row.asInstanceOf[GenericInternalRow].getString(0) + } + + private def unwrapValue(row: InternalRow): Int = { + row.asInstanceOf[GenericInternalRow].getInt(0) + } + + private def unwrapKeyValue(row: (InternalRow, InternalRow)): (String, Int) = { + (unwrapKey(row._1), unwrapValue(row._2)) + } + + private def unwrapKeyValue(row: InternalRow): (String, Int) = { + (row.getString(0), row.getInt(1)) + } +} From 4f8dade98ab2720ffab78a6ace567b9705f1b488 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 10 Mar 2016 17:50:50 -0800 Subject: [PATCH 02/46] Updated tests --- .../streaming/state/StateStore.scala | 94 ++++++------ .../streaming/state/StateStoreRDD.scala | 4 +- .../streaming/state/StateStoreSuite.scala | 134 ++++++++++++++---- 3 files changed, 163 insertions(+), 69 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index feed7e97eb63c..edc1025580736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -117,22 +117,28 @@ private[state] object StateStore extends Logging { private[sql] class StateStore( val id: StateStoreId, val directory: String, - numBatchesToRetain: Int = 2 + numBatchesToRetain: Int = 2, + maxDeltaChainForSnapshots: Int = 10 ) extends Logging { type MapType = mutable.HashMap[InternalRow, InternalRow] import StateStore._ - private val storeMaps = new mutable.HashMap[Long, MapType] + private val loadedMaps = new mutable.HashMap[Long, MapType] private val baseDir = new Path(directory, s"${id.operatorId}/${id.partitionId.toString}") private val fs = baseDir.getFileSystem(new Configuration()) private val serializer = new KryoSerializer(new SparkConf) - @volatile private var uncommittedDelta: UncommittedDelta = null + @volatile private var uncommittedDelta: UncommittedUpdates = null initialize() - private[state] def startUpdates(version: Long): Unit = synchronized { + /** + * Prepare for updates to create a new `version` of the map. The store ensure that updates + * are made on the `version - 1` of the store data. If `version` already exists, it will + * be overwritten when the updates are committed. + */ + private[state] def prepareForUpdates(version: Long): Unit = synchronized { require(version >= 0) if (uncommittedDelta != null) { cancelUpdates() @@ -142,54 +148,65 @@ private[sql] class StateStore( val oldMap = loadMap(version - 1) newMap ++= oldMap } - uncommittedDelta = new UncommittedDelta(version, newMap) + uncommittedDelta = new UncommittedUpdates(version, newMap) } + /** Update the value of a key using the `updateFunc` */ def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { verify(uncommittedDelta != null, "Cannot update data before calling newVersion()") uncommittedDelta.update(key, updateFunc) } + /** Remove keys that satisfy the following condition */ def remove(condition: InternalRow => Boolean): Unit = { verify(uncommittedDelta != null, "Cannot remove data before calling newVersion()") uncommittedDelta.remove(condition) } + /** Commit all the updates that have been made to the store. */ def commitUpdates(): Unit = { verify(uncommittedDelta != null, "Cannot commit data before calling newVersion()") - uncommittedDelta.commit() + uncommittedDelta.commitAndWriteDeltaFile() uncommittedDelta = null } + /** Cancel all the updates that have been made to the store. */ def cancelUpdates(): Unit = { verify(uncommittedDelta != null, "Cannot commit data before calling newVersion()") uncommittedDelta.cancel() uncommittedDelta = null } - def hasCommitted: Boolean = { - uncommittedDelta == null - } + /** + * Get all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ def getAll(): Iterator[InternalRow] = synchronized { - verify(uncommittedDelta == null, "Cannot getAll() before committing") - val lastVersion = fetchFiles().lastOption.map(_.version) - lastVersion.map(loadMap) match { - case Some(map) => - map.iterator.map { case (key, value) => new JoinedRow(key, value) } - case None => - Iterator.empty - } + verify(uncommittedDelta == null, "Cannot getAll() while there are uncommitted updates") + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max) + .iterator + .map { case (key, value) => new JoinedRow(key, value) } + } else Iterator.empty + } + + private[state] def hasUncommittedUpdates: Boolean = { + uncommittedDelta != null } override def toString(): String = { s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" } - // Private methods + // Internal classes and methods + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - private class UncommittedDelta(val version: Long, val map: MapType) { + private class UncommittedUpdates(val version: Long, val map: MapType) { private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private val tempDeltaFileStream = serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) @@ -213,14 +230,13 @@ private[sql] class StateStore( } } - def commit(): Unit = { + def commitAndWriteDeltaFile(): Unit = { try { tempDeltaFileStream.close() val deltaFile = new Path(baseDir, s"${uncommittedDelta.version}.delta") StateStore.this.synchronized { fs.rename(tempDeltaFile, deltaFile) - println("Written " + deltaFile) - storeMaps.put(version, map) + loadedMaps.put(version, map) } } catch { case NonFatal(e) => @@ -235,30 +251,31 @@ private[sql] class StateStore( } } - - private[state] def getInternalMap(version: Long): Option[MapType] = synchronized { - storeMaps.get(version) - } - private def initialize(): Unit = { if (!fs.exists(baseDir)) { fs.mkdirs(baseDir) } else { if (!fs.isDirectory(baseDir)) { - throw new IllegalStateException(s"Cannot use $directory for storing state data as" + - s"$baseDir already exists and is not a directory") + throw new IllegalStateException( + s"Cannot use $directory for storing state data as" + + s"$baseDir already exists and is not a directory") } } } + private[state] def getAll(version: Long): Iterator[InternalRow] = synchronized { + loadMap(version) + .iterator + .map { case (key, value) => new JoinedRow(key, value) } + } + + private def loadMap(version: Long): MapType = { if (version < 0) return new MapType - println(s"Loading version $version") - synchronized { storeMaps.get(version) }.getOrElse { + synchronized { loadedMaps.get(version) }.getOrElse { val mapFromFile = readSnapshotFile(version).getOrElse { val prevMap = loadMap(version - 1) val deltaUpdates = readDeltaFile(version) - println(s"Reading delta for $version") val newMap = new MapType() newMap ++= prevMap newMap.sizeHint(prevMap.size) @@ -268,10 +285,9 @@ private[sql] class StateStore( case KeyRemoved(key) => newMap.remove(key) } } - println("Map = " + newMap.toSeq.mkString(", ")) newMap } - storeMaps.put(version, mapFromFile) + loadedMaps.put(version, mapFromFile) mapFromFile } } @@ -334,10 +350,10 @@ private[sql] class StateStore( val deltaFilesForLastVersion = filesForVersion(files, lastVersion).filter(_.isSnapshot == false) synchronized { - storeMaps.get(lastVersion) + loadedMaps.get(lastVersion) } match { case Some(map) => - if (deltaFilesForLastVersion.size > 10) { + if (deltaFilesForLastVersion.size > maxDeltaChainForSnapshots) { writeSnapshotFile(lastVersion, map) } case None => @@ -359,7 +375,7 @@ private[sql] class StateStore( if (earliestVersionToRetain >= 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head synchronized { - storeMaps.keys.filter(_ < earliestVersionToRetain).foreach(storeMaps.remove) + loadedMaps.keys.filter(_ < earliestVersionToRetain).foreach(loadedMaps.remove) } files.filter(_.version < earliestFileToRetain.version).foreach { f => fs.delete(f.path, true) @@ -408,19 +424,15 @@ private[sql] class StateStore( Seq.empty } val versionToFiles = new mutable.HashMap[Long, StoreFile] - println(s"Fetching files in $baseDir") files.foreach { status => val path = status.getPath - //println(s"\tTesting file $path") val nameParts = path.getName.split("\\.") - println(s"\t${path.getName}, ${nameParts.mkString(",")}") if (nameParts.size == 2) { val version = nameParts(0).toLong nameParts(1).toLowerCase match { case "delta" => // ignore the file otherwise, snapshot file already exists for that batch id if (!versionToFiles.contains(version)) { - //println(s"\tFound file $path") versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) } case "snapshot" => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 8c5e408ed45e3..d5439ec15e803 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -51,9 +51,9 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( storeDirectory ) val inputIter = dataRDD.compute(partition, ctxt) - store.startUpdates(newStoreVersion) + store.prepareForUpdates(newStoreVersion) val outputIter = storeUpdateFunction(store, inputIter) - assert(store.hasCommitted) + assert(!store.hasUncommittedUpdates) outputIter } { if (store != null) store.cancelUpdates() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 3d3680f9bee76..56efbaf7031ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -45,7 +45,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Verify state before starting a new set of updates assert(store.getAll().isEmpty) - assert(store.getInternalMap(0).isEmpty) + assert(!store.hasUncommittedUpdates) intercept[IllegalStateException] { store.update(null, null) } @@ -60,10 +60,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify states after starting updates - store.startUpdates(0) + store.prepareForUpdates(0) intercept[IllegalStateException] { store.getAll() } + assert(store.hasUncommittedUpdates) update(store, "a", 1) intercept[IllegalStateException] { store.getAll() @@ -75,43 +76,51 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth remove(store, _.startsWith("a")) store.commitUpdates() - // Very state after committing + // Verify state after committing + assert(!store.hasUncommittedUpdates) assert(getData(store) === Set("b" -> 2)) - assertMap(store.getInternalMap(0), Map("b" -> 2)) assert(fileExists(store, 0, isSnapshot = false)) - assert(store.getInternalMap(1).isEmpty) + + // Trying to get newer versions should fail + intercept[Exception] { + getData(store, 1) + } + + intercept[Exception] { + getDataFromFiles(store, 1) + } // Reload store from the directory val reloadedStore = new StateStore(store.id, store.directory) assert(getData(reloadedStore) === Set("b" -> 2)) // New updates to the reload store with new version, and does not change old version - reloadedStore.startUpdates(1) + reloadedStore.prepareForUpdates(1) update(reloadedStore, "c", 4) reloadedStore.commitUpdates() assert(getData(reloadedStore) === Set("b" -> 2, "c" -> 4)) - assertMap(reloadedStore.getInternalMap(0), Map("b" -> 2)) - assertMap(reloadedStore.getInternalMap(1), Map("b" -> 2, "c" -> 4)) + assert(getData(reloadedStore, version = 0) === Set("b" -> 2)) + assert(getData(reloadedStore, version = 1) === Set("b" -> 2, "c" -> 4)) assert(fileExists(reloadedStore, 1, isSnapshot = false)) } test("cancelUpdates") { val store = newStore() - store.startUpdates(0) + store.prepareForUpdates(0) update(store, "a", 1) store.commitUpdates() assert(getData(store) === Set("a" -> 1)) // cancelUpdates should not change the data - store.startUpdates(1) + store.prepareForUpdates(1) update(store, "b", 1) store.cancelUpdates() assert(getData(store) === Set("a" -> 1)) // Calling startUpdates again should cancel previous updates - store.startUpdates(1) + store.prepareForUpdates(1) update(store, "b", 1) - store.startUpdates(1) + store.prepareForUpdates(1) update(store, "c", 1) store.commitUpdates() assert(getData(store) === Set("a" -> 1, "c" -> 1)) @@ -121,40 +130,108 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store = newStore() intercept[IllegalArgumentException] { - store.startUpdates(-1) + store.prepareForUpdates(-1) } // Prepare some data in the stoer - store.startUpdates(0) + store.prepareForUpdates(0) update(store, "a", 1) store.commitUpdates() assert(getData(store) === Set("a" -> 1)) intercept[IllegalStateException] { - store.startUpdates(2) + store.prepareForUpdates(2) } // Update store version with some data - println("here") - store.startUpdates(1) + store.prepareForUpdates(1) update(store, "b", 1) store.commitUpdates() - println("xyz") assert(getData(store) === Set("a" -> 1, "b" -> 1)) - println("bla") - assert(getData(new StateStore(store.id, store.directory)) === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(store) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data - store.startUpdates(1) + store.prepareForUpdates(1) update(store, "c", 1) store.commitUpdates() assert(getData(store) === Set("a" -> 1, "c" -> 1)) - assert(getData(new StateStore(store.id, store.directory)) === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(store) === Set("a" -> 1, "c" -> 1)) + } + + test("snapshotting") { + val store = newStore(maxDeltaChainForSnapshots = 5) + + var currentVersion = -1 + def updateVersionTo(targetVersion: Int): Unit = { + for (i <- currentVersion + 1 to targetVersion) { + store.prepareForUpdates(i) + update(store, "a", i) + store.commitUpdates() + } + } + + updateVersionTo(2) + require(getData(store) === Set("a" -> 2)) + store.manageFiles() + assert(getDataFromFiles(store) === Set("a" -> 2)) + for (i <- 0 to 2) { + assert(fileExists(store, i, isSnapshot = false)) // all delta files present + assert(!fileExists(store, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + updateVersionTo(6) + require(getData(store) === Set("a" -> 6), "Store not updated correctly") + store.manageFiles() // do snapshot + assert(getData(store) === Set("a" -> 6), "manageFiles() messed up the data") + assert(getDataFromFiles(store) === Set("a" -> 6)) + + val snapshotVersion = (0 to 6).find(version => fileExists(store, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "Snapshot file not generated") + + + // After version 20, snapshotting should generate newer snapshot files + updateVersionTo(20) + require(getData(store) === Set("a" -> 20), "Store not updated correctly") + store.manageFiles() // do snapshot + assert(getData(store) === Set("a" -> 20), "manageFiles() messed up the data") + assert(getDataFromFiles(store) === Set("a" -> 20)) + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(store, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "No snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "Newer snapshot not generated") + + } + + test("cleaning") { + val store = newStore(maxDeltaChainForSnapshots = 5) + + for (i <- 0 to 20) { + store.prepareForUpdates(i) + update(store, "a", i) + store.commitUpdates() + } + require(getData(store) === Set("a" -> 20), "Store not updated correctly") + store.manageFiles() // do cleanup + assert(fileExists(store, 0, isSnapshot = false)) + + assert(getDataFromFiles(store, 20) === Set("a" -> 20)) + assert(getDataFromFiles(store, 19) === Set("a" -> 19)) + } + + def getData(store: StateStore, version: Int = -1): Set[(String, Int)] = { + if (version < 0) { + store.getAll.map(unwrapKeyValue).toSet + } else { + store.getAll(version).map(unwrapKeyValue).toSet + } + } - def getData(store: StateStore): Set[(String, Int)] = { - store.getAll.map(unwrapKeyValue).toSet + def getDataFromFiles(store: StateStore, version: Int = -1): Set[(String, Int)] = { + getData(new StateStore(store.id, store.directory), version) } def assertMap( @@ -184,10 +261,15 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore invokePrivate method(storeId) } - def newStore(opId: Long = Random.nextLong, partition: Int = 0): StateStore = { + def newStore( + opId: Long = Random.nextLong, + partition: Int = 0, + maxDeltaChainForSnapshots: Int = 10 + ): StateStore = { new StateStore( StateStoreId(opId, partition), - Utils.createDirectory(tempDir, Random.nextString(5)).toString) + Utils.createDirectory(tempDir, Random.nextString(5)).toString, + maxDeltaChainForSnapshots = maxDeltaChainForSnapshots) } def remove(store: StateStore, condition: String => Boolean): Unit = { From f417bde8bd65572e78125c7aa900221fa0bc4d7c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 10 Mar 2016 20:24:39 -0800 Subject: [PATCH 03/46] Added basic unit test for StateStoreRDD --- .../streaming/state/StateStore.scala | 7 +- .../state/StateStoreCoordinator.scala | 5 +- .../streaming/state/StateStoreRDD.scala | 23 ++----- .../execution/streaming/state/package.scala | 38 +++++++++++ .../streaming/state/StateStoreRDDSuite.scala | 64 +++++++++++++++++++ .../streaming/state/StateStoreSuite.scala | 31 ++++----- 6 files changed, 128 insertions(+), 40 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index edc1025580736..626c4e0fe4120 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -172,9 +172,10 @@ private[sql] class StateStore( /** Cancel all the updates that have been made to the store. */ def cancelUpdates(): Unit = { - verify(uncommittedDelta != null, "Cannot commit data before calling newVersion()") - uncommittedDelta.cancel() - uncommittedDelta = null + if (uncommittedDelta != null) { + uncommittedDelta.cancel() + uncommittedDelta = null + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 39380f4d46065..62d94b9614bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -37,8 +37,9 @@ class StateStoreCoordinator(rpcEnv: RpcEnv) { "StateStoreCoordinator", new StateStoreCoordinatorEndpoint(rpcEnv, this)) private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] - def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Unit = { + def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Boolean = { instances.synchronized { instances.put(storeId, ExecutorCacheTaskLocation(host, executorId)) } + true } def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { @@ -74,7 +75,7 @@ private[spark] object StateStoreCoordinator { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => - coordinator.reportActiveInstance(id, host, executorId) + context.reply(coordinator.reportActiveInstance(id, host, executorId)) case VerifyIfInstanceActive(id, executor) => context.reply(coordinator.verifyIfInstanceActive(id, executor)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index d5439ec15e803..55154b5cb190b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -32,21 +32,22 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( operatorId: Long, newStoreVersion: Long, storeDirectory: String, - storeCoordinator: StateStoreCoordinator - ) - extends RDD[OUTPUT](dataRDD) { + storeCoordinator: StateStoreCoordinator) extends RDD[OUTPUT](dataRDD) { override protected def getPartitions: Array[Partition] = dataRDD.partitions override def getPreferredLocations(partition: Partition): Seq[String] = { + Seq.empty + /* storeCoordinator.getLocation( StateStoreId(operatorId, partition.index)).toSeq + */ } override def compute(partition: Partition, ctxt: TaskContext): Iterator[OUTPUT] = { var store: StateStore = null Utils.tryWithSafeFinally { - StateStore.get( + store = StateStore.get( StateStoreId(operatorId, partition.index), storeDirectory ) @@ -60,17 +61,3 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( } } } - -object StateStoreRDD { - implicit def withStateStores[INPUT: ClassTag, OUTPUT: ClassTag]( - dataRDD: RDD[INPUT], - storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], - operatorId: Long, - newStoreVersion: Long, - storeDirectory: String, - storeCoordinator: StateStoreCoordinator - ): RDD[OUTPUT] = { - new StateStoreRDD( - dataRDD, storeUpdateFunction, operatorId, newStoreVersion, storeDirectory, storeCoordinator) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala new file mode 100644 index 0000000000000..bcfc7dc309446 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD + +package object state { + + implicit class StateStoreOps[INPUT: ClassTag](dataRDD: RDD[INPUT]) { + def withStateStores[OUTPUT: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], + operatorId: Long, + newStoreVersion: Long, + storeDirectory: String, + storeCoordinator: StateStoreCoordinator + ): RDD[OUTPUT] = { + new StateStoreRDD( + dataRDD, storeUpdateFunction, operatorId, newStoreVersion, storeDirectory, storeCoordinator) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala new file mode 100644 index 0000000000000..e291ad951bb92 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.util.Random + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + + +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with SharedSparkContext { + private var tempDir = Utils.createTempDir().toString + + import StateStoreSuite._ + + after { + StateStore.clearAll() + } + + test("StateStoreRDD") { + + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update(wrapKey(s), oldRow => { + val oldValue = oldRow.map(unwrapValue).getOrElse(0) + wrapValue(oldValue + 1) + }) + } + store.commitUpdates() + store.getAll().map(unwrapKeyValue) + } + val opId = 0 + val rdd1 = makeRDD(Seq("a", "b", "a")) + .withStateStores(increment, opId, newStoreVersion = 0, path, null) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + val rdd2 = makeRDD(Seq("a", "c")) + .withStateStores(increment, opId, newStoreVersion = 1, path, null) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + } + + private def makeRDD(seq: Seq[String]): RDD[String] = { + sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 56efbaf7031ef..e48a4ee438e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -23,17 +23,19 @@ import scala.collection.mutable import scala.util.Random import org.apache.hadoop.fs.Path -import org.scalatest.{PrivateMethodTester, BeforeAndAfter} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[InternalRow, InternalRow] + import StateStoreSuite._ + private val tempDir = Utils.createTempDir().toString after { @@ -279,36 +281,31 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth private def update(store: StateStore, key: String, value: Int): Unit = { store.update(wrapKey(key), _ => wrapValue(value)) } +} - private def increment(store: StateStore, key: String): Unit = { - val keyRow = new GenericInternalRow(Array(key).asInstanceOf[Array[Any]]) - store.update(keyRow, oldRow => { - val oldValue = oldRow.map(unwrapValue).getOrElse(0) - wrapValue(oldValue + 1) - }) - } +private[state] object StateStoreSuite { - private def wrapValue(i: Int): InternalRow = { + def wrapValue(i: Int): InternalRow = { new GenericInternalRow(Array[Any](i)) } - private def wrapKey(s: String): InternalRow = { + def wrapKey(s: String): InternalRow = { new GenericInternalRow(Array[Any](UTF8String.fromString(s))) } - private def unwrapKey(row: InternalRow): String = { + def unwrapKey(row: InternalRow): String = { row.asInstanceOf[GenericInternalRow].getString(0) } - private def unwrapValue(row: InternalRow): Int = { + def unwrapValue(row: InternalRow): Int = { row.asInstanceOf[GenericInternalRow].getInt(0) } - private def unwrapKeyValue(row: (InternalRow, InternalRow)): (String, Int) = { + def unwrapKeyValue(row: (InternalRow, InternalRow)): (String, Int) = { (unwrapKey(row._1), unwrapValue(row._2)) } - private def unwrapKeyValue(row: InternalRow): (String, Int) = { + def unwrapKeyValue(row: InternalRow): (String, Int) = { (row.getString(0), row.getInt(1)) } -} +} \ No newline at end of file From d8cee54ad183cf899031fb762bba032c6a9cffd5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 10 Mar 2016 20:39:16 -0800 Subject: [PATCH 04/46] Style fix --- .../spark/sql/execution/streaming/state/StateStore.scala | 6 ++---- .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 2 +- .../sql/execution/streaming/state/StateStoreSuite.scala | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 626c4e0fe4120..97b005400c03d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -27,11 +27,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow -import org.apache.spark.util.{RpcUtils, CompletionIterator, Utils} -import org.apache.spark.{SparkEnv, Logging, SparkConf} +import org.apache.spark.util.{CompletionIterator, RpcUtils, Utils} case class StateStoreId(operatorId: Long, partitionId: Int) @@ -441,8 +441,6 @@ private[sql] class StateStore( case _ => logWarning(s"Could not identify file $path") } - } else { - println("\tIgnoring") } } versionToFiles.values.toSeq.sortBy(_.version) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 55154b5cb190b..76e79d962de6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.state import scala.reflect.ClassTag +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import org.apache.spark.{Partition, TaskContext} /** * Created by tdas on 3/9/16. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index e48a4ee438e76..dd13cd026343f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -308,4 +308,4 @@ private[state] object StateStoreSuite { def unwrapKeyValue(row: InternalRow): (String, Int) = { (row.getString(0), row.getInt(1)) } -} \ No newline at end of file +} From 7d74c67f999d53e9f02cf9ed5d2c8d2e951144ca Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 11 Mar 2016 18:55:16 -0800 Subject: [PATCH 05/46] Fixed test --- .../spark/sql/execution/streaming/state/StateStoreSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index dd13cd026343f..8d95329758966 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -57,9 +57,6 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth intercept[IllegalStateException] { store.commitUpdates() } - intercept[IllegalStateException] { - store.cancelUpdates() - } // Verify states after starting updates store.prepareForUpdates(0) From c5dd06159ae88cc0830f680a60ad29944115051c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Mar 2016 02:10:37 -0700 Subject: [PATCH 06/46] Fixed versioning in StateStoreRDD, and made store updates thread-safe --- .../streaming/state/StateStore.scala | 179 ++++++++++++------ .../streaming/state/StateStoreRDD.scala | 1 + .../streaming/state/StateStoreRDDSuite.scala | 54 +++--- .../streaming/state/StateStoreSuite.scala | 27 ++- 4 files changed, 168 insertions(+), 93 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 97b005400c03d..0fd8d6d8f1edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -37,9 +37,9 @@ case class StateStoreId(operatorId: Long, partitionId: Int) private[state] object StateStore extends Logging { - sealed trait Update - case class ValueUpdated(key: InternalRow, value: InternalRow) extends Update - case class KeyRemoved(key: InternalRow) extends Update + sealed trait StoreUpdate + case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate + case class KeyRemoved(key: InternalRow) extends StoreUpdate private val loadedStores = new mutable.HashMap[StateStoreId, StateStore]() private val managementTimer = new Timer("StateStore Timer", true) @@ -50,6 +50,7 @@ private[state] object StateStore extends Logging { startIfNeeded() loadedStores.getOrElseUpdate(storeId, new StateStore(storeId, directory)) } + println(s"Got state store $storeId") reportActiveInstance(storeId) store } @@ -129,7 +130,13 @@ private[sql] class StateStore( private val fs = baseDir.getFileSystem(new Configuration()) private val serializer = new KryoSerializer(new SparkConf) - @volatile private var uncommittedDelta: UncommittedUpdates = null + /** + * Thread local variable to keep track of updates so that if there multiple speculative tasks + * in the same executor trying to update the same store, the updates are thread-safe. + */ + private val uncommittedDelta = new ThreadLocal[UncommittedUpdates]() { + override def initialValue(): UncommittedUpdates = new UncommittedUpdates() + } initialize() @@ -139,64 +146,44 @@ private[sql] class StateStore( * be overwritten when the updates are committed. */ private[state] def prepareForUpdates(version: Long): Unit = synchronized { - require(version >= 0) - if (uncommittedDelta != null) { - cancelUpdates() - } + require(version >= 0, "Version cannot be less than 0") val newMap = new MapType() if (version > 0) { - val oldMap = loadMap(version - 1) - newMap ++= oldMap + newMap ++= loadMap(version - 1) } - uncommittedDelta = new UncommittedUpdates(version, newMap) + uncommittedDelta.get.prepare(version, newMap) } - /** Update the value of a key using the `updateFunc` */ + /** Update the value of a key using the value generated by the update function */ def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { - verify(uncommittedDelta != null, "Cannot update data before calling newVersion()") - uncommittedDelta.update(key, updateFunc) + uncommittedDelta.get.update(key, updateFunc) } - /** Remove keys that satisfy the following condition */ + /** Remove keys that match the following condition */ def remove(condition: InternalRow => Boolean): Unit = { - verify(uncommittedDelta != null, "Cannot remove data before calling newVersion()") - uncommittedDelta.remove(condition) + uncommittedDelta.get.remove(condition) } /** Commit all the updates that have been made to the store. */ def commitUpdates(): Unit = { - verify(uncommittedDelta != null, "Cannot commit data before calling newVersion()") - uncommittedDelta.commitAndWriteDeltaFile() - uncommittedDelta = null + uncommittedDelta.get.commit() } /** Cancel all the updates that have been made to the store. */ def cancelUpdates(): Unit = { - if (uncommittedDelta != null) { - uncommittedDelta.cancel() - uncommittedDelta = null - } + uncommittedDelta.get.reset() } - /** - * Get all the data of the latest version of the store. - * Note that this will look up the files to determined the latest known version. - */ + def lastCommittedData(): Iterator[InternalRow] = { + uncommittedDelta.get.lastCommittedData() + } - def getAll(): Iterator[InternalRow] = synchronized { - verify(uncommittedDelta == null, "Cannot getAll() while there are uncommitted updates") - val versionsInFiles = fetchFiles().map(_.version).toSet - val versionsLoaded = loadedMaps.keySet - val allKnownVersions = versionsInFiles ++ versionsLoaded - if (allKnownVersions.nonEmpty) { - loadMap(allKnownVersions.max) - .iterator - .map { case (key, value) => new JoinedRow(key, value) } - } else Iterator.empty + def lastCommittedUpdates(): Iterator[StoreUpdate] = { + uncommittedDelta.get.lastCommittedUpdates() } private[state] def hasUncommittedUpdates: Boolean = { - uncommittedDelta != null + uncommittedDelta.get.hasUncommittedUpdates } override def toString(): String = { @@ -207,48 +194,96 @@ private[sql] class StateStore( private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - private class UncommittedUpdates(val version: Long, val map: MapType) { - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private val tempDeltaFileStream = - serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) + private class UncommittedUpdates { + + trait State + case object INITIALIZED extends State + case object PREPARED extends State + case object COMMITTED extends State + + private var finalDeltaFile: Path = null + private var tempDeltaFile: Path = null + private var tempDeltaFileStream: SerializationStream = null + private var updatedMap: MapType = null + private var updateVersion: Long = -1 + private var state: State = INITIALIZED + + /** + * Prepare the set updates to be made to the state store by setting the version and the initial + * map on which to apply the updates. + */ + def prepare(version: Long, map: MapType): Unit = { + reset() + updateVersion = version + updatedMap = map + finalDeltaFile = deltaFile(updateVersion) + tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + tempDeltaFileStream = serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) + state = PREPARED + } + /** Update the value of a key using the value generated by the update function */ def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { - verify(uncommittedDelta != null, "Cannot call update() before calling startUpdates()") - val value = updateFunc(uncommittedDelta.map.get(key)) - uncommittedDelta.map.put(key, value) + verify(state == PREPARED, "Cannot call update() before calling startUpdates()") + val value = updateFunc(updatedMap.get(key)) + updatedMap.put(key, value) tempDeltaFileStream.writeObject(ValueUpdated(key, value)) } + /** Remove keys that match the following condition */ def remove(condition: InternalRow => Boolean): Unit = { - verify(uncommittedDelta != null, "Cannot call remove() before calling startUpdates()") - val keyIter = uncommittedDelta.map.keysIterator + verify(state == PREPARED, "Cannot call remove() before calling startUpdates()") + val keyIter = updatedMap.keysIterator while (keyIter.hasNext) { val key = keyIter.next if (condition(key)) { - uncommittedDelta.map.remove(key) + updatedMap.remove(key) tempDeltaFileStream.writeObject(KeyRemoved(key)) } } } - def commitAndWriteDeltaFile(): Unit = { + /** Commit all the updates that have been made to the store. */ + def commit(): Unit = { + verify(state == PREPARED, "Cannot call commitUpdates() before calling prepareForUpdates()") try { tempDeltaFileStream.close() - val deltaFile = new Path(baseDir, s"${uncommittedDelta.version}.delta") + StateStore.this.synchronized { - fs.rename(tempDeltaFile, deltaFile) - loadedMaps.put(version, map) + fs.rename(tempDeltaFile, finalDeltaFile) + loadedMaps.put(updateVersion, updatedMap) } + state = COMMITTED } catch { case NonFatal(e) => + state = INITIALIZED throw new IllegalStateException( - s"Error committing version ${uncommittedDelta.version} into $this", e) + s"Error committing version $updateVersion into $this", e) + } + } + + def reset(): Unit = { + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { + fs.delete(tempDeltaFile, true) } + state = INITIALIZED + } + + def lastCommittedData(): Iterator[InternalRow] = { + verify(state == COMMITTED, "Cannot get iterator of data before calling commitUpdate()") + StateStore.this.iterator(updateVersion) + } + + def lastCommittedUpdates(): Iterator[StoreUpdate] = { + verify(state == COMMITTED, "Cannot get iterator of updates before calling commitUpdate()") + readDeltaFile(finalDeltaFile) } - def cancel(): Unit = { - tempDeltaFileStream.close() - fs.delete(tempDeltaFile, true) + def hasUncommittedUpdates: Boolean = { + state == PREPARED } } @@ -264,7 +299,22 @@ private[sql] class StateStore( } } - private[state] def getAll(version: Long): Iterator[InternalRow] = synchronized { + /** + * Get all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ + private[state] def latestIterator(): Iterator[InternalRow] = synchronized { + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max) + .iterator + .map { case (key, value) => new JoinedRow(key, value) } + } else Iterator.empty + } + + private[state] def iterator(version: Long): Iterator[InternalRow] = synchronized { loadMap(version) .iterator .map { case (key, value) => new JoinedRow(key, value) } @@ -293,17 +343,20 @@ private[sql] class StateStore( } } - private def readDeltaFile(version: Long): Iterator[Update] = { - val fileToRead = deltaFile(version) + private def readDeltaFile(version: Long): Iterator[StoreUpdate] = { + readDeltaFile(deltaFile(version)) + } + + private def readDeltaFile (fileToRead: Path): Iterator[StoreUpdate] = { if (!fs.exists(fileToRead)) { throw new IllegalStateException( - s"Cannot read delta file for version $version of $this: $fileToRead does not exist") + s"Cannot read delta file $fileToRead of $this: $fileToRead does not exist") } val deser = serializer.newInstance() var deserStream: DeserializationStream = null deserStream = deser.deserializeStream(fs.open(fileToRead)) - val iter = deserStream.asIterator.asInstanceOf[Iterator[Update]] - CompletionIterator[Update, Iterator[Update]](iter, { deserStream.close() }) + val iter = deserStream.asIterator.asInstanceOf[Iterator[StoreUpdate]] + CompletionIterator[StoreUpdate, Iterator[StoreUpdate]](iter, { deserStream.close() }) } private def writeSnapshotFile(version: Long, map: MapType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 76e79d962de6e..90c704f22fdf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -47,6 +47,7 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( var store: StateStore = null Utils.tryWithSafeFinally { + println(s"Getting store for version $newStoreVersion") store = StateStore.get( StateStoreId(operatorId, partition.index), storeDirectory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index e291ad951bb92..a9173b5f663db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql.execution.streaming.state import scala.util.Random -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { -class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with SharedSparkContext { + private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) private var tempDir = Utils.createTempDir().toString import StateStoreSuite._ @@ -35,30 +37,36 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with SharedSp StateStore.clearAll() } - test("StateStoreRDD") { - - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update(wrapKey(s), oldRow => { - val oldValue = oldRow.map(unwrapValue).getOrElse(0) - wrapValue(oldValue + 1) - }) + test("versioning and immuability") { + withSpark(new SparkContext(conf)) { sc => + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + wrapKey(s), oldRow => { + val oldValue = oldRow.map(unwrapValue).getOrElse(0) + wrapValue(oldValue + 1) + }) + } + store.commitUpdates() + store.lastCommittedData().map(unwrapKeyValue) } - store.commitUpdates() - store.getAll().map(unwrapKeyValue) - } - val opId = 0 - val rdd1 = makeRDD(Seq("a", "b", "a")) - .withStateStores(increment, opId, newStoreVersion = 0, path, null) - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")) + .withStateStores(increment, opId, newStoreVersion = 0, path, null) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) - val rdd2 = makeRDD(Seq("a", "c")) - .withStateStores(increment, opId, newStoreVersion = 1, path, null) - assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")) + .withStateStores(increment, opId, newStoreVersion = 1, path, null) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } } - private def makeRDD(seq: Seq[String]): RDD[String] = { + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 8d95329758966..9be2553b0ae69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -46,8 +46,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store = newStore() // Verify state before starting a new set of updates - assert(store.getAll().isEmpty) + assert(store.latestIterator().isEmpty) assert(!store.hasUncommittedUpdates) + intercept[IllegalStateException] { + store.lastCommittedData() + } intercept[IllegalStateException] { store.update(null, null) } @@ -58,16 +61,25 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.commitUpdates() } - // Verify states after starting updates + // Verify states after preparing for updates + intercept[IllegalArgumentException] { + store.prepareForUpdates(-1) + } store.prepareForUpdates(0) intercept[IllegalStateException] { - store.getAll() + store.lastCommittedData() + } + intercept[IllegalStateException] { + store.prepareForUpdates(1) } assert(store.hasUncommittedUpdates) + + // Verify state after updating update(store, "a", 1) intercept[IllegalStateException] { - store.getAll() + store.lastCommittedData() } + assert(store.latestIterator().isEmpty) // Make updates and commit update(store, "b", 2) @@ -116,9 +128,10 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.cancelUpdates() assert(getData(store) === Set("a" -> 1)) - // Calling startUpdates again should cancel previous updates + // Calling prepareForUpdates again should cancel previous updates store.prepareForUpdates(1) update(store, "b", 1) + store.prepareForUpdates(1) update(store, "c", 1) store.commitUpdates() @@ -222,9 +235,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def getData(store: StateStore, version: Int = -1): Set[(String, Int)] = { if (version < 0) { - store.getAll.map(unwrapKeyValue).toSet + store.latestIterator.map(unwrapKeyValue).toSet } else { - store.getAll(version).map(unwrapKeyValue).toSet + store.iterator(version).map(unwrapKeyValue).toSet } } From 7adca7004085a463fd1768708f01a7a01775cea8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Mar 2016 02:15:31 -0700 Subject: [PATCH 07/46] Style fixes --- .../apache/spark/sql/execution/streaming/state/StateStore.scala | 1 - .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 0fd8d6d8f1edc..3abb89cf48c43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -50,7 +50,6 @@ private[state] object StateStore extends Logging { startIfNeeded() loadedStores.getOrElseUpdate(storeId, new StateStore(storeId, directory)) } - println(s"Got state store $storeId") reportActiveInstance(storeId) store } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 90c704f22fdf1..76e79d962de6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -47,7 +47,6 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( var store: StateStore = null Utils.tryWithSafeFinally { - println(s"Getting store for version $newStoreVersion") store = StateStore.get( StateStoreId(operatorId, partition.index), storeDirectory From a0ba498185286a921ef303dbc087dad5d226fa41 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Mar 2016 03:13:09 -0700 Subject: [PATCH 08/46] Added docs --- .../streaming/state/StateStore.scala | 243 +++++++++++------- .../streaming/state/StateStoreRDDSuite.scala | 2 +- 2 files changed, 158 insertions(+), 87 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 3abb89cf48c43..b97de87e47f92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -35,84 +35,36 @@ import org.apache.spark.util.{CompletionIterator, RpcUtils, Utils} case class StateStoreId(operatorId: Long, partitionId: Int) -private[state] object StateStore extends Logging { - - sealed trait StoreUpdate - case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate - case class KeyRemoved(key: InternalRow) extends StoreUpdate - - private val loadedStores = new mutable.HashMap[StateStoreId, StateStore]() - private val managementTimer = new Timer("StateStore Timer", true) - @volatile private var managementTask: TimerTask = null - - def get(storeId: StateStoreId, directory: String): StateStore = { - val store = loadedStores.synchronized { - startIfNeeded() - loadedStores.getOrElseUpdate(storeId, new StateStore(storeId, directory)) - } - reportActiveInstance(storeId) - store - } - - def clearAll(): Unit = loadedStores.synchronized { - loadedStores.clear() - if (managementTask != null) { - managementTask.cancel() - managementTask = null - } - } - - private def remove(storeId: StateStoreId): Unit = { - loadedStores.remove(storeId) - } - - private def reportActiveInstance(storeId: StateStoreId): Unit = { - val host = SparkEnv.get.blockManager.blockManagerId.host - val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - askCoordinator[Boolean](ReportActiveInstance(storeId, host, executorId)) - } - - private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { - val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - askCoordinator[Boolean](VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) - } - - private def askCoordinator[T: ClassTag](message: StateStoreCoordinatorMessage): Option[T] = { - try { - val env = SparkEnv.get - if (env != null) { - val coordinatorRef = RpcUtils.makeDriverRef("StateStoreCoordinator", env.conf, env.rpcEnv) - Some(coordinatorRef.askWithRetry[T](message)) - } else { - None - } - } catch { - case NonFatal(e) => - clearAll() - None - } - } - - private def startIfNeeded(): Unit = loadedStores.synchronized { - if (managementTask == null) { - managementTask = new TimerTask { - override def run(): Unit = { manageFiles() } - } - managementTimer.schedule(managementTask, 10000, 10000) - } - } - - private def manageFiles(): Unit = { - loadedStores.synchronized { loadedStores.values.toSeq }.foreach { store => - try { - store.manageFiles() - } catch { - case NonFatal(e) => - logWarning(s"Error performing snapshot and cleaning up store ${store.id}") - } - } - } -} +/** + * A versioned key-value store which can be used to store streaming state data. All data is + * backed by a file system. All updates to the store has to be done in sets transactionally, and + * each set of updates increments the store's version. These versions can be used to re-execute the + * updates (by retries in RDD operations) on the correct version of the store, and regenerate + * the store version. + * + * Usage: + * To update the data in the state store, the following order of operations are needed. + * + * - val store = StateStore.get(operatorId, partitionId) // to get the right store + * - store.prepareForUpdates(newVersion) // must be called for doing any update + * - store.update(...) + * - store.remove(...) + * - store.commitUpdates() // commits all the updates to made with version number + * - store.lastCommittedData() // key-value data after last commit as an iterator + * - store.lastCommittedUpdates() // updates made in the last as an iterator + * + * Concurrency model: + * All updates made after prepareForUpdates() are local to the thread. So concurrent attempts + * from multiple threads will create multiple sets of updates that need to be committed separately. + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Multiple attempts to commit the same version of updates must have the same updates. + * - Background management of files ensures that last versions of the store is always recoverable + * to ensure re-executed RDD operations re-apply updates on the correct past version of the + * store. + */ private[sql] class StateStore( val id: StateStoreId, @@ -144,26 +96,35 @@ private[sql] class StateStore( * are made on the `version - 1` of the store data. If `version` already exists, it will * be overwritten when the updates are committed. */ - private[state] def prepareForUpdates(version: Long): Unit = synchronized { - require(version >= 0, "Version cannot be less than 0") + private[state] def prepareForUpdates(newVersion: Long): Unit = synchronized { + require(newVersion >= 0, "Version cannot be less than 0") val newMap = new MapType() - if (version > 0) { - newMap ++= loadMap(version - 1) + if (newVersion > 0) { + newMap ++= loadMap(newVersion - 1) } - uncommittedDelta.get.prepare(version, newMap) + uncommittedDelta.get.prepare(newVersion, newMap) } - /** Update the value of a key using the value generated by the update function */ + /** + * Update the value of a key using the value generated by the update function. + * This can be called only after prepareForUpdates() has been called in the same thread. + */ def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { uncommittedDelta.get.update(key, updateFunc) } - /** Remove keys that match the following condition */ + /** + * Remove keys that match the following condition. + * This can be called only after prepareForUpdates() has been called in the current thread. + */ def remove(condition: InternalRow => Boolean): Unit = { uncommittedDelta.get.remove(condition) } - /** Commit all the updates that have been made to the store. */ + /** + * Commit all the updates that have been made to the store. + * This can be called only after prepareForUpdates() has been called in the current thread. + */ def commitUpdates(): Unit = { uncommittedDelta.get.commit() } @@ -173,14 +134,26 @@ private[sql] class StateStore( uncommittedDelta.get.reset() } + /** + * Iterator of store data after a set of updates have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ def lastCommittedData(): Iterator[InternalRow] = { uncommittedDelta.get.lastCommittedData() } + /** + * Iterator of the updates that have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ def lastCommittedUpdates(): Iterator[StoreUpdate] = { uncommittedDelta.get.lastCommittedUpdates() } + + /** + * Whether there are updates made in the current thread that have not been committed yet. + */ private[state] def hasUncommittedUpdates: Boolean = { uncommittedDelta.get.hasUncommittedUpdates } @@ -193,6 +166,7 @@ private[sql] class StateStore( private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + /** Class representing all the data related to a set of updates to the store */ private class UncommittedUpdates { trait State @@ -395,6 +369,7 @@ private[sql] class StateStore( cleanup() } + /** Perform a snapshot of the store to allow delta files to be consolidated */ private def doSnapshot(): Unit = { try { val files = fetchFiles() @@ -420,6 +395,11 @@ private[sql] class StateStore( } } + /** + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ private[state] def cleanup(): Unit = { try { val files = fetchFiles() @@ -512,3 +492,94 @@ private[sql] class StateStore( } } } + + +/** + * Companion object to [[StateStore]] that provides helper methods to create and retrive stores + * by their unique ids. + */ +private[state] object StateStore extends Logging { + + sealed trait StoreUpdate + + case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate + + case class KeyRemoved(key: InternalRow) extends StoreUpdate + + private val loadedStores = new mutable.HashMap[StateStoreId, StateStore]() + private val managementTimer = new Timer("StateStore Timer", true) + @volatile private var managementTask: TimerTask = null + + /** Get or create a store associated with the id. */ + def get(storeId: StateStoreId, directory: String): StateStore = { + val store = loadedStores.synchronized { + startIfNeeded() + loadedStores.getOrElseUpdate(storeId, new StateStore(storeId, directory)) + } + reportActiveInstance(storeId) + store + } + + def clearAll(): Unit = loadedStores.synchronized { + loadedStores.clear() + if (managementTask != null) { + managementTask.cancel() + managementTask = null + } + } + + private def remove(storeId: StateStoreId): Unit = { + loadedStores.remove(storeId) + } + + private def reportActiveInstance(storeId: StateStoreId): Unit = { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + askCoordinator[Boolean](ReportActiveInstance(storeId, host, executorId)) + } + + private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + askCoordinator[Boolean](VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) + } + + private def askCoordinator[T: ClassTag](message: StateStoreCoordinatorMessage): Option[T] = { + try { + val env = SparkEnv.get + if (env != null) { + val coordinatorRef = RpcUtils.makeDriverRef("StateStoreCoordinator", env.conf, env.rpcEnv) + Some(coordinatorRef.askWithRetry[T](message)) + } else { + None + } + } catch { + case NonFatal(e) => + clearAll() + None + } + } + + private def startIfNeeded(): Unit = loadedStores.synchronized { + if (managementTask == null) { + managementTask = new TimerTask { + override def run(): Unit = { + manageFiles() + } + } + managementTimer.schedule(managementTask, 10000, 10000) + } + } + + private def manageFiles(): Unit = { + loadedStores.synchronized { + loadedStores.values.toSeq + }.foreach { store => + try { + store.manageFiles() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up store ${store.id}") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index a9173b5f663db..69f69cf3eeedc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -21,10 +21,10 @@ import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { From bee673c794cda3484acf5a8933b8243a49eab27a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Mar 2016 16:56:16 -0700 Subject: [PATCH 09/46] Refactored for new design --- .../state/HDFSBackedStateStoreProvider.scala | 445 ++++++++++++++++ .../streaming/state/StateStore.scala | 490 ++---------------- .../streaming/state/StateStoreRDD.scala | 8 +- .../streaming/state/StateStoreRDDSuite.scala | 4 +- .../streaming/state/StateStoreSuite.scala | 219 ++++---- 5 files changed, 593 insertions(+), 573 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala new file mode 100644 index 0000000000000..a70cf5d4271c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.{TimerTask, Timer} + +import scala.collection.mutable +import scala.collection.mutable.HashMap +import scala.util.Random +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.streaming.state.StateStore._ +import org.apache.spark.util.{Utils, CompletionIterator} +import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, KryoSerializer} +import org.apache.spark.sql.catalyst.InternalRow + + +/** + * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed + * by files in a HDFS-compatible file system. All updates to the store has to be done in sets + * transactionally, and each set of updates increments the store's version. These versions can + * be used to re-execute the updates (by retries in RDD operations) on the correct version of + * the store, and regenerate the store version. + * + * Usage: + * To update the data in the state store, the following order of operations are needed. + * + * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store + * - store.update(...) + * - store.remove(...) + * - store.commit() // commits all the updates to made with version number + * - store.iterator() // key-value data after last commit as an iterator + * - store.updates() // updates made in the last as an iterator + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Multiple attempts to commit the same version of updates must have the same updates. + * - Background management of files ensures that last versions of the store is always recoverable + * to ensure re-executed RDD operations re-apply updates on the correct past version of the + * store. + */ +class HDFSBackedStateStoreProvider( + val id: StateStoreId, + val directory: String, + numBatchesToRetain: Int = 2, + maxDeltaChainForSnapshots: Int = 10 + ) extends StateStoreProvider with Logging { + type MapType = mutable.HashMap[InternalRow, InternalRow] + + import StateStore._ + + + class HDFSBackedStateStore( val version: Long, mapToUpdate: MapType) + extends StateStore { + + trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object CANCELLED extends STATE + + private val newVersion = version + 1 + private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + private val tempDeltaFileStream = + serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) + + @volatile private var state: STATE = UPDATING + @volatile private var finalDeltaFile: Path = null + + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + + /** Update the value of a key using the value generated by the update function */ + override def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { + verify(state == UPDATING, "Cannot update after already committed or cancelled") + val value = updateFunc(mapToUpdate.get(key)) + mapToUpdate.put(key, value) + tempDeltaFileStream.writeObject(ValueUpdated(key, value)) + } + + /** Remove keys that match the following condition */ + override def remove(condition: InternalRow => Boolean): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + val keyIter = mapToUpdate.keysIterator + while (keyIter.hasNext) { + val key = keyIter.next + if (condition(key)) { + mapToUpdate.remove(key) + tempDeltaFileStream.writeObject(KeyRemoved(key)) + } + } + } + + /** Commit all the updates that have been made to the store. */ + override def commit(): Long = { + verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + + try { + tempDeltaFileStream.close() + finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + state = COMMITTED + newVersion + } catch { + case NonFatal(e) => + throw new IllegalStateException( + s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + } + } + + /** Cancel all the updates made on this store. This store will not be usable any more. */ + override def cancel(): Unit = { + state = CANCELLED + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { + fs.delete(tempDeltaFile, true) + } + } + + /** + * Get an iterator of all the store data. This can be called only after committing the + * updates. + */ + override def iterator(): Iterator[InternalRow] = { + verify(state == COMMITTED, "Cannot get iterator of store data before comitting") + HDFSBackedStateStoreProvider.this.iterator(version) + } + + /** + * Get an iterator of all the updates made to the store in the current version. + * This can be called only after committing the updates. + */ + override def updates(): Iterator[StoreUpdate] = { + verify(state == COMMITTED, "Cannot get iterator of updates before committing") + readDeltaFile(finalDeltaFile) + } + + /** + * Whether all updates have been committed + */ + override def hasCommitted: Boolean = { + state == COMMITTED + } + } + + /** Get the state store for making updates to create a new `version` of the store. */ + override def getStore(version: Long): StateStore = synchronized { + require(version >= 0, "Version cannot be less than 0") + val newMap = new MapType() + if (version > 0) { + newMap ++= loadMap(version) + } + new HDFSBackedStateStore(version, newMap) + } + + override def manage(): Unit = { + try { + doSnapshot() + cleanup() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $this") + } + } + + override def toString(): String = { + s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + } + + /* Internal classes and methods */ + + private val loadedMaps = new mutable.HashMap[Long, MapType] + private val baseDir = new Path(directory, s"${id.operatorId}/${id.partitionId.toString}") + private val fs = baseDir.getFileSystem(new Configuration()) + private val serializer = new KryoSerializer(new SparkConf) + + initialize() + + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + /** Commit a set of updates to the store with the given new version */ + private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + synchronized { + val finalDeltaFile = deltaFile(newVersion) + fs.rename(tempDeltaFile, finalDeltaFile) + loadedMaps.put(newVersion, map) + finalDeltaFile + } + } + + /** + * Get iterator of all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ + private[state] def latestIterator(): Iterator[InternalRow] = synchronized { + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max) + .iterator + .map { case (key, value) => new JoinedRow(key, value) } + } else Iterator.empty + } + + /** Get iterator of a specific version of the store */ + private[state] def iterator(version: Long): Iterator[InternalRow] = synchronized { + loadMap(version) + .iterator + .map { case (key, value) => new JoinedRow(key, value) } + } + + /** Initialize the store provider */ + private def initialize(): Unit = { + if (!fs.exists(baseDir)) { + fs.mkdirs(baseDir) + } else { + if (!fs.isDirectory(baseDir)) { + throw new IllegalStateException( + s"Cannot use $directory for storing state data as" + + s"$baseDir already exists and is not a directory") + } + } + } + + /** Load the required version of the map data from the backing files */ + private def loadMap(version: Long): MapType = { + if (version < 0) return new MapType + synchronized { + loadedMaps.get(version) + }.getOrElse { + val mapFromFile = readSnapshotFile(version).getOrElse { + val prevMap = loadMap(version - 1) + val deltaUpdates = readDeltaFile(version) + val newMap = new MapType() + newMap ++= prevMap + newMap.sizeHint(prevMap.size) + while (deltaUpdates.hasNext) { + deltaUpdates.next match { + case ValueUpdated(key, value) => newMap.put(key, value) + case KeyRemoved(key) => newMap.remove(key) + } + } + newMap + } + loadedMaps.put(version, mapFromFile) + mapFromFile + } + } + + private def readDeltaFile(version: Long): Iterator[StoreUpdate] = { + readDeltaFile(deltaFile(version)) + } + + private def readDeltaFile(fileToRead: Path): Iterator[StoreUpdate] = { + if (!fs.exists(fileToRead)) { + throw new IllegalStateException( + s"Cannot read delta file $fileToRead of $this: $fileToRead does not exist") + } + val deser = serializer.newInstance() + var deserStream: DeserializationStream = null + deserStream = deser.deserializeStream(fs.open(fileToRead)) + val iter = deserStream.asIterator.asInstanceOf[Iterator[StoreUpdate]] + CompletionIterator[StoreUpdate, Iterator[StoreUpdate]]( + iter, { + deserStream.close() + }) + } + + private def writeSnapshotFile(version: Long, map: MapType): Unit = { + val fileToWrite = snapshotFile(version) + val ser = serializer.newInstance() + var outputStream: SerializationStream = null + Utils.tryWithSafeFinally { + outputStream = ser.serializeStream(fs.create(fileToWrite, false)) + outputStream.writeAll(map.iterator) + } { + if (outputStream != null) outputStream.close() + } + } + + private def readSnapshotFile(version: Long): Option[MapType] = { + val fileToRead = snapshotFile(version) + if (!fs.exists(fileToRead)) return None + + val deser = serializer.newInstance() + val map = new MapType() + var deserStream: DeserializationStream = null + + try { + deserStream = deser.deserializeStream(fs.open(fileToRead)) + val iter = deserStream.asIterator.asInstanceOf[Iterator[(InternalRow, InternalRow)]] + while (iter.hasNext) { + map += iter.next() + } + Some(map) + } finally { + if (deserStream != null) deserStream.close() + } + } + + + /** Perform a snapshot of the store to allow delta files to be consolidated */ + private def doSnapshot(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + synchronized { + loadedMaps.get(lastVersion) + } match { + case Some(map) => + if (deltaFilesForLastVersion.size > maxDeltaChainForSnapshots) { + writeSnapshotFile(lastVersion, map) + } + case None => + // The last map is not loaded, probably some other instance is incharge + } + + } + } catch { + case NonFatal(e) => + logWarning(s"Error doing snapshots for $this") + } + } + + /** + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ + private[state] def cleanup(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - numBatchesToRetain + if (earliestVersionToRetain >= 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + synchronized { + loadedMaps.keys.filter(_ < earliestVersionToRetain).foreach(loadedMaps.remove) + } + files.filter(_.version < earliestFileToRetain.version).foreach { f => + fs.delete(f.path, true) + } + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this") + } + } + + private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + val deltaBatchIds = (snapshotFile.version + 1) to version + + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + } + verify( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version: ${deltaFiles.mkString(",")}" + ) + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + private def fetchFiles(): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fs.listStatus(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + files.foreach { status => + val path = status.getPath + val nameParts = path.getName.split("\\.") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + logWarning(s"Could not identify file $path") + } + } + } + versionToFiles.values.toSeq.sortBy(_.version) + } + + private def deltaFile(version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + private def snapshotFile(version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index b97de87e47f92..4a4ed427c53c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -21,476 +21,75 @@ import java.util.{Timer, TimerTask} import scala.collection.mutable import scala.reflect.ClassTag -import scala.util.Random import scala.util.control.NonFatal -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} - -import org.apache.spark.{Logging, SparkConf, SparkEnv} -import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.JoinedRow -import org.apache.spark.util.{CompletionIterator, RpcUtils, Utils} +import org.apache.spark.util.RpcUtils +import org.apache.spark.{Logging, SparkEnv} +/** Unique identifier for a [[StateStore]] */ case class StateStoreId(operatorId: Long, partitionId: Int) /** - * A versioned key-value store which can be used to store streaming state data. All data is - * backed by a file system. All updates to the store has to be done in sets transactionally, and - * each set of updates increments the store's version. These versions can be used to re-execute the - * updates (by retries in RDD operations) on the correct version of the store, and regenerate - * the store version. - * - * Usage: - * To update the data in the state store, the following order of operations are needed. - * - * - val store = StateStore.get(operatorId, partitionId) // to get the right store - * - store.prepareForUpdates(newVersion) // must be called for doing any update - * - store.update(...) - * - store.remove(...) - * - store.commitUpdates() // commits all the updates to made with version number - * - store.lastCommittedData() // key-value data after last commit as an iterator - * - store.lastCommittedUpdates() // updates made in the last as an iterator - * - * Concurrency model: - * All updates made after prepareForUpdates() are local to the thread. So concurrent attempts - * from multiple threads will create multiple sets of updates that need to be committed separately. - * - * Fault-tolerance model: - * - Every set of updates is written to a delta file before committing. - * - The state store is responsible for managing, collapsing and cleaning up of delta files. - * - Multiple attempts to commit the same version of updates must have the same updates. - * - Background management of files ensures that last versions of the store is always recoverable - * to ensure re-executed RDD operations re-apply updates on the correct past version of the - * store. + * Base trait for a versioned key-value store used for streaming aggregations */ - -private[sql] class StateStore( - val id: StateStoreId, - val directory: String, - numBatchesToRetain: Int = 2, - maxDeltaChainForSnapshots: Int = 10 - ) extends Logging { - type MapType = mutable.HashMap[InternalRow, InternalRow] +trait StateStore { import StateStore._ - private val loadedMaps = new mutable.HashMap[Long, MapType] - private val baseDir = new Path(directory, s"${id.operatorId}/${id.partitionId.toString}") - private val fs = baseDir.getFileSystem(new Configuration()) - private val serializer = new KryoSerializer(new SparkConf) + /** Unique identifier of the store */ + def id: StateStoreId - /** - * Thread local variable to keep track of updates so that if there multiple speculative tasks - * in the same executor trying to update the same store, the updates are thread-safe. - */ - private val uncommittedDelta = new ThreadLocal[UncommittedUpdates]() { - override def initialValue(): UncommittedUpdates = new UncommittedUpdates() - } - - initialize() - - /** - * Prepare for updates to create a new `version` of the map. The store ensure that updates - * are made on the `version - 1` of the store data. If `version` already exists, it will - * be overwritten when the updates are committed. - */ - private[state] def prepareForUpdates(newVersion: Long): Unit = synchronized { - require(newVersion >= 0, "Version cannot be less than 0") - val newMap = new MapType() - if (newVersion > 0) { - newMap ++= loadMap(newVersion - 1) - } - uncommittedDelta.get.prepare(newVersion, newMap) - } + /** Version of the data in this store before committing updates. */ + def version: Long /** * Update the value of a key using the value generated by the update function. * This can be called only after prepareForUpdates() has been called in the same thread. */ - def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { - uncommittedDelta.get.update(key, updateFunc) - } + def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit /** * Remove keys that match the following condition. * This can be called only after prepareForUpdates() has been called in the current thread. */ - def remove(condition: InternalRow => Boolean): Unit = { - uncommittedDelta.get.remove(condition) - } + def remove(condition: InternalRow => Boolean): Unit /** * Commit all the updates that have been made to the store. * This can be called only after prepareForUpdates() has been called in the current thread. */ - def commitUpdates(): Unit = { - uncommittedDelta.get.commit() - } + def commit(): Long /** Cancel all the updates that have been made to the store. */ - def cancelUpdates(): Unit = { - uncommittedDelta.get.reset() - } + def cancel(): Unit /** * Iterator of store data after a set of updates have been committed. * This can be called only after commitUpdates() has been called in the current thread. */ - def lastCommittedData(): Iterator[InternalRow] = { - uncommittedDelta.get.lastCommittedData() - } + def iterator(): Iterator[InternalRow] /** * Iterator of the updates that have been committed. * This can be called only after commitUpdates() has been called in the current thread. */ - def lastCommittedUpdates(): Iterator[StoreUpdate] = { - uncommittedDelta.get.lastCommittedUpdates() - } - - - /** - * Whether there are updates made in the current thread that have not been committed yet. - */ - private[state] def hasUncommittedUpdates: Boolean = { - uncommittedDelta.get.hasUncommittedUpdates - } - - override def toString(): String = { - s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" - } - - // Internal classes and methods - - private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - - /** Class representing all the data related to a set of updates to the store */ - private class UncommittedUpdates { - - trait State - case object INITIALIZED extends State - case object PREPARED extends State - case object COMMITTED extends State - - private var finalDeltaFile: Path = null - private var tempDeltaFile: Path = null - private var tempDeltaFileStream: SerializationStream = null - private var updatedMap: MapType = null - private var updateVersion: Long = -1 - private var state: State = INITIALIZED - - /** - * Prepare the set updates to be made to the state store by setting the version and the initial - * map on which to apply the updates. - */ - def prepare(version: Long, map: MapType): Unit = { - reset() - updateVersion = version - updatedMap = map - finalDeltaFile = deltaFile(updateVersion) - tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - tempDeltaFileStream = serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) - state = PREPARED - } - - /** Update the value of a key using the value generated by the update function */ - def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { - verify(state == PREPARED, "Cannot call update() before calling startUpdates()") - val value = updateFunc(updatedMap.get(key)) - updatedMap.put(key, value) - tempDeltaFileStream.writeObject(ValueUpdated(key, value)) - } - - /** Remove keys that match the following condition */ - def remove(condition: InternalRow => Boolean): Unit = { - verify(state == PREPARED, "Cannot call remove() before calling startUpdates()") - val keyIter = updatedMap.keysIterator - while (keyIter.hasNext) { - val key = keyIter.next - if (condition(key)) { - updatedMap.remove(key) - tempDeltaFileStream.writeObject(KeyRemoved(key)) - } - } - } - - /** Commit all the updates that have been made to the store. */ - def commit(): Unit = { - verify(state == PREPARED, "Cannot call commitUpdates() before calling prepareForUpdates()") - try { - tempDeltaFileStream.close() - - StateStore.this.synchronized { - fs.rename(tempDeltaFile, finalDeltaFile) - loadedMaps.put(updateVersion, updatedMap) - } - state = COMMITTED - } catch { - case NonFatal(e) => - state = INITIALIZED - throw new IllegalStateException( - s"Error committing version $updateVersion into $this", e) - } - } - - def reset(): Unit = { - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { - fs.delete(tempDeltaFile, true) - } - state = INITIALIZED - } - - def lastCommittedData(): Iterator[InternalRow] = { - verify(state == COMMITTED, "Cannot get iterator of data before calling commitUpdate()") - StateStore.this.iterator(updateVersion) - } - - def lastCommittedUpdates(): Iterator[StoreUpdate] = { - verify(state == COMMITTED, "Cannot get iterator of updates before calling commitUpdate()") - readDeltaFile(finalDeltaFile) - } - - def hasUncommittedUpdates: Boolean = { - state == PREPARED - } - } - - private def initialize(): Unit = { - if (!fs.exists(baseDir)) { - fs.mkdirs(baseDir) - } else { - if (!fs.isDirectory(baseDir)) { - throw new IllegalStateException( - s"Cannot use $directory for storing state data as" + - s"$baseDir already exists and is not a directory") - } - } - } + def updates(): Iterator[StoreUpdate] /** - * Get all the data of the latest version of the store. - * Note that this will look up the files to determined the latest known version. + * Whether all updates have been committed */ - private[state] def latestIterator(): Iterator[InternalRow] = synchronized { - val versionsInFiles = fetchFiles().map(_.version).toSet - val versionsLoaded = loadedMaps.keySet - val allKnownVersions = versionsInFiles ++ versionsLoaded - if (allKnownVersions.nonEmpty) { - loadMap(allKnownVersions.max) - .iterator - .map { case (key, value) => new JoinedRow(key, value) } - } else Iterator.empty - } - - private[state] def iterator(version: Long): Iterator[InternalRow] = synchronized { - loadMap(version) - .iterator - .map { case (key, value) => new JoinedRow(key, value) } - } - - - private def loadMap(version: Long): MapType = { - if (version < 0) return new MapType - synchronized { loadedMaps.get(version) }.getOrElse { - val mapFromFile = readSnapshotFile(version).getOrElse { - val prevMap = loadMap(version - 1) - val deltaUpdates = readDeltaFile(version) - val newMap = new MapType() - newMap ++= prevMap - newMap.sizeHint(prevMap.size) - while (deltaUpdates.hasNext) { - deltaUpdates.next match { - case ValueUpdated(key, value) => newMap.put(key, value) - case KeyRemoved(key) => newMap.remove(key) - } - } - newMap - } - loadedMaps.put(version, mapFromFile) - mapFromFile - } - } - - private def readDeltaFile(version: Long): Iterator[StoreUpdate] = { - readDeltaFile(deltaFile(version)) - } - - private def readDeltaFile (fileToRead: Path): Iterator[StoreUpdate] = { - if (!fs.exists(fileToRead)) { - throw new IllegalStateException( - s"Cannot read delta file $fileToRead of $this: $fileToRead does not exist") - } - val deser = serializer.newInstance() - var deserStream: DeserializationStream = null - deserStream = deser.deserializeStream(fs.open(fileToRead)) - val iter = deserStream.asIterator.asInstanceOf[Iterator[StoreUpdate]] - CompletionIterator[StoreUpdate, Iterator[StoreUpdate]](iter, { deserStream.close() }) - } - - private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val ser = serializer.newInstance() - var outputStream: SerializationStream = null - Utils.tryWithSafeFinally { - outputStream = ser.serializeStream(fs.create(fileToWrite, false)) - outputStream.writeAll(map.iterator) - } { - if (outputStream != null) outputStream.close() - } - } - - private def readSnapshotFile(version: Long): Option[MapType] = { - val fileToRead = snapshotFile(version) - if (!fs.exists(fileToRead)) return None - - val deser = serializer.newInstance() - val map = new MapType() - var deserStream: DeserializationStream = null - - try { - deserStream = deser.deserializeStream(fs.open(fileToRead)) - val iter = deserStream.asIterator.asInstanceOf[Iterator[(InternalRow, InternalRow)]] - while(iter.hasNext) { - map += iter.next() - } - Some(map) - } finally { - if (deserStream != null) deserStream.close() - } - } - - private[state] def manageFiles(): Unit = { - doSnapshot() - cleanup() - } - - /** Perform a snapshot of the store to allow delta files to be consolidated */ - private def doSnapshot(): Unit = { - try { - val files = fetchFiles() - if (files.nonEmpty) { - val lastVersion = files.last.version - val deltaFilesForLastVersion = - filesForVersion(files, lastVersion).filter(_.isSnapshot == false) - synchronized { - loadedMaps.get(lastVersion) - } match { - case Some(map) => - if (deltaFilesForLastVersion.size > maxDeltaChainForSnapshots) { - writeSnapshotFile(lastVersion, map) - } - case None => - // The last map is not loaded, probably some other instance is incharge - } - - } - } catch { - case NonFatal(e) => - logWarning(s"Error doing snapshots for $this") - } - } - - /** - * Clean up old snapshots and delta files that are not needed any more. It ensures that last - * few versions of the store can be recovered from the files, so re-executed RDD operations - * can re-apply updates on the past versions of the store. - */ - private[state] def cleanup(): Unit = { - try { - val files = fetchFiles() - if (files.nonEmpty) { - val earliestVersionToRetain = files.last.version - numBatchesToRetain - if (earliestVersionToRetain >= 0) { - val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head - synchronized { - loadedMaps.keys.filter(_ < earliestVersionToRetain).foreach(loadedMaps.remove) - } - files.filter(_.version < earliestFileToRetain.version).foreach { f => - fs.delete(f.path, true) - } - } - } - } catch { - case NonFatal(e) => - logWarning(s"Error cleaning up files for $this") - } - } + def hasCommitted: Boolean +} - private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { - require(version >= 0) - require(allFiles.exists(_.version == version)) - val latestSnapshotFileBeforeVersion = allFiles - .filter(_.isSnapshot == true) - .takeWhile(_.version <= version) - .lastOption +trait StateStoreProvider { - val deltaBatchFiles = latestSnapshotFileBeforeVersion match { - case Some(snapshotFile) => - val deltaBatchIds = (snapshotFile.version + 1) to version + /** Get the store with the existing version. */ + def getStore(version: Long): StateStore - val deltaFiles = allFiles.filter { file => - file.version > snapshotFile.version && file.version <= version - } - verify( - deltaFiles.size == version - snapshotFile.version, - s"Unexpected list of delta files for version $version: ${deltaFiles.mkString(",")}" - ) - deltaFiles - - case None => - allFiles.takeWhile(_.version <= version) - } - latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles - } - - private def fetchFiles(): Seq[StoreFile] = { - val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) - } catch { - case _: java.io.FileNotFoundException => - Seq.empty - } - val versionToFiles = new mutable.HashMap[Long, StoreFile] - files.foreach { status => - val path = status.getPath - val nameParts = path.getName.split("\\.") - if (nameParts.size == 2) { - val version = nameParts(0).toLong - nameParts(1).toLowerCase match { - case "delta" => - // ignore the file otherwise, snapshot file already exists for that batch id - if (!versionToFiles.contains(version)) { - versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) - } - case "snapshot" => - versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) - case _ => - logWarning(s"Could not identify file $path") - } - } - } - versionToFiles.values.toSeq.sortBy(_.version) - } - - private def deltaFile(version: Long): Path = { - new Path(baseDir, s"$version.delta") - } - - private def snapshotFile(version: Long): Path = { - new Path(baseDir, s"$version.snapshot") - } - - private def verify(condition: => Boolean, msg: String): Unit = { - if (!condition) { - throw new IllegalStateException(msg) - } - } + /** Optional method for providers to allow for background management */ + def manage(): Unit = { } } @@ -500,24 +99,25 @@ private[sql] class StateStore( */ private[state] object StateStore extends Logging { - sealed trait StoreUpdate + val MANAGEMENT_TASK_INTERVAL_SECS = 60 + sealed trait StoreUpdate case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate - case class KeyRemoved(key: InternalRow) extends StoreUpdate - private val loadedStores = new mutable.HashMap[StateStoreId, StateStore]() + + private val loadedStores = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val managementTimer = new Timer("StateStore Timer", true) @volatile private var managementTask: TimerTask = null /** Get or create a store associated with the id. */ - def get(storeId: StateStoreId, directory: String): StateStore = { - val store = loadedStores.synchronized { + def get(storeId: StateStoreId, directory: String, version: Long): StateStore = { + val storeProvider = loadedStores.synchronized { startIfNeeded() - loadedStores.getOrElseUpdate(storeId, new StateStore(storeId, directory)) + loadedStores.getOrElseUpdate(storeId, new HDFSBackedStateStoreProvider(storeId, directory)) } reportActiveInstance(storeId) - store + storeProvider.getStore(version) } def clearAll(): Unit = loadedStores.synchronized { @@ -563,23 +163,21 @@ private[state] object StateStore extends Logging { if (managementTask == null) { managementTask = new TimerTask { override def run(): Unit = { - manageFiles() + loadedStores.synchronized { loadedStores.values.toSeq }.foreach { store => + try { + store.manage() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $store") + } + } } } - managementTimer.schedule(managementTask, 10000, 10000) - } - } - - private def manageFiles(): Unit = { - loadedStores.synchronized { - loadedStores.values.toSeq - }.foreach { store => - try { - store.manageFiles() - } catch { - case NonFatal(e) => - logWarning(s"Error performing snapshot and cleaning up store ${store.id}") - } + managementTimer.schedule( + managementTask, + MANAGEMENT_TASK_INTERVAL_SECS * 1000, + MANAGEMENT_TASK_INTERVAL_SECS * 1000) } } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 76e79d962de6e..8d3c24e594ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -49,15 +49,15 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( Utils.tryWithSafeFinally { store = StateStore.get( StateStoreId(operatorId, partition.index), - storeDirectory + storeDirectory, + newStoreVersion - 1 ) val inputIter = dataRDD.compute(partition, ctxt) - store.prepareForUpdates(newStoreVersion) val outputIter = storeUpdateFunction(store, inputIter) - assert(!store.hasUncommittedUpdates) + assert(store.hasCommitted) outputIter } { - if (store != null) store.cancelUpdates() + if (store != null) store.cancel() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 69f69cf3eeedc..780784db98682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -48,8 +48,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn wrapValue(oldValue + 1) }) } - store.commitUpdates() - store.lastCommittedData().map(unwrapKeyValue) + store.commit() + store.iterator().map(unwrapKeyValue) } val opId = 0 val rdd1 = makeRDD(sc, Seq("a", "b", "a")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 9be2553b0ae69..ebb7b4d6eccf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -42,208 +42,182 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore.clearAll() } - test("startUpdates, update, remove, commitUpdates") { - val store = newStore() + test("update, remove, commit") { + val provider = newStoreProvider() // Verify state before starting a new set of updates - assert(store.latestIterator().isEmpty) - assert(!store.hasUncommittedUpdates) - intercept[IllegalStateException] { - store.lastCommittedData() - } - intercept[IllegalStateException] { - store.update(null, null) - } - intercept[IllegalStateException] { - store.remove(_ => true) - } - intercept[IllegalStateException] { - store.commitUpdates() - } + assert(provider.latestIterator().isEmpty) - // Verify states after preparing for updates - intercept[IllegalArgumentException] { - store.prepareForUpdates(-1) - } - store.prepareForUpdates(0) + val store = provider.getStore(0) + assert(!store.hasCommitted) intercept[IllegalStateException] { - store.lastCommittedData() + store.iterator() } intercept[IllegalStateException] { - store.prepareForUpdates(1) + store.updates() } - assert(store.hasUncommittedUpdates) // Verify state after updating update(store, "a", 1) intercept[IllegalStateException] { - store.lastCommittedData() + store.iterator() } - assert(store.latestIterator().isEmpty) + intercept[IllegalStateException] { + store.updates() + } + assert(provider.latestIterator().isEmpty) - // Make updates and commit + // Make updates, commit and then verify state update(store, "b", 2) update(store, "aa", 3) remove(store, _.startsWith("a")) - store.commitUpdates() + assert(store.commit() === 1) - // Verify state after committing - assert(!store.hasUncommittedUpdates) - assert(getData(store) === Set("b" -> 2)) - assert(fileExists(store, 0, isSnapshot = false)) + assert(store.hasCommitted) + assert(store.iterator() === Set("b" -> 2)) + assert(store.updates() === Set("b" -> 2)) + assert(provider.latestIterator() === Set("b" -> 2)) + assert(fileExists(provider, version = 1, isSnapshot = false)) + assert(getDataFromFiles(provider) === Set("b" -> 2)) // Trying to get newer versions should fail intercept[Exception] { - getData(store, 1) + provider.getStore(2) } - intercept[Exception] { - getDataFromFiles(store, 1) + getDataFromFiles(provider, 2) } - // Reload store from the directory - val reloadedStore = new StateStore(store.id, store.directory) - assert(getData(reloadedStore) === Set("b" -> 2)) - - // New updates to the reload store with new version, and does not change old version - reloadedStore.prepareForUpdates(1) + // New updates to the reloaded store with new version, and does not change old version + val reloadedStore = new HDFSBackedStateStoreProvider(store.id, provider.directory).getStore(1) update(reloadedStore, "c", 4) - reloadedStore.commitUpdates() - assert(getData(reloadedStore) === Set("b" -> 2, "c" -> 4)) - assert(getData(reloadedStore, version = 0) === Set("b" -> 2)) - assert(getData(reloadedStore, version = 1) === Set("b" -> 2, "c" -> 4)) - assert(fileExists(reloadedStore, 1, isSnapshot = false)) + assert(reloadedStore.commit() === 2) + assert(reloadedStore.iterator() === Set("b" -> 2, "c" -> 4)) + assert(reloadedStore.updates() === Set("c" -> 4)) + assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) + assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) } - test("cancelUpdates") { - val store = newStore() - store.prepareForUpdates(0) + test("cancel") { + val provider = newStoreProvider() + val store = provider.getStore(0) update(store, "a", 1) - store.commitUpdates() - assert(getData(store) === Set("a" -> 1)) - - // cancelUpdates should not change the data - store.prepareForUpdates(1) - update(store, "b", 1) - store.cancelUpdates() - assert(getData(store) === Set("a" -> 1)) - - // Calling prepareForUpdates again should cancel previous updates - store.prepareForUpdates(1) - update(store, "b", 1) - - store.prepareForUpdates(1) - update(store, "c", 1) - store.commitUpdates() - assert(getData(store) === Set("a" -> 1, "c" -> 1)) + store.commit() + assert(store.iterator() === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + update(store1, "b", 1) + store1.cancel() + assert(getDataFromFiles(provider) === Set("a" -> 1)) } - test("startUpdates with unexpected versions") { - val store = newStore() + test("getStore with unexpected versions") { + val provider = newStoreProvider() intercept[IllegalArgumentException] { - store.prepareForUpdates(-1) + provider.getStore(-1) } // Prepare some data in the stoer - store.prepareForUpdates(0) + val store = provider.getStore(0) update(store, "a", 1) - store.commitUpdates() - assert(getData(store) === Set("a" -> 1)) + assert(store.commit() === 1) + assert(store.iterator() === Set("a" -> 1)) intercept[IllegalStateException] { - store.prepareForUpdates(2) + provider.getStore(2) } // Update store version with some data - store.prepareForUpdates(1) + provider.getStore(1) update(store, "b", 1) - store.commitUpdates() - assert(getData(store) === Set("a" -> 1, "b" -> 1)) - - assert(getDataFromFiles(store) === Set("a" -> 1, "b" -> 1)) + assert(store.commit() === 2) + assert(store.iterator() === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data - store.prepareForUpdates(1) + provider.getStore(1) update(store, "c", 1) - store.commitUpdates() - assert(getData(store) === Set("a" -> 1, "c" -> 1)) - assert(getDataFromFiles(store) === Set("a" -> 1, "c" -> 1)) + assert(store.commit() === 2) + assert(store.iterator() === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) } test("snapshotting") { - val store = newStore(maxDeltaChainForSnapshots = 5) + val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) var currentVersion = -1 def updateVersionTo(targetVersion: Int): Unit = { for (i <- currentVersion + 1 to targetVersion) { - store.prepareForUpdates(i) + val store = provider.getStore(i - 1) update(store, "a", i) - store.commitUpdates() + store.commit() + } + currentVersion = targetVersion } updateVersionTo(2) - require(getData(store) === Set("a" -> 2)) - store.manageFiles() - assert(getDataFromFiles(store) === Set("a" -> 2)) + require(getDataFromFiles(provider) === Set("a" -> 2)) + provider.manage() // should not generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 2)) for (i <- 0 to 2) { - assert(fileExists(store, i, isSnapshot = false)) // all delta files present - assert(!fileExists(store, i, isSnapshot = true)) // no snapshot files present + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present } // After version 6, snapshotting should generate one snapshot file updateVersionTo(6) - require(getData(store) === Set("a" -> 6), "Store not updated correctly") - store.manageFiles() // do snapshot - assert(getData(store) === Set("a" -> 6), "manageFiles() messed up the data") - assert(getDataFromFiles(store) === Set("a" -> 6)) + require(getDataFromFiles(provider) === Set("a" -> 6), "Store not updated correctly") + provider.manage() // should generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 6), "snapshotting messed up the data") + assert(getDataFromFiles(provider) === Set("a" -> 6)) - val snapshotVersion = (0 to 6).find(version => fileExists(store, version, isSnapshot = true)) + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) assert(snapshotVersion.nonEmpty, "Snapshot file not generated") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) - require(getData(store) === Set("a" -> 20), "Store not updated correctly") - store.manageFiles() // do snapshot - assert(getData(store) === Set("a" -> 20), "manageFiles() messed up the data") - assert(getDataFromFiles(store) === Set("a" -> 20)) + require(getDataFromFiles(provider) === Set("a" -> 20), "Store not updated correctly") + provider.manage() // do snapshot + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + assert(getDataFromFiles(provider) === Set("a" -> 20)) val latestSnapshotVersion = (0 to 20).filter(version => - fileExists(store, version, isSnapshot = true)).lastOption + fileExists(provider, version, isSnapshot = true)).lastOption assert(latestSnapshotVersion.nonEmpty, "No snapshot file found") assert(latestSnapshotVersion.get > snapshotVersion.get, "Newer snapshot not generated") } test("cleaning") { - val store = newStore(maxDeltaChainForSnapshots = 5) + val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) for (i <- 0 to 20) { - store.prepareForUpdates(i) + val store = provider.getStore(i) update(store, "a", i) - store.commitUpdates() + store.commit() } - require(getData(store) === Set("a" -> 20), "Store not updated correctly") - store.manageFiles() // do cleanup - assert(fileExists(store, 0, isSnapshot = false)) + require(provider.latestIterator() === Set("a" -> 20), "Store not updated correctly") + provider.manage() // do cleanup + assert(fileExists(provider, 0, isSnapshot = false)) - assert(getDataFromFiles(store, 20) === Set("a" -> 20)) - assert(getDataFromFiles(store, 19) === Set("a" -> 19)) + assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) + assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) } - def getData(store: StateStore, version: Int = -1): Set[(String, Int)] = { + def getDataFromFiles( + provider: HDFSBackedStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = new HDFSBackedStateStoreProvider(provider.id, provider.directory) if (version < 0) { - store.latestIterator.map(unwrapKeyValue).toSet + reloadedProvider.latestIterator.map(unwrapKeyValue).toSet } else { - store.iterator(version).map(unwrapKeyValue).toSet + reloadedProvider.iterator(version).map(unwrapKeyValue).toSet } - - } - - def getDataFromFiles(store: StateStore, version: Int = -1): Set[(String, Int)] = { - getData(new StateStore(store.id, store.directory), version) } def assertMap( @@ -254,9 +228,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(convertedMap === expectedMap) } - def fileExists(store: StateStore, version: Long, isSnapshot: Boolean): Boolean = { + def fileExists( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { val method = PrivateMethod[Path]('baseDir) - val basePath = store invokePrivate method() + val basePath = provider invokePrivate method() val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" val filePath = new File(basePath.toString, fileName) filePath.exists @@ -273,12 +250,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore invokePrivate method(storeId) } - def newStore( + def newStoreProvider( opId: Long = Random.nextLong, partition: Int = 0, maxDeltaChainForSnapshots: Int = 10 - ): StateStore = { - new StateStore( + ): HDFSBackedStateStoreProvider = { + new HDFSBackedStateStoreProvider( StateStoreId(opId, partition), Utils.createDirectory(tempDir, Random.nextString(5)).toString, maxDeltaChainForSnapshots = maxDeltaChainForSnapshots) From 22d7e6639125a507467766c45393e36b3ee92f3f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Mar 2016 18:49:05 -0700 Subject: [PATCH 10/46] Fixed a lot of things --- .../state/HDFSBackedStateStoreProvider.scala | 19 ++--- .../streaming/state/StateStoreRDD.scala | 11 ++- .../execution/streaming/state/package.scala | 6 +- .../streaming/state/StateStoreRDDSuite.scala | 50 +++++++++++-- .../streaming/state/StateStoreSuite.scala | 71 ++++++++++--------- 5 files changed, 100 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index a70cf5d4271c5..acfe4affa0935 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -143,7 +143,7 @@ class HDFSBackedStateStoreProvider( */ override def iterator(): Iterator[InternalRow] = { verify(state == COMMITTED, "Cannot get iterator of store data before comitting") - HDFSBackedStateStoreProvider.this.iterator(version) + HDFSBackedStateStoreProvider.this.iterator(newVersion) } /** @@ -245,10 +245,8 @@ class HDFSBackedStateStoreProvider( /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { - if (version < 0) return new MapType - synchronized { - loadedMaps.get(version) - }.getOrElse { + if (version <= 0) return new MapType + synchronized { loadedMaps.get(version) }.getOrElse { val mapFromFile = readSnapshotFile(version).getOrElse { val prevMap = loadMap(version - 1) val deltaUpdates = readDeltaFile(version) @@ -328,9 +326,7 @@ class HDFSBackedStateStoreProvider( val lastVersion = files.last.version val deltaFilesForLastVersion = filesForVersion(files, lastVersion).filter(_.isSnapshot == false) - synchronized { - loadedMaps.get(lastVersion) - } match { + synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => if (deltaFilesForLastVersion.size > maxDeltaChainForSnapshots) { writeSnapshotFile(lastVersion, map) @@ -342,7 +338,7 @@ class HDFSBackedStateStoreProvider( } } catch { case NonFatal(e) => - logWarning(s"Error doing snapshots for $this") + logWarning(s"Error doing snapshots for $this", e) } } @@ -356,7 +352,7 @@ class HDFSBackedStateStoreProvider( val files = fetchFiles() if (files.nonEmpty) { val earliestVersionToRetain = files.last.version - numBatchesToRetain - if (earliestVersionToRetain >= 0) { + if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head synchronized { loadedMaps.keys.filter(_ < earliestVersionToRetain).foreach(loadedMaps.remove) @@ -368,7 +364,7 @@ class HDFSBackedStateStoreProvider( } } catch { case NonFatal(e) => - logWarning(s"Error cleaning up files for $this") + logWarning(s"Error cleaning up files for $this", e) } } @@ -380,7 +376,6 @@ class HDFSBackedStateStoreProvider( .filter(_.isSnapshot == true) .takeWhile(_.version <= version) .lastOption - val deltaBatchFiles = latestSnapshotFileBeforeVersion match { case Some(snapshotFile) => val deltaBatchIds = (snapshotFile.version + 1) to version diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 8d3c24e594ac4..80e497455c993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -30,10 +30,12 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( dataRDD: RDD[INPUT], storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], operatorId: Long, - newStoreVersion: Long, + storeVersion: Long, storeDirectory: String, storeCoordinator: StateStoreCoordinator) extends RDD[OUTPUT](dataRDD) { + val nextVersion = storeVersion + 1 + override protected def getPartitions: Array[Partition] = dataRDD.partitions override def getPreferredLocations(partition: Partition): Seq[String] = { Seq.empty @@ -47,11 +49,8 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( var store: StateStore = null Utils.tryWithSafeFinally { - store = StateStore.get( - StateStoreId(operatorId, partition.index), - storeDirectory, - newStoreVersion - 1 - ) + val storeId = StateStoreId(operatorId, partition.index) + store = StateStore.get(storeId, storeDirectory, storeVersion) val inputIter = dataRDD.compute(partition, ctxt) val outputIter = storeUpdateFunction(store, inputIter) assert(store.hasCommitted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index bcfc7dc309446..e5ed5a41b949d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -27,12 +27,12 @@ package object state { def withStateStores[OUTPUT: ClassTag]( storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], operatorId: Long, - newStoreVersion: Long, + storeVersion: Long, storeDirectory: String, storeCoordinator: StateStoreCoordinator - ): RDD[OUTPUT] = { + ): StateStoreRDD[INPUT, OUTPUT] = { new StateStoreRDD( - dataRDD, storeUpdateFunction, operatorId, newStoreVersion, storeDirectory, storeCoordinator) + dataRDD, storeUpdateFunction, operatorId, storeVersion, storeDirectory, storeCoordinator) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 780784db98682..969323ffc0c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.File +import java.nio.file.Files + import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -29,7 +32,8 @@ import org.apache.spark.util.Utils class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) - private var tempDir = Utils.createTempDir().toString + private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + println(tempDir) import StateStoreSuite._ @@ -37,7 +41,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn StateStore.clearAll() } - test("versioning and immuability") { + override def afterAll(): Unit = { + super.afterAll() + Utils.deleteRecursively(new File(tempDir)) + } + + test("versioning and immutability") { withSpark(new SparkContext(conf)) { sc => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val increment = (store: StateStore, iter: Iterator[String]) => { @@ -53,12 +62,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val opId = 0 val rdd1 = makeRDD(sc, Seq("a", "b", "a")) - .withStateStores(increment, opId, newStoreVersion = 0, path, null) + .withStateStores(increment, opId, storeVersion = 0, path, null) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(sc, Seq("a", "c")) - .withStateStores(increment, opId, newStoreVersion = 1, path, null) + .withStateStores(increment, opId, storeVersion = 1, path, null) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -66,7 +75,40 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } } + test("recovering from files") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD(sc: SparkContext, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { + makeRDD(sc, Seq("a")).withStateStores(increment, opId, storeVersion, path, null) + } + + // Generate RDDs and state store data + withSpark(new SparkContext(conf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(conf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } + } + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + + private val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + wrapKey(s), oldRow => { + val oldValue = oldRow.map(unwrapValue).getOrElse(0) + wrapValue(oldValue + 1) + }) + } + store.commit() + store.iterator().map(unwrapKeyValue) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index ebb7b4d6eccf4..0793df28747e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -74,9 +74,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(store.commit() === 1) assert(store.hasCommitted) - assert(store.iterator() === Set("b" -> 2)) - assert(store.updates() === Set("b" -> 2)) - assert(provider.latestIterator() === Set("b" -> 2)) + assert(unwrapToSet(store.iterator()) === Set("b" -> 2)) + assert(unwrapToSet(provider.latestIterator()) === Set("b" -> 2)) assert(fileExists(provider, version = 1, isSnapshot = false)) assert(getDataFromFiles(provider) === Set("b" -> 2)) @@ -92,8 +91,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedStore = new HDFSBackedStateStoreProvider(store.id, provider.directory).getStore(1) update(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) - assert(reloadedStore.iterator() === Set("b" -> 2, "c" -> 4)) - assert(reloadedStore.updates() === Set("c" -> 4)) + assert(unwrapToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) @@ -104,7 +102,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store = provider.getStore(0) update(store, "a", 1) store.commit() - assert(store.iterator() === Set("a" -> 1)) + assert(unwrapToSet(store.iterator()) === Set("a" -> 1)) // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) @@ -124,87 +122,92 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store = provider.getStore(0) update(store, "a", 1) assert(store.commit() === 1) - assert(store.iterator() === Set("a" -> 1)) + assert(unwrapToSet(store.iterator()) === Set("a" -> 1)) intercept[IllegalStateException] { provider.getStore(2) } // Update store version with some data - provider.getStore(1) - update(store, "b", 1) - assert(store.commit() === 2) - assert(store.iterator() === Set("a" -> 1, "b" -> 1)) + val store1 = provider.getStore(1) + update(store1, "b", 1) + assert(store1.commit() === 2) + assert(unwrapToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data - provider.getStore(1) - update(store, "c", 1) - assert(store.commit() === 2) - assert(store.iterator() === Set("a" -> 1, "c" -> 1)) + val store2 = provider.getStore(1) + update(store2, "c", 1) + assert(store2.commit() === 2) + assert(unwrapToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) } test("snapshotting") { val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) - var currentVersion = -1 + var currentVersion = 0 def updateVersionTo(targetVersion: Int): Unit = { for (i <- currentVersion + 1 to targetVersion) { - val store = provider.getStore(i - 1) + val store = provider.getStore(currentVersion) update(store, "a", i) store.commit() - + currentVersion += 1 } - currentVersion = targetVersion + require(currentVersion === targetVersion) } + updateVersionTo(2) require(getDataFromFiles(provider) === Set("a" -> 2)) provider.manage() // should not generate snapshot files assert(getDataFromFiles(provider) === Set("a" -> 2)) - for (i <- 0 to 2) { + + for (i <- 1 to currentVersion) { assert(fileExists(provider, i, isSnapshot = false)) // all delta files present assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present } // After version 6, snapshotting should generate one snapshot file updateVersionTo(6) - require(getDataFromFiles(provider) === Set("a" -> 6), "Store not updated correctly") + require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") provider.manage() // should generate snapshot files assert(getDataFromFiles(provider) === Set("a" -> 6), "snapshotting messed up the data") assert(getDataFromFiles(provider) === Set("a" -> 6)) val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) - assert(snapshotVersion.nonEmpty, "Snapshot file not generated") - + assert(snapshotVersion.nonEmpty, "snapshot file not generated") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) - require(getDataFromFiles(provider) === Set("a" -> 20), "Store not updated correctly") + require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") provider.manage() // do snapshot assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") assert(getDataFromFiles(provider) === Set("a" -> 20)) val latestSnapshotVersion = (0 to 20).filter(version => fileExists(provider, version, isSnapshot = true)).lastOption - assert(latestSnapshotVersion.nonEmpty, "No snapshot file found") - assert(latestSnapshotVersion.get > snapshotVersion.get, "Newer snapshot not generated") + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") } test("cleaning") { val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) - for (i <- 0 to 20) { - val store = provider.getStore(i) + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) update(store, "a", i) store.commit() + provider.manage() // do cleanup } - require(provider.latestIterator() === Set("a" -> 20), "Store not updated correctly") - provider.manage() // do cleanup - assert(fileExists(provider, 0, isSnapshot = false)) + require( + unwrapToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + // last couple of versions should be retrievable assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) } @@ -214,7 +217,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth version: Int = -1): Set[(String, Int)] = { val reloadedProvider = new HDFSBackedStateStoreProvider(provider.id, provider.directory) if (version < 0) { - reloadedProvider.latestIterator.map(unwrapKeyValue).toSet + reloadedProvider.latestIterator().map(unwrapKeyValue).toSet } else { reloadedProvider.iterator(version).map(unwrapKeyValue).toSet } @@ -295,4 +298,8 @@ private[state] object StateStoreSuite { def unwrapKeyValue(row: InternalRow): (String, Int) = { (row.getString(0), row.getInt(1)) } + + def unwrapToSet(iterator: Iterator[InternalRow]): Set[(String, Int)] = { + iterator.map(unwrapKeyValue).toSet + } } From 34ae7ffcb3544743e90a6441c78fd9002e839794 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Mar 2016 18:54:39 -0700 Subject: [PATCH 11/46] Fixed style --- .../state/HDFSBackedStateStoreProvider.scala | 12 ++++-------- .../sql/execution/streaming/state/StateStore.scala | 2 +- .../streaming/state/StateStoreRDDSuite.scala | 1 - 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index acfe4affa0935..1ea813390f988 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,22 +17,18 @@ package org.apache.spark.sql.execution.streaming.state -import java.util.{TimerTask, Timer} - import scala.collection.mutable -import scala.collection.mutable.HashMap import scala.util.Random import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.sql.catalyst.expressions.JoinedRow -import org.apache.spark.sql.execution.streaming.state.StateStore._ -import org.apache.spark.util.{Utils, CompletionIterator} -import org.apache.spark.{SparkConf, Logging} -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, KryoSerializer} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.util.{CompletionIterator, Utils} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 4a4ed427c53c5..473227c91b740 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.util.control.NonFatal +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.util.RpcUtils -import org.apache.spark.{Logging, SparkEnv} /** Unique identifier for a [[StateStore]] */ case class StateStoreId(operatorId: Long, partitionId: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 969323ffc0c72..9c42116811c7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -33,7 +33,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString - println(tempDir) import StateStoreSuite._ From 13c29a2e8a3a6b2bed0bc14c62ed1be4033b3e55 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 00:38:43 -0700 Subject: [PATCH 12/46] Fixed updates iterator --- .../state/HDFSBackedStateStoreProvider.scala | 30 ++++++-- .../streaming/state/StateStore.scala | 10 +-- .../streaming/state/StateStoreSuite.scala | 69 ++++++++++++++++++- 3 files changed, 96 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 1ea813390f988..e0473ec56ab94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -79,6 +79,7 @@ class HDFSBackedStateStoreProvider( private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private val tempDeltaFileStream = serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) + private val allUpdates = new mutable.HashMap[InternalRow, StoreUpdate] @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null @@ -88,8 +89,19 @@ class HDFSBackedStateStoreProvider( /** Update the value of a key using the value generated by the update function */ override def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { verify(state == UPDATING, "Cannot update after already committed or cancelled") - val value = updateFunc(mapToUpdate.get(key)) + val oldValueOption = mapToUpdate.get(key) + val value = updateFunc(oldValueOption) mapToUpdate.put(key, value) + allUpdates.get(key) match { + case Some(ValueAdded(_, _)) => + allUpdates.put(key, ValueAdded(key, value)) + case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + allUpdates.put(key, ValueUpdated(key, value)) + case None => + val update = + if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + allUpdates.put(key, update) + } tempDeltaFileStream.writeObject(ValueUpdated(key, value)) } @@ -101,6 +113,14 @@ class HDFSBackedStateStoreProvider( val key = keyIter.next if (condition(key)) { mapToUpdate.remove(key) + allUpdates.get(key) match { + case Some(ValueUpdated(_, _)) | None => + allUpdates.put(key, KeyRemoved(key)) + case Some(ValueAdded(_, _)) => + allUpdates.remove(key) + case Some(KeyRemoved(_)) => + // Remove already in update map, no need to change + } tempDeltaFileStream.writeObject(KeyRemoved(key)) } } @@ -148,7 +168,7 @@ class HDFSBackedStateStoreProvider( */ override def updates(): Iterator[StoreUpdate] = { verify(state == COMMITTED, "Cannot get iterator of updates before committing") - readDeltaFile(finalDeltaFile) + allUpdates.valuesIterator } /** @@ -251,6 +271,7 @@ class HDFSBackedStateStoreProvider( newMap.sizeHint(prevMap.size) while (deltaUpdates.hasNext) { deltaUpdates.next match { + case ValueAdded(key, value) => newMap.put(key, value) case ValueUpdated(key, value) => newMap.put(key, value) case KeyRemoved(key) => newMap.remove(key) } @@ -263,10 +284,7 @@ class HDFSBackedStateStoreProvider( } private def readDeltaFile(version: Long): Iterator[StoreUpdate] = { - readDeltaFile(deltaFile(version)) - } - - private def readDeltaFile(fileToRead: Path): Iterator[StoreUpdate] = { + val fileToRead = deltaFile(version) if (!fs.exists(fileToRead)) { throw new IllegalStateException( s"Cannot read delta file $fileToRead of $this: $fileToRead does not exist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 473227c91b740..2fbd40f3c79db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -92,6 +92,11 @@ trait StateStoreProvider { def manage(): Unit = { } } +sealed trait StoreUpdate +case class ValueAdded(key: InternalRow, value: InternalRow) extends StoreUpdate +case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate +case class KeyRemoved(key: InternalRow) extends StoreUpdate + /** * Companion object to [[StateStore]] that provides helper methods to create and retrive stores @@ -101,11 +106,6 @@ private[state] object StateStore extends Logging { val MANAGEMENT_TASK_INTERVAL_SECS = 60 - sealed trait StoreUpdate - case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate - case class KeyRemoved(key: InternalRow) extends StoreUpdate - - private val loadedStores = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val managementTimer = new Timer("StateStore Timer", true) @volatile private var managementTask: TimerTask = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 0793df28747e3..9e35dd2022a4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -42,7 +42,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore.clearAll() } - test("update, remove, commit") { + test("update, remove, commit, and all data iterator") { val provider = newStoreProvider() // Verify state before starting a new set of updates @@ -97,6 +97,57 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) } + test("updates iterator with all combos of updates and removes") { + val provider = newStoreProvider() + var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { + val store = provider.getStore(currentVersion) + body(store) + currentVersion += 1 + } + + // New data should be seen in updates as value added, even if they had multiple updates + withStore { store => + update(store, "a", 1) + update(store, "aa", 1) + update(store, "aa", 2) + store.commit() + assert(unwrapUpdates(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(unwrapToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + } + + // Multiple updates to same key should be collapsed in the updates as a single value update + // Keys that have not been updated should not appear in the updates + withStore { store => + update(store, "a", 4) + update(store, "a", 6) + store.commit() + assert(unwrapUpdates(store.updates()) === Set(Updated("a", 6))) + assert(unwrapToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Keys added, updated and finally removed before commit should not appear in updates + withStore { store => + update(store, "b", 4) // Added, finally removed + update(store, "bb", 5) // Added, updated, finally removed + update(store, "bb", 6) + remove(store, _.startsWith("b")) + store.commit() + assert(unwrapUpdates(store.updates()) === Set.empty) + assert(unwrapToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Removed data should be seen in updates as a key removed + // Removed, but re-added data should be seen in updates as a value update + withStore { store => + remove(store, _.startsWith("a")) + update(store, "a", 10) + store.commit() + assert(unwrapUpdates(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(unwrapToSet(store.iterator()) === Set("a" -> 10)) + } + } + test("cancel") { val provider = newStoreProvider() val store = provider.getStore(0) @@ -157,7 +208,6 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth require(currentVersion === targetVersion) } - updateVersionTo(2) require(getDataFromFiles(provider) === Set("a" -> 2)) provider.manage() // should not generate snapshot files @@ -275,6 +325,13 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth private[state] object StateStoreSuite { + /** Trait and class mirroring [[StoreUpdate]] for testing */ + trait TestUpdate + case class Added(key: String, value: Int) extends TestUpdate + case class Updated(key: String, value: Int) extends TestUpdate + case class Removed(key: String) extends TestUpdate + + def wrapValue(i: Int): InternalRow = { new GenericInternalRow(Array[Any](i)) } @@ -302,4 +359,12 @@ private[state] object StateStoreSuite { def unwrapToSet(iterator: Iterator[InternalRow]): Set[(String, Int)] = { iterator.map(unwrapKeyValue).toSet } + + def unwrapUpdates(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + iterator.map { _ match { + case ValueAdded(key, value) => Added(unwrapKey(key), unwrapValue(value)) + case ValueUpdated(key, value) => Updated(unwrapKey(key), unwrapValue(value)) + case KeyRemoved(key) => Removed(unwrapKey(key)) + }}.toSet + } } From d5e2b10ad11b22ea020dd3f3bb13a1a64aa33b50 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 14:36:31 -0700 Subject: [PATCH 13/46] Added unit tests for Coordinator --- .../streaming/state/StateStore.scala | 88 +++++++------ .../state/StateStoreCoordinator.scala | 45 +++++-- .../state/StateStoreCoordinatorSuite.scala | 116 ++++++++++++++++++ .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../streaming/state/StateStoreSuite.scala | 5 +- 5 files changed, 199 insertions(+), 57 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 2fbd40f3c79db..cb23daa9446fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -106,78 +106,86 @@ private[state] object StateStore extends Logging { val MANAGEMENT_TASK_INTERVAL_SECS = 60 - private val loadedStores = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val managementTimer = new Timer("StateStore Timer", true) @volatile private var managementTask: TimerTask = null /** Get or create a store associated with the id. */ def get(storeId: StateStoreId, directory: String, version: Long): StateStore = { - val storeProvider = loadedStores.synchronized { + val storeProvider = loadedProviders.synchronized { startIfNeeded() - loadedStores.getOrElseUpdate(storeId, new HDFSBackedStateStoreProvider(storeId, directory)) + val provider = loadedProviders.getOrElseUpdate( + storeId, new HDFSBackedStateStoreProvider(storeId, directory)) + reportActiveInstance(storeId) + provider } - reportActiveInstance(storeId) storeProvider.getStore(version) } - def clearAll(): Unit = loadedStores.synchronized { - loadedStores.clear() + /** Unload and stop all state store provider */ + def stop(): Unit = loadedProviders.synchronized { + loadedProviders.clear() if (managementTask != null) { managementTask.cancel() managementTask = null + logInfo("StateStore stopped") } } - private def remove(storeId: StateStoreId): Unit = { - loadedStores.remove(storeId) + private def startIfNeeded(): Unit = loadedProviders.synchronized { + if (managementTask == null) { + managementTask = new TimerTask { + override def run(): Unit = { manageAll() } + } + val periodMs = MANAGEMENT_TASK_INTERVAL_SECS * 1000 + managementTimer.schedule(managementTask, periodMs, periodMs) + logInfo("StateStore started") + } + } + + /** + * Execute background management task in all the loaded store providers if they are still + * the active instances according to the coordinator. + */ + private def manageAll(): Unit = { + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfInstanceActive(id)) { + provider.manage() + } else { + remove(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider") + } + } + } + + private def remove(storeId: StateStoreId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeId) } private def reportActiveInstance(storeId: StateStoreId): Unit = { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - askCoordinator[Boolean](ReportActiveInstance(storeId, host, executorId)) + askCoordinator(ReportActiveInstance(storeId, host, executorId)) } private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - askCoordinator[Boolean](VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) + askCoordinator(VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) } - private def askCoordinator[T: ClassTag](message: StateStoreCoordinatorMessage): Option[T] = { + private def askCoordinator(message: StateStoreCoordinatorMessage): Option[Boolean] = { try { - val env = SparkEnv.get - if (env != null) { - val coordinatorRef = RpcUtils.makeDriverRef("StateStoreCoordinator", env.conf, env.rpcEnv) - Some(coordinatorRef.askWithRetry[T](message)) - } else { - None - } + StateStoreCoordinator.ask(message) } catch { case NonFatal(e) => - clearAll() + logWarning("Error communicating") None } } - - private def startIfNeeded(): Unit = loadedStores.synchronized { - if (managementTask == null) { - managementTask = new TimerTask { - override def run(): Unit = { - loadedStores.synchronized { loadedStores.values.toSeq }.foreach { store => - try { - store.manage() - } catch { - case NonFatal(e) => - logWarning(s"Error performing snapshot and cleaning up $store") - } - } - } - } - managementTimer.schedule( - managementTask, - MANAGEMENT_TASK_INTERVAL_SECS * 1000, - MANAGEMENT_TASK_INTERVAL_SECS * 1000) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 62d94b9614bde..d21b07bbe7629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.streaming.state import scala.collection.mutable +import scala.reflect.ClassTag -import org.apache.spark.Logging +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint @@ -34,7 +36,7 @@ private object StopCoordinator extends StateStoreCoordinatorMessage class StateStoreCoordinator(rpcEnv: RpcEnv) { private val coordinatorRef = rpcEnv.setupEndpoint( - "StateStoreCoordinator", new StateStoreCoordinatorEndpoint(rpcEnv, this)) + StateStoreCoordinator.endpointName, new StateStoreCoordinatorEndpoint(rpcEnv, this)) private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Boolean = { @@ -44,7 +46,10 @@ class StateStoreCoordinator(rpcEnv: RpcEnv) { def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { instances.synchronized { - instances.get(storeId).forall(_.executorId == executorId) + instances.get(storeId) match { + case Some(location) => location.executorId == executorId + case None => false + } } } @@ -54,30 +59,44 @@ class StateStoreCoordinator(rpcEnv: RpcEnv) { def makeInstancesInactive(operatorIds: Set[Long]): Unit = { instances.synchronized { - val instancesToRemove = + val storeIdsToRemove = instances.keys.filter(id => operatorIds.contains(id.operatorId)).toSeq - instances --= instancesToRemove + instances --= storeIdsToRemove } } + + def stop(): Unit = { + coordinatorRef.askWithRetry[Boolean](StopCoordinator) + } } -private[spark] object StateStoreCoordinator { +private[sql] object StateStoreCoordinator { + + private val endpointName = "StateStoreCoordinator" - private[spark] class StateStoreCoordinatorEndpoint( + private class StateStoreCoordinatorEndpoint( override val rpcEnv: RpcEnv, coordinator: StateStoreCoordinator) extends RpcEndpoint with Logging { - override def receive: PartialFunction[Any, Unit] = { - case StopCoordinator => - logInfo("StateStoreCoordinator stopped!") - stop() - } - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => context.reply(coordinator.reportActiveInstance(id, host, executorId)) case VerifyIfInstanceActive(id, executor) => context.reply(coordinator.verifyIfInstanceActive(id, executor)) + case StopCoordinator => + // Stop before replying to ensure that endpoint name has been deregistered + stop() + context.reply(true) + } + } + + def ask(message: StateStoreCoordinatorMessage): Option[Boolean] = { + val env = SparkEnv.get + if (env != null) { + val coordinatorRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + Some(coordinatorRef.askWithRetry[Boolean](message)) + } else { + None } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala new file mode 100644 index 0000000000000..8f0462e41d44e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + + test("report, verify, getLocation") { + withCoordinator { coordinator => + val id = StateStoreId(0, 0) + + assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinator.getLocation(id) === None) + + assert(coordinator.reportActiveInstance(id, "hostX", "exec1") === true) + assert(coordinator.verifyIfInstanceActive(id, "exec1") === true) + assert(coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + + assert(coordinator.reportActiveInstance(id, "hostX", "exec2") === true) + assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinator.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } + } + + test("make inactive") { + withCoordinator { coordinator => + val id1 = StateStoreId(0, 0) + val id2 = StateStoreId(1, 0) + val id3 = StateStoreId(0, 1) + val host = "hostX" + val exec = "exec1" + + assert(coordinator.reportActiveInstance(id1, host, exec) === true) + assert(coordinator.reportActiveInstance(id2, host, exec) === true) + assert(coordinator.reportActiveInstance(id3, host, exec) === true) + + assert(coordinator.verifyIfInstanceActive(id1, exec) === true) + assert(coordinator.verifyIfInstanceActive(id2, exec) === true) + assert(coordinator.verifyIfInstanceActive(id3, exec) === true) + + coordinator.makeInstancesInactive(Set(0)) + + assert(coordinator.verifyIfInstanceActive(id1, exec) === false) + assert(coordinator.verifyIfInstanceActive(id2, exec) === true) + assert(coordinator.verifyIfInstanceActive(id3, exec) === false) + + assert(coordinator.getLocation(id1) === None) + assert( + coordinator.getLocation(id2) === + Some(ExecutorCacheTaskLocation(host, exec).toString)) + assert(coordinator.getLocation(id3) === None) + + coordinator.makeInstancesInactive(Set(1)) + assert(coordinator.verifyIfInstanceActive(id2, exec) === false) + assert(coordinator.getLocation(id2) === None) + } + } + + test("communication") { + withCoordinator { coordinator => + import StateStoreCoordinator._ + val id = StateStoreId(0, 0) + val host = "hostX" + + val ref = RpcUtils.makeDriverRef("StateStoreCoordinator", sc.env.conf, sc.env.rpcEnv) + + assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(false)) + + ask(ReportActiveInstance(id, host, "exec1")) + assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(true)) + assert( + coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation(host, "exec1").toString)) + + ask(ReportActiveInstance(id, host, "exec2")) + assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(false)) + assert(ask(VerifyIfInstanceActive(id, "exec2")) === Some(true)) + assert( + coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation(host, "exec2").toString)) + } + } + + private def withCoordinator(body: StateStoreCoordinator => Unit): Unit = { + var coordinator: StateStoreCoordinator = null + try { + coordinator = new StateStoreCoordinator(sc.env.rpcEnv) + body(coordinator) + } finally { + if (coordinator != null) coordinator.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 9c42116811c7e..15a6cbb4e3828 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -37,7 +37,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn import StateStoreSuite._ after { - StateStore.clearAll() + StateStore.stop() } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 9e35dd2022a4c..01dd625f12d86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -39,7 +39,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth private val tempDir = Utils.createTempDir().toString after { - StateStore.clearAll() + StateStore.stop() } test("update, remove, commit, and all data iterator") { @@ -325,13 +325,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth private[state] object StateStoreSuite { - /** Trait and class mirroring [[StoreUpdate]] for testing */ + /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ trait TestUpdate case class Added(key: String, value: Int) extends TestUpdate case class Updated(key: String, value: Int) extends TestUpdate case class Removed(key: String) extends TestUpdate - def wrapValue(i: Int): InternalRow = { new GenericInternalRow(Array[Any](i)) } From f5660d23ce80f84e4d28235805d1c64060aad382 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 16:28:15 -0700 Subject: [PATCH 14/46] Added unit tests and fixed scala style --- .../streaming/state/StateStore.scala | 41 +++++++++++-------- .../state/StateStoreCoordinator.scala | 13 ++++-- .../state/StateStoreCoordinatorSuite.scala | 4 +- .../streaming/state/StateStoreSuite.scala | 29 +++++++++++++ 4 files changed, 65 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index cb23daa9446fc..7c4e0692082b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -112,6 +112,7 @@ private[state] object StateStore extends Logging { /** Get or create a store associated with the id. */ def get(storeId: StateStoreId, directory: String, version: Long): StateStore = { + require(version >= 0) val storeProvider = loadedProviders.synchronized { startIfNeeded() val provider = loadedProviders.getOrElseUpdate( @@ -122,6 +123,10 @@ private[state] object StateStore extends Logging { storeProvider.getStore(version) } + def remove(storeId: StateStoreId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeId) + } + /** Unload and stop all state store provider */ def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() @@ -137,7 +142,14 @@ private[state] object StateStore extends Logging { managementTask = new TimerTask { override def run(): Unit = { manageAll() } } - val periodMs = MANAGEMENT_TASK_INTERVAL_SECS * 1000 + val periodMs = Option(SparkEnv.get).map(_.conf) match { + case Some(conf) => + conf.getTimeAsMs( + "spark.sql.streaming.stateStore.managementInterval", + s"${MANAGEMENT_TASK_INTERVAL_SECS}s") + case None => + MANAGEMENT_TASK_INTERVAL_SECS * 1000 + } managementTimer.schedule(managementTask, periodMs, periodMs) logInfo("StateStore started") } @@ -163,28 +175,25 @@ private[state] object StateStore extends Logging { } } - private def remove(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId) - } - private def reportActiveInstance(storeId: StateStoreId): Unit = { - val host = SparkEnv.get.blockManager.blockManagerId.host - val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - askCoordinator(ReportActiveInstance(storeId, host, executorId)) + try { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + StateStoreCoordinator.ask(ReportActiveInstance(storeId, host, executorId)) + } catch { + case NonFatal(e) => + logWarning(s"Error reporting active instance of $storeId") + } } private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { - val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - askCoordinator(VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) - } - - private def askCoordinator(message: StateStoreCoordinatorMessage): Option[Boolean] = { try { - StateStoreCoordinator.ask(message) + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + StateStoreCoordinator.ask(VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) } catch { case NonFatal(e) => - logWarning("Error communicating") - None + logWarning(s"Error verifying active instance of $storeId") + false } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d21b07bbe7629..42528e40d8327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -18,32 +18,35 @@ package org.apache.spark.sql.execution.streaming.state import scala.collection.mutable -import scala.reflect.ClassTag -import org.apache.spark.util.RpcUtils import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.util.RpcUtils import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint +/** Trait representing all messages to [[StateStoreCoordinator]] */ private sealed trait StateStoreCoordinatorMessage extends Serializable + private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) extends StateStoreCoordinatorMessage private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) extends StateStoreCoordinatorMessage private object StopCoordinator extends StateStoreCoordinatorMessage - +/** Class for coordinating instances of [[StateStore]]s loaded in the cluster */ class StateStoreCoordinator(rpcEnv: RpcEnv) { private val coordinatorRef = rpcEnv.setupEndpoint( StateStoreCoordinator.endpointName, new StateStoreCoordinatorEndpoint(rpcEnv, this)) private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + /** Report active instance of a state store in an executor */ def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Boolean = { instances.synchronized { instances.put(storeId, ExecutorCacheTaskLocation(host, executorId)) } true } + /** Verify whether the given executor has the active instance of a state store */ def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { instances.synchronized { instances.get(storeId) match { @@ -53,11 +56,13 @@ class StateStoreCoordinator(rpcEnv: RpcEnv) { } } + /** Get the location of the state store */ def getLocation(storeId: StateStoreId): Option[String] = { instances.synchronized { instances.get(storeId).map(_.toString) } } - def makeInstancesInactive(operatorIds: Set[Long]): Unit = { + /** Deactivate instances related to a set of operator */ + def deactivateInstances(operatorIds: Set[Long]): Unit = { instances.synchronized { val storeIdsToRemove = instances.keys.filter(id => operatorIds.contains(id.operatorId)).toSeq diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 8f0462e41d44e..297fa85b7194f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -61,7 +61,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinator.verifyIfInstanceActive(id2, exec) === true) assert(coordinator.verifyIfInstanceActive(id3, exec) === true) - coordinator.makeInstancesInactive(Set(0)) + coordinator.deactivateInstances(Set(0)) assert(coordinator.verifyIfInstanceActive(id1, exec) === false) assert(coordinator.verifyIfInstanceActive(id2, exec) === true) @@ -73,7 +73,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinator.getLocation(id3) === None) - coordinator.makeInstancesInactive(Set(1)) + coordinator.deactivateInstances(Set(1)) assert(coordinator.verifyIfInstanceActive(id2, exec) === false) assert(coordinator.getLocation(id2) === None) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 01dd625f12d86..335b43efc0c51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -262,6 +262,35 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) } + test("StateStore.get") { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(0, 0) + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, dir, -1) + } + intercept[IllegalStateException] { + StateStore.get(storeId, dir, 1) + } + + // Increase version of the store + val store0 = StateStore.get(storeId, dir, 0) + assert(store0.version === 0) + update(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, dir, 1).version == 1) + assert(StateStore.get(storeId, dir, 0).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.remove(storeId) + val store1 = StateStore.get(storeId, dir, 1) + update(store1, "a", 2) + assert(store1.commit() === 2) + assert(unwrapToSet(store1.iterator()) === Set("a" -> 2)) + } + def getDataFromFiles( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { From d31368394c78a2b1e46c871267a46660e880d4ff Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 16:54:42 -0700 Subject: [PATCH 15/46] Added unit test for preferred location --- .../state/StateStoreCoordinator.scala | 2 + .../state/StateStoreCoordinatorSuite.scala | 14 ++- .../streaming/state/StateStoreRDDSuite.scala | 100 ++++++++++++------ 3 files changed, 78 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 42528e40d8327..904bc247d7c9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -34,6 +34,7 @@ private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: Str extends StateStoreCoordinatorMessage private object StopCoordinator extends StateStoreCoordinatorMessage + /** Class for coordinating instances of [[StateStore]]s loaded in the cluster */ class StateStoreCoordinator(rpcEnv: RpcEnv) { private val coordinatorRef = rpcEnv.setupEndpoint( @@ -75,6 +76,7 @@ class StateStoreCoordinator(rpcEnv: RpcEnv) { } } + private[sql] object StateStoreCoordinator { private val endpointName = "StateStoreCoordinator" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 297fa85b7194f..08488cc55b759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.util.RpcUtils -import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.{SparkContext, SharedSparkContext, SparkFunSuite} class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + import StateStoreCoordinatorSuite._ + test("report, verify, getLocation") { - withCoordinator { coordinator => + withCoordinator(sc) { coordinator => val id = StateStoreId(0, 0) assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) @@ -46,7 +48,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } test("make inactive") { - withCoordinator { coordinator => + withCoordinator(sc) { coordinator => val id1 = StateStoreId(0, 0) val id2 = StateStoreId(1, 0) val id3 = StateStoreId(0, 1) @@ -80,7 +82,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } test("communication") { - withCoordinator { coordinator => + withCoordinator(sc) { coordinator => import StateStoreCoordinator._ val id = StateStoreId(0, 0) val host = "hostX" @@ -103,8 +105,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, "exec2").toString)) } } +} - private def withCoordinator(body: StateStoreCoordinator => Unit): Unit = { +object StateStoreCoordinatorSuite { + def withCoordinator(sc: SparkContext)(body: StateStoreCoordinator => Unit): Unit = { var coordinator: StateStoreCoordinator = null try { coordinator = new StateStoreCoordinator(sc.env.rpcEnv) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 15a6cbb4e3828..971817d4aaaed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -25,6 +25,8 @@ import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -35,6 +37,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString import StateStoreSuite._ + import StateStoreCoordinatorSuite._ after { StateStore.stop() @@ -46,52 +49,83 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - withSpark(new SparkContext(conf)) { sc => - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - wrapKey(s), oldRow => { - val oldValue = oldRow.map(unwrapValue).getOrElse(0) - wrapValue(oldValue + 1) - }) + quietly { + withSpark(new SparkContext(conf)) { sc => + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + wrapKey(s), oldRow => { + val oldValue = oldRow.map(unwrapValue).getOrElse(0) + wrapValue(oldValue + 1) + }) + } + store.commit() + store.iterator().map(unwrapKeyValue) } - store.commit() - store.iterator().map(unwrapKeyValue) + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")) + .withStateStores(increment, opId, storeVersion = 0, path, null) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")) + .withStateStores(increment, opId, storeVersion = 1, path, null) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) } + } + } + + test("recovering from files") { + quietly { val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")) - .withStateStores(increment, opId, storeVersion = 0, path, null) - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + makeRDD(sc, Seq("a")).withStateStores(increment, opId, storeVersion, path, null) + } - // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")) - .withStateStores(increment, opId, storeVersion = 1, path, null) - assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + // Generate RDDs and state store data + withSpark(new SparkContext(conf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } - // Make sure the previous RDD still has the same data. - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + // With a new context, try using the earlier state store data + withSpark(new SparkContext(conf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } } } - test("recovering from files") { + test("preferred locations using StateStoreCoordinator") { val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - def makeStoreRDD(sc: SparkContext, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - makeRDD(sc, Seq("a")).withStateStores(increment, opId, storeVersion, path, null) - } - - // Generate RDDs and state store data withSpark(new SparkContext(conf)) { sc => - for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) - } - } + withCoordinator(sc) { coordinator => + coordinator.reportActiveInstance(StateStoreId(opId, 0), "host1", "exec1") + coordinator.reportActiveInstance(StateStoreId(opId, 1), "host2", "exec2") - // With a new context, try using the earlier state store data - withSpark(new SparkContext(conf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + val rdd = makeRDD(sc, Seq("a", "b", "a")) + .withStateStores(increment, opId, storeVersion = 0, path, coordinator) + require(rdd.partitions.size === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) != + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) != + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + } } } From 7ea847cc59ebd9f1322a20cd31805ac7f753377d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 16:56:11 -0700 Subject: [PATCH 16/46] Fixed style --- .../execution/streaming/state/StateStoreCoordinator.scala | 4 ++-- .../streaming/state/StateStoreCoordinatorSuite.scala | 2 +- .../sql/execution/streaming/state/StateStoreRDDSuite.scala | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 904bc247d7c9a..18856a19228db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.streaming.state import scala.collection.mutable -import org.apache.spark.{SparkEnv, Logging} -import org.apache.spark.util.RpcUtils +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint +import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ private sealed trait StateStoreCoordinatorMessage extends Serializable diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 08488cc55b759..fa6217da0abc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.streaming.state +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.util.RpcUtils -import org.apache.spark.{SparkContext, SharedSparkContext, SparkFunSuite} class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 971817d4aaaed..a84a1d20a2405 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -25,10 +25,10 @@ import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { @@ -36,8 +36,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString - import StateStoreSuite._ import StateStoreCoordinatorSuite._ + import StateStoreSuite._ after { StateStore.stop() From b8b4632b5ed9be181456da4f53156a0cbebebaf3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 17:34:10 -0700 Subject: [PATCH 17/46] Added unit test for StateStore background management --- .../streaming/state/StateStoreSuite.scala | 90 +++++++++++++------ 1 file changed, 64 insertions(+), 26 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 335b43efc0c51..9344668a58669 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -24,16 +24,21 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[InternalRow, InternalRow] + import StateStoreCoordinatorSuite._ import StateStoreSuite._ private val tempDir = Utils.createTempDir().toString @@ -263,32 +268,65 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("StateStore.get") { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString - val storeId = StateStoreId(0, 0) - - // Verify that trying to get incorrect versions throw errors - intercept[IllegalArgumentException] { - StateStore.get(storeId, dir, -1) - } - intercept[IllegalStateException] { - StateStore.get(storeId, dir, 1) - } + quietly { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(0, 0) - // Increase version of the store - val store0 = StateStore.get(storeId, dir, 0) - assert(store0.version === 0) - update(store0, "a", 1) - store0.commit() + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, dir, -1) + } + intercept[IllegalStateException] { + StateStore.get(storeId, dir, 1) + } - assert(StateStore.get(storeId, dir, 1).version == 1) - assert(StateStore.get(storeId, dir, 0).version == 0) + // Increase version of the store + val store0 = StateStore.get(storeId, dir, 0) + assert(store0.version === 0) + update(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, dir, 1).version == 1) + assert(StateStore.get(storeId, dir, 0).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.remove(storeId) + val store1 = StateStore.get(storeId, dir, 1) + update(store1, "a", 2) + assert(store1.commit() === 2) + assert(unwrapToSet(store1.iterator()) === Set("a" -> 2)) + } + } - // Verify that you can remove the store and still reload and use it - StateStore.remove(storeId) - val store1 = StateStore.get(storeId, dir, 1) - update(store1, "a", 2) - assert(store1.commit() === 2) - assert(unwrapToSet(store1.iterator()) === Set("a" -> 2)) + test("background management") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.sql.streaming.stateStore.managementInterval", "10ms") + val storeId = StateStoreId(0, 0) + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinator(sc) { coordinator => + for (i <- 1 to 20) { + val store = StateStore.get(storeId, dir, i - 1) + update(store, "a", i) + store.commit() + } + + val provider = new HDFSBackedStateStoreProvider(storeId, dir) + + eventually(timeout(4 seconds)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + + val snapshotVersions = (0 to 20).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + } + } + } } def getDataFromFiles( @@ -303,8 +341,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } def assertMap( - testMapOption: Option[MapType], - expectedMap: Map[String, Int]): Unit = { + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { assert(testMapOption.nonEmpty, "no map present") val convertedMap = testMapOption.get.map(unwrapKeyValue) assert(convertedMap === expectedMap) From e89f4d0d1a45d5a4d8c59a433cf55ce5d19010da Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 17:47:25 -0700 Subject: [PATCH 18/46] Updated unit test to test instance unloading in state store --- .../streaming/state/StateStore.scala | 4 +++ .../streaming/state/StateStoreSuite.scala | 32 +++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7c4e0692082b4..fceca4107bef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -127,6 +127,10 @@ private[state] object StateStore extends Logging { loadedProviders.remove(storeId) } + def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeId) + } + /** Unload and stop all state store provider */ def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 9344668a58669..d8cace66107b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -276,6 +276,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth intercept[IllegalArgumentException] { StateStore.get(storeId, dir, -1) } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + intercept[IllegalStateException] { StateStore.get(storeId, dir, 1) } @@ -291,7 +293,10 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Verify that you can remove the store and still reload and use it StateStore.remove(storeId) + assert(!StateStore.isLoaded(storeId)) + val store1 = StateStore.get(storeId, dir, 1) + assert(StateStore.isLoaded(storeId)) update(store1, "a", 2) assert(store1.commit() === 2) assert(unwrapToSet(store1.iterator()) === Set("a" -> 2)) @@ -303,8 +308,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth .setMaster("local") .setAppName("test") .set("spark.sql.streaming.stateStore.managementInterval", "10ms") - val storeId = StateStoreId(0, 0) + val opId = 0 + val storeId = StateStoreId(opId, 0) val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val provider = new HDFSBackedStateStoreProvider(storeId, dir) + quietly { withSpark(new SparkContext(conf)) { sc => withCoordinator(sc) { coordinator => @@ -314,16 +322,34 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.commit() } - val provider = new HDFSBackedStateStoreProvider(storeId, dir) - + // Background management should clean up and generate snapshots eventually(timeout(4 seconds)) { + // Earliest delta file should get cleaned up assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + // Some snapshots should have been generated val snapshotVersions = (0 to 20).filter { version => fileExists(provider, version, isSnapshot = true) } assert(snapshotVersions.nonEmpty, "no snapshot file found") } + + // If driver decides to deactivate all instances of the store, then this instance + // should be unloaded + coordinator.deactivateInstances(Set(opId)) + eventually(timeout(4 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, dir, 20) + assert(StateStore.isLoaded(storeId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinator.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(4 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } } } } From b5e242132ae98342113f27da4824aaf06485e291 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 18:09:48 -0700 Subject: [PATCH 19/46] Updated unit test --- .../sql/execution/streaming/state/StateStoreSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index d8cace66107b4..dfea71bb4190a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -350,8 +350,17 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth eventually(timeout(4 seconds)) { assert(!StateStore.isLoaded(storeId)) } + + // Reload the store and verify + StateStore.get(storeId, dir, 20) + assert(StateStore.isLoaded(storeId)) } } + + // Verify if instance is unloaded if SparkContext is stopped + eventually(timeout(4 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } } } From 76dd988fcac508609be3f32754284bc07524e481 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 18:28:04 -0700 Subject: [PATCH 20/46] Fixed StateStoreRDD unit test --- .../streaming/state/StateStoreRDD.scala | 18 ++++++++---------- .../execution/streaming/state/package.scala | 2 +- .../streaming/state/StateStoreRDDSuite.scala | 12 ++++++------ 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 80e497455c993..9a1ff0a9f7b0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -24,25 +24,23 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * Created by tdas on 3/9/16. - */ + * An RDD that allows computations to be executed against [[StateStore]]s. It + * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as + * preferred locations. + */ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( dataRDD: RDD[INPUT], storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], operatorId: Long, storeVersion: Long, storeDirectory: String, - storeCoordinator: StateStoreCoordinator) extends RDD[OUTPUT](dataRDD) { - - val nextVersion = storeVersion + 1 + storeCoordinator: Option[StateStoreCoordinator]) extends RDD[OUTPUT](dataRDD) { override protected def getPartitions: Array[Partition] = dataRDD.partitions + override def getPreferredLocations(partition: Partition): Seq[String] = { - Seq.empty - /* - storeCoordinator.getLocation( - StateStoreId(operatorId, partition.index)).toSeq - */ + val storeId = StateStoreId(operatorId, partition.index) + storeCoordinator.flatMap(_.getLocation(storeId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[OUTPUT] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index e5ed5a41b949d..2bf6d6c906157 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -29,7 +29,7 @@ package object state { operatorId: Long, storeVersion: Long, storeDirectory: String, - storeCoordinator: StateStoreCoordinator + storeCoordinator: Option[StateStoreCoordinator] = None ): StateStoreRDD[INPUT, OUTPUT] = { new StateStoreRDD( dataRDD, storeUpdateFunction, operatorId, storeVersion, storeDirectory, storeCoordinator) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index a84a1d20a2405..112bfa6586aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -65,12 +65,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val opId = 0 val rdd1 = makeRDD(sc, Seq("a", "b", "a")) - .withStateStores(increment, opId, storeVersion = 0, path, null) + .withStateStores(increment, opId, storeVersion = 0, path) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(sc, Seq("a", "c")) - .withStateStores(increment, opId, storeVersion = 1, path, null) + .withStateStores(increment, opId, storeVersion = 1, path) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -88,7 +88,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc: SparkContext, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - makeRDD(sc, Seq("a")).withStateStores(increment, opId, storeVersion, path, null) + makeRDD(sc, Seq("a")).withStateStores(increment, opId, storeVersion, path) } // Generate RDDs and state store data @@ -115,15 +115,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinator.reportActiveInstance(StateStoreId(opId, 1), "host2", "exec2") val rdd = makeRDD(sc, Seq("a", "b", "a")) - .withStateStores(increment, opId, storeVersion = 0, path, coordinator) + .withStateStores(increment, opId, storeVersion = 0, path, Some(coordinator)) require(rdd.partitions.size === 2) assert( - rdd.preferredLocations(rdd.partitions(0)) != + rdd.preferredLocations(rdd.partitions(0)) === Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) assert( - rdd.preferredLocations(rdd.partitions(1)) != + rdd.preferredLocations(rdd.partitions(1)) === Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) } } From 81238189f560f5df846b36e5cb5e50532f6f82a3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 15 Mar 2016 20:41:37 -0700 Subject: [PATCH 21/46] Added docs --- .../state/HDFSBackedStateStoreProvider.scala | 17 ++++++++++++++--- .../execution/streaming/state/StateStore.scala | 6 +----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index e0473ec56ab94..b3bcf55ffbe1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.{CompletionIterator, Utils} /** * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed - * by files in a HDFS-compatible file system. All updates to the store has to be done in sets + * by files in a HDFS-compatible file system. All updates to the store has to be done in sets * transactionally, and each set of updates increments the store's version. These versions can * be used to re-execute the updates (by retries in RDD operations) on the correct version of * the store, and regenerate the store version. @@ -56,7 +56,7 @@ import org.apache.spark.util.{CompletionIterator, Utils} * to ensure re-executed RDD operations re-apply updates on the correct past version of the * store. */ -class HDFSBackedStateStoreProvider( +private[state] class HDFSBackedStateStoreProvider( val id: StateStoreId, val directory: String, numBatchesToRetain: Int = 2, @@ -66,10 +66,11 @@ class HDFSBackedStateStoreProvider( import StateStore._ - + /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ class HDFSBackedStateStore( val version: Long, mapToUpdate: MapType) extends StateStore { + /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE @@ -94,10 +95,14 @@ class HDFSBackedStateStoreProvider( mapToUpdate.put(key, value) allUpdates.get(key) match { case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added already, keep it marked as added allUpdates.put(key, ValueAdded(key, value)) case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + // Value existed in prev version and updated/removed, mark it as updated allUpdates.put(key, ValueUpdated(key, value)) case None => + // There was no prior update, so mark this as added or updated according to its presence + // in previous version. val update = if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) allUpdates.put(key, update) @@ -113,10 +118,13 @@ class HDFSBackedStateStoreProvider( val key = keyIter.next if (condition(key)) { mapToUpdate.remove(key) + allUpdates.get(key) match { case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed allUpdates.put(key, KeyRemoved(key)) case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates allUpdates.remove(key) case Some(KeyRemoved(_)) => // Remove already in update map, no need to change @@ -189,6 +197,7 @@ class HDFSBackedStateStoreProvider( new HDFSBackedStateStore(version, newMap) } + /** Manage backing files, including creating snapshots and cleaning up old files */ override def manage(): Unit = { try { doSnapshot() @@ -382,6 +391,7 @@ class HDFSBackedStateStoreProvider( } } + /** Files needed to recover the given version of the store */ private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { require(version >= 0) require(allFiles.exists(_.version == version)) @@ -409,6 +419,7 @@ class HDFSBackedStateStoreProvider( latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles } + /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { fs.listStatus(baseDir) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index fceca4107bef7..841acc184e569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -20,12 +20,10 @@ package org.apache.spark.sql.execution.streaming.state import java.util.{Timer, TimerTask} import scala.collection.mutable -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.util.RpcUtils /** Unique identifier for a [[StateStore]] */ case class StateStoreId(operatorId: Long, partitionId: Int) @@ -35,8 +33,6 @@ case class StateStoreId(operatorId: Long, partitionId: Int) */ trait StateStore { - import StateStore._ - /** Unique identifier of the store */ def id: StateStoreId @@ -104,7 +100,7 @@ case class KeyRemoved(key: InternalRow) extends StoreUpdate */ private[state] object StateStore extends Logging { - val MANAGEMENT_TASK_INTERVAL_SECS = 60 + private val MANAGEMENT_TASK_INTERVAL_SECS = 60 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val managementTimer = new Timer("StateStore Timer", true) From 2fb5b85962c3cacd8764292a0d8922e88c1bed7e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Mar 2016 13:13:09 -0700 Subject: [PATCH 22/46] Minor fixes --- .../scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- .../sql/execution/streaming/state/StateStoreRDD.scala | 11 ++++++----- .../spark/sql/execution/streaming/state/package.scala | 11 ++++++----- .../streaming/state/StateStoreRDDSuite.scala | 2 ++ 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index e27d2e6c94f7b..cf4ad81e2bcd7 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkEnv, SparkException} /** * A cleaner that renders closures serializable if they can be done so safely. */ -private[spark] object ClosureCleaner extends Logging { +private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 9a1ff0a9f7b0b..09884970c1f91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -28,13 +28,14 @@ import org.apache.spark.util.Utils * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as * preferred locations. */ -class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( - dataRDD: RDD[INPUT], - storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], +class StateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], operatorId: Long, storeVersion: Long, storeDirectory: String, - storeCoordinator: Option[StateStoreCoordinator]) extends RDD[OUTPUT](dataRDD) { + @transient private val storeCoordinator: Option[StateStoreCoordinator]) + extends RDD[U](dataRDD) { override protected def getPartitions: Array[Partition] = dataRDD.partitions @@ -43,7 +44,7 @@ class StateStoreRDD[INPUT: ClassTag, OUTPUT: ClassTag]( storeCoordinator.flatMap(_.getLocation(storeId)).toSeq } - override def compute(partition: Partition, ctxt: TaskContext): Iterator[OUTPUT] = { + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null Utils.tryWithSafeFinally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 2bf6d6c906157..e772deeae7795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -23,16 +23,17 @@ import org.apache.spark.rdd.RDD package object state { - implicit class StateStoreOps[INPUT: ClassTag](dataRDD: RDD[INPUT]) { - def withStateStores[OUTPUT: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[INPUT]) => Iterator[OUTPUT], + implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { + def withStateStores[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], operatorId: Long, storeVersion: Long, storeDirectory: String, storeCoordinator: Option[StateStoreCoordinator] = None - ): StateStoreRDD[INPUT, OUTPUT] = { + ): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( - dataRDD, storeUpdateFunction, operatorId, storeVersion, storeDirectory, storeCoordinator) + dataRDD, cleanedF, operatorId, storeVersion, storeDirectory, storeCoordinator) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 112bfa6586aa0..82788d77579cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -125,6 +125,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn assert( rdd.preferredLocations(rdd.partitions(1)) === Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() } } } From dee7a0ef0b4f00c2c69e131863903284836d31c6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Mar 2016 15:58:17 -0700 Subject: [PATCH 23/46] Updated store to UnsafeRow instead of InternalRow --- .../state/HDFSBackedStateStoreProvider.scala | 29 +++++++---------- .../streaming/state/StateStore.scala | 14 ++++---- .../streaming/state/StateStoreSuite.scala | 32 +++++++++---------- 3 files changed, 33 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index b3bcf55ffbe1a..e0b255485f298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -26,8 +26,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.{Logging, SparkConf} import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.{CompletionIterator, Utils} @@ -62,9 +61,7 @@ private[state] class HDFSBackedStateStoreProvider( numBatchesToRetain: Int = 2, maxDeltaChainForSnapshots: Int = 10 ) extends StateStoreProvider with Logging { - type MapType = mutable.HashMap[InternalRow, InternalRow] - - import StateStore._ + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ class HDFSBackedStateStore( val version: Long, mapToUpdate: MapType) @@ -80,7 +77,7 @@ private[state] class HDFSBackedStateStoreProvider( private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private val tempDeltaFileStream = serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) - private val allUpdates = new mutable.HashMap[InternalRow, StoreUpdate] + private val allUpdates = new mutable.HashMap[UnsafeRow, StoreUpdate] @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null @@ -88,7 +85,7 @@ private[state] class HDFSBackedStateStoreProvider( override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id /** Update the value of a key using the value generated by the update function */ - override def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit = { + override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { verify(state == UPDATING, "Cannot update after already committed or cancelled") val oldValueOption = mapToUpdate.get(key) val value = updateFunc(oldValueOption) @@ -111,7 +108,7 @@ private[state] class HDFSBackedStateStoreProvider( } /** Remove keys that match the following condition */ - override def remove(condition: InternalRow => Boolean): Unit = { + override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or cancelled") val keyIter = mapToUpdate.keysIterator while (keyIter.hasNext) { @@ -165,7 +162,7 @@ private[state] class HDFSBackedStateStoreProvider( * Get an iterator of all the store data. This can be called only after committing the * updates. */ - override def iterator(): Iterator[InternalRow] = { + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { verify(state == COMMITTED, "Cannot get iterator of store data before comitting") HDFSBackedStateStoreProvider.this.iterator(newVersion) } @@ -237,22 +234,18 @@ private[state] class HDFSBackedStateStoreProvider( * Get iterator of all the data of the latest version of the store. * Note that this will look up the files to determined the latest known version. */ - private[state] def latestIterator(): Iterator[InternalRow] = synchronized { + private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet val versionsLoaded = loadedMaps.keySet val allKnownVersions = versionsInFiles ++ versionsLoaded if (allKnownVersions.nonEmpty) { - loadMap(allKnownVersions.max) - .iterator - .map { case (key, value) => new JoinedRow(key, value) } + loadMap(allKnownVersions.max).iterator } else Iterator.empty } /** Get iterator of a specific version of the store */ - private[state] def iterator(version: Long): Iterator[InternalRow] = synchronized { - loadMap(version) - .iterator - .map { case (key, value) => new JoinedRow(key, value) } + private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + loadMap(version).iterator } /** Initialize the store provider */ @@ -330,7 +323,7 @@ private[state] class HDFSBackedStateStoreProvider( try { deserStream = deser.deserializeStream(fs.open(fileToRead)) - val iter = deserStream.asIterator.asInstanceOf[Iterator[(InternalRow, InternalRow)]] + val iter = deserStream.asIterator.asInstanceOf[Iterator[(UnsafeRow, UnsafeRow)]] while (iter.hasNext) { map += iter.next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 841acc184e569..98f313b869931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -22,8 +22,8 @@ import java.util.{Timer, TimerTask} import scala.collection.mutable import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.sql.catalyst.InternalRow /** Unique identifier for a [[StateStore]] */ case class StateStoreId(operatorId: Long, partitionId: Int) @@ -43,13 +43,13 @@ trait StateStore { * Update the value of a key using the value generated by the update function. * This can be called only after prepareForUpdates() has been called in the same thread. */ - def update(key: InternalRow, updateFunc: Option[InternalRow] => InternalRow): Unit + def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit /** * Remove keys that match the following condition. * This can be called only after prepareForUpdates() has been called in the current thread. */ - def remove(condition: InternalRow => Boolean): Unit + def remove(condition: UnsafeRow => Boolean): Unit /** * Commit all the updates that have been made to the store. @@ -64,7 +64,7 @@ trait StateStore { * Iterator of store data after a set of updates have been committed. * This can be called only after commitUpdates() has been called in the current thread. */ - def iterator(): Iterator[InternalRow] + def iterator(): Iterator[(UnsafeRow, UnsafeRow)] /** * Iterator of the updates that have been committed. @@ -89,9 +89,9 @@ trait StateStoreProvider { } sealed trait StoreUpdate -case class ValueAdded(key: InternalRow, value: InternalRow) extends StoreUpdate -case class ValueUpdated(key: InternalRow, value: InternalRow) extends StoreUpdate -case class KeyRemoved(key: InternalRow) extends StoreUpdate +case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +case class KeyRemoved(key: UnsafeRow) extends StoreUpdate /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index dfea71bb4190a..de9d61f77fb59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -29,14 +29,14 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { - type MapType = mutable.HashMap[InternalRow, InternalRow] + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ import StateStoreSuite._ @@ -433,31 +433,29 @@ private[state] object StateStoreSuite { case class Updated(key: String, value: Int) extends TestUpdate case class Removed(key: String) extends TestUpdate - def wrapValue(i: Int): InternalRow = { - new GenericInternalRow(Array[Any](i)) + def wrapKey(s: String): UnsafeRow = { + val projection = UnsafeProjection.create(Array[DataType](StringType)) + projection.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))) } - def wrapKey(s: String): InternalRow = { - new GenericInternalRow(Array[Any](UTF8String.fromString(s))) + def wrapValue(i: Int): UnsafeRow = { + val projection = UnsafeProjection.create(Array[DataType](IntegerType)) + projection.apply(new GenericInternalRow(Array[Any](i))) } - def unwrapKey(row: InternalRow): String = { - row.asInstanceOf[GenericInternalRow].getString(0) + def unwrapKey(row: UnsafeRow): String = { + row.getUTF8String(0).toString } - def unwrapValue(row: InternalRow): Int = { - row.asInstanceOf[GenericInternalRow].getInt(0) + def unwrapValue(row: UnsafeRow): Int = { + row.getInt(0) } - def unwrapKeyValue(row: (InternalRow, InternalRow)): (String, Int) = { + def unwrapKeyValue(row: (UnsafeRow, UnsafeRow)): (String, Int) = { (unwrapKey(row._1), unwrapValue(row._2)) } - def unwrapKeyValue(row: InternalRow): (String, Int) = { - (row.getString(0), row.getInt(1)) - } - - def unwrapToSet(iterator: Iterator[InternalRow]): Set[(String, Int)] = { + def unwrapToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { iterator.map(unwrapKeyValue).toSet } From 15e178059d86b8284da4c7e095f70d63c1b7acbe Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Mar 2016 20:04:54 -0700 Subject: [PATCH 24/46] Added custom serialization for delta and added schema to state stores --- .../state/HDFSBackedStateStoreProvider.scala | 106 +++++++++++++++--- .../streaming/state/StateStore.scala | 13 ++- .../streaming/state/StateStoreRDD.scala | 5 +- .../execution/streaming/state/package.scala | 14 ++- .../streaming/state/StateStoreRDDSuite.scala | 34 +++--- .../streaming/state/StateStoreSuite.scala | 102 +++++++++-------- 6 files changed, 189 insertions(+), 85 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index e0b255485f298..195a971f369af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,17 +17,21 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.{DataInputStream, DataOutputStream} + import scala.collection.mutable import scala.util.Random import scala.util.control.NonFatal +import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.{Logging, SparkConf} import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** @@ -58,6 +62,8 @@ import org.apache.spark.util.{CompletionIterator, Utils} private[state] class HDFSBackedStateStoreProvider( val id: StateStoreId, val directory: String, + keySchema: StructType, + valueSchema: StructType, numBatchesToRetain: Int = 2, maxDeltaChainForSnapshots: Int = 10 ) extends StateStoreProvider with Logging { @@ -75,8 +81,8 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private val tempDeltaFileStream = - serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) + private val tempDeltaFileStream = fs.create(tempDeltaFile, true) + // serializer.newInstance().serializeStream(fs.create(tempDeltaFile, true)) private val allUpdates = new mutable.HashMap[UnsafeRow, StoreUpdate] @volatile private var state: STATE = UPDATING @@ -104,7 +110,8 @@ private[state] class HDFSBackedStateStoreProvider( if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) allUpdates.put(key, update) } - tempDeltaFileStream.writeObject(ValueUpdated(key, value)) + // tempDeltaFileStream.writeObject(ValueUpdated(key, value)) + writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) } /** Remove keys that match the following condition */ @@ -126,7 +133,8 @@ private[state] class HDFSBackedStateStoreProvider( case Some(KeyRemoved(_)) => // Remove already in update map, no need to change } - tempDeltaFileStream.writeObject(KeyRemoved(key)) + // tempDeltaFileStream.writeObject(KeyRemoved(key)) + writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) } } } @@ -136,7 +144,7 @@ private[state] class HDFSBackedStateStoreProvider( verify(state == UPDATING, "Cannot commit again after already committed or cancelled") try { - tempDeltaFileStream.close() + finalizeDeltaFile(tempDeltaFileStream) finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) state = COMMITTED newVersion @@ -267,10 +275,11 @@ private[state] class HDFSBackedStateStoreProvider( synchronized { loadedMaps.get(version) }.getOrElse { val mapFromFile = readSnapshotFile(version).getOrElse { val prevMap = loadMap(version - 1) - val deltaUpdates = readDeltaFile(version) val newMap = new MapType() - newMap ++= prevMap newMap.sizeHint(prevMap.size) + newMap ++= prevMap + /* + val deltaUpdates = readDeltaFile(version) while (deltaUpdates.hasNext) { deltaUpdates.next match { case ValueAdded(key, value) => newMap.put(key, value) @@ -278,6 +287,8 @@ private[state] class HDFSBackedStateStoreProvider( case KeyRemoved(key) => newMap.remove(key) } } + */ + updateFromDeltaFile(version, newMap) newMap } loadedMaps.put(version, mapFromFile) @@ -285,20 +296,79 @@ private[state] class HDFSBackedStateStoreProvider( } } - private def readDeltaFile(version: Long): Iterator[StoreUpdate] = { + private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { + + def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + + def writeRemove(key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) + } + + update match { + case ValueAdded(key, value) => + writeUpdate(key, value) + case ValueUpdated(key, value) => + writeUpdate(key, value) + case KeyRemoved(key) => + writeRemove(key) + } + } + + private def finalizeDeltaFile(output: DataOutputStream): Unit = { + output.writeInt(-1) // Write this magic number to signify end of file + output.close() + } + + private def updateFromDeltaFile(version: Long, map: MapType): Unit = { val fileToRead = deltaFile(version) if (!fs.exists(fileToRead)) { throw new IllegalStateException( - s"Cannot read delta file $fileToRead of $this: $fileToRead does not exist") + s"Error reading delta file $fileToRead of $this: $fileToRead does not exist") + } + var input: DataInputStream = null + try { + input = fs.open(fileToRead) + var eof = false + + while(!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new Exception( + s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + map.remove(keyRow) + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + } finally { + if (input != null) input.close() } - val deser = serializer.newInstance() - var deserStream: DeserializationStream = null - deserStream = deser.deserializeStream(fs.open(fileToRead)) - val iter = deserStream.asIterator.asInstanceOf[Iterator[StoreUpdate]] - CompletionIterator[StoreUpdate, Iterator[StoreUpdate]]( - iter, { - deserStream.close() - }) } private def writeSnapshotFile(version: Long, map: MapType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 98f313b869931..41158f9e10dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -22,7 +22,11 @@ import java.util.{Timer, TimerTask} import scala.collection.mutable import scala.util.control.NonFatal +import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} + import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, SparkEnv} /** Unique identifier for a [[StateStore]] */ @@ -107,12 +111,17 @@ private[state] object StateStore extends Logging { @volatile private var managementTask: TimerTask = null /** Get or create a store associated with the id. */ - def get(storeId: StateStoreId, directory: String, version: Long): StateStore = { + def get( + storeId: StateStoreId, + directory: String, + keySchema: StructType, + valueSchema: StructType, + version: Long): StateStore = { require(version >= 0) val storeProvider = loadedProviders.synchronized { startIfNeeded() val provider = loadedProviders.getOrElseUpdate( - storeId, new HDFSBackedStateStoreProvider(storeId, directory)) + storeId, new HDFSBackedStateStoreProvider(storeId, directory, keySchema, valueSchema)) reportActiveInstance(storeId) provider } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 09884970c1f91..53500e9b3debc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import scala.reflect.ClassTag +import org.apache.spark.sql.types.StructType import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -34,6 +35,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( operatorId: Long, storeVersion: Long, storeDirectory: String, + keySchema: StructType, + valueSchema: StructType, @transient private val storeCoordinator: Option[StateStoreCoordinator]) extends RDD[U](dataRDD) { @@ -49,7 +52,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( Utils.tryWithSafeFinally { val storeId = StateStoreId(operatorId, partition.index) - store = StateStore.get(storeId, storeDirectory, storeVersion) + store = StateStore.get(storeId, storeDirectory, keySchema, valueSchema, storeVersion) val inputIter = dataRDD.compute(partition, ctxt) val outputIter = storeUpdateFunction(store, inputIter) assert(store.hasCommitted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index e772deeae7795..64f80252d9056 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -20,20 +20,30 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { - def withStateStores[U: ClassTag]( + def mapPartitionWithStateStore[U: ClassTag]( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], operatorId: Long, storeVersion: Long, storeDirectory: String, + keySchema: StructType, + valueSchema: StructType, storeCoordinator: Option[StateStoreCoordinator] = None ): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( - dataRDD, cleanedF, operatorId, storeVersion, storeDirectory, storeCoordinator) + dataRDD, + cleanedF, + operatorId, + storeVersion, + storeDirectory, + keySchema, + valueSchema, + storeCoordinator) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 82788d77579cf..ac5407930eafa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD @@ -35,6 +36,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) import StateStoreCoordinatorSuite._ import StateStoreSuite._ @@ -55,22 +58,22 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => store.update( - wrapKey(s), oldRow => { - val oldValue = oldRow.map(unwrapValue).getOrElse(0) - wrapValue(oldValue + 1) + keyToRow(s), oldRow => { + val oldValue = oldRow.map(rowToValue).getOrElse(0) + valueToRow(oldValue + 1) }) } store.commit() - store.iterator().map(unwrapKeyValue) + store.iterator().map(rowsToKeyValue) } val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")) - .withStateStores(increment, opId, storeVersion = 0, path) + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, opId, storeVersion = 0, path, keySchema, valueSchema) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")) - .withStateStores(increment, opId, storeVersion = 1, path) + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, opId, storeVersion = 1, path, keySchema, valueSchema) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -88,7 +91,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc: SparkContext, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - makeRDD(sc, Seq("a")).withStateStores(increment, opId, storeVersion, path) + makeRDD(sc, Seq("a")).mapPartitionWithStateStore( + increment, opId, storeVersion, path, keySchema, valueSchema) } // Generate RDDs and state store data @@ -114,8 +118,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinator.reportActiveInstance(StateStoreId(opId, 0), "host1", "exec1") coordinator.reportActiveInstance(StateStoreId(opId, 1), "host2", "exec2") - val rdd = makeRDD(sc, Seq("a", "b", "a")) - .withStateStores(increment, opId, storeVersion = 0, path, Some(coordinator)) + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, opId, storeVersion = 0, path, keySchema, valueSchema, Some(coordinator)) require(rdd.partitions.size === 2) assert( @@ -138,12 +142,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => store.update( - wrapKey(s), oldRow => { - val oldValue = oldRow.map(unwrapValue).getOrElse(0) - wrapValue(oldValue + 1) + keyToRow(s), oldRow => { + val oldValue = oldRow.map(rowToValue).getOrElse(0) + valueToRow(oldValue + 1) }) } store.commit() - store.iterator().map(unwrapKeyValue) + store.iterator().map(rowsToKeyValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index de9d61f77fb59..4213819ea93fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -42,6 +42,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth import StateStoreSuite._ private val tempDir = Utils.createTempDir().toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) after { StateStore.stop() @@ -79,9 +81,10 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(store.commit() === 1) assert(store.hasCommitted) - assert(unwrapToSet(store.iterator()) === Set("b" -> 2)) - assert(unwrapToSet(provider.latestIterator()) === Set("b" -> 2)) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) assert(fileExists(provider, version = 1, isSnapshot = false)) + assert(getDataFromFiles(provider) === Set("b" -> 2)) // Trying to get newer versions should fail @@ -93,10 +96,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // New updates to the reloaded store with new version, and does not change old version - val reloadedStore = new HDFSBackedStateStoreProvider(store.id, provider.directory).getStore(1) + val reloadedProvider = new HDFSBackedStateStoreProvider( + store.id, provider.directory, keySchema, valueSchema) + val reloadedStore = reloadedProvider.getStore(1) update(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) - assert(unwrapToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) @@ -117,8 +122,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth update(store, "aa", 1) update(store, "aa", 2) store.commit() - assert(unwrapUpdates(store.updates()) === Set(Added("a", 1), Added("aa", 2))) - assert(unwrapToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) } // Multiple updates to same key should be collapsed in the updates as a single value update @@ -127,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth update(store, "a", 4) update(store, "a", 6) store.commit() - assert(unwrapUpdates(store.updates()) === Set(Updated("a", 6))) - assert(unwrapToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) } // Keys added, updated and finally removed before commit should not appear in updates @@ -138,8 +143,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth update(store, "bb", 6) remove(store, _.startsWith("b")) store.commit() - assert(unwrapUpdates(store.updates()) === Set.empty) - assert(unwrapToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + assert(updatesToSet(store.updates()) === Set.empty) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) } // Removed data should be seen in updates as a key removed @@ -148,8 +153,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth remove(store, _.startsWith("a")) update(store, "a", 10) store.commit() - assert(unwrapUpdates(store.updates()) === Set(Updated("a", 10), Removed("aa"))) - assert(unwrapToSet(store.iterator()) === Set("a" -> 10)) + assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(rowsToSet(store.iterator()) === Set("a" -> 10)) } } @@ -158,7 +163,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store = provider.getStore(0) update(store, "a", 1) store.commit() - assert(unwrapToSet(store.iterator()) === Set("a" -> 1)) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) @@ -178,7 +183,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store = provider.getStore(0) update(store, "a", 1) assert(store.commit() === 1) - assert(unwrapToSet(store.iterator()) === Set("a" -> 1)) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) intercept[IllegalStateException] { provider.getStore(2) @@ -188,14 +193,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store1 = provider.getStore(1) update(store1, "b", 1) assert(store1.commit() === 2) - assert(unwrapToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data val store2 = provider.getStore(1) update(store2, "c", 1) assert(store2.commit() === 2) - assert(unwrapToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) + assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) } @@ -257,7 +262,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth provider.manage() // do cleanup } require( - unwrapToSet(provider.latestIterator()) === Set("a" -> 20), + rowsToSet(provider.latestIterator()) === Set("a" -> 20), "store not updated correctly") assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted @@ -274,32 +279,32 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { - StateStore.get(storeId, dir, -1) + StateStore.get(storeId, dir, keySchema, valueSchema, -1) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { - StateStore.get(storeId, dir, 1) + StateStore.get(storeId, dir, keySchema, valueSchema, 1) } // Increase version of the store - val store0 = StateStore.get(storeId, dir, 0) + val store0 = StateStore.get(storeId, dir, keySchema, valueSchema, 0) assert(store0.version === 0) update(store0, "a", 1) store0.commit() - assert(StateStore.get(storeId, dir, 1).version == 1) - assert(StateStore.get(storeId, dir, 0).version == 0) + assert(StateStore.get(storeId, dir, keySchema, valueSchema, 1).version == 1) + assert(StateStore.get(storeId, dir, keySchema, valueSchema, 0).version == 0) // Verify that you can remove the store and still reload and use it StateStore.remove(storeId) assert(!StateStore.isLoaded(storeId)) - val store1 = StateStore.get(storeId, dir, 1) + val store1 = StateStore.get(storeId, dir, keySchema, valueSchema, 1) assert(StateStore.isLoaded(storeId)) update(store1, "a", 2) assert(store1.commit() === 2) - assert(unwrapToSet(store1.iterator()) === Set("a" -> 2)) + assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) } } @@ -311,13 +316,13 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val opId = 0 val storeId = StateStoreId(opId, 0) val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString - val provider = new HDFSBackedStateStoreProvider(storeId, dir) + val provider = new HDFSBackedStateStoreProvider(storeId, dir, keySchema, valueSchema) quietly { withSpark(new SparkContext(conf)) { sc => withCoordinator(sc) { coordinator => for (i <- 1 to 20) { - val store = StateStore.get(storeId, dir, i - 1) + val store = StateStore.get(storeId, dir, keySchema, valueSchema, i - 1) update(store, "a", i) store.commit() } @@ -342,7 +347,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, dir, 20) + StateStore.get(storeId, dir, keySchema, valueSchema, 20) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -352,7 +357,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, dir, 20) + StateStore.get(storeId, dir, keySchema, valueSchema, 20) assert(StateStore.isLoaded(storeId)) } } @@ -367,11 +372,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def getDataFromFiles( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = new HDFSBackedStateStoreProvider(provider.id, provider.directory) + val reloadedProvider = new HDFSBackedStateStoreProvider( + provider.id, provider.directory, keySchema, valueSchema) if (version < 0) { - reloadedProvider.latestIterator().map(unwrapKeyValue).toSet + reloadedProvider.latestIterator().map(rowsToKeyValue).toSet } else { - reloadedProvider.iterator(version).map(unwrapKeyValue).toSet + reloadedProvider.iterator(version).map(rowsToKeyValue).toSet } } @@ -379,7 +385,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth testMapOption: Option[MapType], expectedMap: Map[String, Int]): Unit = { assert(testMapOption.nonEmpty, "no map present") - val convertedMap = testMapOption.get.map(unwrapKeyValue) + val convertedMap = testMapOption.get.map(rowsToKeyValue) assert(convertedMap === expectedMap) } @@ -413,15 +419,17 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth new HDFSBackedStateStoreProvider( StateStoreId(opId, partition), Utils.createDirectory(tempDir, Random.nextString(5)).toString, + StructType(Seq(StructField("key", StringType, true))), + StructType(Seq(StructField("value", IntegerType, true))), maxDeltaChainForSnapshots = maxDeltaChainForSnapshots) } def remove(store: StateStore, condition: String => Boolean): Unit = { - store.remove(row => condition(unwrapKey(row))) + store.remove(row => condition(rowToKey(row))) } private def update(store: StateStore, key: String, value: Int): Unit = { - store.update(wrapKey(key), _ => wrapValue(value)) + store.update(keyToRow(key), _ => valueToRow(value)) } } @@ -433,37 +441,37 @@ private[state] object StateStoreSuite { case class Updated(key: String, value: Int) extends TestUpdate case class Removed(key: String) extends TestUpdate - def wrapKey(s: String): UnsafeRow = { + def keyToRow(s: String): UnsafeRow = { val projection = UnsafeProjection.create(Array[DataType](StringType)) projection.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))) } - def wrapValue(i: Int): UnsafeRow = { + def valueToRow(i: Int): UnsafeRow = { val projection = UnsafeProjection.create(Array[DataType](IntegerType)) projection.apply(new GenericInternalRow(Array[Any](i))) } - def unwrapKey(row: UnsafeRow): String = { + def rowToKey(row: UnsafeRow): String = { row.getUTF8String(0).toString } - def unwrapValue(row: UnsafeRow): Int = { + def rowToValue(row: UnsafeRow): Int = { row.getInt(0) } - def unwrapKeyValue(row: (UnsafeRow, UnsafeRow)): (String, Int) = { - (unwrapKey(row._1), unwrapValue(row._2)) + def rowsToKeyValue(row: (UnsafeRow, UnsafeRow)): (String, Int) = { + (rowToKey(row._1), rowToValue(row._2)) } - def unwrapToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { - iterator.map(unwrapKeyValue).toSet + def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { + iterator.map(rowsToKeyValue).toSet } - def unwrapUpdates(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { iterator.map { _ match { - case ValueAdded(key, value) => Added(unwrapKey(key), unwrapValue(value)) - case ValueUpdated(key, value) => Updated(unwrapKey(key), unwrapValue(value)) - case KeyRemoved(key) => Removed(unwrapKey(key)) + case ValueAdded(key, value) => Added(rowToKey(key), rowToValue(value)) + case ValueUpdated(key, value) => Updated(rowToKey(key), rowToValue(value)) + case KeyRemoved(key) => Removed(rowToKey(key)) }}.toSet } } From b0bd0430591f1663998b8cf629bcf54b6bf12bfc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Mar 2016 20:10:56 -0700 Subject: [PATCH 25/46] state-store-squashed --- .../state/HDFSBackedStateStoreProvider.scala | 560 ++++++++++++++++++ .../streaming/state/StateStore.scala | 213 +++++++ .../state/StateStoreCoordinator.scala | 111 ++++ .../streaming/state/StateStoreRDD.scala | 69 +++ .../execution/streaming/state/package.scala | 49 ++ .../state/StateStoreCoordinatorSuite.scala | 120 ++++ .../streaming/state/StateStoreRDDSuite.scala | 153 +++++ .../streaming/state/StateStoreSuite.scala | 490 +++++++++++++++ 8 files changed, 1765 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala new file mode 100644 index 0000000000000..434451c4d2454 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -0,0 +1,560 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.{DataInputStream, DataOutputStream} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random +import scala.util.control.NonFatal + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.compress.Lz4Codec + +import org.apache.spark.io.{CompressionCodec, LZ4CompressionCodec} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + + +/** + * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed + * by files in a HDFS-compatible file system. All updates to the store has to be done in sets + * transactionally, and each set of updates increments the store's version. These versions can + * be used to re-execute the updates (by retries in RDD operations) on the correct version of + * the store, and regenerate the store version. + * + * Usage: + * To update the data in the state store, the following order of operations are needed. + * + * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store + * - store.update(...) + * - store.remove(...) + * - store.commit() // commits all the updates to made with version number + * - store.iterator() // key-value data after last commit as an iterator + * - store.updates() // updates made in the last as an iterator + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Multiple attempts to commit the same version of updates may overwrite each other. + * Consistency guarantees depend on whether multiple attempts have the same updates and + * the overwrite semantics of underlying file system. + * - Background maintenance of files ensures that last versions of the store is always recoverable + * to ensure re-executed RDD operations re-apply updates on the correct past version of the + * store. + */ +private[state] class HDFSBackedStateStoreProvider( + val id: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + sparkConf: SparkConf, + hadoopConf: Configuration + ) extends StateStoreProvider with Logging { + + import HDFSBackedStateStoreProvider._ + + type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + + /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ + class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) + extends StateStore { + + /** Trait and classes representing the internal state of the store */ + trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object CANCELLED extends STATE + + private val newVersion = version + 1 + private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + private val tempDeltaFileStream = new DataOutputStream( + new LZ4CompressionCodec(new SparkConf).compressedOutputStream(fs.create(tempDeltaFile, true))) + + // private val tempDeltaFileStream = fs.create(tempDeltaFile, true) + private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() + + @volatile private var state: STATE = UPDATING + @volatile private var finalDeltaFile: Path = null + + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + + /** Update the value of a key using the value generated by the update function */ + override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot update after already committed or cancelled") + val oldValueOption = Option(mapToUpdate.get(key)) + val value = updateFunc(oldValueOption) + mapToUpdate.put(key, value) + + Option(allUpdates.get(key)) match { + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added already, keep it marked as added + allUpdates.put(key, ValueAdded(key, value)) + case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + // Value existed in prev version and updated/removed, mark it as updated + allUpdates.put(key, ValueUpdated(key, value)) + case None => + // There was no prior update, so mark this as added or updated according to its presence + // in previous version. + val update = + if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + allUpdates.put(key, update) + } + writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + } + + /** Remove keys that match the following condition */ + override def remove(condition: UnsafeRow => Boolean): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + val keyIter = mapToUpdate.keySet().iterator() + while (keyIter.hasNext) { + val key = keyIter.next + if (condition(key)) { + keyIter.remove() + + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, KeyRemoved(key)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(KeyRemoved(_)) => + // Remove already in update map, no need to change + } + writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + } + } + } + + /** Commit all the updates that have been made to the store. */ + override def commit(): Long = { + verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + + try { + finalizeDeltaFile(tempDeltaFileStream) + finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + state = COMMITTED + newVersion + } catch { + case NonFatal(e) => + throw new IllegalStateException( + s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + } + } + + /** Cancel all the updates made on this store. This store will not be usable any more. */ + override def cancel(): Unit = { + state = CANCELLED + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { + fs.delete(tempDeltaFile, true) + } + logInfo("Canceled ") + } + + /** + * Get an iterator of all the store data. This can be called only after committing the + * updates. + */ + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + verify(state == COMMITTED, "Cannot get iterator of store data before comitting") + HDFSBackedStateStoreProvider.this.iterator(newVersion) + } + + /** + * Get an iterator of all the updates made to the store in the current version. + * This can be called only after committing the updates. + */ + override def updates(): Iterator[StoreUpdate] = { + verify(state == COMMITTED, "Cannot get iterator of updates before committing") + allUpdates.values().asScala.toIterator + } + + /** + * Whether all updates have been committed + */ + override def hasCommitted: Boolean = { + state == COMMITTED + } + } + + /** Get the state store for making updates to create a new `version` of the store. */ + override def getStore(version: Long): StateStore = synchronized { + require(version >= 0, "Version cannot be less than 0") + val newMap = new MapType() + if (version > 0) { + val time = System.nanoTime() + // newMap ++= loadMap(version) + newMap.putAll(loadMap(version)) + } + new HDFSBackedStateStore(version, newMap) + } + + /** Manage backing files, including creating snapshots and cleaning up old files */ + override def doMaintenance(): Unit = { + try { + doSnapshot() + cleanup() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $this") + } + } + + override def toString(): String = { + s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + } + + /* Internal classes and methods */ + + private val loadedMaps = new mutable.HashMap[Long, MapType] + private val baseDir = new Path(id.rootLocation, s"${id.operatorId}/${id.partitionId.toString}") + private val fs = baseDir.getFileSystem(hadoopConf) + private val serializer = new KryoSerializer(sparkConf) + private val minBatchesToRetain = sparkConf.getInt( + MIN_BATCHES_TO_RETAIN_CONF, DEFAULT_MIN_BATCHES_TO_RETAIN) + private val maxDeltaChainForSnapshots = sparkConf.getInt( + MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, DEFAULT_MAX_DELTA_CHAIN_FOR_SNAPSHOTS) + + initialize() + + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + /** Commit a set of updates to the store with the given new version */ + private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + synchronized { + val finalDeltaFile = deltaFile(newVersion) + fs.rename(tempDeltaFile, finalDeltaFile) + loadedMaps.put(newVersion, map) + finalDeltaFile + } + } + + /** + * Get iterator of all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ + private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } else Iterator.empty + } + + /** Get iterator of a specific version of the store */ + private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + loadMap(version).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } + + /** Initialize the store provider */ + private def initialize(): Unit = { + if (!fs.exists(baseDir)) { + fs.mkdirs(baseDir) + } else { + if (!fs.isDirectory(baseDir)) { + throw new IllegalStateException( + s"Cannot use ${id.rootLocation} for storing state data as" + + s"$baseDir already exists and is not a directory") + } + } + } + + /** Load the required version of the map data from the backing files */ + private def loadMap(version: Long): MapType = { + if (version <= 0) return new MapType + synchronized { loadedMaps.get(version) }.getOrElse { + val mapFromFile = readSnapshotFile(version).getOrElse { + val prevMap = loadMap(version - 1) + val newMap = new MapType(prevMap) + newMap.putAll(prevMap) + updateFromDeltaFile(version, newMap) + newMap + } + loadedMaps.put(version, mapFromFile) + mapFromFile + } + } + + private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { + + def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + + def writeRemove(key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) + } + + update match { + case ValueAdded(key, value) => + writeUpdate(key, value) + case ValueUpdated(key, value) => + writeUpdate(key, value) + case KeyRemoved(key) => + writeRemove(key) + } + } + + private def finalizeDeltaFile(output: DataOutputStream): Unit = { + output.writeInt(-1) // Write this magic number to signify end of file + output.close() + } + + private def updateFromDeltaFile(version: Long, map: MapType): Unit = { + val fileToRead = deltaFile(version) + if (!fs.exists(fileToRead)) { + throw new IllegalStateException( + s"Error reading delta file $fileToRead of $this: $fileToRead does not exist") + } + var input: DataInputStream = null + try { + input = new DataInputStream( + new LZ4CompressionCodec(new SparkConf).compressedInputStream(fs.open(fileToRead))) + var eof = false + + while(!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new Exception( + s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + map.remove(keyRow) + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + } finally { + if (input != null) input.close() + } + } + + private def writeSnapshotFile(version: Long, map: MapType): Unit = { + val fileToWrite = snapshotFile(version) + val ser = serializer.newInstance() + var outputStream: SerializationStream = null + Utils.tryWithSafeFinally { + outputStream = ser.serializeStream(fs.create(fileToWrite, false)) + outputStream.writeAll(map.entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + }) + } { + if (outputStream != null) outputStream.close() + } + } + + private def readSnapshotFile(version: Long): Option[MapType] = { + val fileToRead = snapshotFile(version) + if (!fs.exists(fileToRead)) return None + + val deser = serializer.newInstance() + val map = new MapType() + var deserStream: DeserializationStream = null + + try { + deserStream = deser.deserializeStream(fs.open(fileToRead)) + val iter = deserStream.asIterator.asInstanceOf[Iterator[(UnsafeRow, UnsafeRow)]] + while (iter.hasNext) { + // map += iter.next() + val (key, value) = iter.next() + map.put(key, value) + } + Some(map) + } finally { + if (deserStream != null) deserStream.close() + } + } + + + /** Perform a snapshot of the store to allow delta files to be consolidated */ + private def doSnapshot(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + synchronized { loadedMaps.get(lastVersion) } match { + case Some(map) => + if (deltaFilesForLastVersion.size > maxDeltaChainForSnapshots) { + writeSnapshotFile(lastVersion, map) + } + case None => + // The last map is not loaded, probably some other instance is incharge + } + + } + } catch { + case NonFatal(e) => + logWarning(s"Error doing snapshots for $this", e) + } + } + + /** + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ + private[state] def cleanup(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - minBatchesToRetain + if (earliestVersionToRetain > 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + synchronized { + val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq + mapsToRemove.foreach(loadedMaps.remove) + } + files.filter(_.version < earliestFileToRetain.version).foreach { f => + fs.delete(f.path, true) + } + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this", e) + } + } + + /** Files needed to recover the given version of the store */ + private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + val deltaBatchIds = (snapshotFile.version + 1) to version + + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + } + verify( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version: ${deltaFiles.mkString(",")}" + ) + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + /** Fetch all the files that back the store */ + private def fetchFiles(): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fs.listStatus(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + files.foreach { status => + val path = status.getPath + val nameParts = path.getName.split("\\.") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + logWarning(s"Could not identify file $path") + } + } + } + versionToFiles.values.toSeq.sortBy(_.version) + } + + private def compressStream(outputStream: DataOutputStream): DataOutputStream = { + val compressed = new LZ4CompressionCodec(new SparkConf).compressedOutputStream(outputStream) + new DataOutputStream(compressed) + } + + private def compressStream(inputStream: DataInputStream): DataInputStream = { + val compressed = new LZ4CompressionCodec(new SparkConf).compressedInputStream(inputStream) + new DataInputStream(compressed) + } + + private def deltaFile(version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + private def snapshotFile(version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } +} + +private[state] object HDFSBackedStateStoreProvider { + val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF = "spark.sql.streaming.stateStore.maxDeltaChain" + val DEFAULT_MAX_DELTA_CHAIN_FOR_SNAPSHOTS = 10 + + val MIN_BATCHES_TO_RETAIN_CONF = "spark.sql.streaming.stateStore.minBatchesToRetain" + val DEFAULT_MIN_BATCHES_TO_RETAIN = 2 +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala new file mode 100644 index 0000000000000..8205f4b80d9c1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.{Timer, TimerTask} + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, Logging, SparkEnv} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType + + +/** Unique identifier for a [[StateStore]] */ +case class StateStoreId(rootLocation: String, operatorId: Long, partitionId: Int) + +/** + * Base trait for a versioned key-value store used for streaming aggregations + */ +trait StateStore { + + /** Unique identifier of the store */ + def id: StateStoreId + + /** Version of the data in this store before committing updates. */ + def version: Long + + /** + * Update the value of a key using the value generated by the update function. + * This can be called only after prepareForUpdates() has been called in the same thread. + */ + def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + + /** + * Remove keys that match the following condition. + * This can be called only after prepareForUpdates() has been called in the current thread. + */ + def remove(condition: UnsafeRow => Boolean): Unit + + /** + * Commit all the updates that have been made to the store. + * This can be called only after prepareForUpdates() has been called in the current thread. + */ + def commit(): Long + + /** Cancel all the updates that have been made to the store. */ + def cancel(): Unit + + /** + * Iterator of store data after a set of updates have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ + def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + + /** + * Iterator of the updates that have been committed. + * This can be called only after commitUpdates() has been called in the current thread. + */ + def updates(): Iterator[StoreUpdate] + + /** + * Whether all updates have been committed + */ + def hasCommitted: Boolean +} + + +trait StateStoreProvider { + + /** Get the store with the existing version. */ + def getStore(version: Long): StateStore + + /** Optional method for providers to allow for background management */ + def doMaintenance(): Unit = { } +} + +sealed trait StoreUpdate +case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +case class KeyRemoved(key: UnsafeRow) extends StoreUpdate + + +/** + * Companion object to [[StateStore]] that provides helper methods to create and retrive stores + * by their unique ids. + */ +private[state] object StateStore extends Logging { + + val MANAGEMENT_TASK_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" + val MANAGEMENT_TASK_INTERVAL_SECS = 60 + + private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val managementTimer = new Timer("StateStore Timer", true) + @volatile private var managementTask: TimerTask = null + + /** Get or create a store associated with the id. */ + def get( + storeId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + version: Long, + hadoopConf: Configuration + ): StateStore = { + require(version >= 0) + val storeProvider = loadedProviders.synchronized { + startIfNeeded() + val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val provider = loadedProviders.getOrElseUpdate( + storeId, + new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, sparkConf, hadoopConf)) + reportActiveInstance(storeId) + provider + } + storeProvider.getStore(version) + } + + def remove(storeId: StateStoreId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeId) + } + + def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeId) + } + + /** Unload and stop all state store provider */ + def stop(): Unit = loadedProviders.synchronized { + loadedProviders.clear() + if (managementTask != null) { + managementTask.cancel() + managementTask = null + logInfo("StateStore stopped") + } + } + + private def startIfNeeded(): Unit = loadedProviders.synchronized { + if (managementTask == null) { + managementTask = new TimerTask { + override def run(): Unit = { manageAll() } + } + val periodMs = Option(SparkEnv.get).map(_.conf) match { + case Some(conf) => + conf.getTimeAsMs( + "spark.sql.streaming.stateStore.managementInterval", + s"${MANAGEMENT_TASK_INTERVAL_SECS}s") + case None => + MANAGEMENT_TASK_INTERVAL_SECS * 1000 + } + managementTimer.schedule(managementTask, periodMs, periodMs) + logInfo("StateStore started") + } + } + + /** + * Execute background management task in all the loaded store providers if they are still + * the active instances according to the coordinator. + */ + private def manageAll(): Unit = { + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfInstanceActive(id)) { + provider.doMaintenance() + } else { + remove(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider") + } + } + } + + private def reportActiveInstance(storeId: StateStoreId): Unit = { + try { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + StateStoreCoordinator.ask(ReportActiveInstance(storeId, host, executorId)) + } catch { + case NonFatal(e) => + logWarning(s"Error reporting active instance of $storeId") + } + } + + private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { + try { + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + StateStoreCoordinator.ask(VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) + } catch { + case NonFatal(e) => + logWarning(s"Error verifying active instance of $storeId") + false + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala new file mode 100644 index 0000000000000..fa208ce00d1b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint +import org.apache.spark.util.RpcUtils + +/** Trait representing all messages to [[StateStoreCoordinator]] */ +private sealed trait StateStoreCoordinatorMessage extends Serializable + +private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) + extends StateStoreCoordinatorMessage +private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) + extends StateStoreCoordinatorMessage +private object StopCoordinator extends StateStoreCoordinatorMessage + + +/** Class for coordinating instances of [[StateStore]]s loaded in the cluster */ +class StateStoreCoordinator(rpcEnv: RpcEnv) { + private val coordinatorRef = rpcEnv.setupEndpoint( + StateStoreCoordinator.endpointName, new StateStoreCoordinatorEndpoint(rpcEnv, this)) + private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + + /** Report active instance of a state store in an executor */ + def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Boolean = { + instances.synchronized { instances.put(storeId, ExecutorCacheTaskLocation(host, executorId)) } + true + } + + /** Verify whether the given executor has the active instance of a state store */ + def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { + instances.synchronized { + instances.get(storeId) match { + case Some(location) => location.executorId == executorId + case None => false + } + } + } + + /** Get the location of the state store */ + def getLocation(storeId: StateStoreId): Option[String] = { + instances.synchronized { instances.get(storeId).map(_.toString) } + } + + /** Deactivate instances related to a set of operator */ + def deactivateInstances(storeRootLocation: String): Unit = { + instances.synchronized { + val storeIdsToRemove = + instances.keys.filter(_.rootLocation == storeRootLocation).toSeq + instances --= storeIdsToRemove + } + } + + def stop(): Unit = { + coordinatorRef.askWithRetry[Boolean](StopCoordinator) + } +} + + +private[sql] object StateStoreCoordinator { + + private val endpointName = "StateStoreCoordinator" + + private class StateStoreCoordinatorEndpoint( + override val rpcEnv: RpcEnv, coordinator: StateStoreCoordinator) + extends RpcEndpoint with Logging { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case ReportActiveInstance(id, host, executorId) => + context.reply(coordinator.reportActiveInstance(id, host, executorId)) + case VerifyIfInstanceActive(id, executor) => + context.reply(coordinator.verifyIfInstanceActive(id, executor)) + case StopCoordinator => + // Stop before replying to ensure that endpoint name has been deregistered + stop() + context.reply(true) + } + } + + def ask(message: StateStoreCoordinatorMessage): Option[Boolean] = { + val env = SparkEnv.get + if (env != null) { + val coordinatorRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + Some(coordinatorRef.askWithRetry[Boolean](message)) + } else { + None + } + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala new file mode 100644 index 0000000000000..ee4cf9928cc54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.sql.types.StructType +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * An RDD that allows computations to be executed against [[StateStore]]s. It + * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as + * preferred locations. + */ +class StateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + storeRootLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + @transient private val storeCoordinator: Option[StateStoreCoordinator]) + extends RDD[U](dataRDD) { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val confBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(dataRDD.context.hadoopConfiguration)) + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val storeId = StateStoreId(storeRootLocation, operatorId, partition.index) + storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + var store: StateStore = null + + Utils.tryWithSafeFinally { + val storeId = StateStoreId(storeRootLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + val outputIter = storeUpdateFunction(store, inputIter) + assert(store.hasCommitted) + outputIter + } { + if (store != null) store.cancel() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala new file mode 100644 index 0000000000000..8819008dbfa93 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType + +package object state { + + implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { + def mapPartitionWithStateStore[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + storeRootLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeCoordinator: Option[StateStoreCoordinator] = None + ): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new StateStoreRDD( + dataRDD, + cleanedF, + storeRootLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + storeCoordinator) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala new file mode 100644 index 0000000000000..80278bdf2fed5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.util.RpcUtils + +class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + + import StateStoreCoordinatorSuite._ + + test("report, verify, getLocation") { + withCoordinator(sc) { coordinator => + val id = StateStoreId("x", 0, 0) + + assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinator.getLocation(id) === None) + + assert(coordinator.reportActiveInstance(id, "hostX", "exec1") === true) + assert(coordinator.verifyIfInstanceActive(id, "exec1") === true) + assert(coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + + assert(coordinator.reportActiveInstance(id, "hostX", "exec2") === true) + assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinator.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } + } + + test("make inactive") { + withCoordinator(sc) { coordinator => + val id1 = StateStoreId("x", 0, 0) + val id2 = StateStoreId("y", 1, 0) + val id3 = StateStoreId("x", 0, 1) + val host = "hostX" + val exec = "exec1" + + assert(coordinator.reportActiveInstance(id1, host, exec) === true) + assert(coordinator.reportActiveInstance(id2, host, exec) === true) + assert(coordinator.reportActiveInstance(id3, host, exec) === true) + + assert(coordinator.verifyIfInstanceActive(id1, exec) === true) + assert(coordinator.verifyIfInstanceActive(id2, exec) === true) + assert(coordinator.verifyIfInstanceActive(id3, exec) === true) + + coordinator.deactivateInstances("x") + + assert(coordinator.verifyIfInstanceActive(id1, exec) === false) + assert(coordinator.verifyIfInstanceActive(id2, exec) === true) + assert(coordinator.verifyIfInstanceActive(id3, exec) === false) + + assert(coordinator.getLocation(id1) === None) + assert( + coordinator.getLocation(id2) === + Some(ExecutorCacheTaskLocation(host, exec).toString)) + assert(coordinator.getLocation(id3) === None) + + coordinator.deactivateInstances("y") + assert(coordinator.verifyIfInstanceActive(id2, exec) === false) + assert(coordinator.getLocation(id2) === None) + } + } + + test("communication") { + withCoordinator(sc) { coordinator => + import StateStoreCoordinator._ + val id = StateStoreId("x", 0, 0) + val host = "hostX" + + val ref = RpcUtils.makeDriverRef("StateStoreCoordinator", sc.env.conf, sc.env.rpcEnv) + + assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(false)) + + ask(ReportActiveInstance(id, host, "exec1")) + assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(true)) + assert( + coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation(host, "exec1").toString)) + + ask(ReportActiveInstance(id, host, "exec2")) + assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(false)) + assert(ask(VerifyIfInstanceActive(id, "exec2")) === Some(true)) + assert( + coordinator.getLocation(id) === + Some(ExecutorCacheTaskLocation(host, "exec2").toString)) + } + } +} + +object StateStoreCoordinatorSuite { + def withCoordinator(sc: SparkContext)(body: StateStoreCoordinator => Unit): Unit = { + var coordinator: StateStoreCoordinator = null + try { + coordinator = new StateStoreCoordinator(sc.env.rpcEnv) + body(coordinator) + } finally { + if (coordinator != null) coordinator.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala new file mode 100644 index 0000000000000..1f81c0bed6bd0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.nio.file.Files + +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.util.Utils + +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + + private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) + private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + import StateStoreCoordinatorSuite._ + import StateStoreSuite._ + + after { + StateStore.stop() + } + + override def afterAll(): Unit = { + super.afterAll() + Utils.deleteRecursively(new File(tempDir)) + } + + test("versioning and immutability") { + quietly { + withSpark(new SparkContext(conf)) { sc => + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + test("recovering from files") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + makeRDD(sc, Seq("a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion, keySchema, valueSchema) + } + + // Generate RDDs and state store data + withSpark(new SparkContext(conf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(conf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } + } + } + + test("preferred locations using StateStoreCoordinator") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + withSpark(new SparkContext(conf)) { sc => + withCoordinator(sc) { coordinator => + coordinator.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinator.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema, Some(coordinator)) + require(rdd.partitions.size === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() + } + } + } + + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { + sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) + } + + private val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala new file mode 100644 index 0000000000000..a59fd370be3c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import scala.collection.mutable +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + + import HDFSBackedStateStoreProvider._ + import StateStoreCoordinatorSuite._ + import StateStoreSuite._ + + private val tempDir = Utils.createTempDir().toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + after { + StateStore.stop() + } + + test("update, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator().isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + + // Verify state after updating + update(store, "a", 1) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + assert(provider.latestIterator().isEmpty) + + // Make updates, commit and then verify state + update(store, "b", 2) + update(store, "aa", 3) + remove(store, _.startsWith("a")) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) + assert(fileExists(provider, version = 1, isSnapshot = false)) + + assert(getDataFromFiles(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getDataFromFiles(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = new HDFSBackedStateStoreProvider( + store.id, keySchema, valueSchema, new SparkConf, new Configuration) + val reloadedStore = reloadedProvider.getStore(1) + update(reloadedStore, "c", 4) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) + assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) + } + + test("updates iterator with all combos of updates and removes") { + val provider = newStoreProvider() + var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { + val store = provider.getStore(currentVersion) + body(store) + currentVersion += 1 + } + + // New data should be seen in updates as value added, even if they had multiple updates + withStore { store => + update(store, "a", 1) + update(store, "aa", 1) + update(store, "aa", 2) + store.commit() + assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + } + + // Multiple updates to same key should be collapsed in the updates as a single value update + // Keys that have not been updated should not appear in the updates + withStore { store => + update(store, "a", 4) + update(store, "a", 6) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Keys added, updated and finally removed before commit should not appear in updates + withStore { store => + update(store, "b", 4) // Added, finally removed + update(store, "bb", 5) // Added, updated, finally removed + update(store, "bb", 6) + remove(store, _.startsWith("b")) + store.commit() + assert(updatesToSet(store.updates()) === Set.empty) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Removed data should be seen in updates as a key removed + // Removed, but re-added data should be seen in updates as a value update + withStore { store => + remove(store, _.startsWith("a")) + update(store, "a", 10) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(rowsToSet(store.iterator()) === Set("a" -> 10)) + } + } + + test("cancel") { + val provider = newStoreProvider() + val store = provider.getStore(0) + update(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + update(store1, "b", 1) + store1.cancel() + assert(getDataFromFiles(provider) === Set("a" -> 1)) + } + + test("getStore with unexpected versions") { + val provider = newStoreProvider() + + intercept[IllegalArgumentException] { + provider.getStore(-1) + } + + // Prepare some data in the stoer + val store = provider.getStore(0) + update(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + intercept[IllegalStateException] { + provider.getStore(2) + } + + // Update store version with some data + val store1 = provider.getStore(1) + update(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) + + // Overwrite the version with other data + val store2 = provider.getStore(1) + update(store2, "c", 1) + assert(store2.commit() === 2) + assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) + } + + test("snapshotting") { + val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) + + var currentVersion = 0 + def updateVersionTo(targetVersion: Int): Unit = { + for (i <- currentVersion + 1 to targetVersion) { + val store = provider.getStore(currentVersion) + update(store, "a", i) + store.commit() + currentVersion += 1 + } + require(currentVersion === targetVersion) + } + + updateVersionTo(2) + require(getDataFromFiles(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + updateVersionTo(6) + require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 6), "snapshotting messed up the data") + assert(getDataFromFiles(provider) === Set("a" -> 6)) + + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + + // After version 20, snapshotting should generate newer snapshot files + updateVersionTo(20) + require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + assert(getDataFromFiles(provider) === Set("a" -> 20)) + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + } + + test("cleaning") { + val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) + + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowsToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) + assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + } + + test("StateStore.get") { + quietly { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, keySchema, valueSchema, -1, new Configuration) + } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + + intercept[IllegalStateException] { + StateStore.get(storeId, keySchema, valueSchema, 1, new Configuration) + } + + // Increase version of the store + val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, new Configuration) + assert(store0.version === 0) + update(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, keySchema, valueSchema, 1, new Configuration).version == 1) + assert(StateStore.get(storeId, keySchema, valueSchema, 0, new Configuration).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.remove(storeId) + assert(!StateStore.isLoaded(storeId)) + + val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, new Configuration) + assert(StateStore.isLoaded(storeId)) + update(store1, "a", 2) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + } + } + + test("background management") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.sql.streaming.stateStore.managementInterval", "10ms") + val opId = 0 + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, opId, 0) + val provider = new HDFSBackedStateStoreProvider( + storeId, keySchema, valueSchema, conf, new Configuration) + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinator(sc) { coordinator => + for (i <- 1 to 20) { + val store = StateStore.get(storeId, keySchema, valueSchema, i - 1, new Configuration) + update(store, "a", i) + store.commit() + } + + // Background management should clean up and generate snapshots + eventually(timeout(4 seconds)) { + // Earliest delta file should get cleaned up + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + + // Some snapshots should have been generated + val snapshotVersions = (0 to 20).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // If driver decides to deactivate all instances of the store, then this instance + // should be unloaded + coordinator.deactivateInstances(dir) + eventually(timeout(4 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, new Configuration) + assert(StateStore.isLoaded(storeId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinator.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(4 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, new Configuration) + assert(StateStore.isLoaded(storeId)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + eventually(timeout(4 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + } + } + + def getDataFromFiles( + provider: HDFSBackedStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = new HDFSBackedStateStoreProvider( + provider.id, keySchema, valueSchema, new SparkConf, new Configuration) + if (version < 0) { + reloadedProvider.latestIterator().map(rowsToStringInt).toSet + } else { + reloadedProvider.iterator(version).map(rowsToStringInt).toSet + } + } + + def assertMap( + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { + assert(testMapOption.nonEmpty, "no map present") + val convertedMap = testMapOption.get.map(rowsToStringInt) + assert(convertedMap === expectedMap) + } + + def fileExists( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def storeLoaded(storeId: StateStoreId): Boolean = { + val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) + val loadedStores = StateStore invokePrivate method() + loadedStores.contains(storeId) + } + + def unloadStore(storeId: StateStoreId): Boolean = { + val method = PrivateMethod('remove) + StateStore invokePrivate method(storeId) + } + + def newStoreProvider( + opId: Long = Random.nextLong, + partition: Int = 0, + maxDeltaChainForSnapshots: Int = DEFAULT_MAX_DELTA_CHAIN_FOR_SNAPSHOTS + ): HDFSBackedStateStoreProvider = { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val sparkConf = new SparkConf() + .set(MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, maxDeltaChainForSnapshots.toString) + new HDFSBackedStateStoreProvider( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + sparkConf, + new Configuration()) + } + + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.remove(row => condition(rowToString(row))) + } + + private def update(store: StateStore, key: String, value: Int): Unit = { + store.update(stringToRow(key), _ => intToRow(value)) + } +} + +private[state] object StateStoreSuite { + + /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ + trait TestUpdate + case class Added(key: String, value: Int) extends TestUpdate + case class Updated(key: String, value: Int) extends TestUpdate + case class Removed(key: String) extends TestUpdate + + val strProj = UnsafeProjection.create(Array[DataType](StringType)) + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + + def stringToRow(s: String): UnsafeRow = { + strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy() + } + + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToString(row: UnsafeRow): String = { + row.getUTF8String(0).toString + } + + def rowToInt(row: UnsafeRow): Int = { + row.getInt(0) + } + + def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { + (rowToInt(row._1), rowToInt(row._2)) + } + + + def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { + (rowToString(row._1), rowToInt(row._2)) + } + + def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } + + def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + iterator.map { _ match { + case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) + case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) + case KeyRemoved(key) => Removed(rowToString(key)) + }}.toSet + } +} From 3fe34e90bf7820c73ede7901422a48b83c954f27 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 Mar 2016 15:58:45 -0700 Subject: [PATCH 26/46] Style fixes --- .../streaming/state/HDFSBackedStateStoreProvider.scala | 5 ++--- .../spark/sql/execution/streaming/state/StateStore.scala | 2 +- .../spark/sql/execution/streaming/state/StateStoreRDD.scala | 2 +- .../sql/execution/streaming/state/StateStoreRDDSuite.scala | 2 +- .../sql/execution/streaming/state/StateStoreSuite.scala | 3 +-- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 434451c4d2454..92bd4e1cb065c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -27,10 +27,9 @@ import scala.util.control.NonFatal import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.compress.Lz4Codec -import org.apache.spark.io.{CompressionCodec, LZ4CompressionCodec} import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType @@ -557,4 +556,4 @@ private[state] object HDFSBackedStateStoreProvider { val MIN_BATCHES_TO_RETAIN_CONF = "spark.sql.streaming.stateStore.minBatchesToRetain" val DEFAULT_MIN_BATCHES_TO_RETAIN = 2 -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 8205f4b80d9c1..7bc1c9a20f13e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, Logging, SparkEnv} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index ee4cf9928cc54..6ca394a30d635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.state import scala.reflect.ClassTag -import org.apache.spark.sql.types.StructType import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SerializableConfiguration, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 1f81c0bed6bd0..f417e806fe0be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -24,12 +24,12 @@ import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.Utils class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index a59fd370be3c9..758326de502db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -24,11 +24,10 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} From 8cb0da8a70e415297f7ee1d25d25eae1cc095b53 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 Mar 2016 19:41:52 -0700 Subject: [PATCH 27/46] Some more cleanup --- .../spark/sql/ContinuousQueryManager.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 8 ++-- .../streaming/state/StateStore.scala | 37 ++++++++++--------- .../streaming/state/StateStoreSuite.scala | 4 +- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 0a156ea99a297..5226de27666aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} import org.apache.spark.sql.util.ContinuousQueryListener diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 92bd4e1cb065c..2e7866643a805 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -234,9 +234,9 @@ private[state] class HDFSBackedStateStoreProvider( private val fs = baseDir.getFileSystem(hadoopConf) private val serializer = new KryoSerializer(sparkConf) private val minBatchesToRetain = sparkConf.getInt( - MIN_BATCHES_TO_RETAIN_CONF, DEFAULT_MIN_BATCHES_TO_RETAIN) + MIN_BATCHES_TO_RETAIN_CONF, MIN_BATCHES_TO_RETAIN_DEFAULT) private val maxDeltaChainForSnapshots = sparkConf.getInt( - MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, DEFAULT_MAX_DELTA_CHAIN_FOR_SNAPSHOTS) + MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT) initialize() @@ -552,8 +552,8 @@ private[state] class HDFSBackedStateStoreProvider( private[state] object HDFSBackedStateStoreProvider { val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF = "spark.sql.streaming.stateStore.maxDeltaChain" - val DEFAULT_MAX_DELTA_CHAIN_FOR_SNAPSHOTS = 10 + val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT = 10 val MIN_BATCHES_TO_RETAIN_CONF = "spark.sql.streaming.stateStore.minBatchesToRetain" - val DEFAULT_MIN_BATCHES_TO_RETAIN = 2 + val MIN_BATCHES_TO_RETAIN_DEFAULT = 2 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7bc1c9a20f13e..b97025ba478fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -99,13 +99,17 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate /** - * Companion object to [[StateStore]] that provides helper methods to create and retrive stores - * by their unique ids. + * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores + * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), + * it also runs a periodic background tasks to do maintenance on the loaded stores. For each + * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of + * the store is the active instance. Accordingly, it either keeps it loaded and performance + * maintenance, or unloads the store. */ private[state] object StateStore extends Logging { - val MANAGEMENT_TASK_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" - val MANAGEMENT_TASK_INTERVAL_SECS = 60 + val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" + val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val managementTimer = new Timer("StateStore Timer", true) @@ -121,7 +125,7 @@ private[state] object StateStore extends Logging { ): StateStore = { require(version >= 0) val storeProvider = loadedProviders.synchronized { - startIfNeeded() + startMaintenanceIfNeeded() val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val provider = loadedProviders.getOrElseUpdate( storeId, @@ -150,21 +154,18 @@ private[state] object StateStore extends Logging { } } - private def startIfNeeded(): Unit = loadedProviders.synchronized { - if (managementTask == null) { + /** Start the periodic maintenance task if not already started and if Spark active */ + private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { + val env = SparkEnv.get + if (managementTask == null && env != null) { managementTask = new TimerTask { - override def run(): Unit = { manageAll() } - } - val periodMs = Option(SparkEnv.get).map(_.conf) match { - case Some(conf) => - conf.getTimeAsMs( - "spark.sql.streaming.stateStore.managementInterval", - s"${MANAGEMENT_TASK_INTERVAL_SECS}s") - case None => - MANAGEMENT_TASK_INTERVAL_SECS * 1000 + override def run(): Unit = { doMaintenance() } } + val periodMs = env.conf.getTimeAsMs( + MAINTENANCE_INTERVAL_CONFIG, + s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") managementTimer.schedule(managementTask, periodMs, periodMs) - logInfo("StateStore started") + logInfo("StateStore maintenance timer started") } } @@ -172,7 +173,7 @@ private[state] object StateStore extends Logging { * Execute background management task in all the loaded store providers if they are still * the active instances according to the coordinator. */ - private def manageAll(): Unit = { + private def doMaintenance(): Unit = { loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => try { if (verifyIfInstanceActive(id)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 758326de502db..7e847ab901ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -314,7 +314,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val conf = new SparkConf() .setMaster("local") .setAppName("test") - .set("spark.sql.streaming.stateStore.managementInterval", "10ms") + .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val storeId = StateStoreId(dir, opId, 0) @@ -417,7 +417,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def newStoreProvider( opId: Long = Random.nextLong, partition: Int = 0, - maxDeltaChainForSnapshots: Int = DEFAULT_MAX_DELTA_CHAIN_FOR_SNAPSHOTS + maxDeltaChainForSnapshots: Int = MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT ): HDFSBackedStateStoreProvider = { val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val sparkConf = new SparkConf() From 9fb9c43b0d8ebef0b437226698263f058c04cf18 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 Mar 2016 19:44:41 -0700 Subject: [PATCH 28/46] Reverting unnecessary changes --- core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- .../scala/org/apache/spark/sql/ContinuousQueryManager.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index cf4ad81e2bcd7..e27d2e6c94f7b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkEnv, SparkException} /** * A cleaner that renders closures serializable if they can be done so safely. */ -private[spark] object ClosureCleaner extends Logging { +private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 5226de27666aa..0a156ea99a297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} import org.apache.spark.sql.util.ContinuousQueryListener From e6b1fb3f5ec4cdd711f0f8048933e745d8bf3d81 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 Mar 2016 20:33:59 -0700 Subject: [PATCH 29/46] Fixed style --- .../spark/sql/execution/streaming/state/StateStoreSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 7e847ab901ead..0f575169adc8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -24,9 +24,9 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ From 2bd6cbd4a75ed4df752b323d35d3a86eaf74b241 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 Mar 2016 20:49:06 -0700 Subject: [PATCH 30/46] Fixed logging --- .../state/HDFSBackedStateStoreProvider.scala | 11 +++++------ .../sql/execution/streaming/state/StateStore.scala | 3 ++- .../streaming/state/StateStoreCoordinator.scala | 3 ++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 2e7866643a805..bd469ab95f23e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -28,7 +28,8 @@ import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -87,8 +88,7 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private val tempDeltaFileStream = new DataOutputStream( - new LZ4CompressionCodec(new SparkConf).compressedOutputStream(fs.create(tempDeltaFile, true))) + private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) // private val tempDeltaFileStream = fs.create(tempDeltaFile, true) private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() @@ -344,8 +344,7 @@ private[state] class HDFSBackedStateStoreProvider( } var input: DataInputStream = null try { - input = new DataInputStream( - new LZ4CompressionCodec(new SparkConf).compressedInputStream(fs.open(fileToRead))) + input = decompressStream(fs.open(fileToRead)) var eof = false while(!eof) { @@ -530,7 +529,7 @@ private[state] class HDFSBackedStateStoreProvider( new DataOutputStream(compressed) } - private def compressStream(inputStream: DataInputStream): DataInputStream = { + private def decompressStream(inputStream: DataInputStream): DataInputStream = { val compressed = new LZ4CompressionCodec(new SparkConf).compressedInputStream(inputStream) new DataInputStream(compressed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index b97025ba478fc..2fc58a874d0a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -24,7 +24,8 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index fa208ce00d1b2..a324abff5292a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.streaming.state import scala.collection.mutable -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint From 25afe31080451c16e74bd17bf9f81e5b8aab55f5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 10:58:08 -0700 Subject: [PATCH 31/46] Updated serialization of snapshot files --- .../state/HDFSBackedStateStoreProvider.scala | 62 ++++++++++++++----- .../streaming/state/StateStoreSuite.scala | 26 ++++++-- 2 files changed, 67 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index bd469ab95f23e..94618f28e7966 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -206,7 +206,6 @@ private[state] class HDFSBackedStateStoreProvider( val newMap = new MapType() if (version > 0) { val time = System.nanoTime() - // newMap ++= loadMap(version) newMap.putAll(loadMap(version)) } new HDFSBackedStateStore(version, newMap) @@ -381,14 +380,22 @@ private[state] class HDFSBackedStateStoreProvider( private def writeSnapshotFile(version: Long, map: MapType): Unit = { val fileToWrite = snapshotFile(version) val ser = serializer.newInstance() - var outputStream: SerializationStream = null + var output: DataOutputStream = null Utils.tryWithSafeFinally { - outputStream = ser.serializeStream(fs.create(fileToWrite, false)) - outputStream.writeAll(map.entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) - }) + output = compressStream(fs.create(fileToWrite, false)) + val iter = map.entrySet().iterator() + while(iter.hasNext) { + val entry = iter.next() + val keyBytes = entry.getKey.getBytes() + val valueBytes = entry.getValue.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + output.writeInt(-1) } { - if (outputStream != null) outputStream.close() + if (output != null) output.close() } } @@ -398,19 +405,42 @@ private[state] class HDFSBackedStateStoreProvider( val deser = serializer.newInstance() val map = new MapType() - var deserStream: DeserializationStream = null + var input: DataInputStream = null try { - deserStream = deser.deserializeStream(fs.open(fileToRead)) - val iter = deserStream.asIterator.asInstanceOf[Iterator[(UnsafeRow, UnsafeRow)]] - while (iter.hasNext) { - // map += iter.next() - val (key, value) = iter.next() - map.put(key, value) + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while (!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new Exception( + s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + throw new Exception( + s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize") + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } } Some(map) } finally { - if (deserStream != null) deserStream.close() + if (input != null) input.close() } } @@ -429,7 +459,7 @@ private[state] class HDFSBackedStateStoreProvider( writeSnapshotFile(lastVersion, map) } case None => - // The last map is not loaded, probably some other instance is incharge + // The last map is not loaded, probably some other instance is incharge } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 0f575169adc8f..a864ab577a414 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -234,24 +234,27 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth updateVersionTo(6) require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files - assert(getDataFromFiles(provider) === Set("a" -> 6), "snapshotting messed up the data") - assert(getDataFromFiles(provider) === Set("a" -> 6)) val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getDataFromFiles(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot - assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") - assert(getDataFromFiles(provider) === Set("a" -> 20)) val latestSnapshotVersion = (0 to 20).filter(version => fileExists(provider, version, isSnapshot = true)).lastOption assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") - + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") } test("cleaning") { @@ -403,6 +406,19 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth filePath.exists } + def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + for (version <- 0 until version.toInt) { + for (isSnapshot <- Seq(false, true)) { + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + if (filePath.exists) filePath.delete() + + } + } + } + def storeLoaded(storeId: StateStoreId): Boolean = { val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) val loadedStores = StateStore invokePrivate method() From 29c2af0c2328df0ff376f6b77b0760c6be129aee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 13:58:17 -0700 Subject: [PATCH 32/46] Updated StateStoreCoordinator lifecycle --- .../state/HDFSBackedStateStoreProvider.scala | 2 +- .../streaming/state/StateStore.scala | 32 ++++- .../state/StateStoreCoordinator.scala | 135 +++++++++++------- .../streaming/state/StateStoreRDD.scala | 2 +- .../execution/streaming/state/package.scala | 2 +- .../state/StateStoreCoordinatorSuite.scala | 93 ++++++------ .../streaming/state/StateStoreRDDSuite.scala | 45 +++--- .../streaming/state/StateStoreSuite.scala | 12 +- 8 files changed, 184 insertions(+), 139 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 94618f28e7966..14f69707477ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec -import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, SerializationStream} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 2fc58a874d0a4..7a313ac9d6cd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -114,7 +114,9 @@ private[state] object StateStore extends Logging { private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val managementTimer = new Timer("StateStore Timer", true) + @volatile private var managementTask: TimerTask = null + @volatile private var _coordRef: StateStoreCoordinatorRef = null /** Get or create a store associated with the id. */ def get( @@ -131,7 +133,7 @@ private[state] object StateStore extends Logging { val provider = loadedProviders.getOrElseUpdate( storeId, new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, sparkConf, hadoopConf)) - reportActiveInstance(storeId) + reportActiveStoreInstance(storeId) provider } storeProvider.getStore(version) @@ -177,7 +179,7 @@ private[state] object StateStore extends Logging { private def doMaintenance(): Unit = { loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => try { - if (verifyIfInstanceActive(id)) { + if (verifyIfStoreInstanceActive(id)) { provider.doMaintenance() } else { remove(id) @@ -190,26 +192,44 @@ private[state] object StateStore extends Logging { } } - private def reportActiveInstance(storeId: StateStoreId): Unit = { + private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { try { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - StateStoreCoordinator.ask(ReportActiveInstance(storeId, host, executorId)) + coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) + logDebug(s"Reported that the loaded instance $storeId is active") } catch { case NonFatal(e) => logWarning(s"Error reporting active instance of $storeId") } } - private def verifyIfInstanceActive(storeId: StateStoreId): Boolean = { + private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { try { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - StateStoreCoordinator.ask(VerifyIfInstanceActive(storeId, executorId)).getOrElse(false) + val verified = + coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) + logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" ) + verified } catch { case NonFatal(e) => logWarning(s"Error verifying active instance of $storeId") false } } + + private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + if (_coordRef == null) { + _coordRef = StateStoreCoordinatorRef(env) + } + logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + Some(_coordRef) + } else { + _coordRef = null + None + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index a324abff5292a..790eb1028ddfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -21,9 +21,8 @@ import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinator.StateStoreCoordinatorEndpoint import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ @@ -31,80 +30,108 @@ private sealed trait StateStoreCoordinatorMessage extends Serializable private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) extends StateStoreCoordinatorMessage + private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) extends StateStoreCoordinatorMessage -private object StopCoordinator extends StateStoreCoordinatorMessage +private case class GetLocation(storeId: StateStoreId) + extends StateStoreCoordinatorMessage -/** Class for coordinating instances of [[StateStore]]s loaded in the cluster */ -class StateStoreCoordinator(rpcEnv: RpcEnv) { - private val coordinatorRef = rpcEnv.setupEndpoint( - StateStoreCoordinator.endpointName, new StateStoreCoordinatorEndpoint(rpcEnv, this)) - private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] +private case class DeactivateInstances(storeRootLocation: String) + extends StateStoreCoordinatorMessage - /** Report active instance of a state store in an executor */ - def reportActiveInstance(storeId: StateStoreId, host: String, executorId: String): Boolean = { - instances.synchronized { instances.put(storeId, ExecutorCacheTaskLocation(host, executorId)) } - true - } +private object StopCoordinator + extends StateStoreCoordinatorMessage - /** Verify whether the given executor has the active instance of a state store */ - def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - instances.synchronized { - instances.get(storeId) match { - case Some(location) => location.executorId == executorId - case None => false + +private[sql] object StateStoreCoordinatorRef extends Logging { + + private val endpointName = "StateStoreCoordinator" + + def apply(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + try { + val coordinator = new StateStoreCoordinator() + val endpoint = new RpcEndpoint { + override val rpcEnv: RpcEnv = env.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopCoordinator => + stop() // Stop before replying to ensure that endpoint name has been deregistered + context.reply(true) + case message: StateStoreCoordinatorMessage => + context.reply(coordinator.process(message)) + } } + + val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, endpoint) + logInfo("Registered StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(coordinatorRef) + } catch { + case e: Exception => + logDebug("Retrieving exitsing StateStoreCoordinator endpoint") + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + new StateStoreCoordinatorRef(rpcEndpointRef) } } +} + +private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { + + private[state] def reportActiveInstance( + storeId: StateStoreId, + host: String, + executorId: String): Boolean = { + rpcEndpointRef.askWithRetry[Boolean](ReportActiveInstance(storeId, host, executorId)) + } + + /** Verify whether the given executor has the active instance of a state store */ + private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { + rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId)) + } /** Get the location of the state store */ - def getLocation(storeId: StateStoreId): Option[String] = { - instances.synchronized { instances.get(storeId).map(_.toString) } + private[state] def getLocation(storeId: StateStoreId): Option[String] = { + rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId)) } /** Deactivate instances related to a set of operator */ - def deactivateInstances(storeRootLocation: String): Unit = { - instances.synchronized { - val storeIdsToRemove = - instances.keys.filter(_.rootLocation == storeRootLocation).toSeq - instances --= storeIdsToRemove - } + private[state] def deactivateInstances(storeRootLocation: String): Unit = { + rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation)) } - def stop(): Unit = { - coordinatorRef.askWithRetry[Boolean](StopCoordinator) + private[state] def stop(): Unit = { + rpcEndpointRef.askWithRetry[Boolean](StopCoordinator) } } -private[sql] object StateStoreCoordinator { - - private val endpointName = "StateStoreCoordinator" - - private class StateStoreCoordinatorEndpoint( - override val rpcEnv: RpcEnv, coordinator: StateStoreCoordinator) - extends RpcEndpoint with Logging { +/** Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster */ +private class StateStoreCoordinator { + private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + def process(message: StateStoreCoordinatorMessage): Any = { + message match { case ReportActiveInstance(id, host, executorId) => - context.reply(coordinator.reportActiveInstance(id, host, executorId)) - case VerifyIfInstanceActive(id, executor) => - context.reply(coordinator.verifyIfInstanceActive(id, executor)) - case StopCoordinator => - // Stop before replying to ensure that endpoint name has been deregistered - stop() - context.reply(true) - } - } - - def ask(message: StateStoreCoordinatorMessage): Option[Boolean] = { - val env = SparkEnv.get - if (env != null) { - val coordinatorRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) - Some(coordinatorRef.askWithRetry[Boolean](message)) - } else { - None + instances.put(id, ExecutorCacheTaskLocation(host, executorId)) + true + + case VerifyIfInstanceActive(id, execId) => + instances.get(id) match { + case Some(location) => location.executorId == execId + case None => false + } + + case GetLocation(id) => + instances.get(id).map(_.toString) + + case DeactivateInstances(loc) => + val storeIdsToRemove = + instances.keys.filter(_.rootLocation == loc).toSeq + instances --= storeIdsToRemove + true + + case _ => + throw new IllegalArgumentException("Cannot iden") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 6ca394a30d635..c96455c9561d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -37,7 +37,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeVersion: Long, keySchema: StructType, valueSchema: StructType, - @transient private val storeCoordinator: Option[StateStoreCoordinator]) + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) extends RDD[U](dataRDD) { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 8819008dbfa93..d754f7ab5a1f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -32,7 +32,7 @@ package object state { storeVersion: Long, keySchema: StructType, valueSchema: StructType, - storeCoordinator: Option[StateStoreCoordinator] = None + storeCoordinator: Option[StateStoreCoordinatorRef] = None ): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 80278bdf2fed5..89fa413468e1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,104 +17,93 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.util.RpcUtils +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import StateStoreCoordinatorSuite._ test("report, verify, getLocation") { - withCoordinator(sc) { coordinator => + withCoordinatorRef(sc) { coordinatorRef => val id = StateStoreId("x", 0, 0) - assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) - assert(coordinator.getLocation(id) === None) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.getLocation(id) === None) - assert(coordinator.reportActiveInstance(id, "hostX", "exec1") === true) - assert(coordinator.verifyIfInstanceActive(id, "exec1") === true) - assert(coordinator.getLocation(id) === + assert(coordinatorRef.reportActiveInstance(id, "hostX", "exec1") === true) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) + assert(coordinatorRef.getLocation(id) === Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) - assert(coordinator.reportActiveInstance(id, "hostX", "exec2") === true) - assert(coordinator.verifyIfInstanceActive(id, "exec1") === false) - assert(coordinator.verifyIfInstanceActive(id, "exec2") === true) + assert(coordinatorRef.reportActiveInstance(id, "hostX", "exec2") === true) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) assert( - coordinator.getLocation(id) === + coordinatorRef.getLocation(id) === Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) } } test("make inactive") { - withCoordinator(sc) { coordinator => + withCoordinatorRef(sc) { coordinatorRef => val id1 = StateStoreId("x", 0, 0) val id2 = StateStoreId("y", 1, 0) val id3 = StateStoreId("x", 0, 1) val host = "hostX" val exec = "exec1" - assert(coordinator.reportActiveInstance(id1, host, exec) === true) - assert(coordinator.reportActiveInstance(id2, host, exec) === true) - assert(coordinator.reportActiveInstance(id3, host, exec) === true) + assert(coordinatorRef.reportActiveInstance(id1, host, exec) === true) + assert(coordinatorRef.reportActiveInstance(id2, host, exec) === true) + assert(coordinatorRef.reportActiveInstance(id3, host, exec) === true) - assert(coordinator.verifyIfInstanceActive(id1, exec) === true) - assert(coordinator.verifyIfInstanceActive(id2, exec) === true) - assert(coordinator.verifyIfInstanceActive(id3, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) - coordinator.deactivateInstances("x") + coordinatorRef.deactivateInstances("x") - assert(coordinator.verifyIfInstanceActive(id1, exec) === false) - assert(coordinator.verifyIfInstanceActive(id2, exec) === true) - assert(coordinator.verifyIfInstanceActive(id3, exec) === false) + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false) - assert(coordinator.getLocation(id1) === None) + assert(coordinatorRef.getLocation(id1) === None) assert( - coordinator.getLocation(id2) === + coordinatorRef.getLocation(id2) === Some(ExecutorCacheTaskLocation(host, exec).toString)) - assert(coordinator.getLocation(id3) === None) + assert(coordinatorRef.getLocation(id3) === None) - coordinator.deactivateInstances("y") - assert(coordinator.verifyIfInstanceActive(id2, exec) === false) - assert(coordinator.getLocation(id2) === None) + coordinatorRef.deactivateInstances("y") + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) + assert(coordinatorRef.getLocation(id2) === None) } } - test("communication") { - withCoordinator(sc) { coordinator => - import StateStoreCoordinator._ - val id = StateStoreId("x", 0, 0) - val host = "hostX" + test("multiple references have same coordinator") { + withCoordinatorRef(sc) { coordRef1 => + val coordRef2 = StateStoreCoordinatorRef(sc.env) - val ref = RpcUtils.makeDriverRef("StateStoreCoordinator", sc.env.conf, sc.env.rpcEnv) - - assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(false)) + val id = StateStoreId("x", 0, 0) - ask(ReportActiveInstance(id, host, "exec1")) - assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(true)) + assert(coordRef1.reportActiveInstance(id, "hostX", "exec1") === true) + assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) assert( - coordinator.getLocation(id) === - Some(ExecutorCacheTaskLocation(host, "exec1").toString)) + coordRef2.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) - ask(ReportActiveInstance(id, host, "exec2")) - assert(ask(VerifyIfInstanceActive(id, "exec1")) === Some(false)) - assert(ask(VerifyIfInstanceActive(id, "exec2")) === Some(true)) - assert( - coordinator.getLocation(id) === - Some(ExecutorCacheTaskLocation(host, "exec2").toString)) } } } object StateStoreCoordinatorSuite { - def withCoordinator(sc: SparkContext)(body: StateStoreCoordinator => Unit): Unit = { - var coordinator: StateStoreCoordinator = null + def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { + var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinator = new StateStoreCoordinator(sc.env.rpcEnv) - body(coordinator) + coordinatorRef = StateStoreCoordinatorRef(sc.env) + body(coordinatorRef) } finally { - if (coordinator != null) coordinator.stop() + if (coordinatorRef != null) coordinatorRef.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index f417e806fe0be..7b1c54a6b85ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -110,27 +110,32 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("preferred locations using StateStoreCoordinator") { - val opId = 0 - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - - withSpark(new SparkContext(conf)) { sc => - withCoordinator(sc) { coordinator => - coordinator.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") - coordinator.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") - - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema, Some(coordinator)) - require(rdd.partitions.size === 2) - - assert( - rdd.preferredLocations(rdd.partitions(0)) === - Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) - - assert( - rdd.preferredLocations(rdd.partitions(1)) === - Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - rdd.collect() + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema, Some(coordinatorRef)) + require(rdd.partitions.length === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index a864ab577a414..f7dc080f8f81c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkEnv, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly @@ -254,6 +254,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth fileExists(provider, version, isSnapshot = true)).lastOption assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") } @@ -313,11 +315,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } } - test("background management") { + test("maintenance") { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + .set("spark.rpc.numRetries", "1") val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val storeId = StateStoreId(dir, opId, 0) @@ -326,7 +329,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth quietly { withSpark(new SparkContext(conf)) { sc => - withCoordinator(sc) { coordinator => + withCoordinatorRef(sc) { coordinator => for (i <- 1 to 20) { val store = StateStore.get(storeId, keySchema, valueSchema, i - 1, new Configuration) update(store, "a", i) @@ -369,7 +372,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify if instance is unloaded if SparkContext is stopped - eventually(timeout(4 seconds)) { + require(SparkEnv.get === null) + eventually(timeout(10 seconds)) { assert(!StateStore.isLoaded(storeId)) } } From 32c013918677cb2c93385fbfa6a38b4b8594bea1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 14:12:53 -0700 Subject: [PATCH 33/46] Added StateStoreCoordinator to ContinuousQueryManager --- .../spark/sql/ContinuousQueryManager.scala | 2 + .../execution/streaming/state/package.scala | 22 ++++++++ .../state/StateStoreCoordinatorSuite.scala | 2 +- .../streaming/state/StateStoreRDDSuite.scala | 50 +++++++++---------- .../streaming/state/StateStoreSuite.scala | 2 +- 5 files changed, 51 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 0a156ea99a297..24832e35866cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.util.ContinuousQueryListener /** @@ -33,6 +34,7 @@ import org.apache.spark.sql.util.ContinuousQueryListener @Experimental class ContinuousQueryManager(sqlContext: SQLContext) { + private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef(sqlContext.sparkContext.env) private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) private val activeQueries = new mutable.HashMap[String, ContinuousQuery] private val activeQueriesLock = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index d754f7ab5a1f8..d201e73b07f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -20,12 +20,34 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ def mapPartitionWithStateStore[U: ClassTag]( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + storeRootLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType + )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + mapPartitionWithStateStore( + storeUpdateFunction, + storeRootLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + Some(sqlContext.streams.stateStoreCoordinator)) + } + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + private[state] def mapPartitionWithStateStore[U: ClassTag]( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], storeRootLocation: String, operatorId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 89fa413468e1c..3f42b0396af40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 7b1c54a6b85ba..5fd4512b8e02c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.Utils @@ -39,7 +40,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) - import StateStoreCoordinatorSuite._ import StateStoreSuite._ after { @@ -68,12 +68,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val opId = 0 val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + increment, path, opId, storeVersion = 0, keySchema, valueSchema, None) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) + increment, path, opId, storeVersion = 1, keySchema, valueSchema, None) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -92,7 +92,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { makeRDD(sc, Seq("a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion, keySchema, valueSchema) + increment, path, opId, storeVersion, keySchema, valueSchema, None) } // Generate RDDs and state store data @@ -115,27 +115,27 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSpark(new SparkContext(conf)) { sc => - withCoordinatorRef(sc) { coordinatorRef => - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === - Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema, Some(coordinatorRef)) - require(rdd.partitions.length === 2) - - assert( - rdd.preferredLocations(rdd.partitions(0)) === - Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) - - assert( - rdd.preferredLocations(rdd.partitions(1)) === - Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) - - rdd.collect() - } + implicit val sqlContext = new SQLContext(sc) + val coordinatorRef = sqlContext.streams.stateStoreCoordinator + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + require(rdd.partitions.length === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index f7dc080f8f81c..73d4812eb3680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkEnv, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly From 3824053324aae89956601db060bfb77b013a2303 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 15:08:12 -0700 Subject: [PATCH 34/46] Updated logging --- .../state/HDFSBackedStateStoreProvider.scala | 19 ++-- .../streaming/state/StateStore.scala | 4 +- .../state/StateStoreCoordinator.scala | 87 ++++++++++--------- .../state/StateStoreCoordinatorSuite.scala | 14 +-- .../streaming/state/StateStoreSuite.scala | 16 ++-- 5 files changed, 77 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 14f69707477ad..de4da6864587f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -154,6 +154,7 @@ private[state] class HDFSBackedStateStoreProvider( finalizeDeltaFile(tempDeltaFileStream) finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) state = COMMITTED + logInfo(s"Committed version $newVersion for $this") newVersion } catch { case NonFatal(e) => @@ -208,7 +209,9 @@ private[state] class HDFSBackedStateStoreProvider( val time = System.nanoTime() newMap.putAll(loadMap(version)) } - new HDFSBackedStateStore(version, newMap) + val store = new HDFSBackedStateStore(version, newMap) + logInfo(s"Retrieved version $version of $this for update") + store } /** Manage backing files, including creating snapshots and cleaning up old files */ @@ -280,7 +283,7 @@ private[state] class HDFSBackedStateStoreProvider( } else { if (!fs.isDirectory(baseDir)) { throw new IllegalStateException( - s"Cannot use ${id.rootLocation} for storing state data as" + + s"Cannot use ${id.rootLocation} for storing state data for $this as" + s"$baseDir already exists and is not a directory") } } @@ -375,6 +378,7 @@ private[state] class HDFSBackedStateStoreProvider( } finally { if (input != null) input.close() } + logInfo(s"Read delta file for version $version of $this from $fileToRead") } private def writeSnapshotFile(version: Long, map: MapType): Unit = { @@ -397,6 +401,7 @@ private[state] class HDFSBackedStateStoreProvider( } { if (output != null) output.close() } + logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -438,6 +443,7 @@ private[state] class HDFSBackedStateStoreProvider( } } } + logInfo(s"Read snapshot file for version $version of $this from $fileToRead") Some(map) } finally { if (input != null) input.close() @@ -488,6 +494,7 @@ private[state] class HDFSBackedStateStoreProvider( files.filter(_.version < earliestFileToRetain.version).foreach { f => fs.delete(f.path, true) } + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") } } } catch { @@ -514,7 +521,7 @@ private[state] class HDFSBackedStateStoreProvider( } verify( deltaFiles.size == version - snapshotFile.version, - s"Unexpected list of delta files for version $version: ${deltaFiles.mkString(",")}" + s"Unexpected list of delta files for version $version for $this: $deltaFiles" ) deltaFiles @@ -547,11 +554,13 @@ private[state] class HDFSBackedStateStoreProvider( case "snapshot" => versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) case _ => - logWarning(s"Could not identify file $path") + logWarning(s"Could not identify file $path for $this") } } } - versionToFiles.values.toSeq.sortBy(_.version) + val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) + logDebug(s"Current set of files for $this: $storeFiles") + storeFiles } private def compressStream(outputStream: DataOutputStream): DataOutputStream = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7a313ac9d6cd3..1f08b9d9ab2fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -150,11 +150,12 @@ private[state] object StateStore extends Logging { /** Unload and stop all state store provider */ def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() + _coordRef = null if (managementTask != null) { managementTask.cancel() managementTask = null - logInfo("StateStore stopped") } + logInfo("StateStore stopped") } /** Start the periodic maintenance task if not already started and if Spark active */ @@ -177,6 +178,7 @@ private[state] object StateStore extends Logging { * the active instances according to the coordinator. */ private def doMaintenance(): Unit = { + logDebug("Doing maintenance") loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => try { if (verifyIfStoreInstanceActive(id)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 790eb1028ddfc..fc7fafcce94cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -28,6 +28,7 @@ import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ private sealed trait StateStoreCoordinatorMessage extends Serializable +/** Classes representing messages */ private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) extends StateStoreCoordinatorMessage @@ -43,27 +44,19 @@ private case class DeactivateInstances(storeRootLocation: String) private object StopCoordinator extends StateStoreCoordinatorMessage - +/** Helper object used to create reference to [[StateStoreCoordinator]]. */ private[sql] object StateStoreCoordinatorRef extends Logging { private val endpointName = "StateStoreCoordinator" + /** + * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as + * executors. + */ def apply(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator() - val endpoint = new RpcEndpoint { - override val rpcEnv: RpcEnv = env.rpcEnv - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case StopCoordinator => - stop() // Stop before replying to ensure that endpoint name has been deregistered - context.reply(true) - case message: StateStoreCoordinatorMessage => - context.reply(coordinator.process(message)) - } - } - - val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, endpoint) + val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) } catch { @@ -75,13 +68,17 @@ private[sql] object StateStoreCoordinatorRef extends Logging { } } +/** + * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of + * [[StateStore]]s across all the executors, and get their locations for job scheduling. + */ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( storeId: StateStoreId, host: String, - executorId: String): Boolean = { - rpcEndpointRef.askWithRetry[Boolean](ReportActiveInstance(storeId, host, executorId)) + executorId: String): Unit = { + rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) } /** Verify whether the given executor has the active instance of a state store */ @@ -105,34 +102,38 @@ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointR } -/** Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster */ -private class StateStoreCoordinator { +/** + * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, + * and get their locations for job scheduling. + */ +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint { private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] - def process(message: StateStoreCoordinatorMessage): Any = { - message match { - case ReportActiveInstance(id, host, executorId) => - instances.put(id, ExecutorCacheTaskLocation(host, executorId)) - true - - case VerifyIfInstanceActive(id, execId) => - instances.get(id) match { - case Some(location) => location.executorId == execId - case None => false - } - - case GetLocation(id) => - instances.get(id).map(_.toString) - - case DeactivateInstances(loc) => - val storeIdsToRemove = - instances.keys.filter(_.rootLocation == loc).toSeq - instances --= storeIdsToRemove - true - - case _ => - throw new IllegalArgumentException("Cannot iden") - } + override def receive: PartialFunction[Any, Unit] = { + case ReportActiveInstance(id, host, executorId) => + instances.put(id, ExecutorCacheTaskLocation(host, executorId)) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case VerifyIfInstanceActive(id, execId) => + val response = instances.get(id) match { + case Some(location) => location.executorId == execId + case None => false + } + context.reply(response) + + case GetLocation(id) => + context.reply(instances.get(id).map(_.toString)) + + case DeactivateInstances(loc) => + val storeIdsToRemove = + instances.keys.filter(_.rootLocation == loc).toSeq + instances --= storeIdsToRemove + context.reply(true) + + case StopCoordinator => + stop() // Stop before replying to ensure that endpoint name has been deregistered + context.reply(true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 3f42b0396af40..d623ea094ca6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -31,12 +31,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) - assert(coordinatorRef.reportActiveInstance(id, "hostX", "exec1") === true) + coordinatorRef.reportActiveInstance(id, "hostX", "exec1") assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) assert(coordinatorRef.getLocation(id) === Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) - assert(coordinatorRef.reportActiveInstance(id, "hostX", "exec2") === true) + coordinatorRef.reportActiveInstance(id, "hostX", "exec2") assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) @@ -54,9 +54,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val host = "hostX" val exec = "exec1" - assert(coordinatorRef.reportActiveInstance(id1, host, exec) === true) - assert(coordinatorRef.reportActiveInstance(id2, host, exec) === true) - assert(coordinatorRef.reportActiveInstance(id3, host, exec) === true) + coordinatorRef.reportActiveInstance(id1, host, exec) + coordinatorRef.reportActiveInstance(id2, host, exec) + coordinatorRef.reportActiveInstance(id3, host, exec) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -80,13 +80,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } - test("multiple references have same coordinator") { + test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef(sc.env) val id = StateStoreId("x", 0, 0) - assert(coordRef1.reportActiveInstance(id, "hostX", "exec1") === true) + coordRef1.reportActiveInstance(id, "hostX", "exec1") assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) assert( coordRef2.getLocation(id) === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 73d4812eb3680..d94fdb0bd4465 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -329,15 +329,17 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth quietly { withSpark(new SparkContext(conf)) { sc => - withCoordinatorRef(sc) { coordinator => + withCoordinatorRef(sc) { coordinatorRef => for (i <- 1 to 20) { val store = StateStore.get(storeId, keySchema, valueSchema, i - 1, new Configuration) update(store, "a", i) store.commit() } - // Background management should clean up and generate snapshots - eventually(timeout(4 seconds)) { + assert(coordinatorRef.getLocation(storeId).nonEmpty) + + // Background maintenance should clean up and generate snapshots + eventually(timeout(10 seconds)) { // Earliest delta file should get cleaned up assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") @@ -350,8 +352,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // If driver decides to deactivate all instances of the store, then this instance // should be unloaded - coordinator.deactivateInstances(dir) - eventually(timeout(4 seconds)) { + coordinatorRef.deactivateInstances(dir) + eventually(timeout(10 seconds)) { assert(!StateStore.isLoaded(storeId)) } @@ -360,8 +362,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded - coordinator.reportActiveInstance(storeId, "other-host", "other-exec") - eventually(timeout(4 seconds)) { + coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(10 seconds)) { assert(!StateStore.isLoaded(storeId)) } From 534ad483ed021f0e707cc27ae8adadbd2b585d83 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 16:45:51 -0700 Subject: [PATCH 35/46] Minor updates --- .../state/HDFSBackedStateStoreProvider.scala | 6 +-- .../streaming/state/StateStore.scala | 26 ++++++---- .../state/StateStoreCoordinator.scala | 4 +- .../state/StateStoreCoordinatorSuite.scala | 44 ++++++++++------ .../streaming/state/StateStoreSuite.scala | 51 +++++++++++++++++-- 5 files changed, 97 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index de4da6864587f..5250e4b5cc2ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -214,7 +214,7 @@ private[state] class HDFSBackedStateStoreProvider( store } - /** Manage backing files, including creating snapshots and cleaning up old files */ + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { try { doSnapshot() @@ -564,12 +564,12 @@ private[state] class HDFSBackedStateStoreProvider( } private def compressStream(outputStream: DataOutputStream): DataOutputStream = { - val compressed = new LZ4CompressionCodec(new SparkConf).compressedOutputStream(outputStream) + val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream) new DataOutputStream(compressed) } private def decompressStream(inputStream: DataInputStream): DataInputStream = { - val compressed = new LZ4CompressionCodec(new SparkConf).compressedInputStream(inputStream) + val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream) new DataInputStream(compressed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 1f08b9d9ab2fb..a63625e33fad9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -83,19 +83,23 @@ trait StateStore { def hasCommitted: Boolean } - +/** Trait representing a provider of a specific version of a [[StateStore]]. */ trait StateStoreProvider { /** Get the store with the existing version. */ def getStore(version: Long): StateStore - /** Optional method for providers to allow for background management */ + /** Optional method for providers to allow for background maintenance */ def doMaintenance(): Unit = { } } +/** Trait representing updates made to a [[StateStore]]. */ sealed trait StoreUpdate + case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + case class KeyRemoved(key: UnsafeRow) extends StoreUpdate @@ -113,9 +117,9 @@ private[state] object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() - private val managementTimer = new Timer("StateStore Timer", true) + private val maintenanceTimer = new Timer("StateStore Timer", true) - @volatile private var managementTask: TimerTask = null + @volatile private var maintenanceTask: TimerTask = null @volatile private var _coordRef: StateStoreCoordinatorRef = null /** Get or create a store associated with the id. */ @@ -151,9 +155,9 @@ private[state] object StateStore extends Logging { def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() _coordRef = null - if (managementTask != null) { - managementTask.cancel() - managementTask = null + if (maintenanceTask != null) { + maintenanceTask.cancel() + maintenanceTask = null } logInfo("StateStore stopped") } @@ -161,20 +165,20 @@ private[state] object StateStore extends Logging { /** Start the periodic maintenance task if not already started and if Spark active */ private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { val env = SparkEnv.get - if (managementTask == null && env != null) { - managementTask = new TimerTask { + if (maintenanceTask == null && env != null) { + maintenanceTask = new TimerTask { override def run(): Unit = { doMaintenance() } } val periodMs = env.conf.getTimeAsMs( MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") - managementTimer.schedule(managementTask, periodMs, periodMs) + maintenanceTimer.schedule(maintenanceTask, periodMs, periodMs) logInfo("StateStore maintenance timer started") } } /** - * Execute background management task in all the loaded store providers if they are still + * Execute background maintenance task in all the loaded store providers if they are still * the active instances according to the coordinator. */ private def doMaintenance(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index fc7fafcce94cc..1d145f7e9e314 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -60,8 +60,8 @@ private[sql] object StateStoreCoordinatorRef extends Logging { logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) } catch { - case e: Exception => - logDebug("Retrieving exitsing StateStoreCoordinator endpoint") + case e: IllegalArgumentException => + logDebug("Retrieving existing StateStoreCoordinator endpoint") val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) new StateStoreCoordinatorRef(rpcEndpointRef) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index d623ea094ca6f..e93595aee2374 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming.state +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation @@ -32,17 +35,23 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.getLocation(id) === None) coordinatorRef.reportActiveInstance(id, "hostX", "exec1") - assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) - assert(coordinatorRef.getLocation(id) === - Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } coordinatorRef.reportActiveInstance(id, "hostX", "exec2") - assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) - assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) - assert( - coordinatorRef.getLocation(id) === - Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } } } @@ -58,9 +67,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { coordinatorRef.reportActiveInstance(id2, host, exec) coordinatorRef.reportActiveInstance(id3, host, exec) - assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) - assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) - assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) + + } coordinatorRef.deactivateInstances("x") @@ -87,11 +99,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val id = StateStoreId("x", 0, 0) coordRef1.reportActiveInstance(id, "hostX", "exec1") - assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) - assert( - coordRef2.getLocation(id) === - Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + eventually(timeout(5 seconds)) { + assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordRef2.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index d94fdb0bd4465..81fe1bbacbf64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -279,6 +279,39 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) } + + test("corrupted file handling") { + val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + update(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10).find( version => + fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) + + // Corrupt snapshot file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion) + } + + // Corrupt delta file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + } + test("StateStore.get") { quietly { val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString @@ -335,8 +368,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth update(store, "a", i) store.commit() } - - assert(coordinatorRef.getLocation(storeId).nonEmpty) + eventually(timeout(10 seconds)) { + assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + } // Background maintenance should clean up and generate snapshots eventually(timeout(10 seconds)) { @@ -420,11 +454,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" val filePath = new File(basePath.toString, fileName) if (filePath.exists) filePath.delete() - } } } + def corruptFile( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.delete() + filePath.createNewFile() + } + def storeLoaded(storeId: StateStoreId): Boolean = { val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) val loadedStores = StateStore invokePrivate method() From 19a60a65b77b73cf69cde961c034d04ad17df436 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 16:54:34 -0700 Subject: [PATCH 36/46] Addressed comments --- .../state/HDFSBackedStateStoreProvider.scala | 15 ++++---- .../streaming/state/StateStore.scala | 19 +++++------ .../state/StateStoreCoordinator.scala | 2 +- .../streaming/state/StateStoreRDD.scala | 6 ++-- .../execution/streaming/state/package.scala | 34 +++++++++---------- 5 files changed, 38 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 5250e4b5cc2ba..8bc6fd99cb036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -90,7 +90,6 @@ private[state] class HDFSBackedStateStoreProvider( private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - // private val tempDeltaFileStream = fs.create(tempDeltaFile, true) private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() @volatile private var state: STATE = UPDATING @@ -98,7 +97,11 @@ private[state] class HDFSBackedStateStoreProvider( override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - /** Update the value of a key using the value generated by the update function */ + /** + * Update the value of a key using the value generated by the update function. + * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous + * versions of the store data. + */ override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { verify(state == UPDATING, "Cannot update after already committed or cancelled") val oldValueOption = Option(mapToUpdate.get(key)) @@ -232,9 +235,9 @@ private[state] class HDFSBackedStateStoreProvider( /* Internal classes and methods */ private val loadedMaps = new mutable.HashMap[Long, MapType] - private val baseDir = new Path(id.rootLocation, s"${id.operatorId}/${id.partitionId.toString}") + private val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") private val fs = baseDir.getFileSystem(hadoopConf) - private val serializer = new KryoSerializer(sparkConf) private val minBatchesToRetain = sparkConf.getInt( MIN_BATCHES_TO_RETAIN_CONF, MIN_BATCHES_TO_RETAIN_DEFAULT) private val maxDeltaChainForSnapshots = sparkConf.getInt( @@ -283,7 +286,7 @@ private[state] class HDFSBackedStateStoreProvider( } else { if (!fs.isDirectory(baseDir)) { throw new IllegalStateException( - s"Cannot use ${id.rootLocation} for storing state data for $this as" + + s"Cannot use ${id.checkpointLocation} for storing state data for $this as" + s"$baseDir already exists and is not a directory") } } @@ -383,7 +386,6 @@ private[state] class HDFSBackedStateStoreProvider( private def writeSnapshotFile(version: Long, map: MapType): Unit = { val fileToWrite = snapshotFile(version) - val ser = serializer.newInstance() var output: DataOutputStream = null Utils.tryWithSafeFinally { output = compressStream(fs.create(fileToWrite, false)) @@ -408,7 +410,6 @@ private[state] class HDFSBackedStateStoreProvider( val fileToRead = snapshotFile(version) if (!fs.exists(fileToRead)) return None - val deser = serializer.newInstance() val map = new MapType() var input: DataInputStream = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a63625e33fad9..b2f3eaab7c70e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType /** Unique identifier for a [[StateStore]] */ -case class StateStoreId(rootLocation: String, operatorId: Long, partitionId: Int) +case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) /** * Base trait for a versioned key-value store used for streaming aggregations @@ -46,13 +46,13 @@ trait StateStore { /** * Update the value of a key using the value generated by the update function. - * This can be called only after prepareForUpdates() has been called in the same thread. + * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous + * versions of the store data. */ def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit /** * Remove keys that match the following condition. - * This can be called only after prepareForUpdates() has been called in the current thread. */ def remove(condition: UnsafeRow => Boolean): Unit @@ -108,7 +108,7 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), * it also runs a periodic background tasks to do maintenance on the loaded stores. For each * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of - * the store is the active instance. Accordingly, it either keeps it loaded and performance + * the store is the active instance. Accordingly, it either keeps it loaded and performs * maintenance, or unloads the store. */ private[state] object StateStore extends Logging { @@ -124,12 +124,11 @@ private[state] object StateStore extends Logging { /** Get or create a store associated with the id. */ def get( - storeId: StateStoreId, - keySchema: StructType, - valueSchema: StructType, - version: Long, - hadoopConf: Configuration - ): StateStore = { + storeId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + version: Long, + hadoopConf: Configuration): StateStore = { require(version >= 0) val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 1d145f7e9e314..2036a0a1f5d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -127,7 +127,7 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndp case DeactivateInstances(loc) => val storeIdsToRemove = - instances.keys.filter(_.rootLocation == loc).toSeq + instances.keys.filter(_.checkpointLocation == loc).toSeq instances --= storeIdsToRemove context.reply(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index c96455c9561d7..06c1628e66cee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], - storeRootLocation: String, + checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, @@ -47,7 +47,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(storeRootLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) storeCoordinator.flatMap(_.getLocation(storeId)).toSeq } @@ -55,7 +55,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( var store: StateStore = null Utils.tryWithSafeFinally { - val storeId = StateStoreId(storeRootLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) store = StateStore.get( storeId, keySchema, valueSchema, storeVersion, confBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index d201e73b07f22..65d2cc490e5e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -29,16 +29,16 @@ package object state { /** Map each partition of a RDD along with data in a [[StateStore]]. */ def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], - storeRootLocation: String, - operatorId: Long, - storeVersion: Long, - keySchema: StructType, - valueSchema: StructType - )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType + )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { mapPartitionWithStateStore( storeUpdateFunction, - storeRootLocation, + checkpointLocation, operatorId, storeVersion, keySchema, @@ -48,19 +48,19 @@ package object state { /** Map each partition of a RDD along with data in a [[StateStore]]. */ private[state] def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], - storeRootLocation: String, - operatorId: Long, - storeVersion: Long, - keySchema: StructType, - valueSchema: StructType, - storeCoordinator: Option[StateStoreCoordinatorRef] = None - ): StateStoreRDD[T, U] = { + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeCoordinator: Option[StateStoreCoordinatorRef] = None + ): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( dataRDD, cleanedF, - storeRootLocation, + checkpointLocation, operatorId, storeVersion, keySchema, From 5b7cf538cf4b6485ddee102568e931d66f9c1a75 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 19:37:32 -0700 Subject: [PATCH 37/46] Added StateStoreConf to address comments --- .../state/HDFSBackedStateStoreProvider.scala | 20 ++------ .../streaming/state/StateStore.scala | 6 +-- .../streaming/state/StateStoreConf.scala | 48 +++++++++++++++++++ .../streaming/state/StateStoreRDD.scala | 3 +- .../execution/streaming/state/package.scala | 6 ++- .../streaming/state/StateStoreRDDSuite.scala | 18 +++---- .../streaming/state/StateStoreSuite.scala | 42 +++++++++------- 7 files changed, 98 insertions(+), 45 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 8bc6fd99cb036..4680593bb5d76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -28,10 +28,9 @@ import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec -import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -68,12 +67,10 @@ private[state] class HDFSBackedStateStoreProvider( val id: StateStoreId, keySchema: StructType, valueSchema: StructType, - sparkConf: SparkConf, + storeConf: StateStoreConf, hadoopConf: Configuration ) extends StateStoreProvider with Logging { - import HDFSBackedStateStoreProvider._ - type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ @@ -238,10 +235,7 @@ private[state] class HDFSBackedStateStoreProvider( private val baseDir = new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") private val fs = baseDir.getFileSystem(hadoopConf) - private val minBatchesToRetain = sparkConf.getInt( - MIN_BATCHES_TO_RETAIN_CONF, MIN_BATCHES_TO_RETAIN_DEFAULT) - private val maxDeltaChainForSnapshots = sparkConf.getInt( - MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT) + private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) initialize() @@ -462,7 +456,7 @@ private[state] class HDFSBackedStateStoreProvider( filesForVersion(files, lastVersion).filter(_.isSnapshot == false) synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => - if (deltaFilesForLastVersion.size > maxDeltaChainForSnapshots) { + if (deltaFilesForLastVersion.size > storeConf.maxDeltaChainForSnapshots) { writeSnapshotFile(lastVersion, map) } case None => @@ -485,7 +479,7 @@ private[state] class HDFSBackedStateStoreProvider( try { val files = fetchFiles() if (files.nonEmpty) { - val earliestVersionToRetain = files.last.version - minBatchesToRetain + val earliestVersionToRetain = files.last.version - storeConf.minBatchesToRetain if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head synchronized { @@ -590,9 +584,5 @@ private[state] class HDFSBackedStateStoreProvider( } private[state] object HDFSBackedStateStoreProvider { - val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF = "spark.sql.streaming.stateStore.maxDeltaChain" - val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT = 10 - val MIN_BATCHES_TO_RETAIN_CONF = "spark.sql.streaming.stateStore.minBatchesToRetain" - val MIN_BATCHES_TO_RETAIN_DEFAULT = 2 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index b2f3eaab7c70e..84758538d0fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType @@ -128,14 +128,14 @@ private[state] object StateStore extends Logging { keySchema: StructType, valueSchema: StructType, version: Long, + storeConf: StateStoreConf, hadoopConf: Configuration): StateStore = { require(version >= 0) val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() - val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val provider = loadedProviders.getOrElseUpdate( storeId, - new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, sparkConf, hadoopConf)) + new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) reportActiveStoreInstance(storeId) provider } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala new file mode 100644 index 0000000000000..fb896b910a654 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.execution.streaming.state.StateStoreConf._ +import org.apache.spark.sql.internal.SQLConf + +/** A class that contains configuration parameters for [[StateStore]]s. */ +private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { + + def this() = this(new SQLConf) + + val maxDeltaChainForSnapshots = conf.getConfString( + StateStoreConf.MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, + MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT.toString).toInt + + val minBatchesToRetain = conf.getConfString( + MIN_BATCHES_TO_RETAIN_CONF, + MIN_BATCHES_TO_RETAIN_DEFAULT.toString).toInt +} + +private[state] object StateStoreConf { + + val empty = new StateStoreConf() + + val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF = "spark.sql.streaming.stateStore.maxDeltaChain" + val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT = 10 + + val MIN_BATCHES_TO_RETAIN_CONF = "spark.sql.streaming.stateStore.minBatchesToRetain" + val MIN_BATCHES_TO_RETAIN_DEFAULT = 2 +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 06c1628e66cee..3318660895195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -37,6 +37,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeVersion: Long, keySchema: StructType, valueSchema: StructType, + storeConf: StateStoreConf, @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) extends RDD[U](dataRDD) { @@ -57,7 +58,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( Utils.tryWithSafeFinally { val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, confBroadcast.value.value) + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) val outputIter = storeUpdateFunction(store, inputIter) assert(store.hasCommitted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 65d2cc490e5e4..b249e37921f09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -36,6 +36,7 @@ package object state { keySchema: StructType, valueSchema: StructType )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + mapPartitionWithStateStore( storeUpdateFunction, checkpointLocation, @@ -43,6 +44,7 @@ package object state { storeVersion, keySchema, valueSchema, + new StateStoreConf(sqlContext.conf), Some(sqlContext.streams.stateStoreCoordinator)) } @@ -54,7 +56,8 @@ package object state { storeVersion: Long, keySchema: StructType, valueSchema: StructType, - storeCoordinator: Option[StateStoreCoordinatorRef] = None + storeConf: StateStoreConf, + storeCoordinator: Option[StateStoreCoordinatorRef] ): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( @@ -65,6 +68,7 @@ package object state { storeVersion, keySchema, valueSchema, + storeConf, storeCoordinator) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 5fd4512b8e02c..a5de46026072c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.Utils class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { - private val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getCanonicalName) + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) @@ -53,7 +53,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { quietly { - withSpark(new SparkContext(conf)) { sc => + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContet = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => @@ -68,12 +69,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val opId = 0 val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema, None) + increment, path, opId, storeVersion = 0, keySchema, valueSchema) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema, None) + increment, path, opId, storeVersion = 1, keySchema, valueSchema) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -91,19 +92,20 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc: SparkContext, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) makeRDD(sc, Seq("a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion, keySchema, valueSchema, None) + increment, path, opId, storeVersion, keySchema, valueSchema) } // Generate RDDs and state store data - withSpark(new SparkContext(conf)) { sc => + withSpark(new SparkContext(sparkConf)) { sc => for (i <- 1 to 20) { require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } // With a new context, try using the earlier state store data - withSpark(new SparkContext(conf)) { sc => + withSpark(new SparkContext(sparkConf)) { sc => assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) } } @@ -114,7 +116,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - withSpark(new SparkContext(conf)) { sc => + withSpark(new SparkContext(sparkConf)) { sc => implicit val sqlContext = new SQLContext(sc) val coordinatorRef = sqlContext.streams.stateStoreCoordinator coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 81fe1bbacbf64..6f14ee6ab6062 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -39,7 +40,6 @@ import org.apache.spark.util.Utils class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] - import HDFSBackedStateStoreProvider._ import StateStoreCoordinatorSuite._ import StateStoreSuite._ @@ -99,7 +99,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // New updates to the reloaded store with new version, and does not change old version val reloadedProvider = new HDFSBackedStateStoreProvider( - store.id, keySchema, valueSchema, new SparkConf, new Configuration) + store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) update(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) @@ -316,31 +316,34 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth quietly { val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { - StateStore.get(storeId, keySchema, valueSchema, -1, new Configuration) + StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { - StateStore.get(storeId, keySchema, valueSchema, 1, new Configuration) + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) } // Increase version of the store - val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, new Configuration) + val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) assert(store0.version === 0) update(store0, "a", 1) store0.commit() - assert(StateStore.get(storeId, keySchema, valueSchema, 1, new Configuration).version == 1) - assert(StateStore.get(storeId, keySchema, valueSchema, 0, new Configuration).version == 0) + assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) + assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) // Verify that you can remove the store and still reload and use it StateStore.remove(storeId) assert(!StateStore.isLoaded(storeId)) - val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, new Configuration) + val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) update(store1, "a", 2) assert(store1.commit() === 2) @@ -357,14 +360,17 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val storeId = StateStoreId(dir, opId, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() val provider = new HDFSBackedStateStoreProvider( - storeId, keySchema, valueSchema, conf, new Configuration) + storeId, keySchema, valueSchema, storeConf, hadoopConf) quietly { withSpark(new SparkContext(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => for (i <- 1 to 20) { - val store = StateStore.get(storeId, keySchema, valueSchema, i - 1, new Configuration) + val store = StateStore.get( + storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) update(store, "a", i) store.commit() } @@ -392,7 +398,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, new Configuration) + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -402,7 +408,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, new Configuration) + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } @@ -419,7 +425,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { val reloadedProvider = new HDFSBackedStateStoreProvider( - provider.id, keySchema, valueSchema, new SparkConf, new Configuration) + provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { @@ -484,16 +490,18 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def newStoreProvider( opId: Long = Random.nextLong, partition: Int = 0, - maxDeltaChainForSnapshots: Int = MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT + maxDeltaChainForSnapshots: Int = StateStoreConf.MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT ): HDFSBackedStateStoreProvider = { val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString - val sparkConf = new SparkConf() - .set(MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, maxDeltaChainForSnapshots.toString) + val sqlConf = new SQLConf() + sqlConf.setConfString( + StateStoreConf.MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, + maxDeltaChainForSnapshots.toString) new HDFSBackedStateStoreProvider( StateStoreId(dir, opId, partition), keySchema, valueSchema, - sparkConf, + new StateStoreConf(sqlConf), new Configuration()) } From f4f383803e45aaea182b09c4e3c4b64415b9461e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Mar 2016 20:26:50 -0700 Subject: [PATCH 38/46] Updated StateStoreConf --- .../state/HDFSBackedStateStoreProvider.scala | 4 ++-- .../streaming/state/StateStoreConf.scala | 18 ++++-------------- .../apache/spark/sql/internal/SQLConf.scala | 15 +++++++++++++++ .../streaming/state/StateStoreSuite.scala | 12 +++++------- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 4680593bb5d76..cb57b3184f946 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -456,7 +456,7 @@ private[state] class HDFSBackedStateStoreProvider( filesForVersion(files, lastVersion).filter(_.isSnapshot == false) synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => - if (deltaFilesForLastVersion.size > storeConf.maxDeltaChainForSnapshots) { + if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) { writeSnapshotFile(lastVersion, map) } case None => @@ -479,7 +479,7 @@ private[state] class HDFSBackedStateStoreProvider( try { val files = fetchFiles() if (files.nonEmpty) { - val earliestVersionToRetain = files.last.version - storeConf.minBatchesToRetain + val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index fb896b910a654..ea995c5a9c93a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.sql.execution.streaming.state.StateStoreConf._ import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ @@ -25,24 +24,15 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend def this() = this(new SQLConf) - val maxDeltaChainForSnapshots = conf.getConfString( - StateStoreConf.MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, - MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT.toString).toInt + import SQLConf._ - val minBatchesToRetain = conf.getConfString( - MIN_BATCHES_TO_RETAIN_CONF, - MIN_BATCHES_TO_RETAIN_DEFAULT.toString).toInt + val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + + val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } private[state] object StateStoreConf { - val empty = new StateStoreConf() - - val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF = "spark.sql.streaming.stateStore.maxDeltaChain" - val MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT = 10 - - val MIN_BATCHES_TO_RETAIN_CONF = "spark.sql.streaming.stateStore.minBatchesToRetain" - val MIN_BATCHES_TO_RETAIN_DEFAULT = 2 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 473cde56fdd34..74be2e928da60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -514,6 +514,21 @@ object SQLConf { doc = "When true, the planner will try to find out duplicated exchanges and re-use them", isPublic = false) + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf( + "spark.sql.streaming.stateStore.minDeltasForSnapshot", + defaultValue = Some(10), + doc = "Minimum number of state store delta files that needs to be generated before they " + + "consolidated into snapshots.", + isPublic = false + ) + + val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf( + "spark.sql.streaming.stateStore.minBatchesToRetain", + defaultValue = Some(2), + doc = "Minimum number of versions of a state store's data to retain after cleaning.", + isPublic = false + ) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6f14ee6ab6062..3aeba25f4ea0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -207,7 +207,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("snapshotting") { - val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) + val provider = newStoreProvider(minDeltasForSnapshot = 5) var currentVersion = 0 def updateVersionTo(targetVersion: Int): Unit = { @@ -260,7 +260,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("cleaning") { - val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) + val provider = newStoreProvider(minDeltasForSnapshot = 5) for (i <- 1 to 20) { val store = provider.getStore(i - 1) @@ -281,7 +281,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("corrupted file handling") { - val provider = newStoreProvider(maxDeltaChainForSnapshots = 5) + val provider = newStoreProvider(minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) update(store, "a", i) @@ -490,13 +490,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def newStoreProvider( opId: Long = Random.nextLong, partition: Int = 0, - maxDeltaChainForSnapshots: Int = StateStoreConf.MAX_DELTA_CHAIN_FOR_SNAPSHOTS_DEFAULT + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get ): HDFSBackedStateStoreProvider = { val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val sqlConf = new SQLConf() - sqlConf.setConfString( - StateStoreConf.MAX_DELTA_CHAIN_FOR_SNAPSHOTS_CONF, - maxDeltaChainForSnapshots.toString) + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) new HDFSBackedStateStoreProvider( StateStoreId(dir, opId, partition), keySchema, From 502e5a5ee4b3b0abbe1abdd07b03dc4e3070b5e8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 10:39:45 -0700 Subject: [PATCH 39/46] Addressed comments --- .../spark/sql/execution/streaming/state/StateStore.scala | 3 +-- .../spark/sql/execution/streaming/state/StateStoreConf.scala | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 84758538d0fc9..dc01d43068849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -58,7 +58,6 @@ trait StateStore { /** * Commit all the updates that have been made to the store. - * This can be called only after prepareForUpdates() has been called in the current thread. */ def commit(): Long @@ -113,7 +112,7 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate */ private[state] object StateStore extends Logging { - val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" + val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index ea995c5a9c93a..cca22a0af823f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -35,4 +35,3 @@ private[state] object StateStoreConf { val empty = new StateStoreConf() } - From 24fb325a44fb63cb7021d10b44f516a2d40d02c6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 11:00:28 -0700 Subject: [PATCH 40/46] Minor style fix --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f5874accb57e2..863a876afe9c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -529,15 +529,13 @@ object SQLConf { defaultValue = Some(10), doc = "Minimum number of state store delta files that needs to be generated before they " + "consolidated into snapshots.", - isPublic = false - ) + isPublic = false) val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf( "spark.sql.streaming.stateStore.minBatchesToRetain", defaultValue = Some(2), doc = "Minimum number of versions of a state store's data to retain after cleaning.", - isPublic = false - ) + isPublic = false) val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation", defaultValue = None, From 4752d73af2b7257e0160d1eac81ee9b9279d3109 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 11:03:34 -0700 Subject: [PATCH 41/46] Renamed for consistency --- .../spark/sql/execution/streaming/state/StateStore.scala | 8 +++++--- .../sql/execution/streaming/state/StateStoreSuite.scala | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dc01d43068849..4a59353b5ea95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -141,15 +141,17 @@ private[state] object StateStore extends Logging { storeProvider.getStore(version) } - def remove(storeId: StateStoreId): Unit = loadedProviders.synchronized { + /** Unload a state store provider */ + def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { loadedProviders.remove(storeId) } + /** Whether a state store provider is loaded or not */ def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { loadedProviders.contains(storeId) } - /** Unload and stop all state store provider */ + /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() _coordRef = null @@ -186,7 +188,7 @@ private[state] object StateStore extends Logging { if (verifyIfStoreInstanceActive(id)) { provider.doMaintenance() } else { - remove(id) + unload(id) logInfo(s"Unloaded $provider") } } catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 3aeba25f4ea0f..22b2f4f75d39e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -340,7 +340,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) // Verify that you can remove the store and still reload and use it - StateStore.remove(storeId) + StateStore.unload(storeId) assert(!StateStore.isLoaded(storeId)) val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) From 756762a044ce9a2735950ed27384151502ce020d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 11:44:28 -0700 Subject: [PATCH 42/46] Updated comment --- .../streaming/state/HDFSBackedStateStoreProvider.scala | 2 +- .../apache/spark/sql/execution/streaming/state/StateStore.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index cb57b3184f946..3280c7e6a54d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -146,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider( } } - /** Commit all the updates that have been made to the store. */ + /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { verify(state == UPDATING, "Cannot commit again after already committed or cancelled") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 4a59353b5ea95..70be417145350 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -57,7 +57,7 @@ trait StateStore { def remove(condition: UnsafeRow => Boolean): Unit /** - * Commit all the updates that have been made to the store. + * Commit all the updates that have been made to the store, and return the new version. */ def commit(): Long From 63fad922f166c7fdaf21783f89d01e3dec13c0f6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 11:56:29 -0700 Subject: [PATCH 43/46] Replaced Timer with Scheduled Executor --- .../streaming/state/StateStore.scala | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 70be417145350..cff6272fe760d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.concurrent.{TimeUnit, ScheduledFuture} import java.util.{Timer, TimerTask} import scala.collection.mutable @@ -28,6 +29,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils /** Unique identifier for a [[StateStore]] */ @@ -117,8 +119,10 @@ private[state] object StateStore extends Logging { private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val maintenanceTimer = new Timer("StateStore Timer", true) + private val maintenanceTaskExecutor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") - @volatile private var maintenanceTask: TimerTask = null + @volatile private var maintenanceTask: ScheduledFuture[_] = null @volatile private var _coordRef: StateStoreCoordinatorRef = null /** Get or create a store associated with the id. */ @@ -156,7 +160,7 @@ private[state] object StateStore extends Logging { loadedProviders.clear() _coordRef = null if (maintenanceTask != null) { - maintenanceTask.cancel() + maintenanceTask.cancel(false) maintenanceTask = null } logInfo("StateStore stopped") @@ -166,14 +170,14 @@ private[state] object StateStore extends Logging { private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { val env = SparkEnv.get if (maintenanceTask == null && env != null) { - maintenanceTask = new TimerTask { + val periodMs = env.conf.getTimeAsMs( + MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") + val runnable = new Runnable { override def run(): Unit = { doMaintenance() } } - val periodMs = env.conf.getTimeAsMs( - MAINTENANCE_INTERVAL_CONFIG, - s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") - maintenanceTimer.schedule(maintenanceTask, periodMs, periodMs) - logInfo("StateStore maintenance timer started") + maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate( + runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + logInfo("State Store maintenance task started") } } From b147f599362ca4a6851f2b2d2140330bcdc6de0c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 12:34:22 -0700 Subject: [PATCH 44/46] Style fix --- .../spark/sql/execution/streaming/state/StateStore.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index cff6272fe760d..f1c196b20eeb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.streaming.state -import java.util.concurrent.{TimeUnit, ScheduledFuture} -import java.util.{Timer, TimerTask} +import java.util.Timer +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable import scala.util.control.NonFatal From 819ca17412b60c6afeaa7ac0624abf80ebd422b9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 17:09:52 -0700 Subject: [PATCH 45/46] Fixed coordinator bug and added distributed test --- .../spark/sql/ContinuousQueryManager.scala | 3 +- .../state/HDFSBackedStateStoreProvider.scala | 3 -- .../streaming/state/StateStore.scala | 6 ++-- .../state/StateStoreCoordinator.scala | 10 ++++-- .../state/StateStoreCoordinatorSuite.scala | 4 +-- .../streaming/state/StateStoreRDDSuite.scala | 32 +++++++++++++++++++ 6 files changed, 48 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 5ad243b75066a..465feeb60412f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener @Experimental class ContinuousQueryManager(sqlContext: SQLContext) { - private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef(sqlContext.sparkContext.env) + private[sql] val stateStoreCoordinator = + StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env) private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) private val activeQueries = new mutable.HashMap[String, ContinuousQuery] private val activeQueriesLock = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3280c7e6a54d1..5e8d13bb014c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -583,6 +583,3 @@ private[state] class HDFSBackedStateStoreProvider( } } -private[state] object HDFSBackedStateStoreProvider { - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index f1c196b20eeb7..ca5c864d9e993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.ThreadUtils /** Unique identifier for a [[StateStore]] */ case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) + /** * Base trait for a versioned key-value store used for streaming aggregations */ @@ -84,6 +85,7 @@ trait StateStore { def hasCommitted: Boolean } + /** Trait representing a provider of a specific version of a [[StateStore]]. */ trait StateStoreProvider { @@ -94,6 +96,7 @@ trait StateStoreProvider { def doMaintenance(): Unit = { } } + /** Trait representing updates made to a [[StateStore]]. */ sealed trait StoreUpdate @@ -118,7 +121,6 @@ private[state] object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() - private val maintenanceTimer = new Timer("StateStore Timer", true) private val maintenanceTaskExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") @@ -232,7 +234,7 @@ private[state] object StateStore extends Logging { val env = SparkEnv.get if (env != null) { if (_coordRef == null) { - _coordRef = StateStoreCoordinatorRef(env) + _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 2036a0a1f5d51..5aa0636850255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -53,7 +53,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as * executors. */ - def apply(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { val coordinator = new StateStoreCoordinator(env.rpcEnv) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) @@ -61,11 +61,17 @@ private[sql] object StateStoreCoordinatorRef extends Logging { new StateStoreCoordinatorRef(coordinatorRef) } catch { case e: IllegalArgumentException => - logDebug("Retrieving existing StateStoreCoordinator endpoint") val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(rpcEndpointRef) } } + + def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index e93595aee2374..c99c2f505f3e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -94,7 +94,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => - val coordRef2 = StateStoreCoordinatorRef(sc.env) + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) val id = StateStoreId("x", 0, 0) @@ -114,7 +114,7 @@ object StateStoreCoordinatorSuite { def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinatorRef = StateStoreCoordinatorRef(sc.env) + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) body(coordinatorRef) } finally { if (coordinatorRef != null) coordinatorRef.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index a5de46026072c..24cec30fa335c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -142,6 +142,38 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } } + test("distributed test") { + quietly { + withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => + implicit val sqlContet = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + store.update( + stringToRow(s), oldRow => { + val oldValue = oldRow.map(rowToInt).getOrElse(0) + intToRow(oldValue + 1) + }) + } + store.commit() + store.iterator().map(rowsToStringInt) + } + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 0, keySchema, valueSchema) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( + increment, path, opId, storeVersion = 1, keySchema, valueSchema) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } From 70cc7b11c0d502d4a741ccdcab5e7aa006b5b311 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Mar 2016 17:48:54 -0700 Subject: [PATCH 46/46] Addressed more comments --- .../streaming/state/HDFSBackedStateStoreProvider.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 5e8d13bb014c3..ee015baf3fae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream} +import java.io.{DataInputStream, DataOutputStream, IOException} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -206,7 +206,6 @@ private[state] class HDFSBackedStateStoreProvider( require(version >= 0, "Version cannot be less than 0") val newMap = new MapType() if (version > 0) { - val time = System.nanoTime() newMap.putAll(loadMap(version)) } val store = new HDFSBackedStateStore(version, newMap) @@ -351,7 +350,7 @@ private[state] class HDFSBackedStateStoreProvider( if (keySize == -1) { eof = true } else if (keySize < 0) { - throw new Exception( + throw new IOException( s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") } else { val keyRowBuffer = new Array[Byte](keySize) @@ -416,7 +415,7 @@ private[state] class HDFSBackedStateStoreProvider( if (keySize == -1) { eof = true } else if (keySize < 0) { - throw new Exception( + throw new IOException( s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize") } else { val keyRowBuffer = new Array[Byte](keySize) @@ -427,7 +426,7 @@ private[state] class HDFSBackedStateStoreProvider( val valueSize = input.readInt() if (valueSize < 0) { - throw new Exception( + throw new IOException( s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize") } else { val valueRowBuffer = new Array[Byte](valueSize)