diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0e8b398fc6b97..ca1a0ad077f53 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1888,7 +1888,7 @@ def toJSON(self, use_unicode=False): rdd = self._jschema_rdd.baseSchemaRDD().toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - def saveAsParquetFile(self, path): + def saveAsParquetFile(self, path, overwrite=False): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as @@ -1903,7 +1903,7 @@ def saveAsParquetFile(self, path): >>> sorted(srdd2.collect()) == sorted(srdd.collect()) True """ - self._jschema_rdd.saveAsParquetFile(path) + self._jschema_rdd.saveAsParquetFile(path, overwrite) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b2262e5e6efb6..ab9d0302341b5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -297,7 +297,8 @@ package object dsl { object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String) = WriteToFile(path, logicalPlan) + def writeToFile(path: String, overwrite: Boolean) = + WriteToFile(path, logicalPlan, overwrite) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0b9f01cbae9ea..d6afd028e5730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -126,7 +126,8 @@ case class CreateTableAsSelect[T]( case class WriteToFile( path: String, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + overwrite: Boolean) extends UnaryNode { override def output = child.output } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 3cf9209465b76..5f9a62b86fbb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -68,12 +68,29 @@ private[sql] trait SchemaRDDLike { /** * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` - * function. + * function. It will raise exception if the specified path already existed. * + * @param path The destination path. * @group schema */ def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + saveAsParquetFile(path, false) + } + + /** + * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that + * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` + * function. + * + * @param path The destination path. + * @param overwrite If it's false, an exception will raise if the path already existed, + * otherwise create it. + * If it's true, we either create the specified path or overwrite + * the existed one (by deleting the path and re-create it). + * @group schema + */ + def saveAsParquetFile(path: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan, overwrite)).toRdd } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 99b6611d3bbcf..12b20ef2af5ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -205,11 +205,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object ParquetOperations extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // TODO: need to support writing to other types of files. Unify the below code paths. - case logical.WriteToFile(path, child) => + case logical.WriteToFile(path, child, overwrite) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) - // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil + ParquetRelation.createEmpty( + path, + child.output, + overwrite, + sparkContext.hadoopConfiguration, + sqlContext) + InsertIntoParquetTable(relation, planLater(child), overwrite) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b237a07c72d07..7acf31dbb8acc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -151,15 +151,20 @@ private[sql] object ParquetRelation { * * @param pathString The directory the Parquetfile will be stored in. * @param attributes The schema of the relation. + * @param overwrite Overwrite the existed file path: + * If it's false, an exception will raise if the path already existed, + * otherwise create a new file path. + * If it's true, we will remove the path if it's existed, and recreate it. * @param conf A configuration to be used. + * @param sqlContext SQLContext * @return An empty ParquetRelation. */ def createEmpty(pathString: String, attributes: Seq[Attribute], - allowExisting: Boolean, + overwrite: Boolean, conf: Configuration, sqlContext: SQLContext): ParquetRelation = { - val path = checkPath(pathString, allowExisting, conf) + val path = createPath(pathString, overwrite, conf) conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( sqlContext.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) ParquetRelation.enableLogForwarding() @@ -169,7 +174,7 @@ private[sql] object ParquetRelation { } } - private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { + private def createPath(pathStr: String, overwrite: Boolean, conf: Configuration): Path = { if (pathStr == null) { throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") } @@ -179,9 +184,23 @@ private[sql] object ParquetRelation { throw new IllegalArgumentException( s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") } + val path = origPath.makeQualified(fs) - if (!allowExisting && fs.exists(path)) { - sys.error(s"File $pathStr already exists.") + val pathExisted = fs.exists(path) + + if (pathExisted) { + if (overwrite) { + try { + fs.delete(path, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to clear output directory ${path}") + } + } else { + sys.error(s"File ${path} already exists.") + } + } else { + fs.mkdirs(path) } if (fs.exists(path) && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index a5fe2e8da2840..04c388963a473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -850,6 +850,45 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result2(0)(1) === "the answer") } + test("Test overwrite") { + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + val result1 = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD + val result2 = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD + + // file does not exist with option overwrite = false + result1.saveAsParquetFile(tmpdir.toString, overwrite = false) + parquetFile(tmpdir.toString) + .toSchemaRDD + .registerTempTable("tmpcopy") + val tmpdata1 = sql("SELECT * FROM tmpcopy").collect() + assert(tmpdata1.size === 2) // output the testNestedDir1 + + // file does exist with option overwrite = true + result2.saveAsParquetFile(tmpdir.toString, overwrite = true) + parquetFile(tmpdir.toString) + .toSchemaRDD + .registerTempTable("tmpcopy") + val tmpdata2 = sql("SELECT * FROM tmpcopy").collect() + assert(tmpdata2.size === 1) // output the testNestedDir4 + + // file does exist with option overwrite = false + intercept[Exception] { + result1.saveAsParquetFile(tmpdir.toString, overwrite = false) + } + + Utils.deleteRecursively(tmpdir) + // file does not exist with option overwrite = true + result2.saveAsParquetFile(tmpdir.toString, overwrite = true) + parquetFile(tmpdir.toString) + .toSchemaRDD + .registerTempTable("tmpcopy") + val tmpdata3 = sql("SELECT * FROM tmpcopy").collect() + assert(tmpdata3.size === 1) // output the testNestedDir4 + + Utils.deleteRecursively(tmpdir) + } + test("Writing out Addressbook and reading it back in") { // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME // has no effect in this test case